package org.nd4j.linalg.jcublas;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.ShortPointer;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.api.shape.options.ArrayType;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionDescriptor;
import org.nd4j.linalg.compression.CompressionType;
import org.nd4j.linalg.compression.CompressionUtils;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.blas.CudaBlas;
import org.nd4j.linalg.jcublas.blas.JcublasLapack;
import org.nd4j.linalg.jcublas.blas.JcublasLevel1;
import org.nd4j.linalg.jcublas.blas.JcublasLevel2;
import org.nd4j.linalg.jcublas.blas.JcublasLevel3;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.memory.MemcpyDirection;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.BaseNativeNDArrayFactory;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/JCublasNDArrayFactory.class */
public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
    private static final Logger log = LoggerFactory.getLogger(JCublasNDArrayFactory.class);

    public JCublasNDArrayFactory() {
    }

    public JCublasNDArrayFactory(DataBuffer.Type type, Character ch) {
        super(type, ch);
    }

    public JCublasNDArrayFactory(DataBuffer.Type type, char c) {
        super(type, c);
        AtomicAllocator.getInstance();
    }

    public void createBlas() {
        this.blas = new CudaBlas();
        PointerPointer pointerPointer = new PointerPointer(13L);
        pointerPointer.put(0L, Loader.addressof("cublasSgemv_v2"));
        pointerPointer.put(1L, Loader.addressof("cublasDgemv_v2"));
        pointerPointer.put(2L, Loader.addressof("cublasHgemm"));
        pointerPointer.put(3L, Loader.addressof("cublasSgemm_v2"));
        pointerPointer.put(4L, Loader.addressof("cublasDgemm_v2"));
        pointerPointer.put(5L, Loader.addressof("cublasSgemmEx"));
        pointerPointer.put(6L, Loader.addressof("cublasHgemmBatched"));
        pointerPointer.put(7L, Loader.addressof("cublasSgemmBatched"));
        pointerPointer.put(8L, Loader.addressof("cublasDgemmBatched"));
        pointerPointer.put(9L, Loader.addressof("cusolverDnSgesvd_bufferSize"));
        pointerPointer.put(10L, Loader.addressof("cusolverDnDgesvd_bufferSize"));
        pointerPointer.put(11L, Loader.addressof("cusolverDnSgesvd"));
        pointerPointer.put(12L, Loader.addressof("cusolverDnDgesvd"));
        this.nativeOps.initializeFunctions(pointerPointer);
    }

    public void createLevel1() {
        this.level1 = new JcublasLevel1();
    }

    public void createLevel2() {
        this.level2 = new JcublasLevel2();
    }

    public void createLevel3() {
        this.level3 = new JcublasLevel3();
    }

    public void createLapack() {
        this.lapack = new JcublasLapack();
    }

    public INDArray create(int[] iArr, DataBuffer dataBuffer) {
        return new JCublasNDArray(iArr, dataBuffer);
    }

    public INDArray create(double[][] dArr) {
        return new JCublasNDArray(dArr);
    }

    public INDArray create(double[][] dArr, char c) {
        return new JCublasNDArray(dArr, c);
    }

    public INDArray create(DataBuffer dataBuffer) {
        return new JCublasNDArray(dataBuffer);
    }

    public INDArray create(DataBuffer dataBuffer, long j, long j2, int[] iArr, long j3) {
        return new JCublasNDArray(dataBuffer, new long[]{j, j2}, ArrayUtil.toLongArray(iArr), j3, Nd4j.order().charValue());
    }

    public INDArray create(int[] iArr, char c) {
        return new JCublasNDArray(iArr, c);
    }

    public INDArray createUninitialized(int[] iArr, char c) {
        return new JCublasNDArray(iArr, Nd4j.getStrides(iArr, c), 0L, c, false);
    }

    public INDArray createUninitializedDetached(int[] iArr, char c) {
        MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        Nd4j.getMemoryManager().setCurrentWorkspace((MemoryWorkspace) null);
        JCublasNDArray jCublasNDArray = new JCublasNDArray(iArr, Nd4j.getStrides(iArr, c), 0L, c, false);
        Nd4j.getMemoryManager().setCurrentWorkspace(currentWorkspace);
        return jCublasNDArray;
    }

    public INDArray create(DataBuffer dataBuffer, int[] iArr, int[] iArr2, long j, char c) {
        return new JCublasNDArray(dataBuffer, iArr, iArr2, j, c);
    }

    public INDArray create(float[] fArr, int[] iArr, long j, Character ch) {
        return new JCublasNDArray(fArr, iArr, j, ch.charValue());
    }

    public INDArray create(float[] fArr, long j, long j2, int[] iArr, long j3, char c) {
        return new JCublasNDArray(fArr, new long[]{j, j2}, ArrayUtil.toLongArray(iArr), j3, c);
    }

    public INDArray create(double[] dArr, int[] iArr, char c) {
        return new JCublasNDArray(dArr, iArr, c);
    }

    public INDArray create(double[] dArr, long[] jArr, char c) {
        return new JCublasNDArray(dArr, jArr, c);
    }

    public INDArray create(List<INDArray> list, int[] iArr, char c) {
        return new JCublasNDArray(list, iArr, c);
    }

    public INDArray create(double[] dArr, int[] iArr, long j) {
        return new JCublasNDArray(dArr, iArr, (char) j);
    }

    public INDArray create(double[] dArr, int[] iArr, int[] iArr2, long j, char c) {
        return new JCublasNDArray(dArr, iArr, iArr2, j, c);
    }

    public INDArray create(float[] fArr, int[] iArr, int[] iArr2, long j) {
        return new JCublasNDArray(fArr, iArr, iArr2, j);
    }

    public INDArray create(double[] dArr, int[] iArr, int[] iArr2, long j) {
        return new JCublasNDArray(dArr, iArr, iArr2, j);
    }

    public INDArray create(DataBuffer dataBuffer, int[] iArr) {
        return new JCublasNDArray(dataBuffer, iArr);
    }

    public INDArray create(DataBuffer dataBuffer, int[] iArr, int[] iArr2, long j) {
        return new JCublasNDArray(dataBuffer, iArr, iArr2, j);
    }

    public INDArray create(List<INDArray> list, int[] iArr) {
        return this.order == 'f' ? new JCublasNDArray(list, iArr, ArrayUtil.calcStridesFortran(iArr)) : new JCublasNDArray(list, iArr);
    }

    public INDArray create(float[] fArr, int[] iArr, long j) {
        return new JCublasNDArray(fArr, iArr, j);
    }

    public INDArray create(float[][] fArr) {
        return new JCublasNDArray(fArr);
    }

    public INDArray create(float[][] fArr, char c) {
        return new JCublasNDArray(fArr, c);
    }

    public INDArray create(float[] fArr, int[] iArr, int[] iArr2, long j, char c) {
        return new JCublasNDArray(fArr, iArr, iArr2, j, c);
    }

    public INDArray create(DataBuffer dataBuffer, int[] iArr, long j) {
        return new JCublasNDArray(dataBuffer, iArr, j);
    }

    public INDArray toFlattened(Collection<INDArray> collection) {
        return toFlattened(order(), collection);
    }

    public INDArray toFlattened(char c, Collection<INDArray> collection) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        int i = 0;
        Iterator<INDArray> it = collection.iterator();
        while (it.hasNext()) {
            i = (int) (i + it.next().length());
        }
        INDArray create = Nd4j.create(new int[]{1, i}, c);
        int i2 = 0;
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        for (INDArray iNDArray : collection) {
            CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(create, iNDArray);
            if (iNDArray.ordering() == c && create.elementWiseStride() == iNDArray.elementWiseStride() && create.elementWiseStride() == 1) {
                atomicAllocator.memcpyAsync(create.data(), new CudaPointer(atomicAllocator.getHostPointer(iNDArray).address()), AllocationUtils.getRequiredMemory(AllocationUtils.buildAllocationShape(iNDArray)), i2 * (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE ? 8 : iNDArray.data().dataType() == DataBuffer.Type.FLOAT ? 4 : 2));
                i2 = (int) (i2 + iNDArray.length());
            } else {
                PointerPointer pointerPointer = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(create.shapeInfoDataBuffer()), prepareAction.getOldStream(), atomicAllocator.getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), AddressRetriever.retrieveHostPointer(iNDArray.shapeInfoDataBuffer()), AddressRetriever.retrieveHostPointer(create.shapeInfoDataBuffer())});
                if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
                    this.nativeOps.flattenDouble(pointerPointer, i2, c, atomicAllocator.getPointer(create, prepareAction), atomicAllocator.getPointer(create.shapeInfoDataBuffer(), prepareAction), atomicAllocator.getPointer(iNDArray, prepareAction), atomicAllocator.getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction));
                } else if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
                    this.nativeOps.flattenFloat(pointerPointer, i2, c, atomicAllocator.getPointer(create, prepareAction), atomicAllocator.getPointer(create.shapeInfoDataBuffer(), prepareAction), atomicAllocator.getPointer(iNDArray, prepareAction), atomicAllocator.getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction));
                } else {
                    this.nativeOps.flattenHalf(pointerPointer, i2, c, atomicAllocator.getPointer(create, prepareAction), atomicAllocator.getPointer(create.shapeInfoDataBuffer(), prepareAction), atomicAllocator.getPointer(iNDArray, prepareAction), atomicAllocator.getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction));
                }
                i2 = (int) (i2 + iNDArray.length());
            }
            if (create != null) {
                atomicAllocator.registerAction(prepareAction, create, iNDArray);
            }
        }
        return create;
    }

    public INDArray concat(int i, INDArray... iNDArrayArr) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        if (iNDArrayArr.length == 1) {
            return iNDArrayArr[0];
        }
        int i2 = 0;
        for (int i3 = 0; i3 < iNDArrayArr.length; i3++) {
            if (iNDArrayArr[i3].isCompressed()) {
                Nd4j.getCompressor().decompressi(iNDArrayArr[i3]);
            }
            i2 = (int) (i2 + iNDArrayArr[i3].size(i));
        }
        long[] copy = ArrayUtil.copy(iNDArrayArr[0].shape());
        copy[i] = i2;
        INDArray createUninitialized = Nd4j.createUninitialized(copy, Nd4j.order().charValue());
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(createUninitialized, iNDArrayArr);
        long[] jArr = new long[iNDArrayArr.length];
        long[] jArr2 = new long[iNDArrayArr.length];
        long[] jArr3 = new long[iNDArrayArr.length];
        long[] jArr4 = new long[iNDArrayArr.length];
        long[] jArr5 = new long[iNDArrayArr.length];
        TADManager tADManager = Nd4j.getExecutioner().getTADManager();
        for (int i4 = 0; i4 < iNDArrayArr.length; i4++) {
            jArr[i4] = AddressRetriever.retrieveDeviceAddress(iNDArrayArr[i4].shapeInfoDataBuffer(), prepareAction);
            jArr2[i4] = AtomicAllocator.getInstance().getPointer(iNDArrayArr[i4], prepareAction).address();
            jArr5[i4] = AtomicAllocator.getInstance().getHostPointer(iNDArrayArr[i4].shapeInfoDataBuffer()).address();
            i2 = (int) (i2 + iNDArrayArr[i4].size(i));
            for (int i5 = 0; i5 < iNDArrayArr[i4].rank(); i5++) {
                if (i5 != i && iNDArrayArr[i4].size(i5) != copy[i5]) {
                    throw new IllegalArgumentException("Illegal concatenation at array " + i4 + " and shape element " + i5);
                }
            }
            Pair tADOnlyShapeInfo = tADManager.getTADOnlyShapeInfo(iNDArrayArr[i4], new int[]{i});
            long address = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction).address();
            long address2 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction).address();
            jArr3[i4] = address;
            jArr4[i4] = address2;
        }
        Pair tADOnlyShapeInfo2 = tADManager.getTADOnlyShapeInfo(createUninitialized, new int[]{i});
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(createUninitialized, prepareAction);
        LongPointer retrieveDevicePointer = AddressRetriever.retrieveDevicePointer(createUninitialized.shapeInfoDataBuffer(), prepareAction);
        CudaDoubleDataBuffer cudaDoubleDataBuffer = new CudaDoubleDataBuffer(iNDArrayArr.length);
        CudaDoubleDataBuffer cudaDoubleDataBuffer2 = new CudaDoubleDataBuffer(iNDArrayArr.length);
        CudaDoubleDataBuffer cudaDoubleDataBuffer3 = new CudaDoubleDataBuffer(iNDArrayArr.length);
        CudaDoubleDataBuffer cudaDoubleDataBuffer4 = new CudaDoubleDataBuffer(iNDArrayArr.length);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer, new LongPointer(jArr2), jArr2.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer2, new LongPointer(jArr), jArr.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer3, new LongPointer(jArr3), jArr3.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer4, new LongPointer(jArr4), jArr4.length * 8, 0L);
        Pointer pointer2 = AtomicAllocator.getInstance().getPointer(cudaDoubleDataBuffer, prepareAction);
        Pointer pointer3 = AtomicAllocator.getInstance().getPointer(cudaDoubleDataBuffer2, prepareAction);
        Pointer pointer4 = AtomicAllocator.getInstance().getPointer(cudaDoubleDataBuffer3, prepareAction);
        Pointer pointer5 = AtomicAllocator.getInstance().getPointer(cudaDoubleDataBuffer4, prepareAction);
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(createUninitialized.shapeInfoDataBuffer()), prepareAction.getOldStream(), atomicAllocator.getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), AddressRetriever.retrieveHostPointer(iNDArrayArr[0].shapeInfoDataBuffer()), AddressRetriever.retrieveHostPointer(createUninitialized.shapeInfoDataBuffer()), new LongPointer(jArr5), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond(), prepareAction)});
        if (createUninitialized.data().dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.concatDouble(pointerPointer, i, iNDArrayArr.length, new PointerPointer(new Pointer[]{pointer2}), new PointerPointer(new Pointer[]{pointer3}), pointer, retrieveDevicePointer, new PointerPointer(new Pointer[]{pointer4}), new PointerPointer(new Pointer[]{pointer5}));
        } else if (createUninitialized.data().dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.concatFloat(pointerPointer, i, iNDArrayArr.length, new PointerPointer(new Pointer[]{pointer2}), new PointerPointer(new Pointer[]{pointer3}), (FloatPointer) pointer, retrieveDevicePointer, new PointerPointer(new Pointer[]{pointer4}), new PointerPointer(new Pointer[]{pointer5}));
        } else {
            this.nativeOps.concatHalf(pointerPointer, i, iNDArrayArr.length, new PointerPointer(new Pointer[]{pointer2}), new PointerPointer(new Pointer[]{pointer3}), (ShortPointer) pointer, retrieveDevicePointer, new PointerPointer(new Pointer[]{pointer4}), new PointerPointer(new Pointer[]{pointer5}));
        }
        atomicAllocator.registerAction(prepareAction, createUninitialized, iNDArrayArr);
        return createUninitialized;
    }

    public INDArray specialConcat(int i, INDArray... iNDArrayArr) {
        if (iNDArrayArr.length == 1) {
            return iNDArrayArr[0];
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        PointerPointer pointerPointer = new PointerPointer(iNDArrayArr.length);
        PointerPointer pointerPointer2 = new PointerPointer(iNDArrayArr.length);
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext cudaContext = (CudaContext) atomicAllocator.getDeviceContext().getContext();
        int i2 = 0;
        long[] copy = ArrayUtil.copy(iNDArrayArr[0].shape());
        for (int i3 = 0; i3 < iNDArrayArr.length; i3++) {
            if (iNDArrayArr[i3].isCompressed()) {
                Nd4j.getCompressor().decompressi(iNDArrayArr[i3]);
            }
            atomicAllocator.synchronizeHostData(iNDArrayArr[i3]);
            pointerPointer.put(i3, atomicAllocator.getHostPointer(iNDArrayArr[i3].shapeInfoDataBuffer()));
            pointerPointer2.put(i3, atomicAllocator.getHostPointer(iNDArrayArr[i3].data()));
            i2 = (int) (i2 + iNDArrayArr[i3].size(i));
            for (int i4 = 0; i4 < iNDArrayArr[i3].rank(); i4++) {
                if (i4 != i && iNDArrayArr[i3].size(i4) != copy[i4]) {
                    throw new IllegalArgumentException("Illegal concatenation at array " + i3 + " and shape element " + i4);
                }
            }
        }
        copy[i] = i2;
        PointerPointer pointerPointer3 = new PointerPointer(new Pointer[]{null});
        INDArray createUninitialized = Nd4j.createUninitialized(copy, Nd4j.order().charValue());
        if (createUninitialized.data().dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.specialConcatDouble(pointerPointer3, i, iNDArrayArr.length, pointerPointer2, pointerPointer, createUninitialized.data().addressPointer(), createUninitialized.shapeInfoDataBuffer().addressPointer(), new PointerPointer(new Pointer[]{null}), new PointerPointer(new Pointer[]{null}));
        } else if (createUninitialized.data().dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.specialConcatFloat(pointerPointer3, i, iNDArrayArr.length, pointerPointer2, pointerPointer, createUninitialized.data().addressPointer(), createUninitialized.shapeInfoDataBuffer().addressPointer(), new PointerPointer(new Pointer[]{null}), new PointerPointer(new Pointer[]{null}));
        } else {
            if (createUninitialized.data().dataType() != DataBuffer.Type.HALF) {
                throw new ND4JIllegalStateException("Unknown dataType: " + createUninitialized.data().dataType());
            }
            this.nativeOps.specialConcatHalf(pointerPointer3, i, iNDArrayArr.length, pointerPointer2, pointerPointer, createUninitialized.data().addressPointer(), createUninitialized.shapeInfoDataBuffer().addressPointer(), new PointerPointer(new Pointer[]{null}), new PointerPointer(new Pointer[]{null}));
        }
        AllocationPoint allocationPoint = atomicAllocator.getAllocationPoint(createUninitialized);
        long helperStartTransaction = PerformanceTracker.getInstance().helperStartTransaction();
        this.nativeOps.memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), createUninitialized.lengthLong() * Nd4j.sizeOfDataType(createUninitialized.data().dataType()), CudaConstants.cudaMemcpyHostToDevice, cudaContext.getSpecialStream());
        cudaContext.getSpecialStream().synchronize();
        PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), helperStartTransaction, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
        allocationPoint.tickHostRead();
        allocationPoint.tickDeviceWrite();
        return createUninitialized;
    }

    public INDArray pullRows(INDArray iNDArray, int i, int[] iArr) {
        return pullRows(iNDArray, i, iArr, Nd4j.order().charValue());
    }

    public INDArray pullRows(INDArray iNDArray, int i, long[] jArr) {
        return pullRows(iNDArray, i, ArrayUtil.toInts(jArr));
    }

    public INDArray pullRows(INDArray iNDArray, int i, int[] iArr, char c) {
        long[] jArr;
        if (iArr == null || iArr.length < 1) {
            throw new IllegalStateException("Indexes can't be null or zero-length");
        }
        if (i == 1) {
            jArr = new long[]{iArr.length, iNDArray.shape()[i]};
        } else {
            if (i != 0) {
                throw new UnsupportedOperationException("2D input is expected");
            }
            jArr = new long[]{iNDArray.shape()[i], iArr.length};
        }
        return pullRows(iNDArray, Nd4j.createUninitialized(jArr, c), i, iArr);
    }

    public INDArray pullRows(INDArray iNDArray, INDArray iNDArray2, int i, int[] iArr) {
        long[] jArr;
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        if (iArr == null || iArr.length < 1) {
            throw new IllegalStateException("Indexes can't be null or zero-length");
        }
        if (i == 1) {
            jArr = new long[]{iArr.length, iNDArray.shape()[i]};
        } else {
            if (i != 0) {
                throw new UnsupportedOperationException("2D input is expected");
            }
            jArr = new long[]{iNDArray.shape()[i], iArr.length};
        }
        INDArray iNDArray3 = iNDArray2;
        if (iNDArray3 == null) {
            iNDArray3 = Nd4j.createUninitialized(jArr, this.order);
        } else if (!Arrays.equals(jArr, iNDArray2.shape())) {
            throw new IllegalStateException("Cannot pull rows into destination array: expected destination array of shape " + Arrays.toString(jArr) + " but got destination array of shape " + Arrays.toString(iNDArray2.shape()));
        }
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(iNDArray3, iNDArray);
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction);
        LongPointer pointer2 = AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction);
        DoublePointer pointer3 = AtomicAllocator.getInstance().getPointer(iNDArray3, prepareAction);
        LongPointer pointer4 = AtomicAllocator.getInstance().getPointer(iNDArray3.shapeInfoDataBuffer(), prepareAction);
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(iNDArray3.shapeInfoDataBuffer()), prepareAction.getOldStream(), atomicAllocator.getDeviceIdPointer()});
        CudaLongDataBuffer cudaLongDataBuffer = new CudaLongDataBuffer(iArr.length);
        AtomicAllocator.getInstance().memcpyBlocking(cudaLongDataBuffer, new LongPointer(ArrayUtil.toLongArray(iArr)), iArr.length * 8, 0L);
        LongPointer pointer5 = AtomicAllocator.getInstance().getPointer(cudaLongDataBuffer, prepareAction);
        TADManager tADManager = Nd4j.getExecutioner().getTADManager();
        Pair tADOnlyShapeInfo = tADManager.getTADOnlyShapeInfo(iNDArray, new int[]{i});
        Pair tADOnlyShapeInfo2 = tADManager.getTADOnlyShapeInfo(iNDArray3, new int[]{i});
        LongPointer pointer6 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        LongPointer pointer7 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction);
        Pointer pointer8 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction);
        Pointer pointer9 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond(), prepareAction);
        if (iNDArray3.data().dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.pullRowsDouble(pointerPointer, pointer, pointer2, pointer3, pointer4, iArr.length, pointer5, pointer6, new LongPointerWrapper(pointer8), pointer7, new LongPointerWrapper(pointer9));
        } else if (iNDArray3.data().dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.pullRowsFloat(pointerPointer, (FloatPointer) pointer, pointer2, (FloatPointer) pointer3, pointer4, iArr.length, pointer5, pointer6, new LongPointerWrapper(pointer8), pointer7, new LongPointerWrapper(pointer9));
        } else {
            this.nativeOps.pullRowsHalf(pointerPointer, (ShortPointer) pointer, pointer2, (ShortPointer) pointer3, pointer4, iArr.length, pointer5, pointer6, new LongPointerWrapper(pointer8), pointer7, new LongPointerWrapper(pointer9));
        }
        atomicAllocator.registerAction(prepareAction, iNDArray3, iNDArray);
        return iNDArray3;
    }

    public INDArray accumulate(INDArray iNDArray, INDArray... iNDArrayArr) {
        if (iNDArrayArr == null || iNDArrayArr.length == 0) {
            throw new RuntimeException("Input arrays are missing");
        }
        if (iNDArrayArr.length == 1) {
            return iNDArray.assign(iNDArrayArr[0]);
        }
        if (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !this.nativeOps.isP2PAvailable()) {
            long lengthLong = iNDArray.lengthLong();
            Nd4j.getExecutioner().commit();
            CudaContext cudaContext = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
            PointerPointer pointerPointer = new PointerPointer(iNDArrayArr.length);
            PointerPointer pointerPointer2 = new PointerPointer(new Pointer[]{null, cudaContext.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), new CudaPointer(1L)});
            for (int i = 0; i < iNDArrayArr.length; i++) {
                Nd4j.getCompressor().autoDecompress(iNDArrayArr[i]);
                if (iNDArrayArr[i].elementWiseStride() != 1) {
                    throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays");
                }
                if (iNDArrayArr[i].lengthLong() != lengthLong) {
                    throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
                }
                pointerPointer.put(i, AtomicAllocator.getInstance().getHostPointer(iNDArrayArr[i]));
            }
            if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
                this.nativeOps.accumulateDouble(pointerPointer2, pointerPointer, AtomicAllocator.getInstance().getHostPointer(iNDArray), iNDArrayArr.length, lengthLong);
            } else if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
                this.nativeOps.accumulateFloat(pointerPointer2, pointerPointer, AtomicAllocator.getInstance().getHostPointer(iNDArray), iNDArrayArr.length, lengthLong);
            } else {
                this.nativeOps.accumulateHalf(pointerPointer2, pointerPointer, AtomicAllocator.getInstance().getHostPointer(iNDArray), iNDArrayArr.length, lengthLong);
            }
            AtomicAllocator.getInstance().getAllocationPoint(iNDArray).tickHostWrite();
            return iNDArray;
        }
        Nd4j.getExecutioner().push();
        long lengthLong2 = iNDArray.lengthLong();
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(iNDArray, iNDArrayArr);
        PointerPointer pointerPointer3 = new PointerPointer(new Pointer[]{null, prepareAction.getOldStream(), atomicAllocator.getDeviceIdPointer(), new CudaPointer(0L)});
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction);
        long[] jArr = new long[iNDArrayArr.length];
        for (int i2 = 0; i2 < iNDArrayArr.length; i2++) {
            if (iNDArrayArr[i2].elementWiseStride() != 1) {
                throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays");
            }
            if (iNDArrayArr[i2].lengthLong() != lengthLong2) {
                throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
            }
            AllocationPoint allocationPoint = atomicAllocator.getAllocationPoint(iNDArrayArr[i2]);
            jArr[i2] = allocationPoint.getPointers().getDevicePointer().address();
            allocationPoint.tickDeviceWrite();
        }
        CudaDoubleDataBuffer cudaDoubleDataBuffer = new CudaDoubleDataBuffer(iNDArrayArr.length);
        atomicAllocator.memcpyBlocking(cudaDoubleDataBuffer, new LongPointer(jArr), jArr.length * 8, 0L);
        PointerPointer pointerPointer4 = new PointerPointer(AtomicAllocator.getInstance().getPointer(cudaDoubleDataBuffer, prepareAction));
        if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.accumulateDouble(pointerPointer3, pointerPointer4, pointer, iNDArrayArr.length, lengthLong2);
        } else if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.accumulateFloat(pointerPointer3, pointerPointer4, (FloatPointer) pointer, iNDArrayArr.length, lengthLong2);
        } else {
            this.nativeOps.accumulateHalf(pointerPointer3, pointerPointer4, (ShortPointer) pointer, iNDArrayArr.length, lengthLong2);
        }
        atomicAllocator.getFlowController().registerAction(prepareAction, iNDArray, iNDArrayArr);
        cudaDoubleDataBuffer.address();
        return iNDArray;
    }

    public INDArray average(INDArray iNDArray, INDArray[] iNDArrayArr) {
        if (iNDArrayArr == null || iNDArrayArr.length == 0) {
            throw new RuntimeException("Input arrays are missing");
        }
        if (iNDArrayArr.length == 1) {
            if (iNDArray == null) {
                return null;
            }
            return iNDArray.assign(iNDArrayArr[0]);
        }
        if (!this.nativeOps.isP2PAvailable() || !CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed()) {
            long lengthLong = iNDArray == null ? iNDArrayArr[0].lengthLong() : iNDArray.lengthLong();
            CudaContext cudaContext = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
            PointerPointer pointerPointer = new PointerPointer(iNDArrayArr.length);
            PointerPointer pointerPointer2 = new PointerPointer(new Pointer[]{null, cudaContext.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), new CudaPointer(1L)});
            for (int i = 0; i < iNDArrayArr.length; i++) {
                Nd4j.getCompressor().autoDecompress(iNDArrayArr[i]);
                if (iNDArrayArr[i].elementWiseStride() != 1) {
                    throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays");
                }
                if (iNDArrayArr[i].lengthLong() != lengthLong) {
                    throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
                }
                pointerPointer.put(i, AtomicAllocator.getInstance().getHostPointer(iNDArrayArr[i]));
            }
            if (iNDArrayArr[0].data().dataType() == DataBuffer.Type.DOUBLE) {
                this.nativeOps.averageDouble(pointerPointer2, pointerPointer, iNDArray == null ? null : AtomicAllocator.getInstance().getHostPointer(iNDArray), iNDArrayArr.length, lengthLong, true);
            } else if (iNDArrayArr[0].data().dataType() == DataBuffer.Type.FLOAT) {
                this.nativeOps.averageFloat(pointerPointer2, pointerPointer, iNDArray == null ? null : AtomicAllocator.getInstance().getHostPointer(iNDArray), iNDArrayArr.length, lengthLong, true);
            } else {
                this.nativeOps.averageHalf(pointerPointer2, pointerPointer, iNDArray == null ? null : AtomicAllocator.getInstance().getHostPointer(iNDArray), iNDArrayArr.length, lengthLong, true);
            }
            if (iNDArray != null) {
                AtomicAllocator.getInstance().getAllocationPoint(iNDArray).tickHostWrite();
            }
            for (INDArray iNDArray2 : iNDArrayArr) {
                AtomicAllocator.getInstance().getAllocationPoint(iNDArray2).tickHostWrite();
            }
            return iNDArray;
        }
        Nd4j.getExecutioner().push();
        long lengthLong2 = iNDArray != null ? iNDArray.lengthLong() : iNDArrayArr[0].lengthLong();
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(iNDArray, iNDArrayArr);
        PointerPointer pointerPointer3 = new PointerPointer(new Pointer[]{null, prepareAction.getOldStream(), atomicAllocator.getDeviceIdPointer(), new CudaPointer(0L)});
        Pointer pointer = iNDArray == null ? null : AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction);
        long[] jArr = new long[iNDArrayArr.length];
        for (int i2 = 0; i2 < iNDArrayArr.length; i2++) {
            if (iNDArrayArr[i2].elementWiseStride() != 1) {
                throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays");
            }
            if (iNDArrayArr[i2].lengthLong() != lengthLong2) {
                throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
            }
            AllocationPoint allocationPoint = atomicAllocator.getAllocationPoint(iNDArrayArr[i2]);
            jArr[i2] = allocationPoint.getPointers().getDevicePointer().address();
            allocationPoint.tickDeviceWrite();
        }
        CudaDoubleDataBuffer cudaDoubleDataBuffer = new CudaDoubleDataBuffer(iNDArrayArr.length);
        atomicAllocator.memcpyBlocking(cudaDoubleDataBuffer, new LongPointer(jArr), jArr.length * 8, 0L);
        PointerPointer pointerPointer4 = new PointerPointer(AtomicAllocator.getInstance().getPointer(cudaDoubleDataBuffer, prepareAction));
        if (iNDArrayArr[0].data().dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.averageDouble(pointerPointer3, pointerPointer4, iNDArray == null ? null : (DoublePointer) pointer, iNDArrayArr.length, lengthLong2, true);
        } else if (iNDArrayArr[0].data().dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.averageFloat(pointerPointer3, pointerPointer4, iNDArray == null ? null : (FloatPointer) pointer, iNDArrayArr.length, lengthLong2, true);
        } else {
            this.nativeOps.averageHalf(pointerPointer3, pointerPointer4, iNDArray == null ? null : (ShortPointer) pointer, iNDArrayArr.length, lengthLong2, true);
        }
        atomicAllocator.getFlowController().registerAction(prepareAction, iNDArray, iNDArrayArr);
        cudaDoubleDataBuffer.address();
        return iNDArray;
    }

    public INDArray average(Collection<INDArray> collection) {
        return average((INDArray[]) collection.toArray(new INDArray[0]));
    }

    public INDArray average(INDArray[] iNDArrayArr) {
        if (iNDArrayArr == null || iNDArrayArr.length == 0) {
            throw new RuntimeException("Input arrays are missing");
        }
        return average(Nd4j.createUninitialized(iNDArrayArr[0].shape(), iNDArrayArr[0].ordering()), iNDArrayArr);
    }

    public INDArray average(INDArray iNDArray, Collection<INDArray> collection) {
        return average(iNDArray, (INDArray[]) collection.toArray(new INDArray[0]));
    }

    public void shuffle(INDArray iNDArray, Random random, int... iArr) {
        shuffle(Collections.singletonList(iNDArray), random, iArr);
    }

    public void shuffle(List<INDArray> list, Random random, List<int[]> list2) {
        if (list2 == null || list2.size() == 0) {
            throw new RuntimeException("Dimension can't be null or 0-length");
        }
        if (list == null || list.size() == 0) {
            throw new RuntimeException("No input arrays provided");
        }
        if (list2.size() > 1 && list.size() != list2.size()) {
            throw new IllegalStateException("Number of dimensions do not match number of arrays to shuffle");
        }
        Nd4j.getExecutioner().push();
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext cudaContext = null;
        for (int i = 0; i < list.size(); i++) {
            cudaContext = atomicAllocator.getFlowController().prepareAction(list.get(i), new INDArray[0]);
        }
        int i2 = 1;
        for (int i3 = 0; i3 < list2.get(0).length; i3++) {
            i2 = (int) (i2 * list.get(0).shape()[list2.get(0)[i3]]);
        }
        long length = list.get(0).length() / i2;
        CudaIntDataBuffer cudaIntDataBuffer = new CudaIntDataBuffer(ArrayUtil.buildInterleavedVector(random, (int) length));
        Pointer pointer = atomicAllocator.getPointer(cudaIntDataBuffer, cudaContext);
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{null, cudaContext.getOldStream(), atomicAllocator.getDeviceIdPointer()});
        long[] jArr = new long[list.size()];
        long[] jArr2 = new long[list.size()];
        long[] jArr3 = new long[list.size()];
        long[] jArr4 = new long[list.size()];
        for (int i4 = 0; i4 < list.size(); i4++) {
            INDArray iNDArray = list.get(i4);
            Pointer pointer2 = AtomicAllocator.getInstance().getPointer(iNDArray, cudaContext);
            Pointer pointer3 = AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer(), cudaContext);
            Pair tADOnlyShapeInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(iNDArray, list2.size() > 1 ? list2.get(i4) : list2.get(0));
            Pointer pointer4 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), cudaContext);
            DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
            if (dataBuffer.length() != length) {
                throw new ND4JIllegalStateException("Can't symmetrically shuffle arrays with non-equal number of TADs");
            }
            Pointer pointer5 = AtomicAllocator.getInstance().getPointer(dataBuffer, cudaContext);
            jArr[i4] = pointer2.address();
            jArr2[i4] = pointer3.address();
            jArr3[i4] = pointer4.address();
            jArr4[i4] = pointer5.address();
        }
        CudaDoubleDataBuffer cudaDoubleDataBuffer = new CudaDoubleDataBuffer(list.size());
        CudaDoubleDataBuffer cudaDoubleDataBuffer2 = new CudaDoubleDataBuffer(list.size());
        CudaDoubleDataBuffer cudaDoubleDataBuffer3 = new CudaDoubleDataBuffer(list.size());
        CudaDoubleDataBuffer cudaDoubleDataBuffer4 = new CudaDoubleDataBuffer(list.size());
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer, new LongPointer(jArr), jArr.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer2, new LongPointer(jArr2), jArr.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer3, new LongPointer(jArr3), jArr.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer4, new LongPointer(jArr4), jArr.length * 8, 0L);
        if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.shuffleDouble(pointerPointer, new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer2, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer2, cudaContext)), list.size(), (IntPointer) pointer, new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer3, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer4, cudaContext)));
        } else if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.shuffleFloat(pointerPointer, new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer2, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer2, cudaContext)), list.size(), (IntPointer) pointer, new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer3, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer4, cudaContext)));
        } else {
            this.nativeOps.shuffleHalf(pointerPointer, new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer2, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer2, cudaContext)), list.size(), (IntPointer) pointer, new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer3, cudaContext)), new PointerPointer(atomicAllocator.getPointer(cudaDoubleDataBuffer4, cudaContext)));
        }
        for (int i5 = 0; i5 < list.size(); i5++) {
            atomicAllocator.getFlowController().registerAction(cudaContext, list.get(i5), new INDArray[0]);
        }
        cudaIntDataBuffer.address();
        cudaDoubleDataBuffer.dataType();
        cudaDoubleDataBuffer2.dataType();
        cudaDoubleDataBuffer4.dataType();
        cudaDoubleDataBuffer3.dataType();
    }

    public void shuffle(Collection<INDArray> collection, Random random, int... iArr) {
        shuffle(new ArrayList(collection), random, Collections.singletonList(iArr));
    }

    public INDArray convertDataEx(DataBuffer.TypeEx typeEx, INDArray iNDArray, DataBuffer.TypeEx typeEx2) {
        if (iNDArray.isView()) {
            throw new UnsupportedOperationException("Impossible to compress View. Consider using dup() before. ");
        }
        DataBuffer convertDataEx = convertDataEx(typeEx, iNDArray.data(), typeEx2);
        iNDArray.setData(convertDataEx);
        if (convertDataEx instanceof CompressedDataBuffer) {
            iNDArray.markAsCompressed(true);
        } else {
            iNDArray.markAsCompressed(false);
        }
        return iNDArray;
    }

    public void convertDataEx(DataBuffer.TypeEx typeEx, Pointer pointer, DataBuffer.TypeEx typeEx2, Pointer pointer2, long j) {
        this.nativeOps.convertTypes(new PointerPointer(new Pointer[]{null, ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).getOldStream()}), typeEx.ordinal(), pointer, j, typeEx2.ordinal(), pointer2);
    }

    public void convertDataEx(DataBuffer.TypeEx typeEx, Pointer pointer, DataBuffer.TypeEx typeEx2, DataBuffer dataBuffer) {
        cudaStream_t oldStream = ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).getOldStream();
        if (!(dataBuffer instanceof CompressedDataBuffer)) {
            throw new UnsupportedOperationException();
        }
        long compressedLength = ((CompressedDataBuffer) dataBuffer).getCompressionDescriptor().getCompressedLength();
        long originalLength = ((CompressedDataBuffer) dataBuffer).getCompressionDescriptor().getOriginalLength();
        Pointer mallocDevice = this.nativeOps.mallocDevice(originalLength, (Pointer) null, 0);
        Pointer mallocDevice2 = this.nativeOps.mallocDevice(compressedLength, (Pointer) null, 0);
        this.nativeOps.memcpyAsync(mallocDevice, pointer, originalLength, CudaConstants.cudaMemcpyHostToDevice, oldStream);
        convertDataEx(typeEx, mallocDevice, typeEx2, mallocDevice2, dataBuffer.length());
        this.nativeOps.memcpyAsync(dataBuffer.addressPointer(), mallocDevice2, compressedLength, CudaConstants.cudaMemcpyHostToHost, oldStream);
        oldStream.synchronize();
        if (dataBuffer instanceof CompressedDataBuffer) {
            this.nativeOps.freeDevice(mallocDevice, (Pointer) null);
            this.nativeOps.freeDevice(mallocDevice2, (Pointer) null);
        }
    }

    public void convertDataEx(DataBuffer.TypeEx typeEx, DataBuffer dataBuffer, DataBuffer.TypeEx typeEx2, DataBuffer dataBuffer2) {
        cudaStream_t oldStream = ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).getOldStream();
        Pointer pointer = null;
        PagedPointer pagedPointer = null;
        if (Nd4j.getWorkspaceManager().anyWorkspaceActiveForCurrentThread()) {
            MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
            if (dataBuffer instanceof CompressedDataBuffer) {
                long compressedLength = ((CompressedDataBuffer) dataBuffer).getCompressionDescriptor().getCompressedLength();
                pointer = currentWorkspace.alloc(compressedLength, MemoryKind.DEVICE, DataBuffer.Type.HALF, false);
                this.nativeOps.memcpyAsync(pointer, dataBuffer.addressPointer(), compressedLength, CudaConstants.cudaMemcpyHostToHost, oldStream);
            }
            if (dataBuffer2 instanceof CompressedDataBuffer) {
                pagedPointer = currentWorkspace.alloc(((CompressedDataBuffer) dataBuffer2).getCompressionDescriptor().getCompressedLength(), MemoryKind.DEVICE, DataBuffer.Type.HALF, false);
            }
        } else {
            if (dataBuffer instanceof CompressedDataBuffer) {
                log.info("Replacing source ptr");
                long compressedLength2 = ((CompressedDataBuffer) dataBuffer).getCompressionDescriptor().getCompressedLength();
                pointer = this.nativeOps.mallocDevice(compressedLength2, (Pointer) null, 0);
                this.nativeOps.memcpyAsync(pointer, dataBuffer.addressPointer(), compressedLength2, CudaConstants.cudaMemcpyHostToHost, oldStream);
                oldStream.synchronize();
            } else {
                pointer = AtomicAllocator.getInstance().getPointer(dataBuffer);
            }
            if (dataBuffer2 instanceof CompressedDataBuffer) {
                log.info("Replacing target ptr");
                pagedPointer = this.nativeOps.mallocDevice(((CompressedDataBuffer) dataBuffer2).getCompressionDescriptor().getCompressedLength(), (Pointer) null, 0);
            } else {
                pagedPointer = AtomicAllocator.getInstance().getPointer(dataBuffer2);
            }
        }
        convertDataEx(typeEx, pointer, typeEx2, pagedPointer, dataBuffer2.length());
        Nd4j.getExecutioner().commit();
        if (dataBuffer2 instanceof CompressedDataBuffer) {
            this.nativeOps.memcpyAsync(dataBuffer2.addressPointer(), pagedPointer, dataBuffer2.capacity(), CudaConstants.cudaMemcpyHostToHost, oldStream);
            if (!Nd4j.getWorkspaceManager().anyWorkspaceActiveForCurrentThread()) {
                this.nativeOps.freeDevice(pagedPointer, (Pointer) null);
            }
        }
        if ((dataBuffer instanceof CompressedDataBuffer) && !Nd4j.getWorkspaceManager().anyWorkspaceActiveForCurrentThread()) {
            this.nativeOps.freeDevice(pointer, (Pointer) null);
        }
        Nd4j.getExecutioner().commit();
    }

    public DataBuffer convertDataEx(DataBuffer.TypeEx typeEx, DataBuffer dataBuffer, DataBuffer.TypeEx typeEx2) {
        int i;
        DataBuffer createBuffer;
        if (typeEx2.ordinal() <= 2) {
            i = 1;
        } else if (typeEx2.ordinal() <= 5) {
            i = 2;
        } else if (typeEx2.ordinal() == 6) {
            i = 4;
        } else {
            if (typeEx2.ordinal() != 7) {
                throw new UnsupportedOperationException("Unknown target TypeEx: " + typeEx2.name());
            }
            i = 8;
        }
        Nd4j.getExecutioner().commit();
        if (!(dataBuffer instanceof CompressedDataBuffer)) {
            AtomicAllocator.getInstance().synchronizeHostData(dataBuffer);
        }
        if (CompressionUtils.goingToCompress(typeEx, typeEx2)) {
            BytePointer bytePointer = new BytePointer(dataBuffer.length() * i);
            CompressionDescriptor compressionDescriptor = new CompressionDescriptor(dataBuffer, typeEx2.name());
            compressionDescriptor.setCompressionType(CompressionType.LOSSY);
            compressionDescriptor.setCompressedLength(dataBuffer.length() * i);
            createBuffer = new CompressedDataBuffer(bytePointer, compressionDescriptor);
        } else {
            createBuffer = Nd4j.createBuffer(((CompressedDataBuffer) dataBuffer).getCompressionDescriptor().getNumberOfElements(), false);
            AtomicAllocator.getInstance().getAllocationPoint(createBuffer).tickDeviceWrite();
        }
        convertDataEx(typeEx, dataBuffer, typeEx2, createBuffer);
        return createBuffer;
    }

    public INDArray[] tear(INDArray iNDArray, int... iArr) {
        if (iNDArray.isCompressed()) {
            Nd4j.getCompressor().decompressi(iNDArray);
        }
        Arrays.sort(iArr);
        Pair tADOnlyShapeInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(iNDArray, iArr);
        long j = 1;
        long[] jArr = new long[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            j *= iNDArray.shape()[iArr[i]];
            jArr[i] = iNDArray.shape()[iArr[i]];
        }
        int lengthLong = (int) (iNDArray.lengthLong() / j);
        INDArray[] iNDArrayArr = new INDArray[lengthLong];
        long[] jArr2 = new long[lengthLong];
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction((INDArray) null, iNDArray);
        for (int i2 = 0; i2 < lengthLong; i2++) {
            iNDArrayArr[i2] = Nd4j.createUninitialized(jArr);
            prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(iNDArrayArr[i2], new INDArray[0]);
            jArr2[i2] = AtomicAllocator.getInstance().getPointer(iNDArrayArr[i2], prepareAction).address();
        }
        CudaDoubleDataBuffer cudaDoubleDataBuffer = new CudaDoubleDataBuffer(lengthLong);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer, new LongPointer(jArr2), jArr2.length * 8, 0L);
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{null, prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()});
        if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.tearDouble(pointerPointer, AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction), AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction), new PointerPointer(AtomicAllocator.getInstance().getPointer(cudaDoubleDataBuffer, prepareAction)), AtomicAllocator.getInstance().getPointer(iNDArrayArr[0].shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction), new LongPointerWrapper(AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction)));
        } else if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.tearFloat(pointerPointer, AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction), AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction), new PointerPointer(AtomicAllocator.getInstance().getPointer(cudaDoubleDataBuffer, prepareAction)), AtomicAllocator.getInstance().getPointer(iNDArrayArr[0].shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction), new LongPointerWrapper(AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction)));
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            this.nativeOps.tearHalf(pointerPointer, AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction), AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction), new PointerPointer(AtomicAllocator.getInstance().getPointer(cudaDoubleDataBuffer, prepareAction)), AtomicAllocator.getInstance().getPointer(iNDArrayArr[0].shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction), new LongPointerWrapper(AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction)));
        }
        AtomicAllocator.getInstance().getFlowController().registerActionAllWrite(prepareAction, iNDArrayArr);
        AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, (INDArray) null, iNDArrayArr);
        return iNDArrayArr;
    }

    public INDArray sort(INDArray iNDArray, boolean z) {
        if (iNDArray.isScalar()) {
            return iNDArray;
        }
        Nd4j.getExecutioner().push();
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(iNDArray, new INDArray[0]);
        Pointer hostPointer = AtomicAllocator.getInstance().getHostPointer(iNDArray.shapeInfoDataBuffer());
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{hostPointer, prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), hostPointer, AtomicAllocator.getInstance().getHostPointer(iNDArray.shapeInfoDataBuffer()), hostPointer, hostPointer, hostPointer, hostPointer, hostPointer, hostPointer, hostPointer, hostPointer, hostPointer, new CudaPointer(0L)});
        if (!iNDArray.isView() && iNDArray.lengthLong() > 10485760) {
            Nd4j.getExecutioner().commit();
        }
        if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.sortFloat(pointerPointer, AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction), AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction), z);
        } else if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.sortDouble(pointerPointer, AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction), AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction), z);
        } else {
            if (iNDArray.data().dataType() != DataBuffer.Type.HALF) {
                throw new UnsupportedOperationException("Unknown dataType " + iNDArray.data().dataType());
            }
            this.nativeOps.sortHalf(pointerPointer, AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction), AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction), z);
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, iNDArray, new INDArray[0]);
        return iNDArray;
    }

    public INDArray empty(DataBuffer.Type type) {
        Pair createShapeInformation = Nd4j.getShapeInfoProvider().createShapeInformation(new int[0], new int[0], 0L, 1, 'c', ArrayOptionsHelper.setOptionBit(ArrayOptionsHelper.setOptionBit(0L, ArrayType.EMPTY), type));
        return new JCublasNDArray((DataBuffer) null, (CudaLongDataBuffer) createShapeInformation.getFirst(), (long[]) createShapeInformation.getSecond());
    }

    public INDArray sort(INDArray iNDArray, boolean z, int... iArr) {
        if (iNDArray.isScalar()) {
            return iNDArray;
        }
        Arrays.sort(iArr);
        Nd4j.getExecutioner().push();
        Pair tADOnlyShapeInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(iNDArray, iArr);
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(iNDArray, new INDArray[0]);
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{AtomicAllocator.getInstance().getHostPointer(iNDArray.shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()});
        IntPointer pointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction);
        if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.sortTadFloat(pointerPointer, AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction), AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction), pointer, iArr.length, AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction), new LongPointerWrapper(AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction)), z);
        } else if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.sortTadDouble(pointerPointer, AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction), AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction), pointer, iArr.length, AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction), new LongPointerWrapper(AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction)), z);
        } else {
            if (iNDArray.data().dataType() != DataBuffer.Type.HALF) {
                throw new UnsupportedOperationException("Unknown dataType " + iNDArray.data().dataType());
            }
            this.nativeOps.sortTadHalf(pointerPointer, AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction), AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction), pointer, iArr.length, AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction), new LongPointerWrapper(AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction)), z);
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, iNDArray, new INDArray[0]);
        return iNDArray;
    }

    public INDArray create(float[] fArr, long[] jArr, long[] jArr2, long j) {
        return new JCublasNDArray(fArr, jArr, jArr2, j, Nd4j.order().charValue());
    }

    public INDArray create(double[] dArr, long[] jArr, long[] jArr2, long j) {
        return new JCublasNDArray(dArr, jArr, jArr2, j, Nd4j.order().charValue());
    }

    public INDArray create(DataBuffer dataBuffer, long[] jArr) {
        return new JCublasNDArray(dataBuffer, jArr);
    }

    public INDArray create(DataBuffer dataBuffer, long[] jArr, long[] jArr2, long j) {
        return new JCublasNDArray(dataBuffer, jArr, jArr2, j, Nd4j.order().charValue());
    }

    public INDArray create(List<INDArray> list, long[] jArr) {
        return new JCublasNDArray(list, jArr);
    }

    public INDArray create(long j, long j2, long[] jArr, long j3) {
        return create(new long[]{j, j2}, jArr, j3, Nd4j.order().charValue());
    }

    public INDArray create(long[] jArr, char c) {
        return new JCublasNDArray(jArr, 0L, c);
    }

    public INDArray createUninitialized(long[] jArr, char c) {
        return new JCublasNDArray(jArr, Nd4j.getStrides(jArr, c), 0L, c, false);
    }

    public INDArray createUninitializedDetached(long[] jArr, char c) {
        MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        Nd4j.getMemoryManager().setCurrentWorkspace((MemoryWorkspace) null);
        JCublasNDArray jCublasNDArray = new JCublasNDArray(jArr, Nd4j.getStrides(jArr, c), 0L, c, false);
        Nd4j.getMemoryManager().setCurrentWorkspace(currentWorkspace);
        return jCublasNDArray;
    }

    public INDArray create(DataBuffer dataBuffer, long[] jArr, long[] jArr2, long j, char c) {
        return new JCublasNDArray(dataBuffer, jArr, jArr2, j, c);
    }

    public INDArray create(List<INDArray> list, long[] jArr, char c) {
        return new JCublasNDArray(list, jArr, c);
    }

    public INDArray create(float[] fArr, long[] jArr, long[] jArr2, char c, long j) {
        return new JCublasNDArray(fArr, jArr, jArr2, j, c);
    }

    public INDArray create(float[] fArr, long[] jArr, long[] jArr2, long j, char c) {
        return new JCublasNDArray(fArr, jArr, jArr2, j, c);
    }

    public INDArray create(double[] dArr, long[] jArr, long[] jArr2, long j, char c) {
        return new JCublasNDArray(dArr, jArr, jArr2, j, c);
    }

    public INDArray create(float[] fArr, long[] jArr, long j, Character ch) {
        return new JCublasNDArray(fArr, jArr, Nd4j.getStrides(jArr, ch.charValue()), j, ch.charValue());
    }

    public INDArray create(double[] dArr, long[] jArr, long j, Character ch) {
        return new JCublasNDArray(dArr, jArr, Nd4j.getStrides(jArr, ch.charValue()), j, ch.charValue());
    }

    public INDArray create(float[] fArr, long[] jArr, char c) {
        return new JCublasNDArray(fArr, jArr, Nd4j.getStrides(jArr, this.order), 0L, c);
    }

    public INDArray createSparseCSR(double[] dArr, int[] iArr, int[] iArr2, int[] iArr3, long[] jArr) {
        throw new UnsupportedOperationException();
    }

    public INDArray createSparseCSR(float[] fArr, int[] iArr, int[] iArr2, int[] iArr3, long[] jArr) {
        throw new UnsupportedOperationException();
    }

    public INDArray createSparseCSR(DataBuffer dataBuffer, int[] iArr, int[] iArr2, int[] iArr3, long[] jArr) {
        throw new UnsupportedOperationException();
    }

    public INDArray createSparseCOO(double[] dArr, long[][] jArr, long[] jArr2) {
        throw new UnsupportedOperationException();
    }

    public INDArray createSparseCOO(float[] fArr, long[][] jArr, long[] jArr2) {
        throw new UnsupportedOperationException();
    }

    public INDArray createSparseCOO(double[] dArr, int[][] iArr, long[] jArr) {
        return new JCusparseNDArrayCOO(dArr, iArr, jArr);
    }

    public INDArray createSparseCOO(float[] fArr, int[][] iArr, long[] jArr) {
        return new JCusparseNDArrayCOO(fArr, iArr, jArr);
    }

    public INDArray createSparseCOO(DataBuffer dataBuffer, DataBuffer dataBuffer2, long[] jArr) {
        throw new UnsupportedOperationException();
    }

    public INDArray createSparseCOO(DataBuffer dataBuffer, DataBuffer dataBuffer2, DataBuffer dataBuffer3, long[] jArr) {
        throw new UnsupportedOperationException();
    }

    public INDArray createSparseCOO(DataBuffer dataBuffer, DataBuffer dataBuffer2, long[] jArr, int[] iArr, int[] iArr2, int i, long[] jArr2) {
        throw new UnsupportedOperationException();
    }

    public INDArray sortCooIndices(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }
}
