package ai.libs.jaicore.ml.classification.singlelabel;

import ai.libs.jaicore.basic.ArrayUtil;
import ai.libs.jaicore.ml.core.evaluation.Prediction;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.IntStream;
import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification;

/* loaded from: input_file:ai/libs/jaicore/ml/classification/singlelabel/SingleLabelClassification.class */
public class SingleLabelClassification extends Prediction implements ISingleLabelClassification {
    private double[] labelProbabilities;

    public SingleLabelClassification(int i, int i2) {
        super(Integer.valueOf(i2));
        this.labelProbabilities = new double[i];
        this.labelProbabilities[i2] = 1.0d;
    }

    public SingleLabelClassification(Map<Integer, Double> map) {
        super(Integer.valueOf(labelWithHighestProbability(map)));
        this.labelProbabilities = new double[map.size()];
        map.entrySet().stream().forEach(entry -> {
            this.labelProbabilities[((Integer) entry.getKey()).intValue()] = ((Double) entry.getValue()).doubleValue();
        });
    }

    public SingleLabelClassification(double[] dArr) {
        super(ArrayUtil.argMax(dArr).get(0));
        this.labelProbabilities = dArr;
    }

    public int getIntPrediction() {
        return ((Integer) super.getPrediction()).intValue();
    }

    @Override // ai.libs.jaicore.ml.core.evaluation.Prediction
    public Integer getPrediction() {
        return Integer.valueOf(getIntPrediction());
    }

    @Override // ai.libs.jaicore.ml.core.evaluation.Prediction
    public Integer getLabelWithHighestProbability() {
        return Integer.valueOf(getIntPrediction());
    }

    @Override // ai.libs.jaicore.ml.core.evaluation.Prediction
    public Map<Integer, Double> getClassDistribution() {
        HashMap hashMap = new HashMap();
        IntStream.range(0, this.labelProbabilities.length).forEach(i -> {
            hashMap.put(Integer.valueOf(i), Double.valueOf(this.labelProbabilities[i]));
        });
        return hashMap;
    }

    public double getProbabilityOfLabel(int i) {
        return this.labelProbabilities[i];
    }

    @Override // ai.libs.jaicore.ml.core.evaluation.Prediction
    public Map<Integer, Double> getClassConfidence() {
        HashMap hashMap = new HashMap();
        IntStream.range(0, this.labelProbabilities.length).forEach(i -> {
            hashMap.put(Integer.valueOf(i), Double.valueOf(this.labelProbabilities[i]));
        });
        return hashMap;
    }

    private static int labelWithHighestProbability(Map<Integer, Double> map) {
        Map.Entry<Integer, Double> entry = null;
        for (Map.Entry<Integer, Double> entry2 : map.entrySet()) {
            if (entry == null || entry.getValue().doubleValue() < entry2.getValue().doubleValue()) {
                entry = entry2;
            }
        }
        if (entry == null) {
            throw new IllegalArgumentException("No prediction contained");
        }
        return entry.getKey().intValue();
    }
}
