package org.deeplearning4j.nn.graph;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.graph.vertex.impl.InputVertex;
import org.deeplearning4j.nn.graph.vertex.impl.LayerVertex;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/graph/ComputationGraph.class */
public class ComputationGraph implements Serializable, Model {
    private static final Logger log = LoggerFactory.getLogger(ComputationGraph.class);
    protected ComputationGraphConfiguration configuration;
    protected transient Solver solver;
    protected INDArray flattenedParams;
    protected transient INDArray flattenedGradients;
    protected Gradient gradient;
    protected double score;
    protected GraphVertex[] vertices;
    protected Map<String, GraphVertex> verticesMap;
    protected int[] topologicalOrder;
    protected Layer[] layers;
    private int numInputArrays;
    private int numOutputArrays;
    private transient INDArray[] inputs;
    private transient INDArray[] labels;
    private transient INDArray[] inputMaskArrays;
    private transient INDArray[] labelMaskArrays;
    private NeuralNetConfiguration defaultConfiguration;
    protected boolean initCalled = false;
    private boolean initDone = false;
    private Collection<IterationListener> listeners = new ArrayList();
    private Collection<TrainingListener> trainingListeners = new ArrayList();

    public ComputationGraph(ComputationGraphConfiguration computationGraphConfiguration) {
        this.configuration = computationGraphConfiguration;
        this.numInputArrays = computationGraphConfiguration.getNetworkInputs().size();
        this.numOutputArrays = computationGraphConfiguration.getNetworkOutputs().size();
        this.inputs = new INDArray[this.numInputArrays];
        this.labels = new INDArray[this.numOutputArrays];
        this.defaultConfiguration = computationGraphConfiguration.getDefaultConfiguration();
    }

    public ComputationGraphConfiguration getConfiguration() {
        return this.configuration;
    }

    public int getNumLayers() {
        if (this.layers != null) {
            return this.layers.length;
        }
        return 0;
    }

    public Layer getLayer(int i) {
        return this.layers[i];
    }

    public Layer[] getLayers() {
        return this.layers;
    }

    public Layer getLayer(String str) {
        return this.verticesMap.get(str).getLayer();
    }

    public GraphVertex[] getVertices() {
        return this.vertices;
    }

    public GraphVertex getVertex(String str) {
        return this.verticesMap.get(str);
    }

    public int getNumInputArrays() {
        return this.numInputArrays;
    }

    public int getNumOutputArrays() {
        return this.numOutputArrays;
    }

    public void setInput(int i, INDArray iNDArray) {
        this.inputs[i] = iNDArray;
    }

    public void setInputs(INDArray... iNDArrayArr) {
        if (iNDArrayArr != null && iNDArrayArr.length != this.numInputArrays) {
            throw new IllegalArgumentException("Invalid input array: network has " + this.numInputArrays + " inputs, but array is of length " + iNDArrayArr.length);
        }
        this.inputs = iNDArrayArr;
    }

    public INDArray getInput(int i) {
        if (this.inputs == null) {
            return null;
        }
        return this.inputs[i];
    }

    public INDArray[] getInputs() {
        return this.inputs;
    }

    public INDArray[] getInputMaskArrays() {
        return this.inputMaskArrays;
    }

    public INDArray[] getLabelMaskArrays() {
        return this.labelMaskArrays;
    }

    public void setLabel(int i, INDArray iNDArray) {
        this.labels[i] = iNDArray;
    }

    public void setLabels(INDArray... iNDArrayArr) {
        if (iNDArrayArr != null && iNDArrayArr.length != this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid output array: network has " + this.numOutputArrays + " outputs, but array is of length " + iNDArrayArr.length);
        }
        this.labels = iNDArrayArr;
    }

    public void init() {
        init(null, false);
    }

    public void init(INDArray iNDArray, boolean z) {
        boolean z2;
        if (this.initCalled) {
            return;
        }
        this.topologicalOrder = topologicalSortOrder();
        Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> vertices = this.configuration.getVertices();
        List<String> networkInputs = this.configuration.getNetworkInputs();
        Map<String, List<String>> vertexInputs = this.configuration.getVertexInputs();
        this.vertices = new GraphVertex[networkInputs.size() + this.configuration.getVertices().size()];
        HashMap hashMap = new HashMap();
        int i = 0;
        for (String str : networkInputs) {
            InputVertex inputVertex = new InputVertex(this, str, i, null);
            hashMap.put(str, Integer.valueOf(i));
            int i2 = i;
            i++;
            this.vertices[i2] = inputVertex;
        }
        int i3 = 0;
        int[] iArr = new int[this.topologicalOrder.length];
        int i4 = 0;
        while (i4 < this.configuration.getNetworkInputs().size()) {
            iArr[i4] = 0;
            i4++;
        }
        Iterator<Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex>> it = vertices.entrySet().iterator();
        while (it.hasNext()) {
            iArr[i4] = it.next().getValue().numParams(true);
            i3 += iArr[i4];
            i4++;
        }
        if (iNDArray == null) {
            this.flattenedParams = Nd4j.create(1, i3);
            z2 = true;
        } else {
            if (!iNDArray.isRowVector()) {
                throw new IllegalArgumentException("Invalid parameters: should be a row vector");
            }
            if (iNDArray.length() != i3) {
                throw new IllegalArgumentException("Invalid parameters: expected length " + i3 + ", got length " + iNDArray.length());
            }
            if (z) {
                this.flattenedParams = iNDArray.dup();
            } else {
                this.flattenedParams = iNDArray;
            }
            z2 = false;
        }
        INDArray[] iNDArrayArr = new INDArray[this.topologicalOrder.length];
        int i5 = 0;
        int i6 = 0;
        for (int i7 : this.topologicalOrder) {
            int i8 = iArr[i7];
            if (i8 != 0) {
                iNDArrayArr[i7] = this.flattenedParams.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i5, i5 + i8)});
            }
            i6++;
            i5 += i8;
        }
        int i9 = 0;
        ArrayList arrayList = new ArrayList();
        this.defaultConfiguration.clearVariables();
        List<String> variables = this.defaultConfiguration.variables(false);
        for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> entry : vertices.entrySet()) {
            org.deeplearning4j.nn.conf.graph.GraphVertex value = entry.getValue();
            String key = entry.getKey();
            GraphVertex instantiate = value.instantiate(this, key, i, iNDArrayArr[i], z2);
            if (instantiate.hasLayer()) {
                i9++;
                Layer layer = instantiate.getLayer();
                arrayList.add(layer);
                List<String> variables2 = layer.conf().variables();
                if (variables2 != null) {
                    Iterator<String> it2 = variables2.iterator();
                    while (it2.hasNext()) {
                        variables.add(instantiate.getVertexName() + "_" + it2.next());
                    }
                }
            }
            hashMap.put(key, Integer.valueOf(i));
            int i10 = i;
            i++;
            this.vertices[i10] = instantiate;
        }
        this.layers = (Layer[]) arrayList.toArray(new Layer[i9]);
        this.verticesMap = new HashMap();
        for (GraphVertex graphVertex : this.vertices) {
            this.verticesMap.put(graphVertex.getVertexName(), graphVertex);
        }
        HashMap hashMap2 = new HashMap();
        for (GraphVertex graphVertex2 : this.vertices) {
            String vertexName = graphVertex2.getVertexName();
            List<String> list = vertexInputs.get(vertexName);
            if (list != null) {
                for (String str2 : list) {
                    List list2 = (List) hashMap2.get(str2);
                    if (list2 == null) {
                        list2 = new ArrayList();
                        hashMap2.put(str2, list2);
                    }
                    list2.add(vertexName);
                }
            }
        }
        for (GraphVertex graphVertex3 : this.vertices) {
            String vertexName2 = graphVertex3.getVertexName();
            int vertexIndex = graphVertex3.getVertexIndex();
            List<String> list3 = vertexInputs.get(vertexName2);
            if (list3 != null) {
                VertexIndices[] vertexIndicesArr = new VertexIndices[list3.size()];
                for (int i11 = 0; i11 < list3.size(); i11++) {
                    String str3 = list3.get(i11);
                    int intValue = ((Integer) hashMap.get(str3)).intValue();
                    GraphVertex graphVertex4 = this.vertices[intValue];
                    int indexOf = ((List) hashMap2.get(str3)).indexOf(vertexName2);
                    if (indexOf == -1) {
                        throw new IllegalStateException("Could not find vertex " + vertexIndex + " in the list of outputs for vertex " + graphVertex4 + "; error in graph structure?");
                    }
                    vertexIndicesArr[i11] = new VertexIndices(intValue, indexOf);
                }
                graphVertex3.setInputVertices(vertexIndicesArr);
            }
        }
        for (GraphVertex graphVertex5 : this.vertices) {
            String vertexName3 = graphVertex5.getVertexName();
            List<String> list4 = (List) hashMap2.get(vertexName3);
            if (list4 != null && !list4.isEmpty()) {
                VertexIndices[] vertexIndicesArr2 = new VertexIndices[list4.size()];
                int i12 = 0;
                for (String str4 : list4) {
                    int i13 = i12;
                    i12++;
                    vertexIndicesArr2[i13] = new VertexIndices(((Integer) hashMap.get(str4)).intValue(), vertexInputs.get(str4).indexOf(vertexName3));
                }
                graphVertex5.setOutputVertices(vertexIndicesArr2);
            }
        }
        this.initCalled = true;
    }

    public void initGradientsView() {
        if (!this.initCalled) {
            init();
        }
        int i = 0;
        int[] iArr = new int[this.topologicalOrder.length];
        int i2 = 0;
        while (i2 < this.configuration.getNetworkInputs().size()) {
            iArr[i2] = 0;
            i2++;
        }
        Iterator<Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex>> it = this.configuration.getVertices().entrySet().iterator();
        while (it.hasNext()) {
            iArr[i2] = it.next().getValue().numParams(true);
            i += iArr[i2];
            i2++;
        }
        this.flattenedGradients = Nd4j.create(1, i);
        int i3 = 0;
        int i4 = 0;
        for (int i5 : this.topologicalOrder) {
            int i6 = iArr[i5];
            if (i6 != 0) {
                this.vertices[i5].setBackpropGradientsViewArray(this.flattenedGradients.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i3, i3 + i6)}));
            }
            i4++;
            i3 += i6;
        }
    }

    public void pretrain(DataSetIterator dataSetIterator) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs using a DataSetIterator");
        }
        pretrain(ComputationGraphUtil.toMultiDataSetIterator(dataSetIterator));
    }

    public void pretrain(MultiDataSetIterator multiDataSetIterator) {
        if (this.configuration.isPretrain()) {
            if (this.flattenedGradients == null) {
                initGradientsView();
            }
            for (int i = 0; i < this.topologicalOrder.length; i++) {
                if (this.vertices[i].hasLayer() && !(this.vertices[i].getLayer() instanceof IOutputLayer) && this.vertices[i].getLayer().isPretrainLayer()) {
                    pretrainLayer(this.vertices[i].getVertexName(), multiDataSetIterator);
                }
            }
        }
    }

    public void pretrainLayer(String str, DataSetIterator dataSetIterator) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs using a DataSetIterator");
        }
        pretrainLayer(str, ComputationGraphUtil.toMultiDataSetIterator(dataSetIterator));
    }

    public void pretrainLayer(String str, MultiDataSetIterator multiDataSetIterator) {
        if (this.configuration.isPretrain()) {
            if (this.flattenedGradients == null) {
                initGradientsView();
            }
            if (!this.verticesMap.containsKey(str)) {
                throw new IllegalStateException("Invalid vertex name: " + str);
            }
            if (this.verticesMap.get(str).hasLayer()) {
                int vertexIndex = this.verticesMap.get(str).getVertexIndex();
                LinkedList linkedList = new LinkedList();
                HashSet hashSet = new HashSet();
                linkedList.add(Integer.valueOf(this.topologicalOrder[vertexIndex]));
                hashSet.add(Integer.valueOf(this.topologicalOrder[vertexIndex]));
                for (int i = vertexIndex - 1; i >= 0; i--) {
                    VertexIndices[] outputVertices = this.vertices[this.topologicalOrder[i]].getOutputVertices();
                    boolean z = false;
                    int length = outputVertices.length;
                    int i2 = 0;
                    while (true) {
                        if (i2 >= length) {
                            break;
                        }
                        if (hashSet.contains(Integer.valueOf(outputVertices[i2].getVertexIndex()))) {
                            z = true;
                            break;
                        }
                        i2++;
                    }
                    if (z) {
                        linkedList.addFirst(Integer.valueOf(this.topologicalOrder[i]));
                        hashSet.add(Integer.valueOf(this.topologicalOrder[i]));
                    }
                }
                int[] iArr = new int[linkedList.size()];
                int i3 = 0;
                Iterator it = linkedList.iterator();
                while (it.hasNext()) {
                    int i4 = i3;
                    i3++;
                    iArr[i4] = ((Integer) it.next()).intValue();
                }
                GraphVertex graphVertex = this.vertices[iArr[iArr.length - 1]];
                Layer layer = graphVertex.getLayer();
                if (!multiDataSetIterator.hasNext() && multiDataSetIterator.resetSupported()) {
                    multiDataSetIterator.reset();
                }
                while (multiDataSetIterator.hasNext()) {
                    setInputs(((MultiDataSet) multiDataSetIterator.next()).getFeatures());
                    for (int i5 = 0; i5 < iArr.length - 1; i5++) {
                        GraphVertex graphVertex2 = this.vertices[iArr[i5]];
                        if (graphVertex2.isInputVertex()) {
                            VertexIndices[] outputVertices2 = graphVertex2.getOutputVertices();
                            INDArray iNDArray = this.inputs[graphVertex2.getVertexIndex()];
                            for (VertexIndices vertexIndices : outputVertices2) {
                                this.vertices[vertexIndices.getVertexIndex()].setInput(vertexIndices.getVertexEdgeNumber(), iNDArray.dup());
                            }
                        } else {
                            INDArray doForward = graphVertex2.doForward(true);
                            VertexIndices[] outputVertices3 = graphVertex2.getOutputVertices();
                            if (outputVertices3 != null) {
                                for (VertexIndices vertexIndices2 : outputVertices3) {
                                    this.vertices[vertexIndices2.getVertexIndex()].setInput(vertexIndices2.getVertexEdgeNumber(), doForward);
                                }
                            }
                        }
                    }
                    layer.fit(graphVertex.getInputs()[0]);
                    layer.conf().setPretrain(false);
                }
            }
        }
    }

    public void fit(DataSet dataSet) {
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs or outputs using a DataSet");
        }
        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()}, dataSet.getFeaturesMaskArray() != null ? new INDArray[]{dataSet.getFeaturesMaskArray()} : null, dataSet.getLabelsMaskArray() != null ? new INDArray[]{dataSet.getLabelsMaskArray()} : null);
        } else {
            fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()});
        }
        if (hasMaskArrays) {
            clearLayerMaskArrays();
        }
    }

    public void fit(DataSetIterator dataSetIterator) {
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs or outputs using a DataSetIterator");
        }
        DataSetIterator asyncDataSetIterator = dataSetIterator.asyncSupported() ? new AsyncDataSetIterator(dataSetIterator, 2) : dataSetIterator;
        if (this.trainingListeners.size() > 0) {
            Iterator<TrainingListener> it = this.trainingListeners.iterator();
            while (it.hasNext()) {
                it.next().onEpochStart(this);
            }
        }
        if (this.configuration.isPretrain()) {
            pretrain(asyncDataSetIterator);
        }
        if (this.configuration.isBackprop()) {
            update(TaskUtils.buildTask(asyncDataSetIterator));
            while (asyncDataSetIterator.hasNext()) {
                DataSet dataSet = (DataSet) asyncDataSetIterator.next();
                if (dataSet.getFeatures() == null || dataSet.getLabels() == null) {
                    break;
                }
                boolean hasMaskArrays = dataSet.hasMaskArrays();
                if (hasMaskArrays) {
                    setLayerMaskArrays(dataSet.getFeaturesMaskArray() != null ? new INDArray[]{dataSet.getFeaturesMaskArray()} : null, dataSet.getLabelsMaskArray() != null ? new INDArray[]{dataSet.getLabelsMaskArray()} : null);
                }
                if (this.configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                    doTruncatedBPTT(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()}, hasMaskArrays ? new INDArray[]{dataSet.getFeaturesMaskArray()} : null, hasMaskArrays ? new INDArray[]{dataSet.getLabelsMaskArray()} : null);
                } else {
                    setInput(0, dataSet.getFeatures());
                    setLabel(0, dataSet.getLabels());
                    if (this.solver == null) {
                        this.solver = new Solver.Builder().configure(this.defaultConfiguration).listeners(this.listeners).model(this).build();
                    }
                    this.solver.optimize();
                }
                if (hasMaskArrays) {
                    clearLayerMaskArrays();
                }
                Nd4j.getMemoryManager().invokeGcOccasionally();
            }
        }
        if (this.trainingListeners.size() > 0) {
            Iterator<TrainingListener> it2 = this.trainingListeners.iterator();
            while (it2.hasNext()) {
                it2.next().onEpochEnd(this);
            }
        }
    }

    public void fit(MultiDataSet multiDataSet) {
        fit(multiDataSet.getFeatures(), multiDataSet.getLabels(), multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
        if (multiDataSet.hasMaskArrays()) {
            clearLayerMaskArrays();
        }
    }

    public void fit(MultiDataSetIterator multiDataSetIterator) {
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        MultiDataSetIterator asyncMultiDataSetIterator = multiDataSetIterator.asyncSupported() ? new AsyncMultiDataSetIterator(multiDataSetIterator, 2) : multiDataSetIterator;
        if (this.configuration.isPretrain()) {
            pretrain(asyncMultiDataSetIterator);
        }
        if (this.configuration.isBackprop()) {
            while (asyncMultiDataSetIterator.hasNext()) {
                MultiDataSet multiDataSet = (MultiDataSet) asyncMultiDataSetIterator.next();
                if (multiDataSet.getFeatures() == null || multiDataSet.getLabels() == null) {
                    return;
                }
                if (this.configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                    doTruncatedBPTT(multiDataSet.getFeatures(), multiDataSet.getLabels(), multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
                } else {
                    boolean hasMaskArrays = multiDataSet.hasMaskArrays();
                    if (hasMaskArrays) {
                        setLayerMaskArrays(multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
                    }
                    setInputs(multiDataSet.getFeatures());
                    setLabels(multiDataSet.getLabels());
                    if (this.solver == null) {
                        this.solver = new Solver.Builder().configure(this.defaultConfiguration).listeners(this.listeners).model(this).build();
                    }
                    this.solver.optimize();
                    if (hasMaskArrays) {
                        clearLayerMaskArrays();
                    }
                }
                Nd4j.getMemoryManager().invokeGcOccasionally();
            }
        }
    }

    public void fit(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        fit(iNDArrayArr, iNDArrayArr2, null, null);
    }

    public void fit(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4) {
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        setInputs(iNDArrayArr);
        setLabels(iNDArrayArr2);
        setLayerMaskArrays(iNDArrayArr3, iNDArrayArr4);
        update(TaskUtils.buildTask(iNDArrayArr, iNDArrayArr2));
        if (this.configuration.isPretrain()) {
            pretrain(new SingletonMultiDataSetIterator(new org.nd4j.linalg.dataset.MultiDataSet(iNDArrayArr, iNDArrayArr2, iNDArrayArr3, iNDArrayArr4)));
        }
        if (this.configuration.isBackprop()) {
            if (this.configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(iNDArrayArr, iNDArrayArr2, iNDArrayArr3, iNDArrayArr4);
            } else {
                if (this.solver == null) {
                    this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                }
                this.solver.optimize();
            }
        }
        if (iNDArrayArr3 == null && iNDArrayArr4 == null) {
            return;
        }
        clearLayerMaskArrays();
    }

    public int[] topologicalSortOrder() {
        if (this.topologicalOrder != null) {
            return this.topologicalOrder;
        }
        Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> vertices = this.configuration.getVertices();
        int[] iArr = new int[this.configuration.getNetworkInputs().size() + this.configuration.getVertices().size()];
        int i = 0;
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        int i2 = 0;
        for (String str : this.configuration.getNetworkInputs()) {
            hashMap.put(Integer.valueOf(i2), str);
            hashMap2.put(str, Integer.valueOf(i2));
            i2++;
        }
        Iterator<Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex>> it = vertices.entrySet().iterator();
        while (it.hasNext()) {
            String key = it.next().getKey();
            hashMap.put(Integer.valueOf(i2), key);
            hashMap2.put(key, Integer.valueOf(i2));
            i2++;
        }
        HashMap hashMap3 = new HashMap();
        HashMap hashMap4 = new HashMap();
        Iterator<String> it2 = this.configuration.getNetworkInputs().iterator();
        while (it2.hasNext()) {
            hashMap3.put(Integer.valueOf(((Integer) hashMap2.get(it2.next())).intValue()), null);
        }
        Iterator<Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex>> it3 = vertices.entrySet().iterator();
        while (it3.hasNext()) {
            String key2 = it3.next().getKey();
            int intValue = ((Integer) hashMap2.get(key2)).intValue();
            List<String> list = this.configuration.getVertexInputs().get(key2);
            if (list == null || list.isEmpty()) {
                hashMap3.put(Integer.valueOf(intValue), null);
            } else {
                HashSet hashSet = new HashSet();
                Iterator<String> it4 = list.iterator();
                while (it4.hasNext()) {
                    Integer num = (Integer) hashMap2.get(it4.next());
                    if (num == null) {
                        System.out.println();
                    }
                    hashSet.add(num);
                    Set set = (Set) hashMap4.get(num);
                    if (set == null) {
                        set = new HashSet();
                        hashMap4.put(num, set);
                    }
                    set.add(Integer.valueOf(intValue));
                }
                hashMap3.put(Integer.valueOf(intValue), hashSet);
            }
        }
        LinkedList linkedList = new LinkedList();
        for (Map.Entry entry : hashMap3.entrySet()) {
            Set set2 = (Set) entry.getValue();
            if (set2 == null || set2.isEmpty()) {
                linkedList.add(entry.getKey());
            }
        }
        while (!linkedList.isEmpty()) {
            int intValue2 = ((Integer) linkedList.removeFirst()).intValue();
            int i3 = i;
            i++;
            iArr[i3] = intValue2;
            Set<Integer> set3 = (Set) hashMap4.get(Integer.valueOf(intValue2));
            if (set3 != null) {
                for (Integer num2 : set3) {
                    Set set4 = (Set) hashMap3.get(num2);
                    set4.remove(Integer.valueOf(intValue2));
                    if (set4.isEmpty()) {
                        linkedList.add(num2);
                    }
                }
            }
        }
        for (Map.Entry entry2 : hashMap3.entrySet()) {
            Set set5 = (Set) entry2.getValue();
            if (set5 != null && !set5.isEmpty()) {
                throw new IllegalStateException("Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle (cycle includes vertex \"" + ((String) hashMap.get(entry2.getKey())) + "\")");
            }
        }
        return iArr;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore() {
        if (this.configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
            Map<String, INDArray> rnnActivateUsingStoredState = rnnActivateUsingStoredState(this.inputs, true, true);
            if (this.trainingListeners.size() > 0) {
                Iterator<TrainingListener> it = this.trainingListeners.iterator();
                while (it.hasNext()) {
                    it.next().onForwardPass(this, rnnActivateUsingStoredState);
                }
            }
            calcBackpropGradients(true, new INDArray[0]);
        } else {
            Map<String, INDArray> feedForward = feedForward(true, true);
            if (this.trainingListeners.size() > 0) {
                Iterator<TrainingListener> it2 = this.trainingListeners.iterator();
                while (it2.hasNext()) {
                    it2.next().onForwardPass(this, feedForward);
                }
            }
            calcBackpropGradients(false, new INDArray[0]);
        }
        double calcL1 = calcL1();
        double calcL2 = calcL2();
        this.score = 0.0d;
        Iterator<String> it3 = this.configuration.getNetworkOutputs().iterator();
        while (it3.hasNext()) {
            this.score += ((IOutputLayer) this.verticesMap.get(it3.next()).getLayer()).computeScore(calcL1, calcL2, true);
            calcL1 = 0.0d;
            calcL2 = 0.0d;
        }
        if (this.trainingListeners.size() > 0) {
            Iterator<TrainingListener> it4 = this.trainingListeners.iterator();
            while (it4.hasNext()) {
                it4.next().onBackwardPass(this);
            }
        }
    }

    public Map<String, INDArray> feedForward(INDArray iNDArray, boolean z) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot feedForward with single input for graph network with " + this.numInputArrays + " expected inputs");
        }
        setInput(0, iNDArray);
        return feedForward(z);
    }

    public Map<String, INDArray> feedForward(INDArray[] iNDArrayArr, boolean z) {
        if (this.numInputArrays != iNDArrayArr.length) {
            throw new UnsupportedOperationException("Cannot feedForward with " + iNDArrayArr.length + " inputs for graph network with " + this.numInputArrays + " expected inputs");
        }
        for (int i = 0; i < iNDArrayArr.length; i++) {
            setInput(i, iNDArrayArr[i]);
        }
        return feedForward(z);
    }

    public Map<String, INDArray> feedForward() {
        return feedForward(false);
    }

    public Map<String, INDArray> feedForward(boolean z) {
        return feedForward(z, false);
    }

    private Map<String, INDArray> feedForward(boolean z, boolean z2) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.topologicalOrder.length; i++) {
            GraphVertex graphVertex = this.vertices[this.topologicalOrder[i]];
            if (graphVertex.isInputVertex()) {
                VertexIndices[] outputVertices = graphVertex.getOutputVertices();
                INDArray iNDArray = this.inputs[graphVertex.getVertexIndex()];
                hashMap.put(graphVertex.getVertexName(), iNDArray);
                for (VertexIndices vertexIndices : outputVertices) {
                    this.vertices[vertexIndices.getVertexIndex()].setInput(vertexIndices.getVertexEdgeNumber(), iNDArray.dup());
                }
            } else if (!z2 || !graphVertex.isOutputVertex() || !graphVertex.hasLayer() || !(graphVertex.getLayer() instanceof IOutputLayer)) {
                INDArray doForward = graphVertex.doForward(z);
                if (graphVertex.hasLayer()) {
                    hashMap.put(graphVertex.getVertexName(), doForward);
                }
                VertexIndices[] outputVertices2 = graphVertex.getOutputVertices();
                if (outputVertices2 != null) {
                    for (VertexIndices vertexIndices2 : outputVertices2) {
                        this.vertices[vertexIndices2.getVertexIndex()].setInput(vertexIndices2.getVertexEdgeNumber(), doForward);
                    }
                }
            }
        }
        return hashMap;
    }

    public INDArray[] output(INDArray... iNDArrayArr) {
        return output(false, iNDArrayArr);
    }

    public INDArray outputSingle(INDArray... iNDArrayArr) {
        return outputSingle(false, iNDArrayArr);
    }

    public INDArray[] output(boolean z, INDArray... iNDArrayArr) {
        setInputs(iNDArrayArr);
        Map<String, INDArray> feedForward = feedForward(z);
        INDArray[] iNDArrayArr2 = new INDArray[this.numOutputArrays];
        int i = 0;
        Iterator<String> it = this.configuration.getNetworkOutputs().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            iNDArrayArr2[i2] = feedForward.get(it.next());
        }
        return iNDArrayArr2;
    }

    public INDArray outputSingle(boolean z, INDArray... iNDArrayArr) {
        if (this.numOutputArrays != 1) {
            throw new IllegalStateException("Cannot use outputSingle with ComputationGraph that does not have exactly 1 output. nOutputs: " + this.numOutputArrays);
        }
        return output(z, iNDArrayArr)[0];
    }

    public Gradient backpropGradient(INDArray... iNDArrayArr) {
        if (iNDArrayArr == null || iNDArrayArr.length != this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid input: must have epsilons length equal to number of output arrays");
        }
        calcBackpropGradients(this.configuration.getBackpropType() == BackpropType.TruncatedBPTT, iNDArrayArr);
        return this.gradient;
    }

    protected void calcBackpropGradients(boolean z, INDArray... iNDArrayArr) {
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        LinkedList linkedList = new LinkedList();
        boolean[] zArr = new boolean[this.topologicalOrder.length];
        for (int length = this.topologicalOrder.length - 1; length >= 0; length--) {
            GraphVertex graphVertex = this.vertices[this.topologicalOrder[length]];
            if (!graphVertex.isInputVertex()) {
                if (graphVertex.hasLayer() && (graphVertex.getLayer() instanceof FrozenLayer)) {
                    break;
                }
                if (graphVertex.isOutputVertex()) {
                    int indexOf = this.configuration.getNetworkOutputs().indexOf(graphVertex.getVertexName());
                    if (graphVertex.getLayer() instanceof IOutputLayer) {
                        ((IOutputLayer) graphVertex.getLayer()).setLabels(this.labels[indexOf]);
                    } else {
                        graphVertex.setEpsilon(iNDArrayArr[indexOf]);
                        zArr[this.topologicalOrder[length]] = true;
                    }
                }
                Pair<Gradient, INDArray[]> doBackward = graphVertex.doBackward(z);
                INDArray[] second = doBackward.getSecond();
                VertexIndices[] inputVertices = graphVertex.getInputVertices();
                if (inputVertices != null) {
                    int i = 0;
                    for (VertexIndices vertexIndices : inputVertices) {
                        GraphVertex graphVertex2 = this.vertices[vertexIndices.getVertexIndex()];
                        if (zArr[graphVertex2.getVertexIndex()]) {
                            int i2 = i;
                            i++;
                            graphVertex2.setEpsilon(graphVertex2.getEpsilon().add(second[i2]));
                        } else {
                            int i3 = i;
                            i++;
                            graphVertex2.setEpsilon(second[i3]);
                        }
                        zArr[graphVertex2.getVertexIndex()] = true;
                    }
                }
                if (doBackward.getFirst() != null) {
                    Gradient first = doBackward.getFirst();
                    Map<String, INDArray> gradientForVariable = first.gradientForVariable();
                    LinkedList linkedList2 = new LinkedList();
                    for (Map.Entry<String, INDArray> entry : gradientForVariable.entrySet()) {
                        String key = entry.getKey();
                        linkedList2.addFirst(new Triple(graphVertex.getVertexName() + "_" + key, entry.getValue(), first.flatteningOrderForVariable(key)));
                    }
                    Iterator it = linkedList2.iterator();
                    while (it.hasNext()) {
                        linkedList.addFirst((Triple) it.next());
                    }
                }
            }
        }
        DefaultGradient defaultGradient = new DefaultGradient(this.flattenedGradients);
        Iterator it2 = linkedList.iterator();
        while (it2.hasNext()) {
            Triple triple = (Triple) it2.next();
            defaultGradient.setGradientFor((String) triple.getFirst(), (INDArray) triple.getSecond(), (Character) triple.getThird());
        }
        this.gradient = defaultGradient;
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public ComputationGraph m79clone() {
        INDArray stateViewArray;
        ComputationGraph computationGraph = new ComputationGraph(this.configuration.m31clone());
        computationGraph.init(params().dup(), false);
        if (this.solver != null && (stateViewArray = getUpdater().getStateViewArray()) != null) {
            computationGraph.getUpdater().setStateViewArray(stateViewArray.dup());
        }
        computationGraph.listeners = this.listeners;
        for (int i = 0; i < this.topologicalOrder.length; i++) {
            if (this.vertices[this.topologicalOrder[i]].hasLayer()) {
                String vertexName = this.vertices[this.topologicalOrder[i]].getVertexName();
                if (getLayer(vertexName) instanceof FrozenLayer) {
                    computationGraph.getVertex(vertexName).setLayerAsFrozen();
                }
            }
        }
        return computationGraph;
    }

    public double calcL2() {
        double d = 0.0d;
        for (Layer layer : this.layers) {
            d += layer.calcL2(true);
        }
        return d;
    }

    public double calcL1() {
        double d = 0.0d;
        for (Layer layer : this.layers) {
            d += layer.calcL1(true);
        }
        return d;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setListeners(Collection<IterationListener> collection) {
        this.listeners = collection;
        if (this.layers == null) {
            init();
        }
        for (Layer layer : this.layers) {
            layer.setListeners(collection);
        }
        if (this.solver != null) {
            this.solver.setListeners(collection);
        }
        this.trainingListeners.clear();
        if (collection != null) {
            for (IterationListener iterationListener : collection) {
                if (iterationListener instanceof TrainingListener) {
                    this.trainingListeners.add((TrainingListener) iterationListener);
                }
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setListeners(IterationListener... iterationListenerArr) {
        ArrayList arrayList = new ArrayList();
        if (iterationListenerArr != null && iterationListenerArr.length > 0) {
            for (IterationListener iterationListener : iterationListenerArr) {
                if (iterationListener != null) {
                    arrayList.add(iterationListener);
                }
            }
        }
        setListeners(arrayList);
    }

    public Collection<IterationListener> getListeners() {
        return this.listeners;
    }

    public ComputationGraphUpdater getUpdater() {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
            this.solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this));
        }
        return this.solver.getOptimizer().getComputationGraphUpdater();
    }

    public void setUpdater(ComputationGraphUpdater computationGraphUpdater) {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        }
        this.solver.getOptimizer().setUpdaterComputationGraph(computationGraphUpdater);
    }

    public Layer getOutputLayer(int i) {
        if (i >= this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid index: cannot get output layer " + i + ", total number of network outputs = " + this.numOutputArrays);
        }
        return getLayer(this.configuration.getNetworkOutputs().get(i));
    }

    public INDArray params(boolean z) {
        INDArray params;
        if (z) {
            return this.flattenedParams;
        }
        ArrayList arrayList = new ArrayList(this.layers.length);
        for (int i = 0; i < this.topologicalOrder.length; i++) {
            if (this.vertices[this.topologicalOrder[i]].hasLayer() && (params = this.vertices[this.topologicalOrder[i]].getLayer().params()) != null) {
                arrayList.add(params);
            }
        }
        return Nd4j.toFlattened('f', arrayList);
    }

    public double score(DataSet dataSet) {
        return score(dataSet, false);
    }

    public double score(DataSet dataSet, boolean z) {
        if (this.numInputArrays == 1 && this.numOutputArrays == 1) {
            return score(ComputationGraphUtil.toMultiDataSet(dataSet), z);
        }
        throw new UnsupportedOperationException("Cannot score ComputationGraph network with  DataSet: network does not have 1 input and 1 output arrays");
    }

    public double score(MultiDataSet multiDataSet) {
        return score(multiDataSet, false);
    }

    public double score(MultiDataSet multiDataSet, boolean z) {
        boolean hasMaskArrays = multiDataSet.hasMaskArrays();
        if (hasMaskArrays) {
            setLayerMaskArrays(multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
        }
        feedForward(multiDataSet.getFeatures(), z);
        INDArray[] labels = multiDataSet.getLabels();
        setLabels(labels);
        double calcL1 = calcL1();
        double calcL2 = calcL2();
        double d = 0.0d;
        int i = 0;
        for (String str : this.configuration.getNetworkOutputs()) {
            Layer layer = this.verticesMap.get(str).getLayer();
            if (layer == null || !(layer instanceof IOutputLayer)) {
                log.warn("Cannot calculate score: vertex \"" + str + "\" is not an output layer");
                return 0.0d;
            }
            IOutputLayer iOutputLayer = (IOutputLayer) layer;
            int i2 = i;
            i++;
            iOutputLayer.setLabels(labels[i2]);
            d += iOutputLayer.computeScore(calcL1, calcL2, z);
            calcL1 = 0.0d;
            calcL2 = 0.0d;
        }
        if (hasMaskArrays) {
            clearLayerMaskArrays();
        }
        return d;
    }

    public INDArray scoreExamples(DataSet dataSet, boolean z) {
        if (this.numInputArrays == 1 && this.numOutputArrays == 1) {
            return scoreExamples(ComputationGraphUtil.toMultiDataSet(dataSet), z);
        }
        throw new UnsupportedOperationException("Cannot score ComputationGraph network with  DataSet: network does not have 1 input and 1 output arrays");
    }

    public INDArray scoreExamples(MultiDataSet multiDataSet, boolean z) {
        boolean hasMaskArrays = multiDataSet.hasMaskArrays();
        if (hasMaskArrays) {
            setLayerMaskArrays(multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
        }
        feedForward(multiDataSet.getFeatures(), false);
        setLabels(multiDataSet.getLabels());
        INDArray iNDArray = null;
        double calcL1 = z ? calcL1() : 0.0d;
        double calcL2 = z ? calcL2() : 0.0d;
        int i = 0;
        for (String str : this.configuration.getNetworkOutputs()) {
            Layer layer = this.verticesMap.get(str).getLayer();
            if (layer == null || !(layer instanceof IOutputLayer)) {
                throw new UnsupportedOperationException("Cannot calculate score: vertex \"" + str + "\" is not an output layer");
            }
            IOutputLayer iOutputLayer = (IOutputLayer) layer;
            int i2 = i;
            i++;
            iOutputLayer.setLabels(this.labels[i2]);
            INDArray computeScoreForExamples = iOutputLayer.computeScoreForExamples(calcL1, calcL2);
            if (iNDArray == null) {
                iNDArray = computeScoreForExamples;
            } else {
                iNDArray.addi(computeScoreForExamples);
            }
            calcL1 = 0.0d;
            calcL2 = 0.0d;
        }
        if (hasMaskArrays) {
            clearLayerMaskArrays();
        }
        return iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit() {
        fit(this.inputs, this.labels, this.inputMaskArrays, this.labelMaskArrays);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(INDArray iNDArray, String str) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(Gradient gradient) {
        if (gradient.gradient().length() != numParams(true)) {
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + numParams(true));
        }
        for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
            String key = entry.getKey();
            INDArray value = entry.getValue();
            int indexOf = key.indexOf(95);
            if (indexOf == -1) {
                throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
            }
            String substring = key.substring(0, indexOf);
            String str = key.split("_")[1];
            this.gradient.gradientForVariable().put(key, value);
            getLayer(substring).update(value, str);
        }
        setBackpropGradientsViewArray(gradient.gradient());
    }

    private void update(Task task) {
        if (this.initDone) {
            return;
        }
        this.initDone = true;
        Heartbeat heartbeat = Heartbeat.getInstance();
        Task taskByModel = ModelSerializer.taskByModel(this);
        heartbeat.reportEvent(Event.STANDALONE, EnvironmentUtils.buildEnvironment(), taskByModel);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public double score() {
        return this.score;
    }

    public void setScore(double d) {
        this.score = d;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void accumulateScore(double d) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray params() {
        return params(true);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int numParams() {
        return numParams(true);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int numParams(boolean z) {
        int i = 0;
        for (Layer layer : this.layers) {
            i += layer.numParams(z);
        }
        return i;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        Layer layer;
        int numParams;
        if (iNDArray == this.flattenedParams) {
            return;
        }
        if (this.flattenedParams != null && this.flattenedParams.length() == iNDArray.length()) {
            this.flattenedParams.assign(iNDArray);
            return;
        }
        int i = 0;
        for (int i2 = 0; i2 < this.topologicalOrder.length; i2++) {
            if (this.vertices[this.topologicalOrder[i2]].hasLayer() && (numParams = (layer = this.vertices[this.topologicalOrder[i2]].getLayer()).numParams()) > 0) {
                layer.setParams(iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i, numParams + i)}));
                i += numParams;
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamsViewArray(INDArray iNDArray) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        Layer layer;
        int numParams;
        int i = 0;
        for (int i2 = 0; i2 < this.topologicalOrder.length; i2++) {
            if (this.vertices[this.topologicalOrder[i2]].hasLayer() && (numParams = (layer = this.vertices[this.topologicalOrder[i2]].getLayer()).numParams()) > 0) {
                layer.setBackpropGradientsViewArray(iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i, i + numParams)}));
                i += numParams;
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void applyLearningRateScoreDecay() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
        throw new UnsupportedOperationException("Cannot pretrain ComputationGraph with single INDArray");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void iterate(INDArray iNDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        return this.gradient;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), Double.valueOf(score()));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int batchSize() {
        return this.inputs[0].size(0);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public NeuralNetConfiguration conf() {
        return this.defaultConfiguration;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setConf(NeuralNetConfiguration neuralNetConfiguration) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray input() {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot return single input: ComputationGraph  has multiple inputs");
        }
        if (this.inputs != null) {
            return this.inputs[0];
        }
        return null;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void validateInput() {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public ConvexOptimizer getOptimizer() {
        return this.solver.getOptimizer();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray getParam(String str) {
        int indexOf = str.indexOf(95);
        if (indexOf == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + str + "\"");
        }
        return getLayer(str.substring(0, indexOf)).getParam(str.substring(indexOf + 1));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void initParams() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable() {
        return paramTable(false);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable(boolean z) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Layer layer : this.layers) {
            for (Map.Entry<String, INDArray> entry : layer.paramTable(z).entrySet()) {
                linkedHashMap.put(layer.conf().getLayer().getLayerName() + "_" + entry.getKey(), entry.getValue());
            }
        }
        return linkedHashMap;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamTable(Map<String, INDArray> map) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParam(String str, INDArray iNDArray) {
        int indexOf = str.indexOf(95);
        if (indexOf == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + str + "\"");
        }
        getLayer(str.substring(0, indexOf)).setParam(str.substring(indexOf + 1), iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void clear() {
        this.inputs = null;
        this.labels = null;
        this.inputMaskArrays = null;
        this.labelMaskArrays = null;
    }

    public INDArray[] rnnTimeStep(INDArray... iNDArrayArr) {
        INDArray doForward;
        this.inputs = iNDArrayArr;
        boolean z = true;
        int length = iNDArrayArr.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            if (iNDArrayArr[i].rank() != 2) {
                z = false;
                break;
            }
            i++;
        }
        INDArray[] iNDArrayArr2 = new INDArray[this.numOutputArrays];
        for (int i2 : this.topologicalOrder) {
            GraphVertex graphVertex = this.vertices[i2];
            if (graphVertex.isInputVertex()) {
                VertexIndices[] outputVertices = graphVertex.getOutputVertices();
                INDArray iNDArray = iNDArrayArr[graphVertex.getVertexIndex()];
                for (VertexIndices vertexIndices : outputVertices) {
                    this.vertices[vertexIndices.getVertexIndex()].setInput(vertexIndices.getVertexEdgeNumber(), iNDArray.dup());
                }
            } else {
                if (graphVertex.hasLayer()) {
                    Layer layer = graphVertex.getLayer();
                    doForward = layer instanceof RecurrentLayer ? ((RecurrentLayer) layer).rnnTimeStep(graphVertex.getInputs()[0]) : layer instanceof MultiLayerNetwork ? ((MultiLayerNetwork) layer).rnnTimeStep(graphVertex.getInputs()[0]) : graphVertex.doForward(false);
                } else {
                    doForward = graphVertex.doForward(false);
                }
                if (graphVertex.isOutputVertex()) {
                    iNDArrayArr2[this.configuration.getNetworkOutputs().indexOf(graphVertex.getVertexName())] = doForward;
                }
                VertexIndices[] outputVertices2 = graphVertex.getOutputVertices();
                if (outputVertices2 != null) {
                    for (VertexIndices vertexIndices2 : outputVertices2) {
                        this.vertices[vertexIndices2.getVertexIndex()].setInput(vertexIndices2.getVertexEdgeNumber(), doForward);
                    }
                }
            }
        }
        if (z) {
            for (int i3 = 0; i3 < iNDArrayArr2.length; i3++) {
                if (iNDArrayArr2[i3].rank() == 3 && iNDArrayArr2[i3].size(2) == 1) {
                    iNDArrayArr2[i3] = iNDArrayArr2[i3].tensorAlongDimension(0, new int[]{1, 0});
                }
            }
        }
        this.inputs = null;
        return iNDArrayArr2;
    }

    public Map<String, INDArray> rnnGetPreviousState(int i) {
        return rnnGetPreviousState(this.layers[i].conf().getLayer().getLayerName());
    }

    public Map<String, INDArray> rnnGetPreviousState(String str) {
        Layer layer = this.verticesMap.get(str).getLayer();
        if (layer == null || !(layer instanceof RecurrentLayer)) {
            return null;
        }
        return ((RecurrentLayer) layer).rnnGetPreviousState();
    }

    public Map<String, Map<String, INDArray>> rnnGetPreviousStates() {
        HashMap hashMap = new HashMap();
        for (Layer layer : this.layers) {
            if (layer instanceof RecurrentLayer) {
                hashMap.put(layer.conf().getLayer().getLayerName(), ((RecurrentLayer) layer).rnnGetPreviousState());
            }
        }
        return hashMap;
    }

    public void rnnSetPreviousState(int i, Map<String, INDArray> map) {
        rnnSetPreviousState(this.layers[i].conf().getLayer().getLayerName(), map);
    }

    public void rnnSetPreviousState(String str, Map<String, INDArray> map) {
        Layer layer = this.verticesMap.get(str).getLayer();
        if (layer == null || !(layer instanceof RecurrentLayer)) {
            throw new UnsupportedOperationException("Layer \"" + str + "\" is not a recurrent layer. Cannot set state");
        }
        ((RecurrentLayer) layer).rnnSetPreviousState(map);
    }

    public void rnnSetPreviousStates(Map<String, Map<String, INDArray>> map) {
        for (Map.Entry<String, Map<String, INDArray>> entry : map.entrySet()) {
            rnnSetPreviousState(entry.getKey(), entry.getValue());
        }
    }

    public void rnnClearPreviousState() {
        if (this.layers == null) {
            return;
        }
        for (Layer layer : this.layers) {
            if (layer instanceof RecurrentLayer) {
                ((RecurrentLayer) layer).rnnClearPreviousState();
            } else if (layer instanceof MultiLayerNetwork) {
                ((MultiLayerNetwork) layer).rnnClearPreviousState();
            }
        }
    }

    protected void doTruncatedBPTT(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4) {
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        int i = -1;
        for (INDArray iNDArray : iNDArrayArr) {
            if (iNDArray.rank() == 3) {
                if (i == -1) {
                    i = iNDArray.size(2);
                } else if (i != iNDArray.size(2)) {
                    log.warn("Cannot do TBPTT with time series of different lengths");
                    return;
                }
            }
        }
        for (INDArray iNDArray2 : iNDArrayArr2) {
            if (iNDArray2.rank() == 3) {
                if (i == -1) {
                    i = iNDArray2.size(2);
                } else if (i != iNDArray2.size(2)) {
                    log.warn("Cannot do TBPTT with time series of different lengths");
                    return;
                }
            }
        }
        int tbpttFwdLength = this.configuration.getTbpttFwdLength();
        int i2 = i / tbpttFwdLength;
        if (i % tbpttFwdLength != 0) {
            i2++;
        }
        rnnClearPreviousState();
        INDArray[] iNDArrayArr5 = new INDArray[iNDArrayArr.length];
        INDArray[] iNDArrayArr6 = new INDArray[iNDArrayArr2.length];
        INDArray[] iNDArrayArr7 = iNDArrayArr3 != null ? new INDArray[iNDArrayArr3.length] : null;
        INDArray[] iNDArrayArr8 = iNDArrayArr4 != null ? new INDArray[iNDArrayArr4.length] : null;
        for (int i3 = 0; i3 < i2; i3++) {
            int i4 = i3 * tbpttFwdLength;
            int i5 = i4 + tbpttFwdLength;
            if (i5 > i) {
                i5 = i;
            }
            for (int i6 = 0; i6 < iNDArrayArr.length; i6++) {
                if (iNDArrayArr[i6].rank() != 3) {
                    iNDArrayArr5[i6] = iNDArrayArr[i6];
                } else {
                    iNDArrayArr5[i6] = iNDArrayArr[i6].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i4, i5)});
                }
            }
            for (int i7 = 0; i7 < iNDArrayArr2.length; i7++) {
                if (iNDArrayArr2[i7].rank() != 3) {
                    iNDArrayArr6[i7] = iNDArrayArr2[i7];
                } else {
                    iNDArrayArr6[i7] = iNDArrayArr2[i7].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i4, i5)});
                }
            }
            if (iNDArrayArr3 != null) {
                for (int i8 = 0; i8 < iNDArrayArr3.length; i8++) {
                    if (iNDArrayArr3[i8] != null) {
                        iNDArrayArr7[i8] = iNDArrayArr3[i8].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i4, i5)});
                    }
                }
            }
            if (iNDArrayArr4 != null) {
                for (int i9 = 0; i9 < iNDArrayArr4.length; i9++) {
                    if (iNDArrayArr4[i9] != null) {
                        iNDArrayArr8[i9] = iNDArrayArr4[i9].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i4, i5)});
                    }
                }
            }
            setInputs(iNDArrayArr5);
            setLabels(iNDArrayArr6);
            setLayerMaskArrays(iNDArrayArr7, iNDArrayArr8);
            if (this.solver == null) {
                this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
            }
            this.solver.optimize();
            rnnUpdateStateWithTBPTTState();
        }
        rnnClearPreviousState();
        if (iNDArrayArr3 == null && iNDArrayArr4 == null) {
            return;
        }
        clearLayerMaskArrays();
    }

    public Map<String, INDArray> rnnActivateUsingStoredState(INDArray[] iNDArrayArr, boolean z, boolean z2) {
        INDArray doForward;
        HashMap hashMap = new HashMap();
        for (int i : this.topologicalOrder) {
            GraphVertex graphVertex = this.vertices[i];
            if (graphVertex.isInputVertex()) {
                VertexIndices[] outputVertices = graphVertex.getOutputVertices();
                INDArray iNDArray = iNDArrayArr[graphVertex.getVertexIndex()];
                hashMap.put(graphVertex.getVertexName(), iNDArray);
                for (VertexIndices vertexIndices : outputVertices) {
                    this.vertices[vertexIndices.getVertexIndex()].setInput(vertexIndices.getVertexEdgeNumber(), iNDArray.dup());
                }
            } else {
                if (graphVertex.hasLayer()) {
                    Layer layer = graphVertex.getLayer();
                    if (layer instanceof RecurrentLayer) {
                        doForward = ((RecurrentLayer) layer).rnnActivateUsingStoredState(graphVertex.getInputs()[0], z, z2);
                    } else if (layer instanceof MultiLayerNetwork) {
                        List<INDArray> rnnActivateUsingStoredState = ((MultiLayerNetwork) layer).rnnActivateUsingStoredState(graphVertex.getInputs()[0], z, z2);
                        doForward = rnnActivateUsingStoredState.get(rnnActivateUsingStoredState.size() - 1);
                    } else {
                        doForward = graphVertex.doForward(z);
                    }
                    hashMap.put(graphVertex.getVertexName(), doForward);
                } else {
                    doForward = graphVertex.doForward(z);
                }
                VertexIndices[] outputVertices2 = graphVertex.getOutputVertices();
                if (outputVertices2 != null) {
                    for (VertexIndices vertexIndices2 : outputVertices2) {
                        this.vertices[vertexIndices2.getVertexIndex()].setInput(vertexIndices2.getVertexEdgeNumber(), doForward);
                    }
                }
            }
        }
        return hashMap;
    }

    public void setLayerMaskArrays(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        clearLayerMaskArrays();
        this.inputMaskArrays = iNDArrayArr;
        this.labelMaskArrays = iNDArrayArr2;
        if (iNDArrayArr != null) {
            if (iNDArrayArr.length != this.numInputArrays) {
                throw new IllegalArgumentException("Invalid number of feature mask arrays");
            }
            int i = -1;
            for (INDArray iNDArray : iNDArrayArr) {
                if (iNDArray != null) {
                    i = iNDArray.size(0);
                }
            }
            HashMap hashMap = new HashMap();
            for (int i2 = 0; i2 < this.topologicalOrder.length; i2++) {
                GraphVertex graphVertex = this.vertices[this.topologicalOrder[i2]];
                if (graphVertex.isInputVertex()) {
                    hashMap.put(Integer.valueOf(graphVertex.getVertexIndex()), new Pair(iNDArrayArr[graphVertex.getVertexIndex()], MaskState.Active));
                } else {
                    VertexIndices[] inputVertices = graphVertex.getInputVertices();
                    INDArray[] iNDArrayArr3 = null;
                    MaskState maskState = null;
                    for (int i3 = 0; i3 < inputVertices.length; i3++) {
                        Pair pair = (Pair) hashMap.get(Integer.valueOf(inputVertices[i3].getVertexIndex()));
                        if (pair != null) {
                            if (iNDArrayArr3 == null) {
                                iNDArrayArr3 = new INDArray[inputVertices.length];
                            }
                            iNDArrayArr3[i3] = (INDArray) pair.getFirst();
                            if (maskState == null || maskState == MaskState.Passthrough) {
                                maskState = (MaskState) pair.getSecond();
                            }
                        }
                    }
                    hashMap.put(Integer.valueOf(this.topologicalOrder[i2]), graphVertex.feedForwardMaskArrays(iNDArrayArr3, maskState, i));
                }
            }
        }
        if (iNDArrayArr2 != null) {
            if (iNDArrayArr2.length != this.numOutputArrays) {
                throw new IllegalArgumentException("Invalid number of label mask arrays");
            }
            for (int i4 = 0; i4 < iNDArrayArr2.length; i4++) {
                if (iNDArrayArr2[i4] != null) {
                    this.verticesMap.get(this.configuration.getNetworkOutputs().get(i4)).getLayer().setMaskArray(iNDArrayArr2[i4]);
                }
            }
        }
    }

    public void clearLayerMaskArrays() {
        for (Layer layer : this.layers) {
            layer.setMaskArray(null);
        }
        this.inputMaskArrays = null;
        this.labelMaskArrays = null;
    }

    protected void rnnUpdateStateWithTBPTTState() {
        for (int i = 0; i < this.layers.length; i++) {
            if (this.layers[i] instanceof RecurrentLayer) {
                RecurrentLayer recurrentLayer = (RecurrentLayer) this.layers[i];
                recurrentLayer.rnnSetPreviousState(recurrentLayer.rnnGetTBPTTState());
            } else if (this.layers[i] instanceof MultiLayerNetwork) {
                ((MultiLayerNetwork) this.layers[i]).updateRnnStateWithTBPTTState();
            }
        }
    }

    public Evaluation evaluate(DataSetIterator dataSetIterator) {
        return evaluate(dataSetIterator, null);
    }

    public Evaluation evaluate(DataSetIterator dataSetIterator, List<String> list) {
        return evaluate(dataSetIterator, list, 1);
    }

    public Evaluation evaluate(DataSetIterator dataSetIterator, List<String> list, int i) {
        if (this.layers == null || !(getOutputLayer(0) instanceof IOutputLayer)) {
            throw new IllegalStateException("Cannot evaluate network with no output layer");
        }
        if (list == null) {
            list = dataSetIterator.getLabels();
        }
        Evaluation evaluation = new Evaluation(list, i);
        while (dataSetIterator.hasNext()) {
            org.nd4j.linalg.dataset.DataSet dataSet = (org.nd4j.linalg.dataset.DataSet) dataSetIterator.next();
            if (dataSet.getFeatureMatrix() == null || dataSet.getLabels() == null) {
                break;
            }
            INDArray features = dataSet.getFeatures();
            INDArray labels = dataSet.getLabels();
            INDArray[] output = output(false, features);
            if (labels.rank() == 3) {
                evaluation.evalTimeSeries(labels, output[0]);
            } else {
                evaluation.eval(labels, output[0]);
            }
        }
        return evaluation;
    }

    public String summary() {
        String str = (((("\n" + StringUtils.repeat("=", 140)) + "\n") + String.format("%-40s%-15s%-15s%-30s %s\n", "VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs")) + StringUtils.repeat("=", 140)) + "\n";
        int i = 0;
        for (int i2 : this.topologicalOrder) {
            GraphVertex graphVertex = this.vertices[i2];
            String vertexName = graphVertex.getVertexName();
            String[] split = graphVertex.getClass().toString().split("\\.");
            String str2 = split[split.length - 1];
            String obj = graphVertex.isInputVertex() ? "-" : this.configuration.getVertexInputs().get(vertexName).toString();
            String str3 = "-";
            String str4 = "-";
            String str5 = "-";
            String str6 = "-";
            if (graphVertex.hasLayer()) {
                Layer layer = ((LayerVertex) graphVertex).getLayer();
                String[] split2 = layer.getClass().getName().split("\\.");
                str2 = split2[split2.length - 1];
                str3 = String.valueOf(layer.numParams());
                if (layer.numParams() > 0) {
                    String str7 = "";
                    str4 = String.valueOf(((FeedForwardLayer) layer.conf().getLayer()).getNIn());
                    str5 = String.valueOf(((FeedForwardLayer) layer.conf().getLayer()).getNOut());
                    for (String str8 : layer.conf().getLearningRateByParam().keySet()) {
                        str7 = str7 + str8 + ":" + ArrayUtils.toString(layer.paramTable().get(str8).shape()) + ", ";
                    }
                    str6 = str7.subSequence(0, str7.lastIndexOf(",")).toString();
                }
                if (layer instanceof FrozenLayer) {
                    i += layer.numParams();
                    String[] split3 = ((FrozenLayer) layer).getInsideLayer().getClass().getName().split("\\.");
                    str2 = "Frozen " + split3[split3.length - 1];
                }
            }
            str = (str + String.format("%-40s%-15s%-15s%-30s %s", vertexName + " (" + str2 + ")", str4 + "," + str5, str3, str6, obj)) + "\n";
        }
        return ((((((str + StringUtils.repeat("-", 140)) + String.format("\n%30s %d", "Total Parameters: ", Integer.valueOf(params().length()))) + String.format("\n%30s %d", "Trainable Parameters: ", Integer.valueOf(params().length() - i))) + String.format("\n%30s %d", "Frozen Parameters: ", Integer.valueOf(i))) + "\n") + StringUtils.repeat("=", 140)) + "\n";
    }

    public void setInitDone(boolean z) {
        this.initDone = z;
    }
}
