package ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling;

import ai.libs.jaicore.basic.algorithm.EAlgorithmState;
import ai.libs.jaicore.ml.core.dataset.DatasetDeriver;
import ai.libs.jaicore.ml.core.dataset.DatasetUtil;
import ai.libs.jaicore.ml.core.filter.sampling.SampleElementAddedEvent;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.ASamplingAlgorithm;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.SimpleRandomSampling;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.WaitForSamplingStepEvent;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import org.api4.java.ai.ml.core.dataset.IDataset;
import org.api4.java.ai.ml.core.dataset.IInstance;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.api4.java.algorithm.events.IAlgorithmEvent;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/core/filter/sampling/inmemory/stratified/sampling/StratifiedSampling.class */
public class StratifiedSampling<D extends IDataset<?>> extends ASamplingAlgorithm<D> {
    private Logger logger;
    private IStratifier stratificationTechnique;
    private Random random;
    private DatasetDeriver<D>[] stratiBuilder;
    private boolean allDatapointsAssigned;
    private boolean simpleRandomSamplingStarted;

    /* renamed from: ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.StratifiedSampling$1, reason: invalid class name */
    /* loaded from: input_file:ai/libs/jaicore/ml/core/filter/sampling/inmemory/stratified/sampling/StratifiedSampling$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState = new int[EAlgorithmState.values().length];

        static {
            try {
                $SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState[EAlgorithmState.CREATED.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState[EAlgorithmState.ACTIVE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState[EAlgorithmState.INACTIVE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    public StratifiedSampling(IStratifier iStratifier, Random random, D d) {
        super(d);
        this.logger = LoggerFactory.getLogger(StratifiedSampling.class);
        this.stratiBuilder = null;
        this.allDatapointsAssigned = false;
        this.stratificationTechnique = iStratifier;
        this.random = random;
    }

    public IAlgorithmEvent nextWithException() throws InterruptedException, AlgorithmException, AlgorithmTimeoutedException, AlgorithmExecutionCanceledException {
        switch (AnonymousClass1.$SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState[getState().ordinal()]) {
            case DatasetUtil.EXPANSION_SQUARES /* 1 */:
                if (!this.allDatapointsAssigned) {
                    int hashCode = ((IDataset) getInput()).hashCode();
                    this.stratificationTechnique.setNumCPUs(getNumCPUs());
                    this.stratiBuilder = (DatasetDeriver[]) Array.newInstance((Class<?>) DatasetDeriver.class, this.stratificationTechnique.createStrati((IDataset) getInput()));
                    if (this.stratiBuilder.length == 0) {
                        throw new IllegalStateException("Stratification technique has not created any stratum.");
                    }
                    for (int i = 0; i < this.stratiBuilder.length; i++) {
                        this.stratiBuilder[i] = new DatasetDeriver<>((IDataset) getInput());
                    }
                    if (((IDataset) getInput()).hashCode() != hashCode) {
                        throw new IllegalStateException("Original dataset has been modified!");
                    }
                }
                this.simpleRandomSamplingStarted = false;
                this.logger.info("Stratified sampler initialized.");
                return activate();
            case DatasetUtil.EXPANSION_LOGARITHM /* 2 */:
                if (this.allDatapointsAssigned) {
                    if (this.simpleRandomSamplingStarted) {
                        this.logger.info("Stratified sampling completed.");
                        return terminate();
                    }
                    try {
                        startSimpleRandomSamplingForStrati();
                        this.simpleRandomSamplingStarted = true;
                        return new WaitForSamplingStepEvent(this);
                    } catch (DatasetCreationException e) {
                        throw new AlgorithmException("Could not create sample from strati.", e);
                    }
                }
                this.logger.info("Starting to sort all datapoints into their strati.");
                IDataset iDataset = (IDataset) getInput();
                int size = iDataset.size();
                for (int i2 = 0; i2 < size; i2++) {
                    IInstance iInstance = (IInstance) iDataset.get(i2);
                    if (i2 % 100 == 0) {
                        checkAndConductTermination();
                    }
                    this.logger.debug("Computing stratum for next data point {}", iInstance);
                    int stratum = this.stratificationTechnique.getStratum(iInstance);
                    if (stratum < 0 || stratum >= this.stratiBuilder.length) {
                        throw new AlgorithmException("No existing strati for index " + stratum);
                    }
                    this.stratiBuilder[stratum].add(i2);
                    this.logger.debug("Added data point {} to stratum {}. {} datapoints remaining.", new Object[]{iInstance, Integer.valueOf(stratum), Integer.valueOf((size - i2) - 1)});
                }
                this.allDatapointsAssigned = true;
                int i3 = 0;
                for (DatasetDeriver<D> datasetDeriver : this.stratiBuilder) {
                    this.logger.debug("Elements in stratum: {}", Integer.valueOf(datasetDeriver.currentSizeOfTarget()));
                    i3 += datasetDeriver.currentSizeOfTarget();
                }
                this.logger.info("Finished stratum assignments. Assigned {} data points in total.", Integer.valueOf(i3));
                if (i3 != ((IDataset) getInput()).size()) {
                    throw new IllegalStateException("Not all data have been collected.");
                }
                return new SampleElementAddedEvent(this);
            case DatasetUtil.EXPANSION_PRODUCTS /* 3 */:
                if (this.sample.size() < this.sampleSize) {
                    throw new AlgorithmException("Expected sample size was not reached before termination");
                }
                return terminate();
            default:
                throw new IllegalStateException("Unknown algorithm state " + getState());
        }
    }

    private void startSimpleRandomSamplingForStrati() throws InterruptedException, DatasetCreationException, AlgorithmTimeoutedException, AlgorithmExecutionCanceledException {
        if (this.sampleSize == -1) {
            throw new IllegalStateException("No valid sample size specified");
        }
        this.logger.info("Now drawing simple random elements in each stratum.");
        int[] iArr = new int[this.stratiBuilder.length];
        int i = 0;
        ArrayList arrayList = new ArrayList();
        double size = ((IDataset) getInput()).size();
        for (int i2 = 0; i2 < this.stratiBuilder.length; i2++) {
            if (this.stratiBuilder[i2].currentSizeOfTarget() < 0) {
                throw new IllegalStateException("Builder for stratum " + i2 + " has a negative current target size: " + this.stratiBuilder[i2].currentSizeOfTarget());
            }
            iArr[i2] = (int) Math.floor(this.stratiBuilder[i2].currentSizeOfTarget() * (this.sampleSize / size));
            if (iArr[i2] < 0) {
                throw new IllegalStateException("Determined negative stratum size " + iArr[i2] + " for " + i2 + "-th stratum.");
            }
            i += iArr[i2];
            arrayList.add(Integer.valueOf(i2));
        }
        while (i < this.sampleSize) {
            Collections.shuffle(arrayList, this.random);
            int intValue = ((Integer) arrayList.remove(0)).intValue();
            iArr[intValue] = iArr[intValue] + 1;
            i++;
        }
        if (i != this.sampleSize) {
            throw new IllegalStateException("Number of samples is " + i + " where it should be " + this.sampleSize);
        }
        int i3 = 0;
        for (int i4 = 0; i4 < this.stratiBuilder.length; i4++) {
            i3 += iArr[i4];
        }
        if (i3 != this.sampleSize) {
            throw new IllegalStateException("The total number of samples assigned within the strati is " + i3 + ", but it should be " + this.sampleSize + ".");
        }
        DatasetDeriver datasetDeriver = new DatasetDeriver((IDataset) getInput());
        for (int i5 = 0; i5 < this.stratiBuilder.length; i5++) {
            DatasetDeriver<D> datasetDeriver2 = this.stratiBuilder[i5];
            D build = datasetDeriver2.build();
            if (build.isEmpty()) {
                this.logger.warn("{}-th stratum is empty!", Integer.valueOf(i5));
            } else if (iArr[i5] == 0) {
                this.logger.warn("No samples for stratum {}", Integer.valueOf(i5));
            } else if (iArr[i5] == build.size()) {
                datasetDeriver.addIndices(datasetDeriver2.getIndicesOfNewInstancesInOriginalDataset());
            } else {
                checkAndConductTermination();
                SimpleRandomSampling simpleRandomSampling = new SimpleRandomSampling(this.random, build);
                simpleRandomSampling.setSampleSize(iArr[i5]);
                this.logger.info("Setting sample size for {}-th stratus to {}", Integer.valueOf(i5), Integer.valueOf(iArr[i5]));
                try {
                    this.logger.debug("Calling SimpleRandomSampling");
                    simpleRandomSampling.m92call();
                    this.logger.debug("SimpleRandomSampling finished");
                } catch (InterruptedException e) {
                    throw e;
                } catch (Exception e2) {
                    this.logger.error("Unexpected exception during simple random sampling!", e2);
                }
                if (simpleRandomSampling.getChosenIndices().size() != iArr[i5]) {
                    throw new IllegalStateException("Number of samples drawn for stratum " + i5 + " is " + simpleRandomSampling.getChosenIndices().size() + ", but it should be " + iArr[i5]);
                }
                datasetDeriver.addIndices(datasetDeriver2.getIndicesOfNewInstancesInOriginalDataset(simpleRandomSampling.getChosenIndices()));
            }
        }
        if (datasetDeriver.currentSizeOfTarget() != this.sampleSize) {
            throw new IllegalStateException("The deriver says that the target has " + datasetDeriver.currentSizeOfTarget() + " elements, but it should have been configured for " + this.sampleSize);
        }
        checkAndConductTermination();
        this.logger.info("Strati sub-samples completed, building the final sample and shuffling it.");
        this.sample = (D) datasetDeriver.build();
        if (this.sample.size() != i) {
            throw new IllegalStateException("The sample deriver has produced a sample with " + this.sample.size() + " elements while it should have " + i);
        }
        Collections.shuffle(this.sample, this.random);
        this.logger.info("Overall stratified shuffled sample completed.");
    }

    @Override // ai.libs.jaicore.ml.core.filter.sampling.inmemory.ASamplingAlgorithm
    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
        if (this.stratificationTechnique instanceof ILoggingCustomizable) {
            this.stratificationTechnique.setLoggerName(str + ".stratifier");
        }
    }

    @Override // ai.libs.jaicore.ml.core.filter.sampling.inmemory.ASamplingAlgorithm
    public String getLoggerName() {
        return this.logger.getName();
    }
}
