package org.nd4j.jita.allocator.context.impl;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.nd4j.jita.allocator.context.ContextPack;
import org.nd4j.jita.allocator.context.ContextPool;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
/* loaded from: input_file:org/nd4j/jita/allocator/context/impl/PackedContextPool.class */
public class PackedContextPool extends BasicContextPool implements ContextPool {
    private static final Logger log = LoggerFactory.getLogger(PackedContextPool.class);
    protected static final int LANES_PER_THREAD = CudaEnvironment.getInstance().getConfiguration().getCommandLanesNumber();
    private volatile Map<Long, ContextPack> contextsPool = new ConcurrentHashMap();

    @Override // org.nd4j.jita.allocator.context.impl.BasicContextPool, org.nd4j.jita.allocator.context.ContextPool
    public CudaContext acquireContextForDevice(Integer num) {
        return acquireContextPackForDevice(num).getContextForLane(0);
    }

    @Override // org.nd4j.jita.allocator.context.impl.BasicContextPool, org.nd4j.jita.allocator.context.ContextPool
    public ContextPack acquireContextPackForDevice(Integer num) {
        Long valueOf = Long.valueOf(Thread.currentThread().getId());
        if (!this.contextsPool.containsKey(valueOf)) {
            try {
                try {
                    this.lock.acquire();
                    ContextPack contextPack = new ContextPack(LANES_PER_THREAD);
                    for (int i = 0; i < LANES_PER_THREAD; i++) {
                        CudaContext createNewStream = createNewStream(num);
                        getDeviceBuffers(createNewStream, num.intValue());
                        if (this.cublasPool.get(num) == null) {
                            log.debug("Creating new cuBLAS handle for device [{}]", num);
                            cudaStream_t oldStream = createNewStream(num).getOldStream();
                            cublasHandle_t createNewCublasHandle = createNewCublasHandle(oldStream);
                            createNewStream.setHandle(createNewCublasHandle);
                            createNewStream.setCublasStream(oldStream);
                            this.cublasPool.put(num, createNewCublasHandle);
                            log.debug("Creating new cuSolver handle for device [{}]...", num);
                            cudaStream_t oldStream2 = createNewStream(num).getOldStream();
                            cusolverDnHandle_t createNewSolverHandle = createNewSolverHandle(oldStream2);
                            createNewStream.setSolverHandle(createNewSolverHandle);
                            createNewStream.setSolverStream(oldStream2);
                            this.solverPool.put(num, createNewSolverHandle);
                        } else {
                            log.debug("Reusing cuBLAS handle for device [{}]", num);
                            createNewStream.setHandle(this.cublasPool.get(num));
                            log.debug("Reusing solver here...");
                            createNewStream.setSolverHandle(this.solverPool.get(num));
                        }
                        contextPack.addLane(Integer.valueOf(i), createNewStream);
                    }
                    this.contextsPool.put(valueOf, contextPack);
                    this.lock.release();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            } catch (Throwable th) {
                this.lock.release();
                throw th;
            }
        }
        return this.contextsPool.get(valueOf);
    }
}
