package org.deeplearning4j.nn.params;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/DefaultParamInitializer.class */
public class DefaultParamInitializer implements ParamInitializer {
    private static final DefaultParamInitializer INSTANCE = new DefaultParamInitializer();
    public static final String WEIGHT_KEY = "W";
    public static final String BIAS_KEY = "b";

    public static DefaultParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public int numParams(NeuralNetConfiguration neuralNetConfiguration) {
        FeedForwardLayer feedForwardLayer = (FeedForwardLayer) neuralNetConfiguration.getLayer();
        int nIn = feedForwardLayer.getNIn();
        int nOut = feedForwardLayer.getNOut();
        return (nIn * nOut) + nOut;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> init(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        if (!(neuralNetConfiguration.getLayer() instanceof FeedForwardLayer)) {
            throw new IllegalArgumentException("unsupported layer type: " + neuralNetConfiguration.getLayer().getClass().getName());
        }
        Map<String, INDArray> synchronizedMap = Collections.synchronizedMap(new LinkedHashMap());
        int numParams = numParams(neuralNetConfiguration);
        if (iNDArray.length() != numParams) {
            throw new IllegalStateException("Expected params view of length " + numParams + ", got length " + iNDArray.length());
        }
        FeedForwardLayer feedForwardLayer = (FeedForwardLayer) neuralNetConfiguration.getLayer();
        int nIn = feedForwardLayer.getNIn();
        int nOut = feedForwardLayer.getNOut();
        int i = nIn * nOut;
        INDArray iNDArray2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(0, i)});
        INDArray iNDArray3 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i, i + nOut)});
        synchronizedMap.put("W", createWeightMatrix(neuralNetConfiguration, iNDArray2, z));
        synchronizedMap.put("b", createBias(neuralNetConfiguration, iNDArray3, z));
        neuralNetConfiguration.addVariable("W");
        neuralNetConfiguration.addVariable("b");
        return synchronizedMap;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        FeedForwardLayer feedForwardLayer = (FeedForwardLayer) neuralNetConfiguration.getLayer();
        int nIn = feedForwardLayer.getNIn();
        int nOut = feedForwardLayer.getNOut();
        int i = nIn * nOut;
        INDArray reshape = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(0, i)}).reshape('f', nIn, nOut);
        INDArray iNDArray2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i, i + nOut)});
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("W", reshape);
        linkedHashMap.put("b", iNDArray2);
        return linkedHashMap;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray createBias(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        FeedForwardLayer feedForwardLayer = (FeedForwardLayer) neuralNetConfiguration.getLayer();
        return createBias(feedForwardLayer.getNOut(), feedForwardLayer.getBiasInit(), iNDArray, z);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray createBias(int i, double d, INDArray iNDArray, boolean z) {
        if (z) {
            iNDArray.assign(Nd4j.valueArrayOf(i, d));
        }
        return iNDArray;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray createWeightMatrix(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        FeedForwardLayer feedForwardLayer = (FeedForwardLayer) neuralNetConfiguration.getLayer();
        if (!z) {
            return createWeightMatrix(feedForwardLayer.getNIn(), feedForwardLayer.getNOut(), null, null, iNDArray, false);
        }
        return createWeightMatrix(feedForwardLayer.getNIn(), feedForwardLayer.getNOut(), feedForwardLayer.getWeightInit(), Distributions.createDistribution(feedForwardLayer.getDist()), iNDArray, true);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray createWeightMatrix(int i, int i2, WeightInit weightInit, Distribution distribution, INDArray iNDArray, boolean z) {
        int[] iArr = {i, i2};
        return z ? WeightInitUtil.initWeights(i, i2, iArr, weightInit, distribution, iNDArray) : WeightInitUtil.reshapeWeights(iArr, iNDArray);
    }
}
