package ai.libs.jaicore.ml.regression.learner;

import ai.libs.jaicore.ml.core.learner.ASupervisedLearner;
import ai.libs.jaicore.ml.regression.singlelabel.SingleTargetRegressionPrediction;
import ai.libs.jaicore.ml.regression.singlelabel.SingleTargetRegressionPredictionBatch;
import java.util.ArrayList;
import java.util.Objects;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.IPrediction;
import org.api4.java.ai.ml.core.evaluation.IPredictionBatch;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.regression.evaluation.IRegressionPrediction;
import org.api4.java.ai.ml.regression.evaluation.IRegressionResultBatch;

/* loaded from: input_file:ai/libs/jaicore/ml/regression/learner/ConstantRegressor.class */
public class ConstantRegressor extends ASupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>, IPrediction, IPredictionBatch> {
    private Double constantValue;

    public void fit(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset) throws TrainingException, InterruptedException {
        Objects.requireNonNull(iLabeledDataset);
        if (iLabeledDataset.isEmpty()) {
            throw new IllegalArgumentException("Cannot train majority classifier with empty training set.");
        }
        this.constantValue = Double.valueOf(iLabeledDataset.stream().filter(iLabeledInstance -> {
            return iLabeledInstance.getLabel() != null;
        }).mapToDouble(iLabeledInstance2 -> {
            return ((Double) iLabeledInstance2.getLabel()).doubleValue();
        }).average().getAsDouble());
    }

    @Override // ai.libs.jaicore.ml.core.learner.ASupervisedLearner
    public IRegressionPrediction predict(ILabeledInstance iLabeledInstance) throws PredictionException, InterruptedException {
        return new SingleTargetRegressionPrediction(this.constantValue);
    }

    @Override // ai.libs.jaicore.ml.core.learner.ASupervisedLearner
    public IRegressionResultBatch predict(ILabeledInstance[] iLabeledInstanceArr) throws PredictionException, InterruptedException {
        ArrayList arrayList = new ArrayList(iLabeledInstanceArr.length);
        for (ILabeledInstance iLabeledInstance : iLabeledInstanceArr) {
            arrayList.add(predict(iLabeledInstance));
        }
        return new SingleTargetRegressionPredictionBatch(arrayList);
    }
}
