package org.deeplearning4j.parallelism.factory;

import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.parallelism.trainer.SymmetricTrainer;
import org.deeplearning4j.parallelism.trainer.Trainer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/parallelism/factory/SymmetricTrainerContext.class */
public class SymmetricTrainerContext implements TrainerContext {
    private static final Logger log = LoggerFactory.getLogger(SymmetricTrainerContext.class);

    @Override // org.deeplearning4j.parallelism.factory.TrainerContext
    public void init(Model model, Object... objArr) {
    }

    @Override // org.deeplearning4j.parallelism.factory.TrainerContext
    public Trainer create(String str, int i, Model model, int i2, boolean z, ParallelWrapper parallelWrapper, WorkspaceMode workspaceMode, int i3) {
        SymmetricTrainer symmetricTrainer = new SymmetricTrainer(model, str, i, workspaceMode, parallelWrapper, z);
        symmetricTrainer.setName("SymmetricTrainer thread " + i);
        symmetricTrainer.setDaemon(true);
        return symmetricTrainer;
    }

    @Override // org.deeplearning4j.parallelism.factory.TrainerContext
    public void finalizeRound(Model model, Model... modelArr) {
    }

    @Override // org.deeplearning4j.parallelism.factory.TrainerContext
    public void finalizeTraining(Model model, Model... modelArr) {
        model.setParams(modelArr[0].params());
    }
}
