package org.nd4j.linalg.api.ops;

import java.util.Arrays;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/api/ops/TadCollapseAccumulation.class */
public class TadCollapseAccumulation extends BaseOp {
    protected Op accum;
    protected boolean performSmallerDimension;
    protected int[] smallerDimension;
    protected int[] originalDimension;
    protected int tadsForSmallerDimension;
    protected int tadsForLargerDimension;
    public static final String DEFAULT_NAME = "collapseTad";

    public TadCollapseAccumulation() {
    }

    public TadCollapseAccumulation(Op op, int[] iArr, int[] iArr2, boolean z) {
        this.accum = op;
        this.performSmallerDimension = z;
        this.originalDimension = iArr;
        this.smallerDimension = iArr2;
        this.tadsForSmallerDimension = op.x().tensorssAlongDimension(iArr2);
        this.tadsForLargerDimension = op.x().tensorssAlongDimension(iArr);
    }

    public TadCollapseAccumulation(Op op, int[] iArr, int[] iArr2) {
        this(op, iArr, iArr2, true);
    }

    public TadCollapseAccumulation(Op op, int[] iArr) {
        this.accum = op;
        this.originalDimension = iArr;
    }

    public TadCollapseAccumulation(Op op) {
        this.accum = op;
    }

    public TadCollapseAccumulation(INDArray iNDArray, Op op) {
        super(iNDArray);
        this.accum = op;
    }

    public TadCollapseAccumulation(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, long j, Op op) {
        super(iNDArray, iNDArray2, iNDArray3, j);
        this.accum = op;
    }

    public TadCollapseAccumulation(INDArray iNDArray, INDArray iNDArray2, Op op) {
        super(iNDArray, iNDArray2);
        this.accum = op;
    }

    public TadCollapseAccumulation(INDArray iNDArray, INDArray iNDArray2, long j, Op op) {
        super(iNDArray, iNDArray2, j);
        this.accum = op;
    }

    public Op getAccum() {
        return this.accum;
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public boolean isPassThrough() {
        return true;
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void exec() {
        if (this.smallerDimension == null) {
            this.smallerDimension = new int[]{this.originalDimension[this.originalDimension.length - 1]};
        }
        if ((this.accum instanceof Accumulation) && this.performSmallerDimension) {
            Accumulation accumulation = (Accumulation) this.accum;
            accumulation.setApplyFinalTransform(false);
            Nd4j.getExecutioner().exec(accumulation, this.smallerDimension);
        } else if ((this.accum instanceof IndexAccumulation) && this.performSmallerDimension) {
            Nd4j.getExecutioner().exec((IndexAccumulation) this.accum, this.smallerDimension);
        }
        INDArray create = Nd4j.create(ArrayUtil.removeIndex(this.accum.x().shape(), this.originalDimension));
        int tensorssAlongDimension = this.accum.x().tensorssAlongDimension(this.smallerDimension);
        int tensorssAlongDimension2 = this.accum.x().tensorssAlongDimension(this.originalDimension);
        if (this.accum instanceof Accumulation) {
            int length = this.accum.x().tensorAlongDimension(0, this.originalDimension).length();
            Accumulation accumulation2 = (Accumulation) this.accum;
            for (int i = 0; i < tensorssAlongDimension; i++) {
                int reductionIndexForTad = reductionIndexForTad(i, tensorssAlongDimension2, tensorssAlongDimension);
                create.putScalar(reductionIndexForTad, accumulation2.combineSubResults(create.getDouble(reductionIndexForTad), accumulation2.z().getDouble(i)));
            }
            this.accum.setN(length);
            accumulation2.setApplyFinalTransform(true);
            for (int i2 = 0; i2 < create.length(); i2++) {
                create.putScalar(i2, accumulation2.calculateFinalResult(create.getDouble(i2), length));
            }
        } else if (this.accum instanceof IndexAccumulation) {
            IndexAccumulation indexAccumulation = (IndexAccumulation) this.accum;
            for (int i3 = 0; i3 < tensorssAlongDimension; i3++) {
                int reductionIndexForTad2 = reductionIndexForTad(i3, tensorssAlongDimension2, tensorssAlongDimension);
                create.putScalar(reductionIndexForTad2, indexAccumulation.combineSubResults(this.accum.x().getDouble(i3), i3, create.getDouble(reductionIndexForTad2), reductionIndexForTad2));
            }
        }
        this.accum.setZ(create);
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public INDArray x() {
        return this.accum.x();
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public INDArray y() {
        return this.accum.y();
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public INDArray z() {
        return this.accum.z();
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void exec(int... iArr) {
        this.originalDimension = iArr;
        exec();
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public int opNum() {
        return 0;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public String name() {
        return this.accum == null ? DEFAULT_NAME : this.accum.name();
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public IComplexNumber op(IComplexNumber iComplexNumber, double d) {
        return this.accum.op(iComplexNumber, d);
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public IComplexNumber op(IComplexNumber iComplexNumber, float f) {
        return this.accum.op(iComplexNumber, f);
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public IComplexNumber op(IComplexNumber iComplexNumber, IComplexNumber iComplexNumber2) {
        return this.accum.op(iComplexNumber, iComplexNumber2);
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public float op(float f, float f2) {
        return this.accum.op(f, f2);
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public double op(double d, double d2) {
        return this.accum.op(d, d2);
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public double op(double d) {
        return this.accum.op(d);
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public float op(float f) {
        return this.accum.op(f);
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public IComplexNumber op(IComplexNumber iComplexNumber) {
        return this.accum.op(iComplexNumber);
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public Op opForDimension(int i, int i2) {
        return this.accum.opForDimension(i, i2);
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public Op opForDimension(int i, int... iArr) {
        return this.accum.opForDimension(i, iArr);
    }

    public static int tadIndex(int i, int i2, int i3) {
        return i / (i3 * i2);
    }

    public static int reductionIndexForTad(int i, int i2, int i3) {
        if (i == 0) {
            return 0;
        }
        return i / (i3 / i2);
    }

    public static int tadsPerReduceIndex(int i, int i2) {
        return i2 / i;
    }

    public static int reductionIndexForLinear(int i, int i2, int i3, int i4, int i5) {
        return reductionIndexForTad(tadIndex(i, i2, i3), i4, i5);
    }

    public boolean isPerformSmallerDimension() {
        return this.performSmallerDimension;
    }

    public int[] getSmallerDimension() {
        return this.smallerDimension;
    }

    public int[] getOriginalDimension() {
        return this.originalDimension;
    }

    public int getTadsForSmallerDimension() {
        return this.tadsForSmallerDimension;
    }

    public int getTadsForLargerDimension() {
        return this.tadsForLargerDimension;
    }

    public void setAccum(Op op) {
        this.accum = op;
    }

    public void setPerformSmallerDimension(boolean z) {
        this.performSmallerDimension = z;
    }

    public void setSmallerDimension(int[] iArr) {
        this.smallerDimension = iArr;
    }

    public void setOriginalDimension(int[] iArr) {
        this.originalDimension = iArr;
    }

    public void setTadsForSmallerDimension(int i) {
        this.tadsForSmallerDimension = i;
    }

    public void setTadsForLargerDimension(int i) {
        this.tadsForLargerDimension = i;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof TadCollapseAccumulation)) {
            return false;
        }
        TadCollapseAccumulation tadCollapseAccumulation = (TadCollapseAccumulation) obj;
        if (!tadCollapseAccumulation.canEqual(this)) {
            return false;
        }
        Op accum = getAccum();
        Op accum2 = tadCollapseAccumulation.getAccum();
        if (accum == null) {
            if (accum2 != null) {
                return false;
            }
        } else if (!accum.equals(accum2)) {
            return false;
        }
        return isPerformSmallerDimension() == tadCollapseAccumulation.isPerformSmallerDimension() && Arrays.equals(getSmallerDimension(), tadCollapseAccumulation.getSmallerDimension()) && Arrays.equals(getOriginalDimension(), tadCollapseAccumulation.getOriginalDimension()) && getTadsForSmallerDimension() == tadCollapseAccumulation.getTadsForSmallerDimension() && getTadsForLargerDimension() == tadCollapseAccumulation.getTadsForLargerDimension();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof TadCollapseAccumulation;
    }

    public int hashCode() {
        Op accum = getAccum();
        return (((((((((((1 * 59) + (accum == null ? 43 : accum.hashCode())) * 59) + (isPerformSmallerDimension() ? 79 : 97)) * 59) + Arrays.hashCode(getSmallerDimension())) * 59) + Arrays.hashCode(getOriginalDimension())) * 59) + getTadsForSmallerDimension()) * 59) + getTadsForLargerDimension();
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp
    public String toString() {
        return "TadCollapseAccumulation(accum=" + getAccum() + ", performSmallerDimension=" + isPerformSmallerDimension() + ", smallerDimension=" + Arrays.toString(getSmallerDimension()) + ", originalDimension=" + Arrays.toString(getOriginalDimension()) + ", tadsForSmallerDimension=" + getTadsForSmallerDimension() + ", tadsForLargerDimension=" + getTadsForLargerDimension() + ")";
    }
}
