package org.deeplearning4j.eval;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.eval.curves.PrecisionRecallCurve;
import org.deeplearning4j.eval.curves.RocCurve;
import org.deeplearning4j.eval.serde.ROCArraySerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

/* loaded from: input_file:org/deeplearning4j/eval/ROCBinary.class */
public class ROCBinary extends BaseEvaluation<ROCBinary> {
    public static final int DEFAULT_STATS_PRECISION = 4;

    @JsonSerialize(using = ROCArraySerializer.class)
    private ROC[] underlying;
    private int thresholdSteps;
    private boolean rocRemoveRedundantPts;
    private List<String> labels;

    public ROCBinary() {
        this(0);
    }

    public ROCBinary(int i) {
        this(i, true);
    }

    public ROCBinary(int i, boolean z) {
        this.thresholdSteps = i;
        this.rocRemoveRedundantPts = z;
    }

    @Override // org.deeplearning4j.eval.IEvaluation
    public void reset() {
        this.underlying = null;
    }

    @Override // org.deeplearning4j.eval.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2) {
        eval(iNDArray, iNDArray2, (INDArray) null);
    }

    @Override // org.deeplearning4j.eval.BaseEvaluation, org.deeplearning4j.eval.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        INDArray column;
        int[] iArr;
        if (this.underlying != null && this.underlying.length != iNDArray.size(1)) {
            throw new IllegalStateException("Labels array does not match stored state size. Expected labels array with size " + this.underlying.length + ", got labels array with size " + iNDArray.size(1));
        }
        if (iNDArray.rank() == 3) {
            evalTimeSeries(iNDArray, iNDArray2, iNDArray3);
            return;
        }
        int size = iNDArray.size(1);
        if (this.underlying == null) {
            this.underlying = new ROC[size];
            for (int i = 0; i < size; i++) {
                this.underlying[i] = new ROC(this.thresholdSteps, this.rocRemoveRedundantPts);
            }
        }
        int[] iArr2 = null;
        for (int i2 = 0; i2 < size; i2++) {
            INDArray column2 = iNDArray2.getColumn(i2);
            INDArray column3 = iNDArray.getColumn(i2);
            if (iNDArray3 != null) {
                boolean z = false;
                if (iNDArray3.isColumnVectorOrScalar()) {
                    column = iNDArray3;
                    z = true;
                } else {
                    column = iNDArray3.getColumn(i2);
                }
                if (iArr2 != null) {
                    iArr = iArr2;
                } else {
                    iArr = new int[column.sumNumber().intValue()];
                    int size2 = column.size(0);
                    int i3 = 0;
                    for (int i4 = 0; i4 < size2; i4++) {
                        if (column.getDouble(i4) != EvaluationBinary.DEFAULT_EDGE_VALUE) {
                            int i5 = i3;
                            i3++;
                            iArr[i5] = i4;
                        }
                    }
                    if (z) {
                        iArr2 = iArr;
                    }
                }
                column2 = Nd4j.pullRows(column2, 1, iArr);
                column3 = Nd4j.pullRows(column3, 1, iArr);
            }
            this.underlying[i2].eval(column3, column2);
        }
    }

    @Override // org.deeplearning4j.eval.IEvaluation
    public void merge(ROCBinary rOCBinary) {
        if (this.underlying == null) {
            this.underlying = rOCBinary.underlying;
            return;
        }
        if (rOCBinary.underlying == null) {
            return;
        }
        if (this.underlying.length != rOCBinary.underlying.length) {
            throw new UnsupportedOperationException("Cannot merge ROCBinary: this expects " + this.underlying.length + "outputs, other expects " + rOCBinary.underlying.length + " outputs");
        }
        for (int i = 0; i < this.underlying.length; i++) {
            this.underlying[i].merge(rOCBinary.underlying[i]);
        }
    }

    private void assertIndex(int i) {
        if (this.underlying == null) {
            throw new UnsupportedOperationException("ROCBinary does not have any stats: eval must be called first");
        }
        if (i < 0 || i >= this.underlying.length) {
            throw new IllegalArgumentException("Invalid input: output number must be between 0 and " + (i - 1));
        }
    }

    public int numLabels() {
        if (this.underlying == null) {
            return -1;
        }
        return this.underlying.length;
    }

    public long getCountActualPositive(int i) {
        assertIndex(i);
        return this.underlying[i].getCountActualPositive();
    }

    public long getCountActualNegative(int i) {
        assertIndex(i);
        return this.underlying[i].getCountActualNegative();
    }

    public RocCurve getRocCurve(int i) {
        assertIndex(i);
        return this.underlying[i].getRocCurve();
    }

    public PrecisionRecallCurve getPrecisionRecallCurve(int i) {
        assertIndex(i);
        return this.underlying[i].getPrecisionRecallCurve();
    }

    public double calculateAverageAuc() {
        double d = 0.0d;
        for (int i = 0; i < numLabels(); i++) {
            d += calculateAUC(i);
        }
        return d / numLabels();
    }

    public double calculateAverageAUCPR() {
        double d = 0.0d;
        for (int i = 0; i < numLabels(); i++) {
            d += calculateAUCPR(i);
        }
        return d / numLabels();
    }

    public double calculateAUC(int i) {
        assertIndex(i);
        return this.underlying[i].calculateAUC();
    }

    public double calculateAUCPR(int i) {
        assertIndex(i);
        return this.underlying[i].calculateAUCPR();
    }

    public void setLabelNames(List<String> list) {
        if (list == null) {
            this.labels = null;
        } else {
            this.labels = new ArrayList(list);
        }
    }

    @Override // org.deeplearning4j.eval.IEvaluation
    public String stats() {
        return stats(4);
    }

    public String stats(int i) {
        StringBuilder sb = new StringBuilder();
        int i2 = 15;
        if (this.labels != null) {
            Iterator<String> it = this.labels.iterator();
            while (it.hasNext()) {
                i2 = Math.max(it.next().length(), i2);
            }
        }
        String str = "%-" + (i2 + 5) + "s%-12." + i + "f%-10d%-10d";
        sb.append(String.format("%-" + (i2 + 5) + "s%-12s%-10s%-10s", "Label", "AUC", "# Pos", "# Neg"));
        if (this.underlying != null) {
            for (int i3 = 0; i3 < this.underlying.length; i3++) {
                sb.append("\n").append(String.format(str, this.labels == null ? String.valueOf(i3) : this.labels.get(i3), Double.valueOf(calculateAUC(i3)), Long.valueOf(getCountActualPositive(i3)), Long.valueOf(getCountActualNegative(i3))));
            }
            if (this.thresholdSteps > 0) {
                sb.append("\n");
                sb.append("[Note: Thresholded AUC/AUPRC calculation used with ").append(this.thresholdSteps).append(" steps); accuracy may reduced compared to exact mode]");
            }
        } else {
            sb.append("\n-- No Data --\n");
        }
        return sb.toString();
    }

    @Override // org.deeplearning4j.eval.BaseEvaluation
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ROCBinary)) {
            return false;
        }
        ROCBinary rOCBinary = (ROCBinary) obj;
        if (!rOCBinary.canEqual(this) || !super.equals(obj) || !Arrays.deepEquals(getUnderlying(), rOCBinary.getUnderlying()) || getThresholdSteps() != rOCBinary.getThresholdSteps() || isRocRemoveRedundantPts() != rOCBinary.isRocRemoveRedundantPts()) {
            return false;
        }
        List<String> labels = getLabels();
        List<String> labels2 = rOCBinary.getLabels();
        return labels == null ? labels2 == null : labels.equals(labels2);
    }

    @Override // org.deeplearning4j.eval.BaseEvaluation
    protected boolean canEqual(Object obj) {
        return obj instanceof ROCBinary;
    }

    @Override // org.deeplearning4j.eval.BaseEvaluation
    public int hashCode() {
        int hashCode = (((((super.hashCode() * 59) + Arrays.deepHashCode(getUnderlying())) * 59) + getThresholdSteps()) * 59) + (isRocRemoveRedundantPts() ? 79 : 97);
        List<String> labels = getLabels();
        return (hashCode * 59) + (labels == null ? 43 : labels.hashCode());
    }

    public ROC[] getUnderlying() {
        return this.underlying;
    }

    public int getThresholdSteps() {
        return this.thresholdSteps;
    }

    public boolean isRocRemoveRedundantPts() {
        return this.rocRemoveRedundantPts;
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public void setUnderlying(ROC[] rocArr) {
        this.underlying = rocArr;
    }

    public void setThresholdSteps(int i) {
        this.thresholdSteps = i;
    }

    public void setRocRemoveRedundantPts(boolean z) {
        this.rocRemoveRedundantPts = z;
    }

    public void setLabels(List<String> list) {
        this.labels = list;
    }

    @Override // org.deeplearning4j.eval.BaseEvaluation
    public String toString() {
        return "ROCBinary(underlying=" + Arrays.deepToString(getUnderlying()) + ", thresholdSteps=" + getThresholdSteps() + ", rocRemoveRedundantPts=" + isRocRemoveRedundantPts() + ", labels=" + getLabels() + ")";
    }
}
