package org.deeplearning4j.nn.graph.vertex.impl;

import java.util.Arrays;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.class */
public class UnstackVertex extends BaseGraphVertex {
    private int from;
    private int stackSize;
    private int[] forwardShape;
    private int step;

    public UnstackVertex(ComputationGraph computationGraph, String str, int i, int i2, int i3) {
        this(computationGraph, str, i, null, null, i2, i3);
    }

    public UnstackVertex(ComputationGraph computationGraph, String str, int i, VertexIndices[] vertexIndicesArr, VertexIndices[] vertexIndicesArr2, int i2, int i3) {
        super(computationGraph, str, i, vertexIndicesArr, vertexIndicesArr2);
        this.from = i2;
        this.stackSize = i3;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean hasLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean isOutputVertex() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Layer getLayer() {
        return null;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public INDArray doForward(boolean z) {
        if (!canDoForward()) {
            throw new IllegalStateException("Cannot do forward pass: input not set");
        }
        this.forwardShape = Arrays.copyOf(this.inputs[0].shape(), this.inputs[0].rank());
        this.step = this.inputs[0].size(0) / this.stackSize;
        int i = this.from * this.step;
        int i2 = (this.from + 1) * this.step;
        switch (this.inputs[0].rank()) {
            case 2:
                return this.inputs[0].get(new INDArrayIndex[]{NDArrayIndex.interval(i, i2), NDArrayIndex.all()}).dup();
            case 3:
                return this.inputs[0].get(new INDArrayIndex[]{NDArrayIndex.interval(i, i2), NDArrayIndex.all(), NDArrayIndex.all()}).dup();
            case 4:
                return this.inputs[0].get(new INDArrayIndex[]{NDArrayIndex.interval(i, i2), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()}).dup();
            default:
                throw new UnsupportedOperationException("Cannot get subset for activations of rank " + this.inputs[0].rank());
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<Gradient, INDArray[]> doBackward(boolean z) {
        if (!canDoBackward()) {
            throw new IllegalStateException("Cannot do backward pass: error not set");
        }
        INDArray zeros = Nd4j.zeros(this.forwardShape);
        int i = this.from * this.step;
        int i2 = (this.from + 1) * this.step;
        switch (this.forwardShape.length) {
            case 2:
                zeros.put(new INDArrayIndex[]{NDArrayIndex.interval(i, i2), NDArrayIndex.all()}, this.epsilon);
                break;
            case 3:
                zeros.put(new INDArrayIndex[]{NDArrayIndex.interval(i, i2), NDArrayIndex.all(), NDArrayIndex.all()}, this.epsilon);
                break;
            case 4:
                zeros.put(new INDArrayIndex[]{NDArrayIndex.interval(i, i2), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()}, this.epsilon);
                break;
            default:
                throw new RuntimeException("Invalid activation rank");
        }
        return new Pair<>(null, new INDArray[]{zeros});
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        if (iNDArray != null) {
            throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here");
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] iNDArrayArr, MaskState maskState, int i) {
        if (iNDArrayArr == null || iNDArrayArr.length == 0) {
            return new Pair<>(null, maskState);
        }
        return new Pair<>(iNDArrayArr[0].get(new INDArrayIndex[]{NDArrayIndex.interval(this.from * this.step, (this.from + 1) * this.step), NDArrayIndex.all()}), maskState);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    public String toString() {
        return "UnstackVertex(id=" + getVertexIndex() + ",name=\"" + getVertexName() + "\",fromIdx=" + this.from + ",forwardShape=" + this.forwardShape + ")";
    }
}
