package org.deeplearning4j.arbiter.optimize.runner;

import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.TaskCreator;
import org.deeplearning4j.arbiter.optimize.api.TaskCreatorProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;

/* loaded from: input_file:org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.class */
public class LocalOptimizationRunner extends BaseOptimizationRunner {
    public static final int DEFAULT_MAX_CONCURRENT_TASKS = 1;
    private final int maxConcurrentTasks;
    private TaskCreator taskCreator;
    private ListeningExecutorService executor;
    private long shutdownMaxWaitMS;

    public LocalOptimizationRunner(OptimizationConfiguration optimizationConfiguration) {
        this(optimizationConfiguration, (TaskCreator) null);
    }

    public LocalOptimizationRunner(OptimizationConfiguration optimizationConfiguration, TaskCreator taskCreator) {
        this(1, optimizationConfiguration, taskCreator);
    }

    public LocalOptimizationRunner(int i, OptimizationConfiguration optimizationConfiguration) {
        this(i, optimizationConfiguration, null);
    }

    public LocalOptimizationRunner(int i, OptimizationConfiguration optimizationConfiguration, TaskCreator taskCreator) {
        super(optimizationConfiguration);
        this.shutdownMaxWaitMS = 172800000L;
        if (i <= 0) {
            throw new IllegalArgumentException("maxConcurrentTasks must be > 0 (got: " + i + ")");
        }
        this.maxConcurrentTasks = i;
        if (taskCreator == null) {
            Class<?> cls = optimizationConfiguration.getCandidateGenerator().getParameterSpace().getClass();
            taskCreator = TaskCreatorProvider.defaultTaskCreatorFor(cls);
            if (taskCreator == null) {
                throw new IllegalStateException("No TaskCreator was provided and a default TaskCreator cannot be inferred for ParameterSpace class " + cls.getName() + ". Please provide a TaskCreator via the LocalOptimizationRunner constructor");
            }
        }
        this.taskCreator = taskCreator;
        this.executor = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(i, new ThreadFactory() { // from class: org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner.1
            private AtomicLong counter = new AtomicLong(0);

            @Override // java.util.concurrent.ThreadFactory
            public Thread newThread(Runnable runnable) {
                Thread newThread = Executors.defaultThreadFactory().newThread(runnable);
                newThread.setDaemon(true);
                newThread.setName("LocalCandidateExecutor-" + this.counter.getAndIncrement());
                return newThread;
            }
        }));
        init();
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.BaseOptimizationRunner
    protected int maxConcurrentTasks() {
        return this.maxConcurrentTasks;
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.BaseOptimizationRunner
    protected ListenableFuture<OptimizationResult> execute(Candidate candidate, DataProvider dataProvider, ScoreFunction scoreFunction) {
        return execute(Collections.singletonList(candidate), dataProvider, scoreFunction).get(0);
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.BaseOptimizationRunner
    protected List<ListenableFuture<OptimizationResult>> execute(List<Candidate> list, DataProvider dataProvider, ScoreFunction scoreFunction) {
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<Candidate> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(this.executor.submit(this.taskCreator.create(it.next(), dataProvider, scoreFunction, this.statusListeners, this)));
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.BaseOptimizationRunner
    protected ListenableFuture<OptimizationResult> execute(Candidate candidate, Class<? extends DataSource> cls, Properties properties, ScoreFunction scoreFunction) {
        return execute(Collections.singletonList(candidate), cls, properties, scoreFunction).get(0);
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.BaseOptimizationRunner
    protected List<ListenableFuture<OptimizationResult>> execute(List<Candidate> list, Class<? extends DataSource> cls, Properties properties, ScoreFunction scoreFunction) {
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<Candidate> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(this.executor.submit(this.taskCreator.create(it.next(), cls, properties, scoreFunction, this.statusListeners, this)));
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public void shutdown(boolean z) {
        if (!z) {
            this.executor.shutdownNow();
            return;
        }
        try {
            this.executor.shutdown();
            this.executor.awaitTermination(this.shutdownMaxWaitMS, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    public void setShutdownMaxWaitMS(long j) {
        this.shutdownMaxWaitMS = j;
    }
}
