package ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling;

import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.api4.java.ai.ml.core.dataset.IDataset;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/core/filter/sampling/inmemory/stratified/sampling/KMeansStratifier.class */
public class KMeansStratifier extends ClusterStratiAssigner {
    private Logger logger = LoggerFactory.getLogger(KMeansStratifier.class);
    private final int numberOfStrati;

    public KMeansStratifier(int i, DistanceMeasure distanceMeasure, int i2) {
        this.numberOfStrati = i;
        this.randomSeed = i2;
        this.distanceMeasure = distanceMeasure;
    }

    @Override // ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.IStratifier
    public int createStrati(IDataset<?> iDataset) {
        setDataset(iDataset);
        JDKRandomGenerator jDKRandomGenerator = new JDKRandomGenerator();
        jDKRandomGenerator.setSeed(this.randomSeed);
        KMeansPlusPlusClusterer kMeansPlusPlusClusterer = new KMeansPlusPlusClusterer(this.numberOfStrati, -1, this.distanceMeasure, jDKRandomGenerator);
        this.logger.info("Clustering dataset with {} instances.", Integer.valueOf(iDataset.size()));
        setClusters(kMeansPlusPlusClusterer.cluster(iDataset));
        this.logger.info("Finished clustering");
        return this.numberOfStrati;
    }
}
