package org.deeplearning4j.earlystopping.scorecalc;

import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@Deprecated
/* loaded from: input_file:org/deeplearning4j/earlystopping/scorecalc/DataSetLossCalculatorCG.class */
public class DataSetLossCalculatorCG implements ScoreCalculator<ComputationGraph> {

    @JsonIgnore
    private DataSetIterator dataSetIterator;

    @JsonIgnore
    private MultiDataSetIterator multiDataSetIterator;

    @JsonProperty
    private boolean average;

    public DataSetLossCalculatorCG(DataSetIterator dataSetIterator, boolean z) {
        this.dataSetIterator = dataSetIterator;
        this.average = z;
    }

    public DataSetLossCalculatorCG(MultiDataSetIterator multiDataSetIterator, boolean z) {
        this.multiDataSetIterator = multiDataSetIterator;
        this.average = z;
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator
    public double calculateScore(ComputationGraph computationGraph) {
        double d = 0.0d;
        int i = 0;
        if (this.dataSetIterator != null) {
            this.dataSetIterator.reset();
            while (this.dataSetIterator.hasNext()) {
                DataSet dataSet = (DataSet) this.dataSetIterator.next();
                int size = dataSet.getFeatureMatrix().size(0);
                d += computationGraph.score((org.nd4j.linalg.dataset.api.DataSet) dataSet) * size;
                i += size;
            }
        } else {
            this.multiDataSetIterator.reset();
            while (this.multiDataSetIterator.hasNext()) {
                MultiDataSet multiDataSet = (MultiDataSet) this.multiDataSetIterator.next();
                int size2 = multiDataSet.getFeatures(0).size(0);
                d += computationGraph.score(multiDataSet) * size2;
                i += size2;
            }
        }
        return this.average ? d / i : d;
    }

    public String toString() {
        return "DataSetLossCalculatorCG(" + this.dataSetIterator + ",average=" + this.average + ")";
    }

    public DataSetLossCalculatorCG() {
    }
}
