package org.deeplearning4j.nn.params;

import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.class */
public class VariationalAutoencoderParamInitializer extends DefaultParamInitializer {
    private static final VariationalAutoencoderParamInitializer INSTANCE = new VariationalAutoencoderParamInitializer();
    public static final String WEIGHT_KEY_SUFFIX = "W";
    public static final String BIAS_KEY_SUFFIX = "b";
    public static final String PZX_PREFIX = "pZX";
    public static final String PZX_MEAN_PREFIX = "pZXMean";
    public static final String PZX_LOGSTD2_PREFIX = "pZXLogStd2";
    public static final String PZX_MEAN_W = "pZXMeanW";
    public static final String PZX_MEAN_B = "pZXMeanb";
    public static final String PZX_LOGSTD2_W = "pZXLogStd2W";
    public static final String PZX_LOGSTD2_B = "pZXLogStd2b";
    public static final String PXZ_PREFIX = "pXZ";
    public static final String PXZ_W = "pXZW";
    public static final String PXZ_B = "pXZb";

    public static VariationalAutoencoderParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public int numParams(NeuralNetConfiguration neuralNetConfiguration) {
        VariationalAutoencoder variationalAutoencoder = (VariationalAutoencoder) neuralNetConfiguration.getLayer();
        int nIn = variationalAutoencoder.getNIn();
        int nOut = variationalAutoencoder.getNOut();
        int[] encoderLayerSizes = variationalAutoencoder.getEncoderLayerSizes();
        int[] decoderLayerSizes = variationalAutoencoder.getDecoderLayerSizes();
        int i = 0;
        int i2 = 0;
        while (i2 < encoderLayerSizes.length) {
            i += ((i2 == 0 ? nIn : encoderLayerSizes[i2 - 1]) + 1) * encoderLayerSizes[i2];
            i2++;
        }
        int i3 = i + ((encoderLayerSizes[encoderLayerSizes.length - 1] + 1) * 2 * nOut);
        int i4 = 0;
        while (i4 < decoderLayerSizes.length) {
            i3 += ((i4 == 0 ? nOut : decoderLayerSizes[i4 - 1]) + 1) * decoderLayerSizes[i4];
            i4++;
        }
        return i3 + ((decoderLayerSizes[decoderLayerSizes.length - 1] + 1) * variationalAutoencoder.getOutputDistribution().distributionInputSize(nIn));
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> init(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        if (iNDArray.length() != numParams(neuralNetConfiguration)) {
            throw new IllegalArgumentException("Incorrect paramsView length: Expected length " + numParams(neuralNetConfiguration) + ", got length " + iNDArray.length());
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        VariationalAutoencoder variationalAutoencoder = (VariationalAutoencoder) neuralNetConfiguration.getLayer();
        int nIn = variationalAutoencoder.getNIn();
        int nOut = variationalAutoencoder.getNOut();
        int[] encoderLayerSizes = variationalAutoencoder.getEncoderLayerSizes();
        int[] decoderLayerSizes = variationalAutoencoder.getDecoderLayerSizes();
        WeightInit weightInit = variationalAutoencoder.getWeightInit();
        Distribution createDistribution = Distributions.createDistribution(variationalAutoencoder.getDist());
        int i = 0;
        int i2 = 0;
        while (i2 < encoderLayerSizes.length) {
            int i3 = i2 == 0 ? nIn : encoderLayerSizes[i2 - 1];
            int i4 = i3 * encoderLayerSizes[i2];
            INDArray iNDArray2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i, i + i4)});
            int i5 = i + i4;
            INDArray iNDArray3 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i5, i5 + encoderLayerSizes[i2])});
            i = i5 + encoderLayerSizes[i2];
            INDArray createWeightMatrix = createWeightMatrix(i3, encoderLayerSizes[i2], weightInit, createDistribution, iNDArray2, z);
            INDArray createBias = createBias(encoderLayerSizes[i2], 0.0d, iNDArray3, z);
            String str = "e" + i2 + "W";
            String str2 = "e" + i2 + "b";
            linkedHashMap.put(str, createWeightMatrix);
            linkedHashMap.put(str2, createBias);
            neuralNetConfiguration.addVariable(str);
            neuralNetConfiguration.addVariable(str2);
            i2++;
        }
        int i6 = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
        INDArray iNDArray4 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i, i + i6)});
        int i7 = i + i6;
        INDArray iNDArray5 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i7, i7 + nOut)});
        int i8 = i7 + nOut;
        INDArray createWeightMatrix2 = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, weightInit, createDistribution, iNDArray4, z);
        INDArray createBias2 = createBias(nOut, 0.0d, iNDArray5, z);
        linkedHashMap.put(PZX_MEAN_W, createWeightMatrix2);
        linkedHashMap.put(PZX_MEAN_B, createBias2);
        neuralNetConfiguration.addVariable(PZX_MEAN_W);
        neuralNetConfiguration.addVariable(PZX_MEAN_B);
        INDArray iNDArray6 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i8, i8 + i6)});
        int i9 = i8 + i6;
        INDArray iNDArray7 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i9, i9 + nOut)});
        int i10 = i9 + nOut;
        INDArray createWeightMatrix3 = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, weightInit, createDistribution, iNDArray6, z);
        INDArray createBias3 = createBias(nOut, 0.0d, iNDArray7, z);
        linkedHashMap.put(PZX_LOGSTD2_W, createWeightMatrix3);
        linkedHashMap.put(PZX_LOGSTD2_B, createBias3);
        neuralNetConfiguration.addVariable(PZX_LOGSTD2_W);
        neuralNetConfiguration.addVariable(PZX_LOGSTD2_B);
        int i11 = 0;
        while (i11 < decoderLayerSizes.length) {
            int i12 = i11 == 0 ? nOut : decoderLayerSizes[i11 - 1];
            int i13 = i12 * decoderLayerSizes[i11];
            INDArray iNDArray8 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i10, i10 + i13)});
            int i14 = i10 + i13;
            INDArray iNDArray9 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i14, i14 + decoderLayerSizes[i11])});
            i10 = i14 + decoderLayerSizes[i11];
            INDArray createWeightMatrix4 = createWeightMatrix(i12, decoderLayerSizes[i11], weightInit, createDistribution, iNDArray8, z);
            INDArray createBias4 = createBias(decoderLayerSizes[i11], 0.0d, iNDArray9, z);
            String str3 = "d" + i11 + "W";
            String str4 = "d" + i11 + "b";
            linkedHashMap.put(str3, createWeightMatrix4);
            linkedHashMap.put(str4, createBias4);
            neuralNetConfiguration.addVariable(str3);
            neuralNetConfiguration.addVariable(str4);
            i11++;
        }
        int distributionInputSize = variationalAutoencoder.getOutputDistribution().distributionInputSize(nIn);
        int i15 = decoderLayerSizes[decoderLayerSizes.length - 1] * distributionInputSize;
        INDArray iNDArray10 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i10, i10 + i15)});
        int i16 = i10 + i15;
        INDArray iNDArray11 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i16, i16 + distributionInputSize)});
        INDArray createWeightMatrix5 = createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1], distributionInputSize, weightInit, createDistribution, iNDArray10, z);
        INDArray createBias5 = createBias(distributionInputSize, 0.0d, iNDArray11, z);
        linkedHashMap.put(PXZ_W, createWeightMatrix5);
        linkedHashMap.put(PXZ_B, createBias5);
        neuralNetConfiguration.addVariable(PXZ_W);
        neuralNetConfiguration.addVariable(PXZ_B);
        return linkedHashMap;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        VariationalAutoencoder variationalAutoencoder = (VariationalAutoencoder) neuralNetConfiguration.getLayer();
        int nIn = variationalAutoencoder.getNIn();
        int nOut = variationalAutoencoder.getNOut();
        int[] encoderLayerSizes = variationalAutoencoder.getEncoderLayerSizes();
        int[] decoderLayerSizes = variationalAutoencoder.getDecoderLayerSizes();
        int i = 0;
        int i2 = 0;
        while (i2 < encoderLayerSizes.length) {
            int i3 = i2 == 0 ? nIn : encoderLayerSizes[i2 - 1];
            int i4 = i3 * encoderLayerSizes[i2];
            INDArray iNDArray2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i, i + i4)});
            int i5 = i + i4;
            INDArray iNDArray3 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i5, i5 + encoderLayerSizes[i2])});
            i = i5 + encoderLayerSizes[i2];
            linkedHashMap.put("e" + i2 + "W", iNDArray2.reshape('f', i3, encoderLayerSizes[i2]));
            linkedHashMap.put("e" + i2 + "b", iNDArray3);
            i2++;
        }
        int i6 = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
        INDArray iNDArray4 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i, i + i6)});
        int i7 = i + i6;
        INDArray iNDArray5 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i7, i7 + nOut)});
        int i8 = i7 + nOut;
        linkedHashMap.put(PZX_MEAN_W, iNDArray4.reshape('f', encoderLayerSizes[encoderLayerSizes.length - 1], nOut));
        linkedHashMap.put(PZX_MEAN_B, iNDArray5);
        INDArray iNDArray6 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i8, i8 + i6)});
        int i9 = i8 + i6;
        INDArray iNDArray7 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i9, i9 + nOut)});
        int i10 = i9 + nOut;
        linkedHashMap.put(PZX_LOGSTD2_W, createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, null, null, iNDArray6, false));
        linkedHashMap.put(PZX_LOGSTD2_B, iNDArray7);
        int i11 = 0;
        while (i11 < decoderLayerSizes.length) {
            int i12 = i11 == 0 ? nOut : decoderLayerSizes[i11 - 1];
            int i13 = i12 * decoderLayerSizes[i11];
            INDArray iNDArray8 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i10, i10 + i13)});
            int i14 = i10 + i13;
            INDArray iNDArray9 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i14, i14 + decoderLayerSizes[i11])});
            i10 = i14 + decoderLayerSizes[i11];
            INDArray createWeightMatrix = createWeightMatrix(i12, decoderLayerSizes[i11], null, null, iNDArray8, false);
            INDArray createBias = createBias(decoderLayerSizes[i11], 0.0d, iNDArray9, false);
            linkedHashMap.put("d" + i11 + "W", createWeightMatrix);
            linkedHashMap.put("d" + i11 + "b", createBias);
            i11++;
        }
        int distributionInputSize = variationalAutoencoder.getOutputDistribution().distributionInputSize(nIn);
        int i15 = decoderLayerSizes[decoderLayerSizes.length - 1] * distributionInputSize;
        INDArray iNDArray10 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i10, i10 + i15)});
        int i16 = i10 + i15;
        INDArray iNDArray11 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i16, i16 + distributionInputSize)});
        INDArray createWeightMatrix2 = createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1], distributionInputSize, null, null, iNDArray10, false);
        INDArray createBias2 = createBias(distributionInputSize, 0.0d, iNDArray11, false);
        linkedHashMap.put(PXZ_W, createWeightMatrix2);
        linkedHashMap.put(PXZ_B, createBias2);
        return linkedHashMap;
    }
}
