package ai.libs.jaicore.ml.hpo.multifidelity.hyperband;

import ai.libs.jaicore.basic.MathExt;
import ai.libs.jaicore.basic.algorithm.AOptimizer;
import ai.libs.jaicore.basic.algorithm.EAlgorithmState;
import ai.libs.jaicore.components.api.IEvaluatedSoftwareConfigurationSolution;
import ai.libs.jaicore.components.model.ComponentInstance;
import ai.libs.jaicore.components.model.ComponentInstanceUtil;
import ai.libs.jaicore.ml.core.dataset.DatasetUtil;
import ai.libs.jaicore.ml.hpo.multifidelity.MultiFidelitySoftwareConfigurationProblem;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Semaphore;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.api4.java.algorithm.events.IAlgorithmEvent;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.api4.java.common.attributedobjects.ObjectEvaluationFailedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/hpo/multifidelity/hyperband/Hyperband.class */
public class Hyperband extends AOptimizer<MultiFidelitySoftwareConfigurationProblem<Double>, HyperbandSolutionCandidate, MultiFidelityScore> {
    private static final Logger LOGGER = LoggerFactory.getLogger(Hyperband.class);
    private double eta;
    private double rMax;
    private double crashedEvaluationScore;
    private double b;
    private int sMax;
    private Random rand;
    private ExecutorService pool;

    /* renamed from: ai.libs.jaicore.ml.hpo.multifidelity.hyperband.Hyperband$1, reason: invalid class name */
    /* loaded from: input_file:ai/libs/jaicore/ml/hpo/multifidelity/hyperband/Hyperband$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState = new int[EAlgorithmState.values().length];

        static {
            try {
                $SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState[EAlgorithmState.CREATED.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState[EAlgorithmState.INACTIVE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState[EAlgorithmState.ACTIVE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:ai/libs/jaicore/ml/hpo/multifidelity/hyperband/Hyperband$HyperbandSolutionCandidate.class */
    public class HyperbandSolutionCandidate implements IEvaluatedSoftwareConfigurationSolution<MultiFidelityScore> {
        private ComponentInstance ci;
        private MultiFidelityScore score;

        public HyperbandSolutionCandidate(ComponentInstance componentInstance, double d, double d2) {
            this.ci = componentInstance;
            this.score = new MultiFidelityScore(d, d2);
        }

        /* renamed from: getScore, reason: merged with bridge method [inline-methods] */
        public MultiFidelityScore m110getScore() {
            return this.score;
        }

        /* renamed from: getComponentInstance, reason: merged with bridge method [inline-methods] */
        public ComponentInstance m109getComponentInstance() {
            return this.ci;
        }

        public String toString() {
            return "c:" + this.score;
        }
    }

    /* loaded from: input_file:ai/libs/jaicore/ml/hpo/multifidelity/hyperband/Hyperband$MultiFidelityScore.class */
    public class MultiFidelityScore implements Comparable<MultiFidelityScore> {
        private final double r;
        private final double score;

        public MultiFidelityScore(double d, double d2) {
            this.r = d;
            this.score = d2;
        }

        @Override // java.lang.Comparable
        public int compareTo(MultiFidelityScore multiFidelityScore) {
            int compare = Double.compare(multiFidelityScore.r, this.r);
            return compare != 0 ? compare : Double.compare(this.score, multiFidelityScore.score);
        }

        public int hashCode() {
            int hashCode = (31 * 1) + getEnclosingInstance().hashCode();
            long doubleToLongBits = Double.doubleToLongBits(this.r);
            int i = (31 * hashCode) + ((int) (doubleToLongBits ^ (doubleToLongBits >>> 32)));
            long doubleToLongBits2 = Double.doubleToLongBits(this.score);
            return (31 * i) + ((int) (doubleToLongBits2 ^ (doubleToLongBits2 >>> 32)));
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof MultiFidelityScore)) {
                return false;
            }
            MultiFidelityScore multiFidelityScore = (MultiFidelityScore) obj;
            return Math.abs(this.r - multiFidelityScore.r) < 1.0E-8d && Math.abs(this.score - multiFidelityScore.score) < 1.0E-8d;
        }

        public String toString() {
            return "(" + this.r + ";" + this.score + ")";
        }

        private Hyperband getEnclosingInstance() {
            return Hyperband.this;
        }
    }

    public Hyperband(IHyperbandConfig iHyperbandConfig, MultiFidelitySoftwareConfigurationProblem<Double> multiFidelitySoftwareConfigurationProblem) {
        super(iHyperbandConfig, multiFidelitySoftwareConfigurationProblem);
        this.pool = null;
        this.rand = new Random(iHyperbandConfig.getSeed());
    }

    public IAlgorithmEvent nextWithException() throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException {
        switch (AnonymousClass1.$SwitchMap$ai$libs$jaicore$basic$algorithm$EAlgorithmState[getState().ordinal()]) {
            case DatasetUtil.EXPANSION_SQUARES /* 1 */:
                this.eta = m107getConfig().getEta();
                this.rMax = ((MultiFidelitySoftwareConfigurationProblem) getInput()).m104getCompositionEvaluator().getMaxBudget();
                this.crashedEvaluationScore = m107getConfig().getCrashScore();
                if (m107getConfig().getIterations().equals("auto")) {
                    this.sMax = (int) Math.floor(MathExt.logBase(this.rMax, this.eta));
                } else {
                    this.sMax = Integer.parseInt(m107getConfig().getIterations());
                }
                this.b = (this.sMax + 1) * this.rMax;
                if (m107getConfig().cpus() > 1) {
                    this.pool = Executors.newFixedThreadPool(m107getConfig().cpus());
                }
                LOGGER.info("Initialized HyperBand with eta={}, r_max={}, s_max={}, b={} and parallelizing with {} cpu cores.", new Object[]{Double.valueOf(this.eta), Double.valueOf(this.rMax), Integer.valueOf(this.sMax), Double.valueOf(this.b), Integer.valueOf(m107getConfig().cpus())});
                return super.activate();
            case DatasetUtil.EXPANSION_LOGARITHM /* 2 */:
                throw new AlgorithmException("Algorithm has already finished.");
            case DatasetUtil.EXPANSION_PRODUCTS /* 3 */:
            default:
                for (int i = this.sMax; i >= 0; i--) {
                    int ceil = (int) Math.ceil((this.b / this.rMax) * (Math.pow(this.eta, i) / (i + 1)));
                    double pow = this.rMax * Math.pow(this.eta, -i);
                    LOGGER.info("Execute round {} of HyperBand with n={}, r={}", new Object[]{Integer.valueOf((this.sMax - i) + 1), Integer.valueOf(ceil), Double.valueOf(pow)});
                    List<ComponentInstance> nCandidates = getNCandidates(ceil);
                    for (int i2 = 0; i2 <= i; i2++) {
                        int floor = (int) Math.floor(ceil / Math.pow(this.eta, i2));
                        double pow2 = pow * Math.pow(this.eta, i2);
                        System.out.println(floor + " " + pow2);
                        List<HyperbandSolutionCandidate> evaluate = evaluate(nCandidates, pow2);
                        evaluate.sort((hyperbandSolutionCandidate, hyperbandSolutionCandidate2) -> {
                            return hyperbandSolutionCandidate.m110getScore().compareTo(hyperbandSolutionCandidate2.m110getScore());
                        });
                        updateBestSeenSolution(evaluate.get(0));
                        nCandidates.clear();
                        Stream mapToObj = IntStream.range(0, (int) Math.floor(floor / this.eta)).mapToObj(i3 -> {
                            return ((HyperbandSolutionCandidate) evaluate.get(i3)).m109getComponentInstance();
                        });
                        Objects.requireNonNull(nCandidates);
                        mapToObj.forEach((v1) -> {
                            r1.add(v1);
                        });
                    }
                }
                if (this.pool != null) {
                    this.pool.shutdownNow();
                }
                return super.terminate();
        }
    }

    private List<HyperbandSolutionCandidate> evaluate(List<ComponentInstance> list, double d) throws InterruptedException {
        ReentrantLock reentrantLock = new ReentrantLock();
        ArrayList arrayList = new ArrayList(list.size());
        Semaphore semaphore = new Semaphore(0);
        ArrayList arrayList2 = new ArrayList(list.size());
        for (ComponentInstance componentInstance : list) {
            arrayList2.add(() -> {
                double d2;
                try {
                    d2 = ((Double) ((MultiFidelitySoftwareConfigurationProblem) getInput()).m104getCompositionEvaluator().evaluate(componentInstance, d)).doubleValue();
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    d2 = this.crashedEvaluationScore;
                } catch (ObjectEvaluationFailedException e2) {
                    d2 = this.crashedEvaluationScore;
                }
                reentrantLock.lock();
                try {
                    arrayList.add(new HyperbandSolutionCandidate(componentInstance, d, d2));
                    reentrantLock.unlock();
                    semaphore.release();
                } catch (Throwable th) {
                    reentrantLock.unlock();
                    semaphore.release();
                    throw th;
                }
            });
        }
        if (this.pool != null) {
            Stream stream = arrayList2.stream();
            ExecutorService executorService = this.pool;
            Objects.requireNonNull(executorService);
            stream.forEach(executorService::submit);
            semaphore.acquire(list.size());
        } else {
            arrayList2.stream().forEach((v0) -> {
                v0.run();
            });
        }
        return arrayList;
    }

    private List<ComponentInstance> getNCandidates(int i) {
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(ComponentInstanceUtil.sampleRandomComponentInstance(((MultiFidelitySoftwareConfigurationProblem) getInput()).getRequiredInterface(), ((MultiFidelitySoftwareConfigurationProblem) getInput()).getComponents(), this.rand));
        }
        return arrayList;
    }

    /* renamed from: getConfig, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public IHyperbandConfig m107getConfig() {
        return (IHyperbandConfig) super.getConfig();
    }
}
