package org.deeplearning4j.util;

import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.MultiLayerUpdater;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.schedule.ISchedule;

/* loaded from: input_file:org/deeplearning4j/util/NetworkUtils.class */
public class NetworkUtils {
    private NetworkUtils() {
    }

    public static ComputationGraph toComputationGraph(MultiLayerNetwork multiLayerNetwork) {
        ComputationGraphConfiguration.GraphBuilder graphBuilder = new NeuralNetConfiguration.Builder().graphBuilder();
        MultiLayerConfiguration m30clone = multiLayerNetwork.getLayerWiseConfigurations().m30clone();
        int i = 0;
        String str = "in";
        graphBuilder.addInputs("in");
        for (NeuralNetConfiguration neuralNetConfiguration : m30clone.getConfs()) {
            String valueOf = String.valueOf(i);
            graphBuilder.addLayer(valueOf, neuralNetConfiguration.getLayer(), m30clone.getInputPreProcess(i), str);
            str = valueOf;
            i++;
        }
        graphBuilder.setOutputs(str);
        ComputationGraph computationGraph = new ComputationGraph(graphBuilder.build());
        computationGraph.init();
        computationGraph.setParams(multiLayerNetwork.params());
        INDArray stateViewArray = multiLayerNetwork.getUpdater().getStateViewArray();
        if (stateViewArray != null) {
            computationGraph.getUpdater().getUpdaterStateViewArray().assign(stateViewArray);
        }
        return computationGraph;
    }

    public static void setLearningRate(MultiLayerNetwork multiLayerNetwork, double d) {
        setLearningRate(multiLayerNetwork, d, (ISchedule) null);
    }

    public static void setLearningRate(MultiLayerNetwork multiLayerNetwork, ISchedule iSchedule) {
        setLearningRate(multiLayerNetwork, Double.NaN, iSchedule);
    }

    private static void setLearningRate(MultiLayerNetwork multiLayerNetwork, double d, ISchedule iSchedule) {
        int i = multiLayerNetwork.getnLayers();
        for (int i2 = 0; i2 < i; i2++) {
            setLearningRate(multiLayerNetwork, i2, d, iSchedule, false);
        }
        refreshUpdater(multiLayerNetwork);
    }

    public static void setLearningRate(MultiLayerNetwork multiLayerNetwork, int i, double d) {
        setLearningRate(multiLayerNetwork, i, d, (ISchedule) null, true);
    }

    public static void setLearningRate(MultiLayerNetwork multiLayerNetwork, int i, ISchedule iSchedule) {
        setLearningRate(multiLayerNetwork, i, Double.NaN, iSchedule, true);
    }

    private static void setLearningRate(MultiLayerNetwork multiLayerNetwork, int i, double d, ISchedule iSchedule, boolean z) {
        Layer layer = multiLayerNetwork.getLayer(i).conf().getLayer();
        if (layer instanceof BaseLayer) {
            IUpdater iUpdater = ((BaseLayer) layer).getIUpdater();
            if (iUpdater != null && iUpdater.hasLearningRate()) {
                if (iSchedule != null) {
                    iUpdater.setLrAndSchedule(Double.NaN, iSchedule);
                } else {
                    iUpdater.setLrAndSchedule(d, (ISchedule) null);
                }
            }
            if (z) {
                refreshUpdater(multiLayerNetwork);
            }
        }
    }

    private static void refreshUpdater(MultiLayerNetwork multiLayerNetwork) {
        INDArray stateViewArray = multiLayerNetwork.getUpdater().getStateViewArray();
        multiLayerNetwork.setUpdater(null);
        ((MultiLayerUpdater) multiLayerNetwork.getUpdater()).setStateViewArray(stateViewArray);
    }

    public static void setLearningRate(ComputationGraph computationGraph, double d) {
        setLearningRate(computationGraph, d, (ISchedule) null);
    }

    public static void setLearningRate(ComputationGraph computationGraph, ISchedule iSchedule) {
        setLearningRate(computationGraph, Double.NaN, iSchedule);
    }

    private static void setLearningRate(ComputationGraph computationGraph, double d, ISchedule iSchedule) {
        for (org.deeplearning4j.nn.api.Layer layer : computationGraph.getLayers()) {
            setLearningRate(computationGraph, layer.conf().getLayer().getLayerName(), d, iSchedule, false);
        }
        refreshUpdater(computationGraph);
    }

    public static void setLearningRate(ComputationGraph computationGraph, String str, double d) {
        setLearningRate(computationGraph, str, d, (ISchedule) null, true);
    }

    public static void setLearningRate(ComputationGraph computationGraph, String str, ISchedule iSchedule) {
        setLearningRate(computationGraph, str, Double.NaN, iSchedule, true);
    }

    private static void setLearningRate(ComputationGraph computationGraph, String str, double d, ISchedule iSchedule, boolean z) {
        Layer layer = computationGraph.getLayer(str).conf().getLayer();
        if (layer instanceof BaseLayer) {
            IUpdater iUpdater = ((BaseLayer) layer).getIUpdater();
            if (iUpdater != null && iUpdater.hasLearningRate()) {
                if (iSchedule != null) {
                    iUpdater.setLrAndSchedule(Double.NaN, iSchedule);
                } else {
                    iUpdater.setLrAndSchedule(d, (ISchedule) null);
                }
            }
            if (z) {
                refreshUpdater(computationGraph);
            }
        }
    }

    private static void refreshUpdater(ComputationGraph computationGraph) {
        INDArray stateViewArray = computationGraph.getUpdater().getStateViewArray();
        computationGraph.setUpdater(null);
        computationGraph.getUpdater().setStateViewArray(stateViewArray);
    }
}
