package org.deeplearning4j.parallelism;

import java.util.ArrayList;
import java.util.List;
import java.util.Observer;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.inference.InferenceMode;
import org.deeplearning4j.parallelism.inference.InferenceObservable;
import org.deeplearning4j.parallelism.inference.LoadBalanceMode;
import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable;
import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObserver;
import org.deeplearning4j.parallelism.inference.observers.BatchedInferenceObservable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/parallelism/ParallelInference.class */
public class ParallelInference {
    protected Model model;
    protected long nanos;
    protected int workers;
    protected int batchLimit;
    protected InferenceMode inferenceMode;
    protected int queueLimit;
    private BlockingQueue<InferenceObservable> observables;
    private InferenceWorker[] zoo;
    private ObservablesProvider provider;
    public static final int DEFAULT_BATCH_LIMIT = 32;
    public static final int DEFAULT_QUEUE_LIMIT = 64;
    private static final Logger log = LoggerFactory.getLogger(ParallelInference.class);
    public static final int DEFAULT_NUM_WORKERS = Nd4j.getAffinityManager().getNumberOfDevices();
    public static final InferenceMode DEFAULT_INFERENCE_MODE = InferenceMode.BATCHED;
    protected LoadBalanceMode loadBalanceMode = LoadBalanceMode.FIFO;
    private final Object locker = new Object();

    /* loaded from: input_file:org/deeplearning4j/parallelism/ParallelInference$Builder.class */
    public static class Builder {
        private Model model;
        private int workers = ParallelInference.DEFAULT_NUM_WORKERS;
        private int batchLimit = 32;
        private InferenceMode inferenceMode = ParallelInference.DEFAULT_INFERENCE_MODE;
        private int queueLimit = 64;
        protected LoadBalanceMode loadBalanceMode = LoadBalanceMode.FIFO;

        public Builder(@NonNull Model model) {
            if (model == null) {
                throw new NullPointerException("model is marked @NonNull but is null");
            }
            this.model = model;
        }

        public Builder inferenceMode(@NonNull InferenceMode inferenceMode) {
            if (inferenceMode == null) {
                throw new NullPointerException("inferenceMode is marked @NonNull but is null");
            }
            this.inferenceMode = inferenceMode;
            return this;
        }

        public Builder loadBalanceMode(@NonNull LoadBalanceMode loadBalanceMode) {
            if (loadBalanceMode == null) {
                throw new NullPointerException("loadBalanceMode is marked @NonNull but is null");
            }
            this.loadBalanceMode = loadBalanceMode;
            return this;
        }

        public Builder workers(int i) {
            if (i < 1) {
                throw new IllegalStateException("Workers should be positive value");
            }
            this.workers = i;
            return this;
        }

        public Builder batchLimit(int i) {
            if (i < 1) {
                throw new IllegalStateException("Batch limit should be positive value");
            }
            this.batchLimit = i;
            return this;
        }

        public Builder queueLimit(int i) {
            if (i < 1) {
                throw new IllegalStateException("Queue limit should be positive value");
            }
            this.queueLimit = i;
            return this;
        }

        public ParallelInference build() {
            if (this.inferenceMode == InferenceMode.INPLACE) {
                InplaceParallelInference inplaceParallelInference = new InplaceParallelInference();
                inplaceParallelInference.inferenceMode = this.inferenceMode;
                inplaceParallelInference.model = this.model;
                inplaceParallelInference.workers = this.workers;
                inplaceParallelInference.loadBalanceMode = this.loadBalanceMode;
                inplaceParallelInference.init();
                return inplaceParallelInference;
            }
            ParallelInference parallelInference = new ParallelInference();
            parallelInference.batchLimit = this.batchLimit;
            parallelInference.queueLimit = this.queueLimit;
            parallelInference.inferenceMode = this.inferenceMode;
            parallelInference.model = this.model;
            parallelInference.workers = this.workers;
            parallelInference.loadBalanceMode = this.loadBalanceMode;
            parallelInference.init();
            return parallelInference;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/parallelism/ParallelInference$InferenceWorker.class */
    public class InferenceWorker extends Thread implements Runnable {
        private BlockingQueue<InferenceObservable> inputQueue;
        private AtomicBoolean shouldWork;
        private AtomicBoolean isStopped;
        private Model protoModel;
        private Model replicatedModel;
        private AtomicLong counter;
        private boolean rootDevice;
        private ReentrantReadWriteLock modelLock;

        private InferenceWorker(int i, @NonNull Model model, @NonNull BlockingQueue blockingQueue, boolean z) {
            this.shouldWork = new AtomicBoolean(true);
            this.isStopped = new AtomicBoolean(false);
            this.counter = new AtomicLong(0L);
            this.modelLock = new ReentrantReadWriteLock();
            if (model == null) {
                throw new NullPointerException("model is marked @NonNull but is null");
            }
            if (blockingQueue == null) {
                throw new NullPointerException("inputQueue is marked @NonNull but is null");
            }
            this.inputQueue = blockingQueue;
            this.protoModel = model;
            this.rootDevice = z;
            setDaemon(true);
            setName("InferenceThread-" + i);
        }

        protected long getCounterValue() {
            return this.counter.get();
        }

        protected void updateModel(@NonNull Model model) {
            if (model == null) {
                throw new NullPointerException("model is marked @NonNull but is null");
            }
            try {
                this.modelLock.writeLock().lock();
                this.protoModel = model;
                initializeReplicaModel();
            } finally {
                this.modelLock.writeLock().unlock();
            }
        }

        protected void initializeReplicaModel() {
            if (this.protoModel instanceof ComputationGraph) {
                if (this.rootDevice) {
                    this.replicatedModel = this.protoModel;
                    return;
                }
                this.replicatedModel = new ComputationGraph(ComputationGraphConfiguration.fromJson(this.protoModel.getConfiguration().toJson()));
                this.replicatedModel.init();
                synchronized (ParallelInference.this.locker) {
                    this.replicatedModel.setParams(this.protoModel.params().unsafeDuplication(true));
                    Nd4j.getExecutioner().commit();
                }
                return;
            }
            if (this.protoModel instanceof MultiLayerNetwork) {
                if (this.rootDevice) {
                    this.replicatedModel = this.protoModel;
                    return;
                }
                this.replicatedModel = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(this.protoModel.getLayerWiseConfigurations().toJson()));
                this.replicatedModel.init();
                synchronized (ParallelInference.this.locker) {
                    this.replicatedModel.setParams(this.protoModel.params().unsafeDuplication(true));
                    Nd4j.getExecutioner().commit();
                }
            }
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            try {
                try {
                    try {
                        initializeReplicaModel();
                        boolean z = this.replicatedModel instanceof ComputationGraph;
                        boolean z2 = this.replicatedModel instanceof MultiLayerNetwork;
                        while (this.shouldWork.get()) {
                            InferenceObservable take = this.inputQueue.take();
                            if (take != null) {
                                this.counter.incrementAndGet();
                                if (z) {
                                    List<Pair<INDArray[], INDArray[]>> inputBatches = take.getInputBatches();
                                    ArrayList arrayList = new ArrayList(inputBatches.size());
                                    try {
                                        for (Pair<INDArray[], INDArray[]> pair : inputBatches) {
                                            try {
                                                this.modelLock.readLock().lock();
                                                arrayList.add(this.replicatedModel.output(false, (INDArray[]) pair.getFirst(), (INDArray[]) pair.getSecond()));
                                                this.modelLock.readLock().unlock();
                                            } catch (Throwable th) {
                                                this.modelLock.readLock().unlock();
                                                throw th;
                                                break;
                                            }
                                        }
                                        take.setOutputBatches(arrayList);
                                    } catch (Exception e) {
                                        take.setOutputException(e);
                                    }
                                } else if (z2) {
                                    List<Pair<INDArray[], INDArray[]>> inputBatches2 = take.getInputBatches();
                                    ArrayList arrayList2 = new ArrayList(inputBatches2.size());
                                    try {
                                        for (Pair<INDArray[], INDArray[]> pair2 : inputBatches2) {
                                            INDArray iNDArray = ((INDArray[]) pair2.getFirst())[0];
                                            INDArray iNDArray2 = pair2.getSecond() == null ? null : ((INDArray[]) pair2.getSecond())[0];
                                            try {
                                                this.modelLock.readLock().lock();
                                                arrayList2.add(new INDArray[]{this.replicatedModel.output(iNDArray, false, iNDArray2, (INDArray) null)});
                                                this.modelLock.readLock().unlock();
                                            } catch (Throwable th2) {
                                                this.modelLock.readLock().unlock();
                                                throw th2;
                                                break;
                                            }
                                        }
                                        take.setOutputBatches(arrayList2);
                                    } catch (Exception e2) {
                                        take.setOutputException(e2);
                                    }
                                } else {
                                    continue;
                                }
                            }
                        }
                        this.isStopped.set(true);
                    } catch (Exception e3) {
                        throw new RuntimeException(e3);
                    }
                } catch (InterruptedException e4) {
                    Thread.currentThread().interrupt();
                    this.isStopped.set(true);
                }
            } catch (Throwable th3) {
                this.isStopped.set(true);
                throw th3;
            }
        }

        protected void shutdown() {
            this.shouldWork.set(false);
            do {
            } while (!this.isStopped.get());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/deeplearning4j/parallelism/ParallelInference$ObservablesProvider.class */
    public static class ObservablesProvider {
        private BlockingQueue<InferenceObservable> targetQueue;
        private long nanos;
        private int batchLimit;
        private volatile BatchedInferenceObservable currentObservable;
        private final Object locker = new Object();

        protected ObservablesProvider(long j, int i, @NonNull BlockingQueue<InferenceObservable> blockingQueue) {
            if (blockingQueue == null) {
                throw new NullPointerException("queue is marked @NonNull but is null");
            }
            this.targetQueue = blockingQueue;
            this.nanos = j;
            this.batchLimit = i;
        }

        protected InferenceObservable setInput(@NonNull Observer observer, INDArray iNDArray) {
            if (observer == null) {
                throw new NullPointerException("observer is marked @NonNull but is null");
            }
            return setInput(observer, new INDArray[]{iNDArray}, null);
        }

        protected InferenceObservable setInput(@NonNull Observer observer, INDArray... iNDArrayArr) {
            if (observer == null) {
                throw new NullPointerException("observer is marked @NonNull but is null");
            }
            return setInput(observer, iNDArrayArr, null);
        }

        protected InferenceObservable setInput(@NonNull Observer observer, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
            BatchedInferenceObservable batchedInferenceObservable;
            if (observer == null) {
                throw new NullPointerException("observer is marked @NonNull but is null");
            }
            synchronized (this.locker) {
                boolean z = false;
                if (this.currentObservable == null || this.currentObservable.getCounter() >= this.batchLimit || this.currentObservable.isLocked()) {
                    z = true;
                    this.currentObservable = new BatchedInferenceObservable();
                }
                this.currentObservable.addInput(iNDArrayArr, iNDArrayArr2);
                this.currentObservable.addObserver(observer);
                if (z) {
                    try {
                        this.targetQueue.put(this.currentObservable);
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                        throw new RuntimeException(e);
                    }
                }
                batchedInferenceObservable = this.currentObservable;
            }
            return batchedInferenceObservable;
        }
    }

    public void updateModel(@NonNull Model model) {
        if (model == null) {
            throw new NullPointerException("model is marked @NonNull but is null");
        }
        if (this.zoo == null) {
            this.model = model;
            return;
        }
        for (InferenceWorker inferenceWorker : this.zoo) {
            inferenceWorker.updateModel(model);
        }
    }

    protected Model[] getCurrentModelsFromWorkers() {
        if (this.zoo == null) {
            return new Model[0];
        }
        Model[] modelArr = new Model[this.zoo.length];
        int i = 0;
        for (InferenceWorker inferenceWorker : this.zoo) {
            int i2 = i;
            i++;
            modelArr[i2] = inferenceWorker.replicatedModel;
        }
        return modelArr;
    }

    protected void init() {
        this.observables = new LinkedBlockingQueue(this.queueLimit);
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int intValue = Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue();
        AtomicBoolean atomicBoolean = new AtomicBoolean(false);
        this.zoo = new InferenceWorker[this.workers];
        for (int i = 0; i < this.workers; i++) {
            int i2 = i % numberOfDevices;
            boolean z = !atomicBoolean.get() && i2 == intValue;
            atomicBoolean.compareAndSet(false, z);
            this.zoo[i] = new InferenceWorker(i, this.model, this.observables, z);
            Nd4j.getAffinityManager().attachThreadToDevice(this.zoo[i], Integer.valueOf(i2));
            this.zoo[i].setDaemon(true);
            this.zoo[i].start();
        }
        if (this.inferenceMode == InferenceMode.BATCHED) {
            log.info("Initializing ObservablesProvider...");
            this.provider = new ObservablesProvider(this.nanos, this.batchLimit, this.observables);
        }
    }

    protected long getWorkerCounter(int i) {
        return this.zoo[i].getCounterValue();
    }

    public synchronized void shutdown() {
        if (this.zoo == null) {
            return;
        }
        for (int i = 0; i < this.zoo.length; i++) {
            if (this.zoo[i] != null) {
                this.zoo[i].interrupt();
                this.zoo[i].shutdown();
                this.zoo[i] = null;
            }
        }
        this.zoo = null;
        System.gc();
    }

    public INDArray output(double[] dArr) {
        return output(Nd4j.create(dArr));
    }

    public INDArray output(float[] fArr) {
        return output(Nd4j.create(fArr));
    }

    public INDArray output(INDArray iNDArray) {
        return output(iNDArray, (INDArray) null);
    }

    public INDArray output(INDArray iNDArray, INDArray iNDArray2) {
        INDArray[] output = output(new INDArray[]{iNDArray}, iNDArray2 == null ? null : new INDArray[]{iNDArray2});
        if (output.length != 1) {
            throw new IllegalArgumentException("Network has multiple (" + output.length + ") output arrays, but only a single output can be returned using this method. Use for output(INDArray[] input, INDArray[] inputMasks) for multi-output nets");
        }
        return output[0];
    }

    public INDArray output(DataSet dataSet) {
        return output(dataSet.getFeatures(), dataSet.getFeaturesMaskArray());
    }

    public INDArray[] output(INDArray... iNDArrayArr) {
        return output(iNDArrayArr, (INDArray[]) null);
    }

    public INDArray[] output(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        InferenceObservable input;
        BasicInferenceObserver basicInferenceObserver = new BasicInferenceObserver();
        if (this.inferenceMode == InferenceMode.SEQUENTIAL) {
            input = new BasicInferenceObservable(iNDArrayArr, iNDArrayArr2);
            input.addObserver(basicInferenceObserver);
            try {
                this.observables.put(input);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new RuntimeException(e);
            }
        } else {
            input = this.provider.setInput(basicInferenceObserver, iNDArrayArr, iNDArrayArr2);
        }
        try {
            basicInferenceObserver.waitTillDone();
            return input.getOutput();
        } catch (Exception e2) {
            throw new RuntimeException(e2);
        }
    }
}
