package org.deeplearning4j.rl4j.learning.async;

import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/AsyncThread.class */
public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet> extends Thread implements StepCountable {
    private static final Logger log = LoggerFactory.getLogger(AsyncThread.class);
    private int threadNumber;
    private IHistoryProcessor historyProcessor;
    private int stepCounter = 0;
    private int epochCounter = 0;
    private int lastMonitor = -10000;

    /* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/AsyncThread$AsyncStatEntry.class */
    public static final class AsyncStatEntry implements DataManager.StatEntry {
        private final int stepCounter;
        private final int epochCounter;
        private final double reward;
        private final int episodeLength;
        private final double score;

        @Override // org.deeplearning4j.rl4j.util.DataManager.StatEntry
        public int getStepCounter() {
            return this.stepCounter;
        }

        @Override // org.deeplearning4j.rl4j.util.DataManager.StatEntry
        public int getEpochCounter() {
            return this.epochCounter;
        }

        @Override // org.deeplearning4j.rl4j.util.DataManager.StatEntry
        public double getReward() {
            return this.reward;
        }

        public int getEpisodeLength() {
            return this.episodeLength;
        }

        public double getScore() {
            return this.score;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof AsyncStatEntry)) {
                return false;
            }
            AsyncStatEntry asyncStatEntry = (AsyncStatEntry) obj;
            return getStepCounter() == asyncStatEntry.getStepCounter() && getEpochCounter() == asyncStatEntry.getEpochCounter() && Double.compare(getReward(), asyncStatEntry.getReward()) == 0 && getEpisodeLength() == asyncStatEntry.getEpisodeLength() && Double.compare(getScore(), asyncStatEntry.getScore()) == 0;
        }

        public int hashCode() {
            int stepCounter = (((1 * 59) + getStepCounter()) * 59) + getEpochCounter();
            long doubleToLongBits = Double.doubleToLongBits(getReward());
            int episodeLength = (((stepCounter * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits))) * 59) + getEpisodeLength();
            long doubleToLongBits2 = Double.doubleToLongBits(getScore());
            return (episodeLength * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
        }

        public String toString() {
            return "AsyncThread.AsyncStatEntry(stepCounter=" + getStepCounter() + ", epochCounter=" + getEpochCounter() + ", reward=" + getReward() + ", episodeLength=" + getEpisodeLength() + ", score=" + getScore() + ")";
        }

        public AsyncStatEntry(int i, int i2, double d, int i3, double d2) {
            this.stepCounter = i;
            this.epochCounter = i2;
            this.reward = d;
            this.episodeLength = i3;
            this.score = d2;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/AsyncThread$SubEpochReturn.class */
    public static final class SubEpochReturn<O> {
        private final int steps;
        private final O lastObs;
        private final double reward;
        private final double score;

        public int getSteps() {
            return this.steps;
        }

        public O getLastObs() {
            return this.lastObs;
        }

        public double getReward() {
            return this.reward;
        }

        public double getScore() {
            return this.score;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof SubEpochReturn)) {
                return false;
            }
            SubEpochReturn subEpochReturn = (SubEpochReturn) obj;
            if (getSteps() != subEpochReturn.getSteps()) {
                return false;
            }
            O lastObs = getLastObs();
            Object lastObs2 = subEpochReturn.getLastObs();
            if (lastObs == null) {
                if (lastObs2 != null) {
                    return false;
                }
            } else if (!lastObs.equals(lastObs2)) {
                return false;
            }
            return Double.compare(getReward(), subEpochReturn.getReward()) == 0 && Double.compare(getScore(), subEpochReturn.getScore()) == 0;
        }

        public int hashCode() {
            int steps = (1 * 59) + getSteps();
            O lastObs = getLastObs();
            int hashCode = (steps * 59) + (lastObs == null ? 43 : lastObs.hashCode());
            long doubleToLongBits = Double.doubleToLongBits(getReward());
            int i = (hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            long doubleToLongBits2 = Double.doubleToLongBits(getScore());
            return (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
        }

        public String toString() {
            return "AsyncThread.SubEpochReturn(steps=" + getSteps() + ", lastObs=" + getLastObs() + ", reward=" + getReward() + ", score=" + getScore() + ")";
        }

        public SubEpochReturn(int i, O o, double d, double d2) {
            this.steps = i;
            this.lastObs = o;
            this.reward = d;
            this.score = d2;
        }
    }

    public AsyncThread(AsyncGlobal<NN> asyncGlobal, int i) {
        this.threadNumber = i;
    }

    public void setHistoryProcessor(IHistoryProcessor.Configuration configuration) {
        this.historyProcessor = new HistoryProcessor(configuration);
    }

    public void setHistoryProcessor(IHistoryProcessor iHistoryProcessor) {
        this.historyProcessor = iHistoryProcessor;
    }

    protected void postEpoch() {
        if (getHistoryProcessor() != null) {
            getHistoryProcessor().stopMonitor();
        }
    }

    protected void preEpoch() {
        if (getStepCounter() - this.lastMonitor < 10000 || getHistoryProcessor() == null || !getDataManager().isSaveData()) {
            return;
        }
        this.lastMonitor = getStepCounter();
        getHistoryProcessor().startMonitor(getDataManager().getVideoDir() + "/video-" + this.threadNumber + "-" + getEpochCounter() + "-" + getStepCounter() + ".mp4", getMdp().getObservationSpace().getShape());
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v15, types: [org.deeplearning4j.rl4j.space.Encodable] */
    /* JADX WARN: Type inference failed for: r0v57, types: [org.deeplearning4j.rl4j.space.Encodable] */
    @Override // java.lang.Thread, java.lang.Runnable
    public void run() {
        try {
            try {
                log.info("ThreadNum-" + this.threadNumber + " Started!");
                getCurrent().reset();
                Learning.InitMdp initMdp = Learning.initMdp(getMdp(), this.historyProcessor);
                O o = (Encodable) initMdp.getLastObs();
                double reward = initMdp.getReward();
                int steps = initMdp.getSteps();
                preEpoch();
                while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
                    SubEpochReturn<O> trainSubEpoch = trainSubEpoch(o, Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - steps));
                    o = trainSubEpoch.getLastObs();
                    this.stepCounter += trainSubEpoch.getSteps();
                    steps += trainSubEpoch.getSteps();
                    reward += trainSubEpoch.getReward();
                    double score = trainSubEpoch.getScore();
                    if (steps >= getConf().getMaxEpochStep() || getMdp().isDone()) {
                        postEpoch();
                        AsyncStatEntry asyncStatEntry = new AsyncStatEntry(getStepCounter(), this.epochCounter, reward, steps, score);
                        getDataManager().appendStat(asyncStatEntry);
                        log.info("ThreadNum-" + this.threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + asyncStatEntry.getReward());
                        getCurrent().reset();
                        Learning.InitMdp initMdp2 = Learning.initMdp(getMdp(), this.historyProcessor);
                        o = (Encodable) initMdp2.getLastObs();
                        reward = initMdp2.getReward();
                        steps = initMdp2.getSteps();
                        this.epochCounter++;
                        preEpoch();
                    }
                }
            } catch (Exception e) {
                log.error("Thread crashed: " + e.getCause());
                getAsyncGlobal().setRunning(false);
                e.printStackTrace();
                postEpoch();
            }
        } finally {
            postEpoch();
        }
    }

    protected abstract NN getCurrent();

    protected abstract int getThreadNumber();

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract AsyncGlobal<NN> getAsyncGlobal();

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract MDP<O, A, AS> getMdp();

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract AsyncConfiguration getConf();

    protected abstract DataManager getDataManager();

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract Policy<O, A> getPolicy(NN nn);

    protected abstract SubEpochReturn<O> trainSubEpoch(O o, int i);

    @Override // org.deeplearning4j.rl4j.learning.StepCountable
    public int getStepCounter() {
        return this.stepCounter;
    }

    public void setStepCounter(int i) {
        this.stepCounter = i;
    }

    public int getEpochCounter() {
        return this.epochCounter;
    }

    public void setEpochCounter(int i) {
        this.epochCounter = i;
    }

    public IHistoryProcessor getHistoryProcessor() {
        return this.historyProcessor;
    }

    public int getLastMonitor() {
        return this.lastMonitor;
    }
}
