package org.deeplearning4j.eval;

import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.eval.meta.Prediction;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.accum.MatchCondition;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/eval/Evaluation.class */
public class Evaluation extends BaseEvaluation<Evaluation> {
    private static final Logger log = LoggerFactory.getLogger(Evaluation.class);
    protected final int topN;
    protected int topNCorrectCount;
    protected int topNTotalCount;
    protected Counter<Integer> truePositives;
    protected Counter<Integer> falsePositives;
    protected Counter<Integer> trueNegatives;
    protected Counter<Integer> falseNegatives;
    protected ConfusionMatrix<Integer> confusion;
    protected int numRowCounter;
    protected List<String> labelsList;
    protected static final double DEFAULT_EDGE_VALUE = 0.0d;
    protected Map<Pair<Integer, Integer>, List<Object>> confusionMatrixMetaData;

    public Evaluation() {
        this.topNCorrectCount = 0;
        this.topNTotalCount = 0;
        this.truePositives = new Counter<>();
        this.falsePositives = new Counter<>();
        this.trueNegatives = new Counter<>();
        this.falseNegatives = new Counter<>();
        this.numRowCounter = 0;
        this.labelsList = new ArrayList();
        this.topN = 1;
    }

    public Evaluation(int i) {
        this(createLabels(i), 1);
    }

    public Evaluation(List<String> list) {
        this(list, 1);
    }

    public Evaluation(Map<Integer, String> map) {
        this(createLabelsFromMap(map), 1);
    }

    public Evaluation(List<String> list, int i) {
        this.topNCorrectCount = 0;
        this.topNTotalCount = 0;
        this.truePositives = new Counter<>();
        this.falsePositives = new Counter<>();
        this.trueNegatives = new Counter<>();
        this.falseNegatives = new Counter<>();
        this.numRowCounter = 0;
        this.labelsList = new ArrayList();
        this.labelsList = list;
        if (list != null) {
            createConfusion(list.size());
        }
        this.topN = i;
    }

    private static List<String> createLabels(int i) {
        if (i == 1) {
            i = 2;
        }
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(String.valueOf(i2));
        }
        return arrayList;
    }

    private static List<String> createLabelsFromMap(Map<Integer, String> map) {
        int size = map.size();
        ArrayList arrayList = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            String str = map.get(Integer.valueOf(i));
            if (str == null) {
                throw new IllegalArgumentException("Invalid labels map: missing key for class " + i + " (expect integers 0 to " + (size - 1) + ")");
            }
            arrayList.add(str);
        }
        return arrayList;
    }

    private void createConfusion(int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(Integer.valueOf(i2));
        }
        this.confusion = new ConfusionMatrix<>(arrayList);
    }

    public void eval(INDArray iNDArray, INDArray iNDArray2, ComputationGraph computationGraph) {
        eval(iNDArray, computationGraph.output(false, iNDArray2)[0]);
    }

    public void eval(INDArray iNDArray, INDArray iNDArray2, MultiLayerNetwork multiLayerNetwork) {
        eval(iNDArray, multiLayerNetwork.output(iNDArray2, Layer.TrainingMode.TEST));
    }

    @Override // org.deeplearning4j.eval.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2) {
        eval(iNDArray, iNDArray2, (List<? extends Serializable>) null);
    }

    @Override // org.deeplearning4j.eval.BaseEvaluation, org.deeplearning4j.eval.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2, List<? extends Serializable> list) {
        this.numRowCounter += iNDArray.shape()[0];
        if (this.confusion == null) {
            int columns = iNDArray.columns();
            if (columns == 1) {
                columns = 2;
            }
            this.labelsList = new ArrayList(columns);
            for (int i = 0; i < columns; i++) {
                this.labelsList.add(String.valueOf(i));
            }
            createConfusion(columns);
        }
        if (iNDArray.length() != iNDArray2.length()) {
            throw new IllegalArgumentException("Unable to evaluate. Outcome matrices not same length");
        }
        int columns2 = iNDArray.columns();
        int rows = iNDArray.rows();
        if (columns2 == 1) {
            INDArray gt = iNDArray2.gt(Double.valueOf(0.5d));
            int intValue = gt.mul(iNDArray).sumNumber().intValue();
            int intValue2 = gt.mul(Double.valueOf(-1.0d)).addi(Double.valueOf(1.0d)).muli(iNDArray).sumNumber().intValue();
            int intValue3 = gt.mul(iNDArray.mul(Double.valueOf(-1.0d)).addi(Double.valueOf(1.0d))).sumNumber().intValue();
            int i2 = ((rows - intValue) - intValue2) - intValue3;
            this.confusion.add(1, 1, intValue);
            this.confusion.add(1, 0, intValue3);
            this.confusion.add(0, 1, intValue2);
            this.confusion.add(0, 0, i2);
            this.truePositives.incrementCount(1, intValue);
            this.falsePositives.incrementCount(1, intValue2);
            this.falseNegatives.incrementCount(1, intValue2);
            this.trueNegatives.incrementCount(1, intValue);
            this.truePositives.incrementCount(0, i2);
            this.falsePositives.incrementCount(0, intValue3);
            this.falseNegatives.incrementCount(0, intValue3);
            this.trueNegatives.incrementCount(0, i2);
            if (list != null) {
                for (int i3 = 0; i3 < gt.size(0) && i3 < list.size(); i3++) {
                    addToMetaConfusionMatrix(iNDArray.getDouble(0) == DEFAULT_EDGE_VALUE ? 0 : 1, gt.getDouble(0) == DEFAULT_EDGE_VALUE ? 0 : 1, list.get(i3));
                }
            }
        } else {
            INDArray argMax = Nd4j.argMax(iNDArray2, new int[]{1});
            INDArray argMax2 = Nd4j.argMax(iNDArray, new int[]{1});
            int length = argMax.length();
            for (int i4 = 0; i4 < length; i4++) {
                int i5 = (int) argMax2.getDouble(i4);
                int i6 = (int) argMax.getDouble(i4);
                this.confusion.add(Integer.valueOf(i5), Integer.valueOf(i6));
                if (list != null && list.size() > i4) {
                    addToMetaConfusionMatrix(i5, i6, list.get(i4));
                }
            }
            for (int i7 = 0; i7 < columns2; i7++) {
                INDArray eps = argMax.eps(Integer.valueOf(i7));
                INDArray column = iNDArray.getColumn(i7);
                int intValue4 = eps.mul(column).sumNumber().intValue();
                int intValue5 = eps.mul(column.mul(Double.valueOf(-1.0d)).addi(Double.valueOf(1.0d))).sumNumber().intValue();
                int intValue6 = eps.mul(Double.valueOf(-1.0d)).addi(Double.valueOf(1.0d)).muli(column).sumNumber().intValue();
                this.truePositives.incrementCount(Integer.valueOf(i7), intValue4);
                this.falsePositives.incrementCount(Integer.valueOf(i7), intValue5);
                this.falseNegatives.incrementCount(Integer.valueOf(i7), intValue6);
                this.trueNegatives.incrementCount(Integer.valueOf(i7), ((rows - intValue4) - intValue5) - intValue6);
            }
        }
        if (columns2 <= 1 || this.topN <= 1) {
            return;
        }
        INDArray argMax3 = Nd4j.argMax(iNDArray, new int[]{1});
        int length2 = argMax3.length();
        for (int i8 = 0; i8 < length2; i8++) {
            if (((int) Nd4j.getExecutioner().exec(new MatchCondition(iNDArray2.getRow(i8), Conditions.greaterThan(Double.valueOf(iNDArray2.getDouble(i8, (int) argMax3.getDouble(i8))))), new int[]{Integer.MAX_VALUE}).getDouble(0)) < this.topN) {
                this.topNCorrectCount++;
            }
            this.topNTotalCount++;
        }
    }

    public void eval(int i, int i2) {
        this.numRowCounter++;
        if (this.confusion == null) {
            throw new UnsupportedOperationException("Cannot evaluate single example without initializing confusion matrix first");
        }
        addToConfusion(Integer.valueOf(i2), Integer.valueOf(i));
        if (i == i2) {
            incrementTruePositives(Integer.valueOf(i));
            for (Integer num : this.confusion.getClasses()) {
                if (num.intValue() != i) {
                    this.trueNegatives.incrementCount(num, 1.0d);
                }
            }
            return;
        }
        incrementFalseNegatives(Integer.valueOf(i2));
        incrementFalsePositives(Integer.valueOf(i));
        for (Integer num2 : this.confusion.getClasses()) {
            if (num2.intValue() != i && num2.intValue() != i2) {
                this.trueNegatives.incrementCount(num2, 1.0d);
            }
        }
    }

    public String stats() {
        return stats(false);
    }

    public String stats(boolean z) {
        StringBuilder append = new StringBuilder().append("\n");
        StringBuilder sb = new StringBuilder();
        List<Integer> classes = this.confusion.getClasses();
        for (Integer num : classes) {
            String resolveLabelForClass = resolveLabelForClass(num);
            for (Integer num2 : classes) {
                int count = this.confusion.getCount(num, num2);
                if (count != 0) {
                    append.append(String.format("Examples labeled as %s classified by model as %s: %d times%n", resolveLabelForClass, resolveLabelForClass(num2), Integer.valueOf(count)));
                }
            }
            if (!z && this.truePositives.getCount(num) == DEFAULT_EDGE_VALUE) {
                if (this.falsePositives.getCount(num) == DEFAULT_EDGE_VALUE) {
                    sb.append(String.format("Warning: class %s was never predicted by the model. This class was excluded from the average precision%n", resolveLabelForClass));
                }
                if (this.falseNegatives.getCount(num) == DEFAULT_EDGE_VALUE) {
                    sb.append(String.format("Warning: class %s has never appeared as a true label. This class was excluded from the average recall%n", resolveLabelForClass));
                }
            }
        }
        append.append("\n");
        append.append((CharSequence) sb);
        DecimalFormat decimalFormat = new DecimalFormat("#.####");
        double accuracy = accuracy();
        double precision = precision();
        double recall = recall();
        double f1 = f1();
        append.append("\n==========================Scores========================================");
        append.append("\n Accuracy:        ").append(format(decimalFormat, accuracy));
        if (this.topN > 1) {
            append.append("\n Top ").append(this.topN).append(" Accuracy:  ").append(format(decimalFormat, topNAccuracy()));
        }
        append.append("\n Precision:       ").append(format(decimalFormat, precision));
        append.append("\n Recall:          ").append(format(decimalFormat, recall));
        append.append("\n F1 Score:        ").append(format(decimalFormat, f1));
        append.append("\n========================================================================");
        return append.toString();
    }

    private static String format(DecimalFormat decimalFormat, double d) {
        return (Double.isNaN(d) || Double.isInfinite(d)) ? String.valueOf(d) : decimalFormat.format(d);
    }

    private String resolveLabelForClass(Integer num) {
        return (this.labelsList == null || this.labelsList.size() <= num.intValue()) ? num.toString() : this.labelsList.get(num.intValue());
    }

    public double precision(Integer num) {
        return precision(num, DEFAULT_EDGE_VALUE);
    }

    public double precision(Integer num, double d) {
        double count = this.truePositives.getCount(num);
        double count2 = this.falsePositives.getCount(num);
        return (count == DEFAULT_EDGE_VALUE && count2 == DEFAULT_EDGE_VALUE) ? d : count / (count + count2);
    }

    public double precision() {
        double d = 0.0d;
        int i = 0;
        for (Integer num : this.confusion.getClasses()) {
            if (precision(num, -1.0d) != -1.0d) {
                d += precision(num);
                i++;
            }
        }
        return d / i;
    }

    public double recall(Integer num) {
        return recall(num, DEFAULT_EDGE_VALUE);
    }

    public double recall(Integer num, double d) {
        double count = this.truePositives.getCount(num);
        double count2 = this.falseNegatives.getCount(num);
        return (count == DEFAULT_EDGE_VALUE && count2 == DEFAULT_EDGE_VALUE) ? d : count / (count + count2);
    }

    public double recall() {
        double d = 0.0d;
        int i = 0;
        for (Integer num : this.confusion.getClasses()) {
            if (recall(num, -1.0d) != -1.0d) {
                d += recall(num);
                i++;
            }
        }
        return d / i;
    }

    public double falsePositiveRate(Integer num) {
        return recall(num, DEFAULT_EDGE_VALUE);
    }

    public double falsePositiveRate(Integer num, double d) {
        double count = this.falsePositives.getCount(num);
        double count2 = this.trueNegatives.getCount(num);
        return (count == DEFAULT_EDGE_VALUE && count2 == DEFAULT_EDGE_VALUE) ? d : count / (count + count2);
    }

    public double falsePositiveRate() {
        double d = 0.0d;
        int i = 0;
        for (Integer num : this.confusion.getClasses()) {
            if (falsePositiveRate(num, -1.0d) != -1.0d) {
                d += falsePositiveRate(num);
                i++;
            }
        }
        return d / i;
    }

    public double falseNegativeRate(Integer num) {
        return recall(num, DEFAULT_EDGE_VALUE);
    }

    public double falseNegativeRate(Integer num, double d) {
        double count = this.falseNegatives.getCount(num);
        double count2 = this.truePositives.getCount(num);
        return (count == DEFAULT_EDGE_VALUE && count2 == DEFAULT_EDGE_VALUE) ? d : count / (count + count2);
    }

    public double falseNegativeRate() {
        double d = 0.0d;
        int i = 0;
        for (Integer num : this.confusion.getClasses()) {
            if (falseNegativeRate(num, -1.0d) != -1.0d) {
                d += falseNegativeRate(num);
                i++;
            }
        }
        return d / i;
    }

    public double falseAlarmRate() {
        return (falsePositiveRate() + falseNegativeRate()) / 2.0d;
    }

    public double f1(Integer num) {
        double precision = precision(num);
        double recall = recall(num);
        return (precision == DEFAULT_EDGE_VALUE || recall == DEFAULT_EDGE_VALUE) ? DEFAULT_EDGE_VALUE : 2.0d * ((precision * recall) / (precision + recall));
    }

    public double f1() {
        double precision = precision();
        double recall = recall();
        return (precision == DEFAULT_EDGE_VALUE || recall == DEFAULT_EDGE_VALUE) ? DEFAULT_EDGE_VALUE : 2.0d * ((precision * recall) / (precision + recall));
    }

    public double accuracy() {
        int size = this.confusion.getClasses().size();
        int i = 0;
        for (int i2 = 0; i2 < size; i2++) {
            i += this.confusion.getCount(Integer.valueOf(i2), Integer.valueOf(i2));
        }
        return i / getNumRowCounter();
    }

    public double topNAccuracy() {
        return this.topN <= 1 ? accuracy() : this.topNTotalCount == 0 ? DEFAULT_EDGE_VALUE : this.topNCorrectCount / this.topNTotalCount;
    }

    public Map<Integer, Integer> truePositives() {
        return convertToMap(this.truePositives, this.confusion.getClasses().size());
    }

    public Map<Integer, Integer> trueNegatives() {
        return convertToMap(this.trueNegatives, this.confusion.getClasses().size());
    }

    public Map<Integer, Integer> falsePositives() {
        return convertToMap(this.falsePositives, this.confusion.getClasses().size());
    }

    public Map<Integer, Integer> falseNegatives() {
        return convertToMap(this.falseNegatives, this.confusion.getClasses().size());
    }

    public Map<Integer, Integer> negative() {
        return addMapsByKey(trueNegatives(), falsePositives());
    }

    public Map<Integer, Integer> positive() {
        return addMapsByKey(truePositives(), falseNegatives());
    }

    private Map<Integer, Integer> convertToMap(Counter<Integer> counter, int i) {
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < i; i2++) {
            hashMap.put(Integer.valueOf(i2), Integer.valueOf((int) counter.getCount(Integer.valueOf(i2))));
        }
        return hashMap;
    }

    private Map<Integer, Integer> addMapsByKey(Map<Integer, Integer> map, Map<Integer, Integer> map2) {
        HashMap hashMap = new HashMap();
        HashSet<Integer> hashSet = new HashSet(map.keySet());
        hashSet.addAll(map2.keySet());
        for (Integer num : hashSet) {
            Integer num2 = map.get(num);
            Integer num3 = map2.get(num);
            if (num2 == null) {
                num2 = 0;
            }
            if (num3 == null) {
                num3 = 0;
            }
            hashMap.put(num, Integer.valueOf(num2.intValue() + num3.intValue()));
        }
        return hashMap;
    }

    public void incrementTruePositives(Integer num) {
        this.truePositives.incrementCount(num, 1.0d);
    }

    public void incrementTrueNegatives(Integer num) {
        this.trueNegatives.incrementCount(num, 1.0d);
    }

    public void incrementFalseNegatives(Integer num) {
        this.falseNegatives.incrementCount(num, 1.0d);
    }

    public void incrementFalsePositives(Integer num) {
        this.falsePositives.incrementCount(num, 1.0d);
    }

    public void addToConfusion(Integer num, Integer num2) {
        this.confusion.add(num, num2);
    }

    public int classCount(Integer num) {
        return this.confusion.getActualTotal(num);
    }

    public int getNumRowCounter() {
        return this.numRowCounter;
    }

    public int getTopNCorrectCount() {
        if (this.topN > 1) {
            return this.topNCorrectCount;
        }
        int size = this.confusion.getClasses().size();
        int i = 0;
        for (int i2 = 0; i2 < size; i2++) {
            i += this.confusion.getCount(Integer.valueOf(i2), Integer.valueOf(i2));
        }
        return i;
    }

    public int getTopNTotalCount() {
        return this.topN <= 1 ? getNumRowCounter() : this.topNTotalCount;
    }

    public String getClassLabel(Integer num) {
        return resolveLabelForClass(num);
    }

    public ConfusionMatrix<Integer> getConfusionMatrix() {
        return this.confusion;
    }

    @Override // org.deeplearning4j.eval.IEvaluation
    public void merge(Evaluation evaluation) {
        if (evaluation == null) {
            return;
        }
        this.truePositives.incrementAll(evaluation.truePositives);
        this.falsePositives.incrementAll(evaluation.falsePositives);
        this.trueNegatives.incrementAll(evaluation.trueNegatives);
        this.falseNegatives.incrementAll(evaluation.falseNegatives);
        if (this.confusion == null) {
            if (evaluation.confusion != null) {
                this.confusion = new ConfusionMatrix<>(evaluation.confusion);
            }
        } else if (evaluation.confusion != null) {
            this.confusion.add(evaluation.confusion);
        }
        this.numRowCounter += evaluation.numRowCounter;
        if (this.labelsList.isEmpty()) {
            this.labelsList.addAll(evaluation.labelsList);
        }
        if (this.topN != evaluation.topN) {
            log.warn("Different topN values ({} vs {}) detected during Evaluation merging. Top N accuracy may not be accurate.", Integer.valueOf(this.topN), Integer.valueOf(evaluation.topN));
        }
        this.topNCorrectCount += evaluation.topNCorrectCount;
        this.topNTotalCount += evaluation.topNTotalCount;
    }

    public String confusionToString() {
        int size = this.confusion.getClasses().size();
        int i = 0;
        Iterator<String> it = this.labelsList.iterator();
        while (it.hasNext()) {
            i = Math.max(i, it.next().length());
        }
        int max = Math.max(i + 5, 10);
        StringBuilder sb = new StringBuilder();
        sb.append("%-3d");
        sb.append("%-");
        sb.append(max);
        sb.append("s | ");
        StringBuilder sb2 = new StringBuilder();
        sb2.append("   %-").append(max).append("s   ");
        for (int i2 = 0; i2 < size; i2++) {
            sb.append("%7d");
            sb2.append("%7d");
        }
        String sb3 = sb.toString();
        StringBuilder sb4 = new StringBuilder();
        Object[] objArr = new Object[size + 1];
        objArr[0] = "Predicted:";
        for (int i3 = 0; i3 < size; i3++) {
            objArr[i3 + 1] = Integer.valueOf(i3);
        }
        sb4.append(String.format(sb2.toString(), objArr)).append("\n");
        sb4.append("   Actual:\n");
        for (int i4 = 0; i4 < size; i4++) {
            Object[] objArr2 = new Object[size + 2];
            objArr2[0] = Integer.valueOf(i4);
            objArr2[1] = this.labelsList.get(i4);
            for (int i5 = 0; i5 < size; i5++) {
                objArr2[i5 + 2] = Integer.valueOf(this.confusion.getCount(Integer.valueOf(i4), Integer.valueOf(i5)));
            }
            sb4.append(String.format(sb3, objArr2));
            sb4.append("\n");
        }
        return sb4.toString();
    }

    private void addToMetaConfusionMatrix(int i, int i2, Object obj) {
        if (this.confusionMatrixMetaData == null) {
            this.confusionMatrixMetaData = new HashMap();
        }
        Pair<Integer, Integer> pair = new Pair<>(Integer.valueOf(i), Integer.valueOf(i2));
        List<Object> list = this.confusionMatrixMetaData.get(pair);
        if (list == null) {
            list = new ArrayList();
            this.confusionMatrixMetaData.put(pair, list);
        }
        list.add(obj);
    }

    public List<Prediction> getPredictionErrors() {
        if (this.confusionMatrixMetaData == null) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList<Map.Entry> arrayList2 = new ArrayList(this.confusionMatrixMetaData.entrySet());
        Collections.sort(arrayList2, new Comparator<Map.Entry<Pair<Integer, Integer>, List<Object>>>() { // from class: org.deeplearning4j.eval.Evaluation.1
            @Override // java.util.Comparator
            public int compare(Map.Entry<Pair<Integer, Integer>, List<Object>> entry, Map.Entry<Pair<Integer, Integer>, List<Object>> entry2) {
                Pair<Integer, Integer> key = entry.getKey();
                Pair<Integer, Integer> key2 = entry2.getKey();
                int compare = Integer.compare(key.getFirst().intValue(), key2.getFirst().intValue());
                return compare != 0 ? compare : Integer.compare(key.getSecond().intValue(), key2.getSecond().intValue());
            }
        });
        for (Map.Entry entry : arrayList2) {
            Pair pair = (Pair) entry.getKey();
            if (!((Integer) pair.getFirst()).equals(pair.getSecond())) {
                Iterator it = ((List) entry.getValue()).iterator();
                while (it.hasNext()) {
                    arrayList.add(new Prediction(((Integer) pair.getFirst()).intValue(), ((Integer) pair.getSecond()).intValue(), it.next()));
                }
            }
        }
        return arrayList;
    }

    public List<Prediction> getPredictionsByActualClass(int i) {
        if (this.confusionMatrixMetaData == null) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<Pair<Integer, Integer>, List<Object>> entry : this.confusionMatrixMetaData.entrySet()) {
            if (entry.getKey().getFirst().intValue() == i) {
                int intValue = entry.getKey().getFirst().intValue();
                int intValue2 = entry.getKey().getSecond().intValue();
                Iterator<Object> it = entry.getValue().iterator();
                while (it.hasNext()) {
                    arrayList.add(new Prediction(intValue, intValue2, it.next()));
                }
            }
        }
        return arrayList;
    }

    public List<Prediction> getPredictionByPredictedClass(int i) {
        if (this.confusionMatrixMetaData == null) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<Pair<Integer, Integer>, List<Object>> entry : this.confusionMatrixMetaData.entrySet()) {
            if (entry.getKey().getSecond().intValue() == i) {
                int intValue = entry.getKey().getFirst().intValue();
                int intValue2 = entry.getKey().getSecond().intValue();
                Iterator<Object> it = entry.getValue().iterator();
                while (it.hasNext()) {
                    arrayList.add(new Prediction(intValue, intValue2, it.next()));
                }
            }
        }
        return arrayList;
    }

    public List<Prediction> getPredictions(int i, int i2) {
        if (this.confusionMatrixMetaData == null) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        List<Object> list = this.confusionMatrixMetaData.get(new Pair(Integer.valueOf(i), Integer.valueOf(i2)));
        if (list == null) {
            return arrayList;
        }
        Iterator<Object> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(new Prediction(i, i2, it.next()));
        }
        return arrayList;
    }

    public List<String> getLabelsList() {
        return this.labelsList;
    }

    public void setLabelsList(List<String> list) {
        this.labelsList = list;
    }
}
