package ai.djl.training;

import ai.djl.ndarray.NDList;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.listener.TrainingListener;
import ai.djl.translate.TranslateException;
import ai.djl.util.Preconditions;
import java.io.IOException;
import java.util.ArrayList;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;

/* loaded from: input_file:ai/djl/training/EasyTrain.class */
public final class EasyTrain {
    private EasyTrain() {
    }

    public static void fit(Trainer trainer, int i, Dataset dataset, Dataset dataset2) throws IOException, TranslateException {
        for (int i2 = 0; i2 < i; i2++) {
            for (Batch batch : trainer.iterateDataset(dataset)) {
                trainBatch(trainer, batch);
                trainer.step();
                batch.close();
            }
            evaluateDataset(trainer, dataset2);
            trainer.notifyListeners(trainingListener -> {
                trainingListener.onEpoch(trainer);
            });
        }
    }

    public static void trainBatch(Trainer trainer, Batch batch) {
        if (trainer.getManager().getEngine() != batch.getManager().getEngine()) {
            throw new IllegalArgumentException("The data must be on the same engine as the trainer. You may need to change one of your NDManagers.");
        }
        Batch[] split = batch.split(trainer.getDevices(), false);
        TrainingListener.BatchData batchData = new TrainingListener.BatchData(batch, new ConcurrentHashMap(), new ConcurrentHashMap());
        GradientCollector newGradientCollector = trainer.newGradientCollector();
        try {
            if (split.length <= 1 || !trainer.getExecutorService().isPresent()) {
                for (Batch batch2 : split) {
                    trainSplit(trainer, newGradientCollector, batchData, batch2);
                }
            } else {
                ExecutorService executorService = trainer.getExecutorService().get();
                ArrayList arrayList = new ArrayList(split.length);
                for (Batch batch3 : split) {
                    arrayList.add(CompletableFuture.supplyAsync(() -> {
                        return Boolean.valueOf(trainSplit(trainer, newGradientCollector, batchData, batch3));
                    }, executorService));
                }
                CompletableFuture.allOf((CompletableFuture[]) arrayList.stream().toArray(i -> {
                    return new CompletableFuture[i];
                }));
            }
            if (newGradientCollector != null) {
                newGradientCollector.close();
            }
            trainer.notifyListeners(trainingListener -> {
                trainingListener.onTrainingBatch(trainer, batchData);
            });
        } catch (Throwable th) {
            if (newGradientCollector != null) {
                try {
                    newGradientCollector.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private static boolean trainSplit(Trainer trainer, GradientCollector gradientCollector, TrainingListener.BatchData batchData, Batch batch) {
        NDList data = batch.getData();
        NDList labels = batch.getLabels();
        NDList forward = trainer.forward(data, labels);
        long nanoTime = System.nanoTime();
        gradientCollector.backward(trainer.getLoss().evaluate(labels, forward));
        trainer.addMetric("backward", nanoTime);
        long nanoTime2 = System.nanoTime();
        batchData.getLabels().put(labels.get(0).getDevice(), labels);
        batchData.getPredictions().put(forward.get(0).getDevice(), forward);
        trainer.addMetric("training-metrics", nanoTime2);
        return true;
    }

    public static void validateBatch(Trainer trainer, Batch batch) {
        Preconditions.checkArgument(trainer.getManager().getEngine() == batch.getManager().getEngine(), "The data must be on the same engine as the trainer. You may need to change one of your NDManagers.");
        Batch[] split = batch.split(trainer.getDevices(), false);
        TrainingListener.BatchData batchData = new TrainingListener.BatchData(batch, new ConcurrentHashMap(), new ConcurrentHashMap());
        if (split.length <= 1 || !trainer.getExecutorService().isPresent()) {
            for (Batch batch2 : split) {
                validateSplit(trainer, batchData, batch2);
            }
        } else {
            ExecutorService executorService = trainer.getExecutorService().get();
            ArrayList arrayList = new ArrayList(split.length);
            for (Batch batch3 : split) {
                arrayList.add(CompletableFuture.supplyAsync(() -> {
                    return Boolean.valueOf(validateSplit(trainer, batchData, batch3));
                }, executorService));
            }
            CompletableFuture.allOf((CompletableFuture[]) arrayList.stream().toArray(i -> {
                return new CompletableFuture[i];
            }));
        }
        trainer.notifyListeners(trainingListener -> {
            trainingListener.onValidationBatch(trainer, batchData);
        });
    }

    private static boolean validateSplit(Trainer trainer, TrainingListener.BatchData batchData, Batch batch) {
        NDList data = batch.getData();
        NDList labels = batch.getLabels();
        NDList evaluate = trainer.evaluate(data);
        batchData.getLabels().put(labels.get(0).getDevice(), labels);
        batchData.getPredictions().put(evaluate.get(0).getDevice(), evaluate);
        return true;
    }

    public static void evaluateDataset(Trainer trainer, Dataset dataset) throws IOException, TranslateException {
        if (dataset != null) {
            for (Batch batch : trainer.iterateDataset(dataset)) {
                validateBatch(trainer, batch);
                batch.close();
            }
        }
    }
}
