package org.deeplearning4j.nn.multilayer;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.api.FwdPassType;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.OutputAdapter;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.util.CrashReportingUtil;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.NetworkUtils;
import org.deeplearning4j.util.OutputLayerUtil;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.workspace.ND4JWorkspaceException;
import org.nd4j.linalg.workspace.WorkspaceUtils;
import org.nd4j.util.OneTimeLogger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/multilayer/MultiLayerNetwork.class */
public class MultiLayerNetwork implements Serializable, Classifier, Layer, NeuralNetwork {
    protected Layer[] layers;
    protected LinkedHashMap<String, Layer> layerMap;
    protected INDArray input;
    protected INDArray labels;
    protected boolean initCalled;
    protected Collection<TrainingListener> trainingListeners;
    protected NeuralNetConfiguration defaultConfiguration;
    protected MultiLayerConfiguration layerWiseConfigurations;
    protected Gradient gradient;
    protected double score;
    protected boolean initDone;
    protected INDArray flattenedParams;
    protected transient INDArray flattenedGradients;
    protected boolean clearTbpttState;
    protected transient ThreadLocal<Long> lastEtlTime;
    protected INDArray mask;
    protected int layerIndex;
    protected transient Solver solver;
    protected transient Map<String, Pointer> helperWorkspaces;
    protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM";
    protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT";
    protected static final String WS_LAYER_ACT_1 = "WS_LAYER_ACT_1";
    protected static final String WS_LAYER_ACT_2 = "WS_LAYER_ACT_2";
    protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM";
    protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM";
    protected WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG;
    protected WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG;
    private static final Logger log = LoggerFactory.getLogger(MultiLayerNetwork.class);
    protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(0.05d).policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(0.05d).policyReset(ResetPolicy.BLOCK_LEFT).policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.FIRST_LOOP).build();

    public MultiLayerNetwork(MultiLayerConfiguration multiLayerConfiguration) {
        this.layerMap = new LinkedHashMap<>();
        this.initCalled = false;
        this.trainingListeners = new ArrayList();
        this.initDone = false;
        this.clearTbpttState = true;
        this.lastEtlTime = new ThreadLocal<>();
        this.helperWorkspaces = new HashMap();
        this.layerWiseConfigurations = multiLayerConfiguration;
        this.defaultConfiguration = multiLayerConfiguration.getConf(0).m33clone();
        this.WS_LAYER_WORKING_MEM_CONFIG = getLayerWorkingMemWSConfig(2 * (this.layerWiseConfigurations.getConfs().size() + this.layerWiseConfigurations.getInputPreProcessors().size()));
        this.WS_LAYER_ACT_X_CONFIG = getLayerActivationWSConfig(this.layerWiseConfigurations.getConfs().size());
    }

    protected static WorkspaceConfiguration getLayerWorkingMemWSConfig(int i) {
        return WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.02d).policyLearning(LearningPolicy.OVER_TIME).cyclesBeforeInitialization(i).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    }

    protected static WorkspaceConfiguration getLayerActivationWSConfig(int i) {
        return WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.02d).policyLearning(LearningPolicy.OVER_TIME).cyclesBeforeInitialization(i).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setCacheMode(CacheMode cacheMode) {
        if (cacheMode == null) {
            cacheMode = CacheMode.NONE;
        }
        for (Layer layer : this.layers) {
            layer.setCacheMode(cacheMode);
        }
    }

    public void setLastEtlTime(long j) {
        this.lastEtlTime.set(Long.valueOf(j));
    }

    public long getLastEtlTime() {
        Long l = this.lastEtlTime.get();
        if (l == null) {
            return 0L;
        }
        return l.longValue();
    }

    public MultiLayerNetwork(String str, INDArray iNDArray) {
        this(MultiLayerConfiguration.fromJson(str));
        init();
        setParameters(iNDArray);
    }

    public MultiLayerNetwork(MultiLayerConfiguration multiLayerConfiguration, INDArray iNDArray) {
        this(multiLayerConfiguration);
        init();
        setParameters(iNDArray);
    }

    protected void intializeConfigurations() {
        if (this.layerWiseConfigurations == null) {
            this.layerWiseConfigurations = new MultiLayerConfiguration.Builder().build();
        }
        if (this.layers == null) {
            this.layers = new Layer[getnLayers()];
        }
        if (this.defaultConfiguration == null) {
            this.defaultConfiguration = new NeuralNetConfiguration.Builder().build();
        }
    }

    public void pretrain(DataSetIterator dataSetIterator) {
        pretrain(dataSetIterator, 1);
    }

    public void pretrain(DataSetIterator dataSetIterator, int i) {
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        for (int i2 = 0; i2 < getnLayers(); i2++) {
            pretrainLayer(i2, dataSetIterator, i);
        }
    }

    public void pretrainLayer(int i, DataSetIterator dataSetIterator) {
        pretrainLayer(i, dataSetIterator, 1);
    }

    public void pretrainLayer(int i, DataSetIterator dataSetIterator, int i2) {
        Preconditions.checkState(i2 > 0, "Number of epochs (%s) must be a positive number", i2);
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        if (i >= this.layers.length) {
            throw new IllegalArgumentException("Cannot pretrain layer: layerIdx (" + i + ") >= numLayers (" + this.layers.length + ")");
        }
        if (this.layers[i].isPretrainLayer()) {
            if (i2 > 1 && !dataSetIterator.resetSupported()) {
                throw new IllegalStateException("Cannot fit multiple epochs (" + i2 + ") on an iterator that doesn't support resetting");
            }
            if (!dataSetIterator.hasNext() && dataSetIterator.resetSupported()) {
                dataSetIterator.reset();
            }
            log.info("Starting unsupervised training on layer " + i + " for " + i2 + " epochs");
            for (int i3 = 0; i3 < i2; i3++) {
                if (i3 > 0) {
                    dataSetIterator.reset();
                }
                while (dataSetIterator.hasNext()) {
                    this.input = ((DataSet) dataSetIterator.next()).getFeatures();
                    pretrainLayer(i, this.input);
                }
            }
            getLayer(i).conf().setEpochCount(getLayer(i).conf().getEpochCount() + 1);
        }
    }

    public void pretrainLayer(int i, INDArray iNDArray) {
        setInput(iNDArray);
        setLayerMaskArrays(null, null);
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        if (i >= this.layers.length) {
            throw new IllegalArgumentException("Cannot pretrain layer: layerIdx (" + i + ") >= numLayers (" + this.layers.length + ")");
        }
        LayerWorkspaceMgr noWorkspaces = this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().defaultWorkspace(WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
        noWorkspaces.setHelperWorkspacePointers(this.helperWorkspaces);
        Layer layer = this.layers[i];
        if (layer.isPretrainLayer()) {
            INDArray outputOfLayerDetached = i == 0 ? this.input : outputOfLayerDetached(false, FwdPassType.STANDARD, this.layerIndex - 1, iNDArray, null, null, null);
            MemoryWorkspace notifyScopeEntered = noWorkspaces.notifyScopeEntered(ArrayType.FF_WORKING_MEM);
            Throwable th = null;
            try {
                try {
                    if (this.layerWiseConfigurations.getInputPreProcess(i) != null) {
                        outputOfLayerDetached = this.layerWiseConfigurations.getInputPreProcess(i).preProcess(outputOfLayerDetached, (int) this.input.size(0), LayerWorkspaceMgr.noWorkspaces(this.helperWorkspaces));
                    }
                    layer.fit(outputOfLayerDetached, noWorkspaces);
                    if (notifyScopeEntered != null) {
                        if (0 == 0) {
                            notifyScopeEntered.close();
                            return;
                        }
                        try {
                            notifyScopeEntered.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } catch (Throwable th4) {
                if (notifyScopeEntered != null) {
                    if (th != null) {
                        try {
                            notifyScopeEntered.close();
                        } catch (Throwable th5) {
                            th.addSuppressed(th5);
                        }
                    } else {
                        notifyScopeEntered.close();
                    }
                }
                throw th4;
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int batchSize() {
        return (int) this.input.size(0);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public NeuralNetConfiguration conf() {
        return this.defaultConfiguration;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setConf(NeuralNetConfiguration neuralNetConfiguration) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray input() {
        return this.input;
    }

    @Override // org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public ConvexOptimizer getOptimizer() {
        return this.solver.getOptimizer();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray getParam(String str) {
        int indexOf = str.indexOf(95);
        if (indexOf == -1) {
            throw new IllegalStateException("Invalid param key: does not have layer separator: \"" + str + "\"");
        }
        return this.layers[Integer.parseInt(str.substring(0, indexOf))].getParam(str.substring(indexOf + 1));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable() {
        return paramTable(false);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable(boolean z) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i = 0; i < this.layers.length; i++) {
            for (Map.Entry<String, INDArray> entry : this.layers[i].paramTable(z).entrySet()) {
                linkedHashMap.put(i + "_" + entry.getKey(), entry.getValue());
            }
        }
        return linkedHashMap;
    }

    @Override // org.deeplearning4j.nn.api.Trainable
    public boolean updaterDivideByMinibatch(String str) {
        int indexOf = str.indexOf(95);
        int parseInt = Integer.parseInt(str.substring(0, indexOf));
        return getLayer(parseInt).updaterDivideByMinibatch(str.substring(indexOf + 1));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamTable(Map<String, INDArray> map) {
        Map<String, INDArray> paramTable = paramTable();
        if (!paramTable.keySet().equals(map.keySet())) {
            throw new IllegalArgumentException("Cannot set param table: parameter keys do not match.\nCurrent: " + paramTable.keySet() + "\nTo set: " + map.keySet());
        }
        for (String str : map.keySet()) {
            INDArray iNDArray = paramTable.get(str);
            INDArray iNDArray2 = map.get(str);
            if (!Arrays.equals(iNDArray.shape(), iNDArray2.shape())) {
                throw new IllegalArgumentException("Cannot set parameter table: parameter \"" + str + "\" shapes do not match. Current = " + Arrays.toString(iNDArray.shape()) + ", to set = " + Arrays.toString(iNDArray2.shape()));
            }
        }
        for (String str2 : map.keySet()) {
            paramTable.get(str2).assign(map.get(str2));
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParam(String str, INDArray iNDArray) {
        int indexOf = str.indexOf(95);
        if (indexOf == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + str + "\"");
        }
        this.layers[Integer.parseInt(str.substring(0, indexOf))].setParam(str.substring(indexOf + 1), iNDArray);
    }

    public MultiLayerConfiguration getLayerWiseConfigurations() {
        return this.layerWiseConfigurations;
    }

    public void setLayerWiseConfigurations(MultiLayerConfiguration multiLayerConfiguration) {
        this.layerWiseConfigurations = multiLayerConfiguration;
    }

    @Override // org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public void init() {
        init(null, false);
    }

    public void init(INDArray iNDArray, boolean z) {
        boolean z2;
        if (this.layerWiseConfigurations == null || this.layers == null) {
            intializeConfigurations();
        }
        if (this.initCalled) {
            return;
        }
        DataType dataType = getLayerWiseConfigurations().getDataType();
        if (iNDArray != null && iNDArray.dataType() != dataType) {
            if (!z) {
                throw new IllegalStateException("Error initializing network: Network datatype is set to " + dataType + " but provided array has datatype " + iNDArray.dataType() + " with cloneParametersArray argument set to false. Cannot initialize net with specified datatype array if that array does not match network datatype");
            }
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
            Throwable th = null;
            try {
                try {
                    iNDArray = iNDArray.castTo(dataType);
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (scopeOutOfWorkspaces != null) {
                    if (th != null) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                throw th3;
            }
        }
        if (this.layerMap == null) {
            this.layerMap = new LinkedHashMap<>();
        }
        if (this.layerWiseConfigurations.getTrainingWorkspaceMode() == null) {
            this.layerWiseConfigurations.setTrainingWorkspaceMode(WorkspaceMode.NONE);
        }
        if (this.layerWiseConfigurations.getInferenceWorkspaceMode() == null) {
            this.layerWiseConfigurations.setInferenceWorkspaceMode(WorkspaceMode.NONE);
        }
        if (this.layerWiseConfigurations.getCacheMode() == null) {
            this.layerWiseConfigurations.setCacheMode(CacheMode.NONE);
        }
        OneTimeLogger.info(log, "Starting MultiLayerNetwork with WorkspaceModes set to [training: {}; inference: {}], cacheMode set to [{}]", new Object[]{this.layerWiseConfigurations.getTrainingWorkspaceMode(), this.layerWiseConfigurations.getInferenceWorkspaceMode(), this.layerWiseConfigurations.getCacheMode()});
        int i = getnLayers();
        if (i < 1) {
            throw new IllegalStateException("Unable to create network: number of layers is less than 1");
        }
        if (this.layers == null || this.layers[0] == null) {
            if (this.layers == null) {
                this.layers = new Layer[i];
            }
            long j = 0;
            long[] jArr = new long[i];
            for (int i2 = 0; i2 < i; i2++) {
                NeuralNetConfiguration conf = this.layerWiseConfigurations.getConf(i2);
                jArr[i2] = conf.getLayer().initializer().numParams(conf);
                j += jArr[i2];
            }
            if (iNDArray != null) {
                if (!iNDArray.isRowVectorOrScalar()) {
                    throw new IllegalArgumentException("Invalid parameters: should be a row vector");
                }
                if (iNDArray.length() != j) {
                    throw new IllegalArgumentException("Invalid parameters: expected length " + j + ", got length " + iNDArray.length());
                }
                if (z) {
                    this.flattenedParams = iNDArray.dup();
                } else {
                    this.flattenedParams = iNDArray;
                }
                z2 = false;
            } else if (j > 0) {
                this.flattenedParams = Nd4j.create(dataType, new long[]{1, j});
                z2 = true;
            } else {
                this.flattenedParams = null;
                z2 = false;
            }
            if (z2) {
                Nd4j.getRandom().setSeed(getDefaultConfiguration().getSeed());
            }
            long j2 = 0;
            for (int i3 = 0; i3 < i; i3++) {
                INDArray iNDArray2 = jArr[i3] > 0 ? this.flattenedParams.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j2, j2 + jArr[i3])}) : null;
                j2 += jArr[i3];
                NeuralNetConfiguration conf2 = this.layerWiseConfigurations.getConf(i3);
                this.layers[i3] = conf2.getLayer().instantiate(conf2, this.trainingListeners, i3, iNDArray2, z2, dataType);
                this.layerMap.put(conf2.getLayer().getLayerName(), this.layers[i3]);
            }
            this.initCalled = true;
        }
        this.defaultConfiguration.clearVariables();
        List<String> variables = this.defaultConfiguration.variables(false);
        for (int i4 = 0; i4 < this.layers.length; i4++) {
            if (this.layers[i4] == null) {
                throw new IllegalStateException("Encountered null layer during initialization for layer " + i4 + ": " + this.layerWiseConfigurations.getConf(i4).getLayer().getClass().getSimpleName() + " initialization returned null layer?");
            }
            Iterator<String> it = this.layers[i4].conf().variables().iterator();
            while (it.hasNext()) {
                variables.add(i4 + "_" + it.next());
            }
        }
        if (this.solver == null) {
            MemoryWorkspace scopeOutOfWorkspaces2 = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th5 = null;
            try {
                try {
                    this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                    this.solver.initOptimizer();
                    if (scopeOutOfWorkspaces2 != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces2.close();
                            } catch (Throwable th6) {
                                th5.addSuppressed(th6);
                            }
                        } else {
                            scopeOutOfWorkspaces2.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th7) {
                if (scopeOutOfWorkspaces2 != null) {
                    if (th5 != null) {
                        try {
                            scopeOutOfWorkspaces2.close();
                        } catch (Throwable th8) {
                            th5.addSuppressed(th8);
                        }
                    } else {
                        scopeOutOfWorkspaces2.close();
                    }
                }
                throw th7;
            }
        }
        for (int i5 = 1; i5 < this.layers.length; i5++) {
            this.layers[i5].allowInputModification(true);
        }
        synchronizeIterEpochCounts();
    }

    public void setGradientsAccumulator(GradientsAccumulator gradientsAccumulator) {
        if (!isInitCalled()) {
            init();
        }
        this.solver.getOptimizer().setGradientsAccumulator(gradientsAccumulator);
    }

    public boolean isInitCalled() {
        return this.initCalled;
    }

    public void initGradientsView() {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            if (this.layers == null) {
                init();
            }
            int length = this.layers.length;
            long j = 0;
            long[] jArr = new long[length];
            for (int i = 0; i < length; i++) {
                NeuralNetConfiguration conf = this.layerWiseConfigurations.getConf(i);
                jArr[i] = conf.getLayer().initializer().numParams(conf);
                j += jArr[i];
            }
            if (j > 0) {
                this.flattenedGradients = Nd4j.create(this.flattenedParams.dataType(), new long[]{1, j}, 'f');
            }
            long j2 = 0;
            for (int i2 = 0; i2 < this.layers.length; i2++) {
                if (jArr[i2] != 0) {
                    this.layers[i2].setBackpropGradientsViewArray(this.flattenedGradients.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j2, j2 + jArr[i2])}));
                    j2 += jArr[i2];
                }
            }
            if (scopeOutOfWorkspaces != null) {
                if (0 == 0) {
                    scopeOutOfWorkspaces.close();
                    return;
                }
                try {
                    scopeOutOfWorkspaces.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (0 != 0) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    protected INDArray activationFromPrevLayer(int i, INDArray iNDArray, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (getLayerWiseConfigurations().getInputPreProcess(i) != null) {
            iNDArray = getLayerWiseConfigurations().getInputPreProcess(i).preProcess(iNDArray, getInputMiniBatchSize(), layerWorkspaceMgr);
        }
        return this.layers[i].activate(iNDArray, z, layerWorkspaceMgr);
    }

    public INDArray activateSelectedLayers(int i, int i2, INDArray iNDArray) {
        if (iNDArray == null) {
            throw new IllegalStateException("Unable to perform activation; no input found");
        }
        if (i < 0 || i >= this.layers.length || i >= i2) {
            throw new IllegalStateException("Unable to perform activation; FROM is out of layer space");
        }
        if (i2 < 1 || i2 >= this.layers.length) {
            throw new IllegalStateException("Unable to perform activation; TO is out of layer space");
        }
        try {
            LayerWorkspaceMgr noWorkspaces = LayerWorkspaceMgr.noWorkspaces(this.helperWorkspaces);
            INDArray iNDArray2 = iNDArray;
            for (int i3 = i; i3 <= i2; i3++) {
                iNDArray2 = activationFromPrevLayer(i3, iNDArray2, false, noWorkspaces);
            }
            return iNDArray2;
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public List<INDArray> feedForward(INDArray iNDArray, boolean z) {
        setInput(iNDArray);
        return feedForward(z);
    }

    public List<INDArray> feedForward(boolean z) {
        try {
            return ffToLayerActivationsDetached(z, FwdPassType.STANDARD, false, this.layers.length - 1, this.input, this.mask, null, true);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public List<INDArray> feedForward(boolean z, boolean z2) {
        try {
            return ffToLayerActivationsDetached(z, FwdPassType.STANDARD, false, this.layers.length - 1, this.input, this.mask, null, z2);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public List<INDArray> feedForwardToLayer(int i, INDArray iNDArray) {
        try {
            return ffToLayerActivationsDetached(false, FwdPassType.STANDARD, false, i, iNDArray, this.mask, null, true);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public List<INDArray> feedForwardToLayer(int i, INDArray iNDArray, boolean z) {
        try {
            return ffToLayerActivationsDetached(z, FwdPassType.STANDARD, false, this.layers[i].getIndex(), iNDArray, this.mask, null, true);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public List<INDArray> feedForwardToLayer(int i, boolean z) {
        try {
            return ffToLayerActivationsDetached(z, FwdPassType.STANDARD, false, i, this.input, this.mask, null, true);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    protected void validateArrayWorkspaces(LayerWorkspaceMgr layerWorkspaceMgr, INDArray iNDArray, ArrayType arrayType, int i, boolean z, String str) {
        try {
            layerWorkspaceMgr.validateArrayLocation(arrayType, iNDArray, false, i > 0);
        } catch (ND4JWorkspaceException e) {
            String layerName = this.layers[i].conf().getLayer().getLayerName();
            throw new IllegalStateException(str + ": array (" + arrayType + ") workspace validation failed (" + (z ? "preprocessor" : "layer ") + i + (layerName != null ? " - layer name \"" + layerName + "\"" : "") + " - class: " + (z ? this.layerWiseConfigurations.getInputPreProcess(i).getClass().getName() : this.layers[i].getClass().getName()) + ") - array is defined in incorrect workspace", e);
        }
    }

    protected synchronized List<INDArray> ffToLayerActivationsDetached(boolean z, @NonNull FwdPassType fwdPassType, boolean z2, int i, @NonNull INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, boolean z3) {
        LayerWorkspaceMgr build;
        if (fwdPassType == null) {
            throw new NullPointerException("fwdPassType is marked @NonNull but is null");
        }
        if (iNDArray == null) {
            throw new NullPointerException("input is marked @NonNull but is null");
        }
        setInput(iNDArray);
        setLayerMaskArrays(iNDArray2, iNDArray3);
        WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in ffToLayerActivationsDetached");
        if ((z ? this.layerWiseConfigurations.getTrainingWorkspaceMode() : this.layerWiseConfigurations.getInferenceWorkspaceMode()) == WorkspaceMode.NONE) {
            build = LayerWorkspaceMgr.noWorkspaces();
        } else {
            build = LayerWorkspaceMgr.builder().noWorkspaceFor(ArrayType.ACTIVATIONS).with(ArrayType.INPUT, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            if (iNDArray.isAttached()) {
                build.setNoLeverageOverride(iNDArray.data().getParentWorkspace().getId());
            }
            if (!z3) {
                build.setScopedOutFor(ArrayType.INPUT);
            }
        }
        build.setHelperWorkspacePointers(this.helperWorkspaces);
        ArrayList arrayList = new ArrayList();
        arrayList.add(build.leverageTo(ArrayType.INPUT, iNDArray));
        for (int i2 = 0; i2 <= i; i2++) {
            MemoryWorkspace notifyScopeEntered = build.notifyScopeEntered(ArrayType.FF_WORKING_MEM);
            Throwable th = null;
            try {
                try {
                    if (getLayerWiseConfigurations().getInputPreProcess(i2) != null) {
                        iNDArray = getLayerWiseConfigurations().getInputPreProcess(i2).preProcess(iNDArray, getInputMiniBatchSize(), build);
                        validateArrayWorkspaces(build, iNDArray, ArrayType.ACTIVATIONS, i2, true, "Feed forward to layer (inference)");
                    }
                    if (fwdPassType == FwdPassType.STANDARD) {
                        iNDArray = this.layers[i2].activate(iNDArray, z, build);
                    } else {
                        if (fwdPassType != FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) {
                            throw new IllegalStateException("Forward pass type not supported for this method: " + fwdPassType);
                        }
                        if (this.layers[i2] instanceof RecurrentLayer) {
                            iNDArray = ((RecurrentLayer) this.layers[i2]).rnnActivateUsingStoredState(iNDArray, z, z2, build);
                        } else if ((this.layers[i2] instanceof BaseWrapperLayer) && (((BaseWrapperLayer) this.layers[i2]).getUnderlying() instanceof RecurrentLayer)) {
                            iNDArray = ((RecurrentLayer) ((BaseWrapperLayer) this.layers[i2]).getUnderlying()).rnnActivateUsingStoredState(iNDArray, z, z2, build);
                        } else if (this.layers[i2] instanceof MultiLayerNetwork) {
                            List<INDArray> rnnActivateUsingStoredState = ((MultiLayerNetwork) this.layers[i2]).rnnActivateUsingStoredState(iNDArray, z, z2);
                            iNDArray = rnnActivateUsingStoredState.get(rnnActivateUsingStoredState.size() - 1);
                        } else {
                            iNDArray = this.layers[i2].activate(iNDArray, z, build);
                        }
                    }
                    validateArrayWorkspaces(build, iNDArray, ArrayType.ACTIVATIONS, i2, false, "Feed forward to layer (inference)");
                    arrayList.add(iNDArray);
                    if (notifyScopeEntered != null) {
                        if (0 != 0) {
                            try {
                                notifyScopeEntered.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            notifyScopeEntered.close();
                        }
                    }
                    if (z3) {
                        this.layers[i2].clear();
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (notifyScopeEntered != null) {
                    if (th != null) {
                        try {
                            notifyScopeEntered.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        notifyScopeEntered.close();
                    }
                }
                throw th3;
            }
        }
        return arrayList;
    }

    protected synchronized List<INDArray> ffToLayerActivationsInWs(int i, @NonNull FwdPassType fwdPassType, boolean z, @NonNull INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        LayerWorkspaceMgr build;
        if (fwdPassType == null) {
            throw new NullPointerException("fwdPassType is marked @NonNull but is null");
        }
        if (iNDArray == null) {
            throw new NullPointerException("input is marked @NonNull but is null");
        }
        setInput(iNDArray);
        setLayerMaskArrays(iNDArray2, iNDArray3);
        if (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) {
            WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in ffToLayerActivationsInWs when training workspace is set to NONE");
            build = LayerWorkspaceMgr.noWorkspaces();
        } else {
            build = LayerWorkspaceMgr.builder().with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            if (iNDArray.isAttached()) {
                build.setNoLeverageOverride(iNDArray.data().getParentWorkspace().getId());
            }
            if (this.layerWiseConfigurations.getCacheMode() != CacheMode.NONE) {
                build.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG);
                build.setWorkspace(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG);
            }
            WorkspaceUtils.assertOpenAndActive(WS_ALL_LAYERS_ACT, "ffToLayerActivationsInWs method requires workspace WS_ALL_LAYERS_ACT to be open");
        }
        build.setHelperWorkspacePointers(this.helperWorkspaces);
        ArrayList arrayList = new ArrayList();
        arrayList.add(build.leverageTo(ArrayType.INPUT, iNDArray));
        for (int i2 = 0; i2 <= i; i2++) {
            MemoryWorkspace notifyScopeEntered = build.notifyScopeEntered(ArrayType.FF_WORKING_MEM);
            Throwable th = null;
            try {
                if (getLayerWiseConfigurations().getInputPreProcess(i2) != null) {
                    iNDArray = getLayerWiseConfigurations().getInputPreProcess(i2).preProcess(iNDArray, getInputMiniBatchSize(), build);
                    validateArrayWorkspaces(build, iNDArray, ArrayType.ACTIVATIONS, i2, true, "Feed forward to layer (training)");
                }
                if (fwdPassType == FwdPassType.STANDARD) {
                    iNDArray = this.layers[i2].activate(iNDArray, true, build);
                } else {
                    if (fwdPassType != FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) {
                        throw new IllegalStateException("FwdPassType not supported for this method: " + fwdPassType);
                    }
                    if (this.layers[i2] instanceof RecurrentLayer) {
                        iNDArray = ((RecurrentLayer) this.layers[i2]).rnnActivateUsingStoredState(iNDArray, true, z, build);
                    } else if ((this.layers[i2] instanceof BaseWrapperLayer) && (((BaseWrapperLayer) this.layers[i2]).getUnderlying() instanceof RecurrentLayer)) {
                        iNDArray = ((RecurrentLayer) ((BaseWrapperLayer) this.layers[i2]).getUnderlying()).rnnActivateUsingStoredState(iNDArray, true, z, build);
                    } else if (this.layers[i2] instanceof MultiLayerNetwork) {
                        List<INDArray> rnnActivateUsingStoredState = ((MultiLayerNetwork) this.layers[i2]).rnnActivateUsingStoredState(iNDArray, true, z);
                        iNDArray = rnnActivateUsingStoredState.get(rnnActivateUsingStoredState.size() - 1);
                    } else {
                        iNDArray = this.layers[i2].activate(iNDArray, true, build);
                    }
                }
                if (iNDArray == null) {
                    throw new IllegalStateException("Layer " + i2 + " returned null activations");
                }
                validateArrayWorkspaces(build, iNDArray, ArrayType.ACTIVATIONS, i2, false, "Feed forward to layer (training)");
                validateArrayWorkspaces(build, this.layers[i2].input(), ArrayType.INPUT, i2, false, "Feed forward to layer (training)");
                arrayList.add(iNDArray);
                if (notifyScopeEntered != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeEntered.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        notifyScopeEntered.close();
                    }
                }
            } catch (Throwable th3) {
                if (notifyScopeEntered != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeEntered.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        notifyScopeEntered.close();
                    }
                }
                throw th3;
            }
        }
        return arrayList;
    }

    protected INDArray outputOfLayerDetached(boolean z, @NonNull FwdPassType fwdPassType, int i, @NonNull INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, MemoryWorkspace memoryWorkspace) {
        LayerWorkspaceMgr build;
        LayerWorkspaceMgr build2;
        MemoryWorkspace memoryWorkspace2;
        boolean z2;
        boolean isScopeActive;
        if (fwdPassType == null) {
            throw new NullPointerException("fwdPassType is marked @NonNull but is null");
        }
        if (iNDArray == null) {
            throw new NullPointerException("input is marked @NonNull but is null");
        }
        setInput(iNDArray);
        setLayerMaskArrays(iNDArray2, iNDArray3);
        if (memoryWorkspace == null || (memoryWorkspace instanceof DummyWorkspace)) {
            WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in outputOfLayerDetached", true);
        } else {
            Preconditions.checkState(memoryWorkspace.isScopeActive(), "Workspace \"" + memoryWorkspace.getId() + "\" was provided for the network/layer outputs. When provided, this workspace must be opened before calling the output method; furthermore, closing the workspace is the responsibility of the user");
        }
        WorkspaceMode trainingWorkspaceMode = z ? this.layerWiseConfigurations.getTrainingWorkspaceMode() : this.layerWiseConfigurations.getInferenceWorkspaceMode();
        if (trainingWorkspaceMode == WorkspaceMode.NONE) {
            build = LayerWorkspaceMgr.noWorkspaces();
            build2 = build;
            if (memoryWorkspace != null && !(memoryWorkspace instanceof DummyWorkspace)) {
                throw new IllegalStateException("Workspace \"" + memoryWorkspace.getId() + "\" was provided for the network/layer outputs, however " + (z ? "training" : "inference") + " workspace mode is set to NONE. Cannot put output activations into the specified workspace ifworkspaces are disabled for the network. use getConfiguration().setTraining/InferenceWorkspaceMode(WorkspaceMode.ENABLED)");
            }
        } else {
            build = LayerWorkspaceMgr.builder().with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.ACTIVATIONS, WS_LAYER_ACT_1, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.INPUT, WS_LAYER_ACT_2, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            build2 = LayerWorkspaceMgr.builder().with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.ACTIVATIONS, WS_LAYER_ACT_2, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.INPUT, WS_LAYER_ACT_1, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
        }
        build.setHelperWorkspacePointers(this.helperWorkspaces);
        build2.setHelperWorkspacePointers(this.helperWorkspaces);
        MemoryWorkspace memoryWorkspace3 = null;
        MemoryWorkspace memoryWorkspace4 = null;
        MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        for (int i2 = 0; i2 <= i; i2++) {
            try {
                LayerWorkspaceMgr layerWorkspaceMgr = i2 % 2 == 0 ? build : build2;
                if (i2 == 0 && trainingWorkspaceMode != WorkspaceMode.NONE) {
                    layerWorkspaceMgr.setWorkspace(ArrayType.INPUT, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG);
                }
                MemoryWorkspace notifyScopeEntered = layerWorkspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM);
                Throwable th = null;
                try {
                    try {
                        memoryWorkspace4 = layerWorkspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS);
                        memoryWorkspace4.setPreviousWorkspace(currentWorkspace);
                        if (i2 == 0 && iNDArray.isAttached()) {
                            layerWorkspaceMgr.setNoLeverageOverride(iNDArray.data().getParentWorkspace().getId());
                        }
                        if (getLayerWiseConfigurations().getInputPreProcess(i2) != null) {
                            iNDArray = getLayerWiseConfigurations().getInputPreProcess(i2).preProcess(iNDArray, getInputMiniBatchSize(), layerWorkspaceMgr);
                            validateArrayWorkspaces(layerWorkspaceMgr, iNDArray, ArrayType.ACTIVATIONS, i2, true, "Output of layer (inference)");
                        }
                        if (i2 == i) {
                            if (memoryWorkspace == null || (memoryWorkspace instanceof DummyWorkspace)) {
                                layerWorkspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS);
                            } else {
                                layerWorkspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, memoryWorkspace.getId(), memoryWorkspace.getWorkspaceConfiguration());
                            }
                        }
                        if (fwdPassType == FwdPassType.STANDARD) {
                            iNDArray = this.layers[i2].activate(iNDArray, z, layerWorkspaceMgr);
                        } else {
                            if (fwdPassType != FwdPassType.RNN_TIMESTEP) {
                                throw new IllegalArgumentException("Unsupported forward pass type for this method: " + fwdPassType);
                            }
                            iNDArray = this.layers[i2] instanceof RecurrentLayer ? ((RecurrentLayer) this.layers[i2]).rnnTimeStep(reshapeTimeStepInput(iNDArray), layerWorkspaceMgr) : ((this.layers[i2] instanceof BaseWrapperLayer) && (((BaseWrapperLayer) this.layers[i2]).getUnderlying() instanceof RecurrentLayer)) ? ((RecurrentLayer) ((BaseWrapperLayer) this.layers[i2]).getUnderlying()).rnnTimeStep(reshapeTimeStepInput(iNDArray), layerWorkspaceMgr) : this.layers[i2] instanceof MultiLayerNetwork ? ((MultiLayerNetwork) this.layers[i2]).rnnTimeStep(reshapeTimeStepInput(iNDArray)) : this.layers[i2].activate(iNDArray, false, layerWorkspaceMgr);
                        }
                        this.layers[i2].clear();
                        validateArrayWorkspaces(layerWorkspaceMgr, iNDArray, ArrayType.ACTIVATIONS, i2, false, "Output of layer (inference)");
                        if (memoryWorkspace3 != null) {
                            memoryWorkspace3.close();
                        }
                        memoryWorkspace3 = memoryWorkspace4;
                        memoryWorkspace4 = null;
                        if (notifyScopeEntered != null) {
                            if (0 != 0) {
                                try {
                                    notifyScopeEntered.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                notifyScopeEntered.close();
                            }
                        }
                        if (i2 == 0 && trainingWorkspaceMode != WorkspaceMode.NONE) {
                            layerWorkspaceMgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, this.WS_LAYER_ACT_X_CONFIG);
                        }
                    } finally {
                    }
                } finally {
                }
            } finally {
                if (memoryWorkspace3 != null) {
                    memoryWorkspace3.close();
                }
                if (memoryWorkspace4 != null) {
                    while (memoryWorkspace4.isScopeActive()) {
                        memoryWorkspace4.close();
                    }
                }
                Nd4j.getMemoryManager().setCurrentWorkspace(currentWorkspace);
                if (memoryWorkspace == null || (memoryWorkspace instanceof DummyWorkspace)) {
                    WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active at the end of outputOfLayerDetached", true);
                } else {
                    Preconditions.checkState(memoryWorkspace.isScopeActive(), "Expected output workspace to still be openat end of outputOfLayerDetached, but it is closed. This suggests an implementation or layer workspace problem");
                }
            }
        }
        if (memoryWorkspace2 != null) {
            while (true) {
                if (!isScopeActive) {
                    break;
                }
            }
        }
        if (memoryWorkspace != null) {
            if (!z2) {
                return iNDArray;
            }
        }
        return iNDArray;
    }

    private INDArray reshapeTimeStepInput(INDArray iNDArray) {
        if (iNDArray.rank() == 2) {
            long[] shape = iNDArray.shape();
            iNDArray = iNDArray.reshape(new long[]{shape[0], shape[1], 1});
        }
        return iNDArray;
    }

    public List<INDArray> feedForward() {
        return feedForward(false);
    }

    public List<INDArray> feedForward(INDArray iNDArray) {
        if (iNDArray == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        }
        setInput(iNDArray);
        return feedForward();
    }

    public List<INDArray> feedForward(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        setLayerMaskArrays(iNDArray2, iNDArray3);
        List<INDArray> feedForward = feedForward(iNDArray);
        clearLayerMaskArrays();
        return feedForward;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        return this.gradient;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), Double.valueOf(score()));
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public MultiLayerNetwork m161clone() {
        INDArray stateViewArray;
        if (!this.initCalled) {
            init();
        }
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(this.layerWiseConfigurations.m30clone());
        multiLayerNetwork.init(params().dup(), false);
        if (this.solver != null && (stateViewArray = getUpdater().getStateViewArray()) != null) {
            multiLayerNetwork.getUpdater().setStateViewArray(multiLayerNetwork, stateViewArray.dup(), false);
        }
        if (hasAFrozenLayer()) {
            Layer[] layers = multiLayerNetwork.getLayers();
            for (int i = 0; i < this.layers.length; i++) {
                if (this.layers[i] instanceof FrozenLayer) {
                    layers[i] = new FrozenLayer(multiLayerNetwork.getLayer(i));
                }
            }
            multiLayerNetwork.setLayers(layers);
        }
        return multiLayerNetwork;
    }

    protected boolean hasAFrozenLayer() {
        for (int i = 0; i < this.layers.length - 1; i++) {
            if (this.layers[i] instanceof FrozenLayer) {
                return true;
            }
        }
        return false;
    }

    public INDArray params(boolean z) {
        if (z) {
            return params();
        }
        ArrayList arrayList = new ArrayList();
        for (Layer layer : getLayers()) {
            INDArray params = layer.params();
            if (params != null) {
                arrayList.add(params);
            }
        }
        return Nd4j.toFlattened('f', arrayList);
    }

    @Override // org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray params() {
        return this.flattenedParams;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        if (this.flattenedParams == iNDArray) {
            return;
        }
        if (this.flattenedParams != null && iNDArray.length() == this.flattenedParams.length()) {
            if (iNDArray != this.flattenedParams) {
                this.flattenedParams.assign(iNDArray);
                return;
            }
            return;
        }
        if (this.flattenedParams == null) {
            this.flattenedParams = iNDArray.dup();
        }
        int i = 0;
        for (int i2 = 0; i2 < getLayers().length; i2++) {
            Layer layer = getLayer(i2);
            long numParams = layer.numParams();
            if (numParams > 0) {
                layer.setParams(iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i, numParams + i)}));
                i = (int) (i + numParams);
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamsViewArray(INDArray iNDArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray getGradientsViewArray() {
        return this.flattenedGradients;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        int i = 0;
        for (Layer layer : this.layers) {
            if (layer.numParams() != 0) {
                layer.setBackpropGradientsViewArray(iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i, i + layer.numParams())}));
                i = (int) (i + layer.numParams());
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Trainable
    public TrainingConfig getConfig() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public long numParams() {
        if (!isInitCalled()) {
            init();
        }
        if (this.flattenedParams == null) {
            return 0L;
        }
        return this.flattenedParams.length();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public long numParams(boolean z) {
        int i = 0;
        for (int i2 = 0; i2 < this.layers.length; i2++) {
            i = (int) (i + this.layers[i2].numParams(z));
        }
        return i;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(org.nd4j.linalg.dataset.api.DataSet dataSet) {
        return f1Score(dataSet.getFeatures(), dataSet.getLabels());
    }

    public void fit(@NonNull DataSetIterator dataSetIterator, int i) {
        if (dataSetIterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        Preconditions.checkArgument(i > 0, "Number of epochs much be > 0. Got numEpochs = %s", i);
        Preconditions.checkArgument(i == 1 || dataSetIterator.resetSupported(), "Cannot perform multiple epochs training usingiterator thas does not support resetting (iterator.resetSupported() returned false)");
        for (int i2 = 0; i2 < i; i2++) {
            fit(dataSetIterator);
        }
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSetIterator dataSetIterator) {
        try {
            fitHelper(dataSetIterator);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private synchronized void fitHelper(DataSetIterator dataSetIterator) {
        DataSetIterator dataSetIterator2;
        boolean z = false;
        if (dataSetIterator.asyncSupported()) {
            dataSetIterator2 = new AsyncDataSetIterator(dataSetIterator, Math.min(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true);
            z = true;
        } else {
            dataSetIterator2 = dataSetIterator;
        }
        Iterator<TrainingListener> it = this.trainingListeners.iterator();
        while (it.hasNext()) {
            it.next().onEpochStart(this);
        }
        LayerWorkspaceMgr noWorkspaces = getLayerWiseConfigurations().getTrainingWorkspaceMode() == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).build();
        noWorkspaces.setHelperWorkspacePointers(this.helperWorkspaces);
        update(TaskUtils.buildTask(dataSetIterator2));
        if (!dataSetIterator2.hasNext() && dataSetIterator2.resetSupported()) {
            dataSetIterator2.reset();
        }
        long currentTimeMillis = System.currentTimeMillis();
        while (dataSetIterator2.hasNext()) {
            DataSet dataSet = (DataSet) dataSetIterator2.next();
            this.lastEtlTime.set(Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
            if (dataSet.getFeatures() == null || dataSet.getLabels() == null) {
                break;
            }
            boolean hasMaskArrays = dataSet.hasMaskArrays();
            if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(dataSet.getFeatures(), dataSet.getLabels(), dataSet.getFeaturesMaskArray(), dataSet.getLabelsMaskArray(), noWorkspaces);
            } else {
                if (hasMaskArrays) {
                    setLayerMaskArrays(dataSet.getFeaturesMaskArray(), dataSet.getLabelsMaskArray());
                }
                setInput(dataSet.getFeatures());
                setLabels(dataSet.getLabels());
                if (this.solver == null) {
                    MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                    Throwable th = null;
                    try {
                        try {
                            this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                            if (scopeOutOfWorkspaces != null) {
                                if (0 != 0) {
                                    try {
                                        scopeOutOfWorkspaces.close();
                                    } catch (Throwable th2) {
                                        th.addSuppressed(th2);
                                    }
                                } else {
                                    scopeOutOfWorkspaces.close();
                                }
                            }
                        } finally {
                        }
                    } catch (Throwable th3) {
                        if (scopeOutOfWorkspaces != null) {
                            if (th != null) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                        throw th3;
                    }
                }
                this.solver.optimize(noWorkspaces);
            }
            if (hasMaskArrays) {
                clearLayerMaskArrays();
            }
            currentTimeMillis = System.currentTimeMillis();
            synchronizeIterEpochCounts();
        }
        if (!this.trainingListeners.isEmpty()) {
            Iterator<TrainingListener> it2 = this.trainingListeners.iterator();
            while (it2.hasNext()) {
                it2.next().onEpochEnd(this);
            }
        }
        clearLayersStates();
        if (z) {
            ((AsyncDataSetIterator) dataSetIterator2).shutdown();
        }
        incrementEpochCount();
    }

    public Pair<Gradient, INDArray> calculateGradients(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        if (iNDArray == null) {
            throw new NullPointerException("features is marked @NonNull but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        try {
            return calculateGradientsHelper(iNDArray, iNDArray2, iNDArray3, iNDArray4);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private Pair<Gradient, INDArray> calculateGradientsHelper(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        LayerWorkspaceMgr build;
        setInput(iNDArray);
        setLabels(iNDArray2);
        setLayerMaskArrays(iNDArray3, iNDArray4);
        if (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) {
            build = LayerWorkspaceMgr.noWorkspaces();
        } else {
            build = LayerWorkspaceMgr.builder().with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            if (this.layerWiseConfigurations.getCacheMode() != null) {
                build.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG);
            }
        }
        build.setHelperWorkspacePointers(this.helperWorkspaces);
        MemoryWorkspace notifyScopeEntered = build.notifyScopeEntered(ArrayType.ACTIVATIONS);
        Throwable th = null;
        try {
            List<INDArray> ffToLayerActivationsInWs = ffToLayerActivationsInWs(this.layers.length - 2, FwdPassType.STANDARD, false, this.input, this.mask, iNDArray3);
            if (!this.trainingListeners.isEmpty()) {
                Iterator<TrainingListener> it = this.trainingListeners.iterator();
                while (it.hasNext()) {
                    it.next().onForwardPass(this, ffToLayerActivationsInWs);
                }
            }
            INDArray iNDArray5 = ffToLayerActivationsInWs.get(ffToLayerActivationsInWs.size() - 1);
            if (this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1) != null) {
                iNDArray5 = this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1).preProcess(iNDArray5, getInputMiniBatchSize(), build);
            }
            getOutputLayer().setInput(iNDArray5, build);
            Pair<Gradient, INDArray> calcBackpropGradients = calcBackpropGradients(null, true, false, true);
            if (calcBackpropGradients.getSecond() != null) {
                calcBackpropGradients.setSecond(((INDArray) calcBackpropGradients.getSecond()).detach());
            }
            return calcBackpropGradients;
        } finally {
            if (notifyScopeEntered != null) {
                if (0 != 0) {
                    try {
                        notifyScopeEntered.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    notifyScopeEntered.close();
                }
            }
        }
    }

    protected Pair<Gradient, INDArray> calcBackpropGradients(INDArray iNDArray, boolean z, boolean z2, boolean z3) {
        LayerWorkspaceMgr build;
        LayerWorkspaceMgr build2;
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        DefaultGradient defaultGradient = new DefaultGradient(this.flattenedGradients);
        if (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) {
            build = LayerWorkspaceMgr.noWorkspaces();
            build2 = build;
            WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in calcBackpropGradients when training workspace is set to none");
        } else {
            build = LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.ACTIVATION_GRAD, WS_LAYER_ACT_1, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            build2 = LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.ACTIVATION_GRAD, WS_LAYER_ACT_2, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            if (iNDArray == null) {
                WorkspaceUtils.assertOpenActiveAndCurrent(WS_ALL_LAYERS_ACT, "calcBackpropGradients method requires workspace WS_ALL_LAYERS_ACT to be open when workspaces are used");
            }
        }
        build.setHelperWorkspacePointers(this.helperWorkspaces);
        build2.setHelperWorkspacePointers(this.helperWorkspaces);
        getnLayers();
        LinkedList linkedList = new LinkedList();
        Pair<Gradient, INDArray> pair = null;
        MemoryWorkspace memoryWorkspace = null;
        MemoryWorkspace memoryWorkspace2 = null;
        MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        try {
            int length = this.layers.length - 1;
            while (length >= 0 && !(this.layers[length] instanceof FrozenLayer)) {
                LayerWorkspaceMgr layerWorkspaceMgr = length % 2 == 0 ? build : build2;
                if (z && length == this.layers.length - 1) {
                    if (!(getOutputLayer() instanceof IOutputLayer)) {
                        log.warn("Warning: final layer isn't output layer. You cannot use backprop without an output layer.");
                        if (memoryWorkspace != null) {
                            memoryWorkspace.close();
                        }
                        if (memoryWorkspace2 != null) {
                            memoryWorkspace2.close();
                        }
                        Nd4j.getMemoryManager().setCurrentWorkspace(currentWorkspace);
                        return null;
                    }
                    IOutputLayer iOutputLayer = (IOutputLayer) getOutputLayer();
                    if (this.labels == null && iOutputLayer.needsLabels()) {
                        throw new IllegalStateException("No labels found");
                    }
                    iOutputLayer.setLabels(this.labels);
                }
                memoryWorkspace2 = layerWorkspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD);
                MemoryWorkspace notifyScopeEntered = layerWorkspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM);
                Throwable th = null;
                try {
                    try {
                        memoryWorkspace2.setPreviousWorkspace(currentWorkspace);
                        notifyScopeEntered.setPreviousWorkspace(currentWorkspace);
                        pair = !z2 ? this.layers[length].backpropGradient(length == this.layers.length - 1 ? iNDArray : (INDArray) pair.getRight(), layerWorkspaceMgr) : this.layers[length] instanceof RecurrentLayer ? ((RecurrentLayer) this.layers[length]).tbpttBackpropGradient((INDArray) pair.getSecond(), this.layerWiseConfigurations.getTbpttBackLength(), layerWorkspaceMgr) : this.layers[length].backpropGradient((INDArray) pair.getSecond(), layerWorkspaceMgr);
                        if (pair.getSecond() != null) {
                            validateArrayWorkspaces(layerWorkspaceMgr, (INDArray) pair.getSecond(), ArrayType.ACTIVATION_GRAD, length, false, "Backprop");
                        }
                        for (Map.Entry<String, INDArray> entry : ((Gradient) pair.getFirst()).gradientForVariable().entrySet()) {
                            String key = entry.getKey();
                            linkedList.addLast(new Triple(String.valueOf(length) + "_" + key, entry.getValue(), ((Gradient) pair.getFirst()).flatteningOrderForVariable(key)));
                        }
                        if (getLayerWiseConfigurations().getInputPreProcess(length) != null) {
                            pair = new Pair<>(pair.getFirst(), this.layerWiseConfigurations.getInputPreProcess(length).backprop((INDArray) pair.getSecond(), getInputMiniBatchSize(), layerWorkspaceMgr));
                            if (length > 0 && pair.getSecond() != null) {
                                validateArrayWorkspaces(layerWorkspaceMgr, (INDArray) pair.getSecond(), ArrayType.ACTIVATION_GRAD, length, true, "Backprop");
                            }
                        }
                        if (length == 0) {
                            if (!z3 || pair.getSecond() == null) {
                                pair.setSecond((Object) null);
                            } else {
                                pair.setSecond(((INDArray) pair.getSecond()).detach());
                            }
                        }
                        if (memoryWorkspace != null) {
                            memoryWorkspace.close();
                        }
                        memoryWorkspace = memoryWorkspace2;
                        memoryWorkspace2 = null;
                        if (notifyScopeEntered != null) {
                            if (0 != 0) {
                                try {
                                    notifyScopeEntered.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                notifyScopeEntered.close();
                            }
                        }
                        length--;
                    } finally {
                    }
                } finally {
                }
            }
            if (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) {
                WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in calcBackpropGradients when training workspace is set to none");
            } else if (iNDArray == null) {
                WorkspaceUtils.assertOpenActiveAndCurrent(WS_ALL_LAYERS_ACT, "calcBackpropGradients: WS_ALL_LAYERS_ACT is no longer the currently open/active workspace");
            }
            Iterator it = linkedList.iterator();
            while (it.hasNext()) {
                Triple triple = (Triple) it.next();
                defaultGradient.setGradientFor((String) triple.getFirst(), (INDArray) triple.getSecond(), (Character) triple.getThird());
            }
            return new Pair<>(defaultGradient, pair.getSecond());
        } finally {
            if (memoryWorkspace != null) {
                memoryWorkspace.close();
            }
            if (memoryWorkspace2 != null) {
                memoryWorkspace2.close();
            }
            Nd4j.getMemoryManager().setCurrentWorkspace(currentWorkspace);
        }
    }

    protected void doTruncatedBPTT(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (iNDArray.rank() != 3 || iNDArray2.rank() != 3) {
            log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got " + Arrays.toString(iNDArray.shape()) + "\tand labels with shape " + Arrays.toString(iNDArray2.shape()));
            return;
        }
        if (iNDArray.size(2) != iNDArray2.size(2)) {
            log.warn("Input and label time series have different lengths: {} input length, {} label length", Long.valueOf(iNDArray.size(2)), Long.valueOf(iNDArray2.size(2)));
            return;
        }
        int tbpttFwdLength = this.layerWiseConfigurations.getTbpttFwdLength();
        update(TaskUtils.buildTask(iNDArray, iNDArray2));
        long size = iNDArray.size(2);
        long j = size / tbpttFwdLength;
        if (size % tbpttFwdLength != 0) {
            j++;
        }
        rnnClearPreviousState();
        for (int i = 0; i < j; i++) {
            long j2 = i * tbpttFwdLength;
            long j3 = j2 + tbpttFwdLength;
            if (j3 > size) {
                j3 = size;
            }
            INDArray[] subsetsForTbptt = getSubsetsForTbptt((int) j2, (int) j3, iNDArray, iNDArray2, iNDArray3, iNDArray4);
            setInput(subsetsForTbptt[0]);
            setLabels(subsetsForTbptt[1]);
            setLayerMaskArrays(subsetsForTbptt[2], subsetsForTbptt[3]);
            if (this.solver == null) {
                MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                Throwable th = null;
                try {
                    try {
                        this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                        if (scopeOutOfWorkspaces != null) {
                            if (0 != 0) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                    } catch (Throwable th3) {
                        if (scopeOutOfWorkspaces != null) {
                            if (th != null) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                        throw th3;
                    }
                } finally {
                }
            }
            this.solver.optimize(layerWorkspaceMgr);
            updateRnnStateWithTBPTTState();
        }
        rnnClearPreviousState();
        clearLayerMaskArrays();
    }

    private INDArray[] getSubsetsForTbptt(int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        INDArray[] iNDArrayArr = new INDArray[4];
        iNDArrayArr[0] = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, i2)});
        iNDArrayArr[1] = iNDArray2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, i2)});
        if (iNDArray3 != null) {
            iNDArrayArr[2] = iNDArray3.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i, i2)});
        }
        if (iNDArray4 != null) {
            iNDArrayArr[3] = iNDArray4.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i, i2)});
        }
        return iNDArrayArr;
    }

    public void updateRnnStateWithTBPTTState() {
        for (int i = 0; i < this.layers.length; i++) {
            if (this.layers[i] instanceof RecurrentLayer) {
                RecurrentLayer recurrentLayer = (RecurrentLayer) this.layers[i];
                recurrentLayer.rnnSetPreviousState(recurrentLayer.rnnGetTBPTTState());
            } else if (this.layers[i] instanceof MultiLayerNetwork) {
                ((MultiLayerNetwork) this.layers[i]).updateRnnStateWithTBPTTState();
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Collection<TrainingListener> getListeners() {
        return this.trainingListeners;
    }

    @Deprecated
    public Collection<TrainingListener> getTrainingListeners() {
        return this.trainingListeners;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setListeners(Collection<TrainingListener> collection) {
        if (this.layers == null) {
            init();
        }
        for (Layer layer : this.layers) {
            layer.setListeners(collection);
        }
        if (this.solver != null) {
            this.solver.setListeners(collection);
        }
        this.trainingListeners.clear();
        if (collection != null) {
            this.trainingListeners.addAll(collection);
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void addListeners(TrainingListener... trainingListenerArr) {
        Collections.addAll(this.trainingListeners, trainingListenerArr);
        if (this.solver != null) {
            this.solver.setListeners(this.trainingListeners);
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setListeners(TrainingListener... trainingListenerArr) {
        ArrayList arrayList = new ArrayList();
        if (trainingListenerArr != null && trainingListenerArr.length > 0) {
            for (TrainingListener trainingListener : trainingListenerArr) {
                if (trainingListener != null) {
                    arrayList.add(trainingListener);
                }
            }
        }
        setListeners(arrayList);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int[] predict(INDArray iNDArray) {
        INDArray output = output(iNDArray, Layer.TrainingMode.TEST);
        int[] iArr = new int[(int) iNDArray.size(0)];
        if (iNDArray.isRowVectorOrScalar()) {
            iArr[0] = Nd4j.getBlasWrapper().iamax(output);
        } else {
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
            }
        }
        return iArr;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public List<String> predict(org.nd4j.linalg.dataset.api.DataSet dataSet) {
        Preconditions.checkState(dataSet.getLabelNamesList() != null, "This method can only be used when the DataSet contains a label name list");
        int[] predict = predict(dataSet.getFeatures());
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < predict.length; i++) {
            arrayList.add(i, dataSet.getLabelName(predict[i]));
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, INDArray iNDArray2) {
        fit(iNDArray, iNDArray2, null, null);
    }

    public synchronized void fit(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        try {
            fitHelper(iNDArray, iNDArray2, iNDArray3, iNDArray4);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private void fitHelper(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        if (numParams() == 0) {
            return;
        }
        setInput(iNDArray);
        setLabels(iNDArray2);
        setLayerMaskArrays(iNDArray3, iNDArray4);
        update(TaskUtils.buildTask(iNDArray, iNDArray2));
        LayerWorkspaceMgr noWorkspaces = this.layerWiseConfigurations.getTrainingWorkspaceMode() == null ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).build();
        noWorkspaces.setHelperWorkspacePointers(this.helperWorkspaces);
        if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
            doTruncatedBPTT(iNDArray, iNDArray2, iNDArray3, iNDArray4, noWorkspaces);
        } else {
            if (this.solver == null) {
                MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                Throwable th = null;
                try {
                    try {
                        this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                        if (scopeOutOfWorkspaces != null) {
                            if (0 != 0) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                    } finally {
                    }
                } catch (Throwable th3) {
                    if (scopeOutOfWorkspaces != null) {
                        if (th != null) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                    throw th3;
                }
            }
            this.solver.optimize(noWorkspaces);
        }
        clearLayerMaskArrays();
        clearLayersStates();
        synchronizeIterEpochCounts();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        throw new UnsupportedOperationException("Not supported: use pretrainLayer");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(org.nd4j.linalg.dataset.api.DataSet dataSet) {
        fit(dataSet.getFeatures(), dataSet.getLabels(), dataSet.getFeaturesMaskArray(), dataSet.getLabelsMaskArray());
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, int[] iArr) {
        fit(iNDArray, FeatureUtil.toOutcomeMatrix(iArr, (int) ((OutputLayer) getOutputLayer().conf().getLayer()).getNOut()));
    }

    public INDArray output(INDArray iNDArray, Layer.TrainingMode trainingMode) {
        return output(iNDArray, trainingMode == Layer.TrainingMode.TRAIN);
    }

    public INDArray output(INDArray iNDArray, boolean z) {
        return output(iNDArray, z, (INDArray) null, (INDArray) null);
    }

    public INDArray output(INDArray iNDArray, boolean z, INDArray iNDArray2, INDArray iNDArray3) {
        return output(iNDArray, z, iNDArray2, iNDArray3, null);
    }

    public INDArray output(INDArray iNDArray, boolean z, MemoryWorkspace memoryWorkspace) {
        return output(iNDArray, z, null, null, memoryWorkspace);
    }

    public synchronized INDArray output(INDArray iNDArray, boolean z, INDArray iNDArray2, INDArray iNDArray3, MemoryWorkspace memoryWorkspace) {
        try {
            return outputOfLayerDetached(z, FwdPassType.STANDARD, this.layers.length - 1, iNDArray, iNDArray2, iNDArray3, memoryWorkspace);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public synchronized <T> T output(@NonNull INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, @NonNull OutputAdapter<T> outputAdapter) {
        if (iNDArray == null) {
            throw new NullPointerException("inputs is marked @NonNull but is null");
        }
        if (outputAdapter == null) {
            throw new NullPointerException("outputAdapter is marked @NonNull but is null");
        }
        MemoryWorkspace andActivateWorkspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM);
        Throwable th = null;
        try {
            if (outputAdapter instanceof ModelAdapter) {
                T t = (T) ((ModelAdapter) outputAdapter).apply(this, new INDArray[]{iNDArray}, new INDArray[]{iNDArray2}, new INDArray[]{iNDArray3});
                if (andActivateWorkspace != null) {
                    if (0 != 0) {
                        try {
                            andActivateWorkspace.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        andActivateWorkspace.close();
                    }
                }
                return t;
            }
            T apply = outputAdapter.apply(output(iNDArray, false, iNDArray2, iNDArray3, andActivateWorkspace));
            if (andActivateWorkspace != null) {
                if (0 != 0) {
                    try {
                        andActivateWorkspace.close();
                    } catch (Throwable th3) {
                        th.addSuppressed(th3);
                    }
                } else {
                    andActivateWorkspace.close();
                }
            }
            return apply;
        } catch (Throwable th4) {
            if (andActivateWorkspace != null) {
                if (0 != 0) {
                    try {
                        andActivateWorkspace.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    andActivateWorkspace.close();
                }
            }
            throw th4;
        }
    }

    public INDArray output(INDArray iNDArray) {
        return output(iNDArray, Layer.TrainingMode.TEST);
    }

    public INDArray output(DataSetIterator dataSetIterator, boolean z) {
        ArrayList arrayList = new ArrayList();
        long[] jArr = null;
        while (dataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet) dataSetIterator.next();
            INDArray features = dataSet.getFeatures();
            if (features != null) {
                INDArray output = output(features, z, dataSet.getFeaturesMaskArray(), dataSet.getLabelsMaskArray());
                arrayList.add(output);
                if (jArr == null) {
                    jArr = output.shape();
                } else {
                    long[] shape = output.shape();
                    Preconditions.checkState(jArr.length == shape.length, "Error during forward pass:different minibatches have different output array ranks - first minibatch shape %s, last minibatch shape %s", jArr, shape);
                    for (int i = 1; i < shape.length; i++) {
                        Preconditions.checkState(jArr[i] == shape[i], "Current output shape does not match first output array shape at position %s: all dimensions must match other than the first dimension.\n For variable length output size/length use cases such as for RNNs with multiple sequence lengths, use one of the other (non iterator) output methods. First batch output shape: %s, current batch output shape: %s", Integer.valueOf(i), jArr, shape);
                    }
                }
            }
        }
        return Nd4j.concat(0, (INDArray[]) arrayList.toArray(new INDArray[arrayList.size()]));
    }

    public INDArray output(DataSetIterator dataSetIterator) {
        return output(dataSetIterator, false);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(INDArray iNDArray, INDArray iNDArray2) {
        feedForward(iNDArray);
        setLabels(iNDArray2);
        Evaluation evaluation = new Evaluation();
        evaluation.eval(iNDArray2, output(iNDArray));
        return evaluation.f1();
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    @Deprecated
    public int numLabels() {
        return (int) this.labels.size(1);
    }

    public double score(DataSet dataSet) {
        return score(dataSet, false);
    }

    public double score(DataSet dataSet, boolean z) {
        try {
            return scoreHelper(dataSet, z);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private double scoreHelper(DataSet dataSet, boolean z) {
        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            setLayerMaskArrays(dataSet.getFeaturesMaskArray(), dataSet.getLabelsMaskArray());
        }
        if (!(getOutputLayer() instanceof IOutputLayer)) {
            throw new IllegalStateException("Cannot calculate score if final layer is not an instance of IOutputLayer. Final layer is of type: " + getOutputLayer().getClass());
        }
        LayerWorkspaceMgr noWorkspaces = (z ? this.layerWiseConfigurations.getTrainingWorkspaceMode() : this.layerWiseConfigurations.getInferenceWorkspaceMode()) == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).noWorkspaceFor(ArrayType.ACTIVATIONS).noWorkspaceFor(ArrayType.INPUT).build();
        noWorkspaces.setHelperWorkspacePointers(this.helperWorkspaces);
        INDArray outputOfLayerDetached = outputOfLayerDetached(z, FwdPassType.STANDARD, this.layers.length - 2, dataSet.getFeatures(), dataSet.getFeaturesMaskArray(), dataSet.getLabelsMaskArray(), null);
        IOutputLayer iOutputLayer = (IOutputLayer) getOutputLayer();
        if (getLayerWiseConfigurations().getInputPreProcess(this.layers.length - 1) != null) {
            outputOfLayerDetached = getLayerWiseConfigurations().getInputPreProcess(this.layers.length - 1).preProcess(outputOfLayerDetached, (int) dataSet.getFeatures().size(0), noWorkspaces);
        }
        iOutputLayer.setInput(outputOfLayerDetached, noWorkspaces);
        iOutputLayer.setLabels(dataSet.getLabels());
        MemoryWorkspace notifyScopeEntered = noWorkspaces.notifyScopeEntered(ArrayType.FF_WORKING_MEM);
        Throwable th = null;
        try {
            try {
                double computeScore = iOutputLayer.computeScore(calcRegularizationScore(true), z, noWorkspaces);
                if (notifyScopeEntered != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeEntered.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        notifyScopeEntered.close();
                    }
                }
                if (hasMaskArrays) {
                    clearLayerMaskArrays();
                }
                clearLayersStates();
                return computeScore;
            } finally {
            }
        } catch (Throwable th3) {
            if (notifyScopeEntered != null) {
                if (th != null) {
                    try {
                        notifyScopeEntered.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    notifyScopeEntered.close();
                }
            }
            throw th3;
        }
    }

    public INDArray scoreExamples(DataSetIterator dataSetIterator, boolean z) {
        ArrayList arrayList = new ArrayList();
        while (dataSetIterator.hasNext()) {
            arrayList.add(scoreExamples((DataSet) dataSetIterator.next(), z));
        }
        return Nd4j.toFlattened('f', arrayList);
    }

    public INDArray scoreExamples(DataSet dataSet, boolean z) {
        try {
            return scoreExamplesHelper(dataSet, z);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private INDArray scoreExamplesHelper(DataSet dataSet, boolean z) {
        INDArray outputOfLayerDetached = outputOfLayerDetached(false, FwdPassType.STANDARD, this.layers.length - 2, dataSet.getFeatures(), dataSet.getFeaturesMaskArray(), dataSet.getLabelsMaskArray(), null);
        setLabels(dataSet.getLabels());
        setLayerMaskArrays(dataSet.getFeaturesMaskArray(), dataSet.getLabelsMaskArray());
        LayerWorkspaceMgr noWorkspaces = LayerWorkspaceMgr.noWorkspaces();
        if (!(getOutputLayer() instanceof IOutputLayer)) {
            throw new UnsupportedOperationException("Cannot calculate score with respect to labels without an OutputLayer");
        }
        IOutputLayer iOutputLayer = (IOutputLayer) getOutputLayer();
        if (this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1) != null) {
            outputOfLayerDetached = this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1).preProcess(outputOfLayerDetached, (int) dataSet.getFeatures().size(0), noWorkspaces);
        }
        iOutputLayer.setLabels(dataSet.getLabels());
        iOutputLayer.setInput(outputOfLayerDetached, noWorkspaces);
        INDArray computeScoreForExamples = iOutputLayer.computeScoreForExamples(z ? calcRegularizationScore(true) : EvaluationBinary.DEFAULT_EDGE_VALUE, noWorkspaces);
        clearLayersStates();
        clearLayerMaskArrays();
        return computeScoreForExamples;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit() {
        fit(this.input, this.labels);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(INDArray iNDArray, String str) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public double score() {
        return this.score;
    }

    public void setScore(double d) {
        this.score = d;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr) {
        computeGradientAndScore();
    }

    public void computeGradientAndScore() {
        LayerWorkspaceMgr build;
        if (!(getOutputLayer() instanceof IOutputLayer)) {
            throw new DL4JException("Cannot calculate gradient and score with respect to labels: final layer is not an IOutputLayer. Final layer class: " + getOutputLayer().getClass() + ". To calculate gradients and fit a network using backpropagation, the final layer must be an output layer");
        }
        if (this.layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) {
            build = LayerWorkspaceMgr.noWorkspaces();
        } else {
            build = LayerWorkspaceMgr.builder().with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            if (this.layerWiseConfigurations.getCacheMode() != null) {
                build.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG);
            }
        }
        boolean z = this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT;
        FwdPassType fwdPassType = z ? FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE : FwdPassType.STANDARD;
        synchronizeIterEpochCounts();
        MemoryWorkspace notifyScopeEntered = build.notifyScopeEntered(ArrayType.ACTIVATIONS);
        Throwable th = null;
        try {
            List<INDArray> ffToLayerActivationsInWs = ffToLayerActivationsInWs(this.layers.length - 2, fwdPassType, z, this.input, this.mask, null);
            if (!this.trainingListeners.isEmpty()) {
                Iterator<TrainingListener> it = this.trainingListeners.iterator();
                while (it.hasNext()) {
                    it.next().onForwardPass(this, ffToLayerActivationsInWs);
                }
            }
            INDArray iNDArray = ffToLayerActivationsInWs.get(ffToLayerActivationsInWs.size() - 1);
            if (this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1) != null) {
                iNDArray = this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1).preProcess(iNDArray, getInputMiniBatchSize(), build);
            }
            getOutputLayer().setInput(iNDArray, build);
            Pair<Gradient, INDArray> calcBackpropGradients = calcBackpropGradients(null, true, false, false);
            this.gradient = calcBackpropGradients == null ? null : (Gradient) calcBackpropGradients.getFirst();
            MemoryWorkspace notifyScopeEntered2 = build.notifyScopeEntered(ArrayType.FF_WORKING_MEM);
            Throwable th2 = null;
            try {
                try {
                    this.score = ((IOutputLayer) getOutputLayer()).computeScore(calcRegularizationScore(true), true, build);
                    if (notifyScopeEntered2 != null) {
                        if (0 != 0) {
                            try {
                                notifyScopeEntered2.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            notifyScopeEntered2.close();
                        }
                    }
                    if (!this.trainingListeners.isEmpty()) {
                        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                        Throwable th4 = null;
                        try {
                            try {
                                Iterator<TrainingListener> it2 = this.trainingListeners.iterator();
                                while (it2.hasNext()) {
                                    it2.next().onBackwardPass(this);
                                }
                                if (scopeOutOfWorkspaces != null) {
                                    if (0 != 0) {
                                        try {
                                            scopeOutOfWorkspaces.close();
                                        } catch (Throwable th5) {
                                            th4.addSuppressed(th5);
                                        }
                                    } else {
                                        scopeOutOfWorkspaces.close();
                                    }
                                }
                            } finally {
                            }
                        } catch (Throwable th6) {
                            if (scopeOutOfWorkspaces != null) {
                                if (th4 != null) {
                                    try {
                                        scopeOutOfWorkspaces.close();
                                    } catch (Throwable th7) {
                                        th4.addSuppressed(th7);
                                    }
                                } else {
                                    scopeOutOfWorkspaces.close();
                                }
                            }
                            throw th6;
                        }
                    }
                    getOutputLayer().clearNoiseWeightParams();
                } finally {
                }
            } catch (Throwable th8) {
                if (notifyScopeEntered2 != null) {
                    if (th2 != null) {
                        try {
                            notifyScopeEntered2.close();
                        } catch (Throwable th9) {
                            th2.addSuppressed(th9);
                        }
                    } else {
                        notifyScopeEntered2.close();
                    }
                }
                throw th8;
            }
        } finally {
            if (notifyScopeEntered != null) {
                if (0 != 0) {
                    try {
                        notifyScopeEntered.close();
                    } catch (Throwable th10) {
                        th.addSuppressed(th10);
                    }
                } else {
                    notifyScopeEntered.close();
                }
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void clear() {
        for (Layer layer : this.layers) {
            layer.clear();
        }
        this.input = null;
        this.labels = null;
        this.solver = null;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void applyConstraints(int i, int i2) {
        for (Layer layer : this.layers) {
            layer.applyConstraints(i, i2);
        }
    }

    public void setInput(INDArray iNDArray) {
        this.input = iNDArray;
        if (this.layers == null) {
            init();
        }
        if (iNDArray != null) {
            if (iNDArray.length() == 0) {
                throw new IllegalArgumentException("Invalid input: length 0 (shape: " + Arrays.toString(iNDArray.shape()) + ")");
            }
            setInputMiniBatchSize((int) iNDArray.size(0));
        }
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setInput(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        throw new UnsupportedOperationException("Not supported");
    }

    public Layer getOutputLayer() {
        Layer layer = getLayers()[getLayers().length - 1];
        if (layer instanceof FrozenLayerWithBackprop) {
            layer = ((FrozenLayerWithBackprop) layer).getInsideLayer();
        }
        return layer;
    }

    public void setParameters(INDArray iNDArray) {
        setParams(iNDArray);
    }

    public NeuralNetConfiguration getDefaultConfiguration() {
        return this.defaultConfiguration;
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public INDArray getInput() {
        return this.input;
    }

    public void setLabels(INDArray iNDArray) {
        this.labels = iNDArray;
    }

    public int getnLayers() {
        return this.layerWiseConfigurations.getConfs().size();
    }

    public synchronized Layer[] getLayers() {
        return this.layers;
    }

    public Layer getLayer(int i) {
        Preconditions.checkArgument(i >= 0 && i < this.layers.length, "Invalid layer index: layer index must be 0 to %s (inclusive), got index %s", this.layers.length - 1, i);
        return this.layers[i];
    }

    public Layer getLayer(String str) {
        return this.layerMap.get(str);
    }

    public List<String> getLayerNames() {
        return new ArrayList(this.layerMap.keySet());
    }

    public void setLayers(Layer[] layerArr) {
        this.layers = layerArr;
    }

    public INDArray getMask() {
        return this.mask;
    }

    public void setMask(INDArray iNDArray) {
        this.mask = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray getMaskArray() {
        return this.mask;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void clearNoiseWeightParams() {
        for (Layer layer : this.layers) {
            layer.clearNoiseWeightParams();
        }
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void allowInputModification(boolean z) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        MaskState maskState2;
        if (iNDArray == null) {
            for (int i2 = 0; i2 < this.layers.length; i2++) {
                this.layers[i2].feedForwardMaskArray(null, null, i);
            }
        } else {
            for (int i3 = 0; i3 < this.layers.length; i3++) {
                InputPreProcessor inputPreProcess = getLayerWiseConfigurations().getInputPreProcess(i3);
                if (inputPreProcess != null) {
                    Pair<INDArray, MaskState> feedForwardMaskArray = inputPreProcess.feedForwardMaskArray(iNDArray, maskState, i);
                    if (feedForwardMaskArray != null) {
                        iNDArray = (INDArray) feedForwardMaskArray.getFirst();
                        maskState = (MaskState) feedForwardMaskArray.getSecond();
                    } else {
                        iNDArray = null;
                        maskState = null;
                    }
                }
                Pair<INDArray, MaskState> feedForwardMaskArray2 = this.layers[i3].feedForwardMaskArray(iNDArray, maskState, i);
                if (feedForwardMaskArray2 != null) {
                    iNDArray = (INDArray) feedForwardMaskArray2.getFirst();
                    maskState2 = (MaskState) feedForwardMaskArray2.getSecond();
                } else {
                    iNDArray = null;
                    maskState2 = null;
                }
                maskState = maskState2;
            }
        }
        return new Pair<>(iNDArray, maskState);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public LayerHelper getHelper() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.MULTILAYER;
    }

    public INDArray activate(Layer.TrainingMode trainingMode) {
        return output(this.input, trainingMode == Layer.TrainingMode.TRAIN);
    }

    public INDArray activate(INDArray iNDArray, Layer.TrainingMode trainingMode) {
        return output(iNDArray, trainingMode == Layer.TrainingMode.TRAIN);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (getOutputLayer() instanceof IOutputLayer) {
            throw new UnsupportedOperationException("Cannot calculate gradients based on epsilon with OutputLayer");
        }
        return calcBackpropGradients(iNDArray, false, false, true);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setIndex(int i) {
        this.layerIndex = i;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getIndex() {
        return this.layerIndex;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getIterationCount() {
        return getLayerWiseConfigurations().getIterationCount();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getEpochCount() {
        return getLayerWiseConfigurations().getEpochCount();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setIterationCount(int i) {
        getLayerWiseConfigurations().setIterationCount(i);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setEpochCount(int i) {
        getLayerWiseConfigurations().setEpochCount(i);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public double calcRegularizationScore(boolean z) {
        double d = 0.0d;
        for (int i = 0; i < this.layers.length; i++) {
            d += this.layers[i].calcRegularizationScore(z);
        }
        return d;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(Gradient gradient) {
        if (gradient.gradient().length() != numParams(true)) {
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + numParams(true));
        }
        for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
            String key = entry.getKey();
            INDArray value = entry.getValue();
            int indexOf = key.indexOf(95);
            if (indexOf == -1) {
                throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
            }
            Integer valueOf = Integer.valueOf(Integer.parseInt(key.substring(0, indexOf)));
            String substring = key.substring(indexOf + 1);
            this.gradient.gradientForVariable().put(key, value);
            this.layers[valueOf.intValue()].update(value, substring);
        }
        setBackpropGradientsViewArray(gradient.gradient());
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setInputMiniBatchSize(int i) {
        if (this.layers != null) {
            for (Layer layer : this.layers) {
                layer.setInputMiniBatchSize(i);
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getInputMiniBatchSize() {
        if (conf().isMiniBatch()) {
            return (int) this.input.size(0);
        }
        return 1;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setMaskArray(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    public INDArray rnnTimeStep(INDArray iNDArray) {
        return rnnTimeStep(iNDArray, null);
    }

    public INDArray rnnTimeStep(INDArray iNDArray, MemoryWorkspace memoryWorkspace) {
        try {
            boolean z = iNDArray.rank() == 2;
            INDArray outputOfLayerDetached = outputOfLayerDetached(false, FwdPassType.RNN_TIMESTEP, this.layers.length - 1, iNDArray, null, null, memoryWorkspace);
            return (z && outputOfLayerDetached.rank() == 3 && this.layers[this.layers.length - 1].type() == Layer.Type.RECURRENT) ? outputOfLayerDetached.tensorAlongDimension(0L, new int[]{1, 0}) : outputOfLayerDetached;
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public Map<String, INDArray> rnnGetPreviousState(int i) {
        if (i < 0 || i >= this.layers.length) {
            throw new IllegalArgumentException("Invalid layer number");
        }
        Layer layer = this.layers[i];
        if (layer instanceof BaseWrapperLayer) {
            layer = ((BaseWrapperLayer) layer).getUnderlying();
        }
        if (layer instanceof RecurrentLayer) {
            return ((RecurrentLayer) layer).rnnGetPreviousState();
        }
        throw new IllegalArgumentException("Layer is not an RNN layer");
    }

    public void rnnSetPreviousState(int i, Map<String, INDArray> map) {
        if (i < 0 || i >= this.layers.length) {
            throw new IllegalArgumentException("Invalid layer number");
        }
        Layer layer = this.layers[i];
        if (layer instanceof BaseWrapperLayer) {
            layer = ((BaseWrapperLayer) layer).getUnderlying();
        }
        if (!(layer instanceof RecurrentLayer)) {
            throw new IllegalArgumentException("Layer is not an RNN layer");
        }
        ((RecurrentLayer) layer).rnnSetPreviousState(map);
    }

    public void rnnClearPreviousState() {
        if (this.layers == null) {
            return;
        }
        for (int i = 0; i < this.layers.length; i++) {
            if (this.layers[i] instanceof RecurrentLayer) {
                ((RecurrentLayer) this.layers[i]).rnnClearPreviousState();
            } else if (this.layers[i] instanceof MultiLayerNetwork) {
                ((MultiLayerNetwork) this.layers[i]).rnnClearPreviousState();
            } else if ((this.layers[i] instanceof BaseWrapperLayer) && (((BaseWrapperLayer) this.layers[i]).getUnderlying() instanceof RecurrentLayer)) {
                ((RecurrentLayer) ((BaseWrapperLayer) this.layers[i]).getUnderlying()).rnnClearPreviousState();
            }
        }
    }

    public List<INDArray> rnnActivateUsingStoredState(INDArray iNDArray, boolean z, boolean z2) {
        return ffToLayerActivationsDetached(z, FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE, z2, this.layers.length - 1, iNDArray, this.mask, null, false);
    }

    public Updater getUpdater() {
        return getUpdater(true);
    }

    public Updater getUpdater(boolean z) {
        if (this.solver == null && z) {
            synchronized (this) {
                if (this.solver == null) {
                    this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                    this.solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this));
                }
            }
        }
        if (this.solver != null) {
            return this.solver.getOptimizer().getUpdater();
        }
        return null;
    }

    public void setUpdater(Updater updater) {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        }
        this.solver.getOptimizer().setUpdater(updater);
    }

    public void setLayerMaskArrays(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray != null) {
            feedForwardMaskArray(iNDArray, MaskState.Active, (int) iNDArray.size(0));
        }
        if (iNDArray2 == null || !(getOutputLayer() instanceof IOutputLayer)) {
            return;
        }
        this.layers[this.layers.length - 1].setMaskArray(iNDArray2);
    }

    public void clearLayerMaskArrays() {
        for (Layer layer : this.layers) {
            layer.setMaskArray(null);
        }
    }

    public <T extends Evaluation> T evaluate(DataSetIterator dataSetIterator) {
        return (T) evaluate(dataSetIterator, null);
    }

    public <T extends RegressionEvaluation> T evaluateRegression(DataSetIterator dataSetIterator) {
        return (T) ((RegressionEvaluation[]) doEvaluation(dataSetIterator, new RegressionEvaluation(dataSetIterator.totalOutcomes())))[0];
    }

    @Deprecated
    public <T extends ROC> T evaluateROC(DataSetIterator dataSetIterator) {
        return (T) evaluateROC(dataSetIterator, 0);
    }

    public <T extends ROC> T evaluateROC(DataSetIterator dataSetIterator, int i) {
        Layer outputLayer = getOutputLayer();
        if (getLayerWiseConfigurations().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROC.class);
        }
        return ((org.deeplearning4j.eval.ROC[]) doEvaluation(dataSetIterator, new org.deeplearning4j.eval.ROC(i)))[0];
    }

    @Deprecated
    public <T extends ROCMultiClass> T evaluateROCMultiClass(DataSetIterator dataSetIterator) {
        return (T) evaluateROCMultiClass(dataSetIterator, 0);
    }

    public <T extends ROCMultiClass> T evaluateROCMultiClass(DataSetIterator dataSetIterator, int i) {
        Layer outputLayer = getOutputLayer();
        if (getLayerWiseConfigurations().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROCMultiClass.class);
        }
        return ((org.deeplearning4j.eval.ROCMultiClass[]) doEvaluation(dataSetIterator, new org.deeplearning4j.eval.ROCMultiClass(i)))[0];
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public <T extends IEvaluation> T[] doEvaluation(DataSetIterator dataSetIterator, T... tArr) {
        try {
            return (T[]) doEvaluationHelper(dataSetIterator, tArr);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public <T extends IEvaluation> T[] doEvaluationHelper(DataSetIterator dataSetIterator, T... tArr) {
        MemoryWorkspace scopeOutOfWorkspaces;
        if (!dataSetIterator.hasNext() && dataSetIterator.resetSupported()) {
            dataSetIterator.reset();
        }
        DataSetIterator asyncDataSetIterator = dataSetIterator.asyncSupported() ? new AsyncDataSetIterator(dataSetIterator, 2, true) : dataSetIterator;
        WorkspaceMode trainingWorkspaceMode = this.layerWiseConfigurations.getTrainingWorkspaceMode();
        this.layerWiseConfigurations.setTrainingWorkspaceMode(this.layerWiseConfigurations.getInferenceWorkspaceMode());
        boolean z = this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT;
        MemoryWorkspace workspaceForCurrentThread = getLayerWiseConfigurations().getInferenceWorkspaceMode() == WorkspaceMode.ENABLED ? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM) : new DummyWorkspace();
        while (asyncDataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet) asyncDataSetIterator.next();
            if (dataSet.getFeatures() != null && dataSet.getLabels() != null) {
                INDArray features = dataSet.getFeatures();
                INDArray labels = dataSet.getLabels();
                INDArray featuresMaskArray = dataSet.getFeaturesMaskArray();
                INDArray labelsMaskArray = dataSet.getLabelsMaskArray();
                if (z) {
                    rnnClearPreviousState();
                    int tbpttFwdLength = this.layerWiseConfigurations.getTbpttFwdLength();
                    long size = features.size(2);
                    long j = size / tbpttFwdLength;
                    if (size % tbpttFwdLength != 0) {
                        j++;
                    }
                    for (int i = 0; i < j; i++) {
                        INDArray[] subsetsForTbptt = getSubsetsForTbptt(i * tbpttFwdLength, (int) Math.min(r0 + tbpttFwdLength, size), features, labels, featuresMaskArray, labelsMaskArray);
                        setLayerMaskArrays(subsetsForTbptt[2], subsetsForTbptt[3]);
                        MemoryWorkspace notifyScopeEntered = workspaceForCurrentThread.notifyScopeEntered();
                        Throwable th = null;
                        try {
                            INDArray rnnTimeStep = rnnTimeStep(subsetsForTbptt[0], notifyScopeEntered);
                            scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
                            Throwable th2 = null;
                            try {
                                try {
                                    for (T t : tArr) {
                                        t.eval(subsetsForTbptt[1], rnnTimeStep, subsetsForTbptt[3]);
                                    }
                                    if (scopeOutOfWorkspaces != null) {
                                        if (0 != 0) {
                                            try {
                                                scopeOutOfWorkspaces.close();
                                            } catch (Throwable th3) {
                                                th2.addSuppressed(th3);
                                            }
                                        } else {
                                            scopeOutOfWorkspaces.close();
                                        }
                                    }
                                    if (notifyScopeEntered != null) {
                                        if (0 != 0) {
                                            try {
                                                notifyScopeEntered.close();
                                            } catch (Throwable th4) {
                                                th.addSuppressed(th4);
                                            }
                                        } else {
                                            notifyScopeEntered.close();
                                        }
                                    }
                                } finally {
                                }
                            } finally {
                            }
                        } catch (Throwable th5) {
                            if (notifyScopeEntered != null) {
                                if (0 != 0) {
                                    try {
                                        notifyScopeEntered.close();
                                    } catch (Throwable th6) {
                                        th.addSuppressed(th6);
                                    }
                                } else {
                                    notifyScopeEntered.close();
                                }
                            }
                            throw th5;
                        }
                    }
                } else {
                    MemoryWorkspace notifyScopeEntered2 = workspaceForCurrentThread.notifyScopeEntered();
                    Throwable th7 = null;
                    try {
                        INDArray outputOfLayerDetached = outputOfLayerDetached(false, FwdPassType.STANDARD, this.layers.length - 1, features, featuresMaskArray, labelsMaskArray, notifyScopeEntered2);
                        scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
                        Throwable th8 = null;
                        try {
                            try {
                                for (T t2 : tArr) {
                                    t2.eval(labels, outputOfLayerDetached, labelsMaskArray);
                                }
                                if (scopeOutOfWorkspaces != null) {
                                    if (0 != 0) {
                                        try {
                                            scopeOutOfWorkspaces.close();
                                        } catch (Throwable th9) {
                                            th8.addSuppressed(th9);
                                        }
                                    } else {
                                        scopeOutOfWorkspaces.close();
                                    }
                                }
                            } finally {
                            }
                        } finally {
                        }
                    } finally {
                        if (notifyScopeEntered2 != null) {
                            if (0 != 0) {
                                try {
                                    notifyScopeEntered2.close();
                                } catch (Throwable th10) {
                                    th7.addSuppressed(th10);
                                }
                            } else {
                                notifyScopeEntered2.close();
                            }
                        }
                    }
                }
                clearLayersStates();
            }
        }
        if (dataSetIterator.asyncSupported()) {
            ((AsyncDataSetIterator) asyncDataSetIterator).shutdown();
        }
        this.layerWiseConfigurations.setTrainingWorkspaceMode(trainingWorkspaceMode);
        return tArr;
    }

    public Evaluation evaluate(DataSetIterator dataSetIterator, List<String> list) {
        return evaluate(dataSetIterator, list, 1);
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray updaterState() {
        if (getUpdater() != null) {
            return getUpdater().getStateViewArray();
        }
        return null;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void fit(MultiDataSet multiDataSet) {
        if (multiDataSet.getFeatures().length != 1 || multiDataSet.getLabels().length != 1) {
            throw new DL4JInvalidInputException("MultiLayerNetwork can't handle MultiDataSet with more than 1 features or labels array.Please consider use of ComputationGraph");
        }
        INDArray features = multiDataSet.getFeatures(0);
        INDArray labels = multiDataSet.getLabels(0);
        INDArray iNDArray = null;
        INDArray iNDArray2 = null;
        if (multiDataSet.getFeaturesMaskArrays() != null) {
            iNDArray = multiDataSet.getFeaturesMaskArrays()[0];
        }
        if (multiDataSet.getFeaturesMaskArrays() != null) {
            iNDArray2 = multiDataSet.getLabelsMaskArrays()[0];
        }
        fit((org.nd4j.linalg.dataset.api.DataSet) new DataSet(features, labels, iNDArray, iNDArray2));
    }

    public void fit(@NonNull MultiDataSetIterator multiDataSetIterator, int i) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        Preconditions.checkArgument(i > 0, "Number of epochs much be > 0. Got numEpochs = %s", i);
        Preconditions.checkArgument(i == 1 || multiDataSetIterator.resetSupported(), "Cannot perform multiple epochs training usingiterator has does not support resetting (iterator.resetSupported() returned false)");
        for (int i2 = 0; i2 < i; i2++) {
            fit(multiDataSetIterator);
        }
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void fit(MultiDataSetIterator multiDataSetIterator) {
        fit((DataSetIterator) new MultiDataSetWrapperIterator(multiDataSetIterator));
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public <T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator multiDataSetIterator, T[] tArr) {
        return (T[]) doEvaluation((DataSetIterator) new MultiDataSetWrapperIterator(multiDataSetIterator), (IEvaluation[]) tArr);
    }

    public Evaluation evaluate(DataSetIterator dataSetIterator, List<String> list, int i) {
        if (this.layers == null || !(getOutputLayer() instanceof IOutputLayer)) {
            throw new IllegalStateException("Cannot evaluate network with no output layer");
        }
        if (list == null) {
            try {
                list = dataSetIterator.getLabels();
            } catch (Throwable th) {
            }
        }
        Layer outputLayer = getOutputLayer();
        if (getLayerWiseConfigurations().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), Evaluation.class);
        }
        org.deeplearning4j.eval.Evaluation evaluation = new org.deeplearning4j.eval.Evaluation(list, i);
        doEvaluation(dataSetIterator, (IEvaluation[]) new Evaluation[]{evaluation});
        return evaluation;
    }

    protected void update(Task task) {
        if (this.initDone) {
            return;
        }
        this.initDone = true;
        Heartbeat heartbeat = Heartbeat.getInstance();
        Task taskByModel = ModelSerializer.taskByModel(this);
        heartbeat.reportEvent(Event.STANDALONE, EnvironmentUtils.buildEnvironment(), taskByModel);
    }

    public String summary() {
        return summary(null);
    }

    public String summary(InputType inputType) {
        StringBuilder sb = new StringBuilder();
        sb.append("\n");
        ArrayList arrayList = new ArrayList();
        if (inputType == null) {
            arrayList.add(new String[]{"LayerName (LayerType)", "nIn,nOut", "TotalParams", "ParamsShape"});
        } else {
            arrayList.add(new String[]{"LayerName (LayerType)", "nIn,nOut", "TotalParams", "ParamsShape", "InputShape", "OutputShape"});
        }
        int[] iArr = new int[inputType == null ? 4 : 6];
        String[] strArr = (String[]) arrayList.get(0);
        for (int i = 0; i < strArr.length; i++) {
            iArr[i] = strArr[i].length();
        }
        int i2 = 0;
        for (Layer layer : getLayers()) {
            String layerName = layer.conf().getLayer().getLayerName();
            if (layerName == null) {
                layerName = String.valueOf(layer.getIndex());
            }
            String str = "-";
            String str2 = "-";
            String str3 = "-";
            String[] split = layer.getClass().getName().split("\\.");
            String str4 = split[split.length - 1];
            String valueOf = String.valueOf(layer.numParams());
            String str5 = "";
            String str6 = "";
            if (inputType != null) {
                InputPreProcessor inputPreProcess = getLayerWiseConfigurations().getInputPreProcess(layer.getIndex());
                str5 = inputType.toString();
                if (inputPreProcess != null) {
                    inputType = inputPreProcess.getOutputType(inputType);
                    str5 = str5 + "--> " + inputType.toString();
                }
                InputType outputType = layer.conf().getLayer().getOutputType(layer.getIndex(), inputType);
                str6 = outputType.toString();
                inputType = outputType;
            }
            if (layer.numParams() > 0) {
                String str7 = "";
                if (layer instanceof BidirectionalLayer) {
                    BidirectionalLayer bidirectionalLayer = (BidirectionalLayer) layer;
                    str2 = String.valueOf(((Bidirectional) bidirectionalLayer.conf().getLayer()).getNIn());
                    str3 = String.valueOf(((Bidirectional) bidirectionalLayer.conf().getLayer()).getNOut());
                } else {
                    try {
                        str2 = String.valueOf(((FeedForwardLayer) layer.conf().getLayer()).getNIn());
                        str3 = String.valueOf(((FeedForwardLayer) layer.conf().getLayer()).getNOut());
                    } catch (Exception e) {
                    }
                }
                for (String str8 : layer.paramTable().keySet()) {
                    str7 = str7 + str8 + ":" + ArrayUtils.toString(layer.paramTable().get(str8).shape()) + ", ";
                }
                str = str7.subSequence(0, str7.lastIndexOf(",")).toString();
            }
            if (layer instanceof FrozenLayer) {
                i2 = (int) (i2 + layer.numParams());
                String[] split2 = ((FrozenLayer) layer).getInsideLayer().getClass().getName().split("\\.");
                str4 = "Frozen " + split2[split2.length - 1];
            }
            String[] strArr2 = inputType == null ? new String[]{layerName + " (" + str4 + ")", str2 + "," + str3, valueOf, str} : new String[]{layerName + " (" + str4 + ")", str2 + "," + str3, valueOf, str, str5, str6};
            for (int i3 = 0; i3 < strArr2.length; i3++) {
                iArr[i3] = Math.max(iArr[i3], strArr2[i3] == null ? 0 : strArr2[i3].length());
            }
            arrayList.add(strArr2);
        }
        StringBuilder sb2 = new StringBuilder();
        int i4 = 0;
        int i5 = 0;
        for (int i6 : iArr) {
            int i7 = i5;
            i5++;
            int i8 = i7 == iArr.length - 1 ? i6 : i6 + 3;
            sb2.append("%-").append(i8).append("s");
            i4 += i8;
        }
        sb2.append("\n");
        String sb3 = sb2.toString();
        sb.append(StringUtils.repeat("=", i4)).append("\n");
        boolean z = true;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            sb.append(String.format(sb3, (String[]) it.next()));
            if (z) {
                sb.append(StringUtils.repeat("=", i4)).append("\n");
                z = false;
            }
        }
        sb.append(StringUtils.repeat("-", i4));
        sb.append(String.format("\n%30s %d", "Total Parameters: ", Long.valueOf(params().length())));
        sb.append(String.format("\n%30s %d", "Trainable Parameters: ", Long.valueOf(params().length() - i2)));
        sb.append(String.format("\n%30s %d", "Frozen Parameters: ", Integer.valueOf(i2)));
        sb.append("\n");
        sb.append(StringUtils.repeat("=", i4));
        sb.append("\n");
        return sb.toString();
    }

    public String memoryInfo(int i, InputType inputType) {
        return CrashReportingUtil.generateMemoryStatus(this, i, inputType);
    }

    public void clearLayersStates() {
        for (Layer layer : this.layers) {
            layer.clear();
            layer.clearNoiseWeightParams();
        }
    }

    public void incrementEpochCount() {
        this.layerWiseConfigurations.setEpochCount(this.layerWiseConfigurations.getEpochCount() + 1);
        synchronizeIterEpochCounts();
    }

    protected void synchronizeIterEpochCounts() {
        int iterationCount = getIterationCount();
        int epochCount = getEpochCount();
        for (Layer layer : this.layers) {
            layer.setIterationCount(iterationCount);
            layer.setEpochCount(epochCount);
        }
    }

    public void save(File file) throws IOException {
        save(file, true);
    }

    public void save(File file, boolean z) throws IOException {
        ModelSerializer.writeModel(this, file, z);
    }

    public static MultiLayerNetwork load(File file, boolean z) throws IOException {
        return ModelSerializer.restoreMultiLayerNetwork(file, z);
    }

    public ComputationGraph toComputationGraph() {
        return NetworkUtils.toComputationGraph(this);
    }

    public MultiLayerNetwork convertDataType(@NonNull DataType dataType) {
        if (dataType == null) {
            throw new NullPointerException("dataType is marked @NonNull but is null");
        }
        Preconditions.checkState(dataType.isFPType(), "Invalid DataType: %s. Can only convert network to a floating point type", dataType);
        if (dataType == params().dataType()) {
            return this;
        }
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                INDArray castTo = params().castTo(dataType);
                MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(getLayerWiseConfigurations().toJson());
                fromJson.setDataType(dataType);
                MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(fromJson);
                multiLayerNetwork.init(castTo, false);
                Updater updater = getUpdater(false);
                if (updater != null && updater.getStateViewArray() != null) {
                    multiLayerNetwork.getUpdater(true).getStateViewArray().assign(updater.getStateViewArray());
                }
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return multiLayerNetwork;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public void setLearningRate(double d) {
        NetworkUtils.setLearningRate(this, d);
    }

    public void setLearningRate(ISchedule iSchedule) {
        NetworkUtils.setLearningRate(this, iSchedule);
    }

    public void setLearningRate(int i, double d) {
        NetworkUtils.setLearningRate(this, i, d);
    }

    public void setLearningRate(int i, ISchedule iSchedule) {
        NetworkUtils.setLearningRate(this, i, iSchedule);
    }

    public Double getLearningRate(int i) {
        return NetworkUtils.getLearningRate(this, i);
    }

    public int layerSize(int i) {
        if (i < 0 || i > this.layers.length) {
            throw new IllegalArgumentException("Invalid layer index: " + i + ". Layer index must be between 0 and " + (this.layers.length - 1) + " inclusive");
        }
        org.deeplearning4j.nn.conf.layers.Layer layer = this.layers[i].conf().getLayer();
        if (layer == null || !(layer instanceof FeedForwardLayer)) {
            return 0;
        }
        return (int) ((FeedForwardLayer) layer).getNOut();
    }

    public int layerInputSize(int i) {
        if (i < 0 || i > this.layers.length) {
            throw new IllegalArgumentException("Invalid layer index: " + i + ". Layer index must be between 0 and " + (this.layers.length - 1) + " inclusive");
        }
        org.deeplearning4j.nn.conf.layers.Layer layer = this.layers[i].conf().getLayer();
        if (layer == null || !(layer instanceof FeedForwardLayer)) {
            return 0;
        }
        return (int) ((FeedForwardLayer) layer).getNIn();
    }

    public boolean equals(Object obj) {
        if (obj == null || !(obj instanceof MultiLayerNetwork)) {
            return false;
        }
        MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) obj;
        return multiLayerNetwork.params().equals(params()) && getLayerWiseConfigurations().equals(multiLayerNetwork.getLayerWiseConfigurations()) && getUpdater().equals(multiLayerNetwork.getUpdater());
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        ModelSerializer.writeModel((Model) this, (OutputStream) objectOutputStream, true);
    }

    private void readObject(ObjectInputStream objectInputStream) throws ClassNotFoundException, IOException {
        MultiLayerNetwork restoreMultiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork((InputStream) objectInputStream, true);
        this.defaultConfiguration = restoreMultiLayerNetwork.defaultConfiguration.m33clone();
        this.layerWiseConfigurations = restoreMultiLayerNetwork.layerWiseConfigurations.m30clone();
        init();
        this.flattenedParams.assign(restoreMultiLayerNetwork.flattenedParams);
        this.WS_LAYER_WORKING_MEM_CONFIG = getLayerWorkingMemWSConfig(2 * (this.layerWiseConfigurations.getConfs().size() + this.layerWiseConfigurations.getInputPreProcessors().size()));
        this.WS_LAYER_ACT_X_CONFIG = getLayerActivationWSConfig(this.layerWiseConfigurations.getConfs().size());
        if (restoreMultiLayerNetwork.getUpdater() == null || restoreMultiLayerNetwork.getUpdater(false).getStateViewArray() == null) {
            return;
        }
        getUpdater(true).getStateViewArray().assign(restoreMultiLayerNetwork.getUpdater(false).getStateViewArray());
    }

    public void setInitDone(boolean z) {
        this.initDone = z;
    }

    public INDArray getFlattenedGradients() {
        return this.flattenedGradients;
    }

    public Map<String, Pointer> getHelperWorkspaces() {
        return this.helperWorkspaces;
    }
}
