package org.deeplearning4j.nn.layers.variational;

import java.beans.ConstructorProperties;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.variational.CompositeReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper;
import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.VariationalAutoencoderParamInitializer;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.class */
public class VariationalAutoencoder implements Layer {
    protected INDArray input;
    protected INDArray paramsFlattened;
    protected INDArray gradientsFlattened;
    protected Map<String, INDArray> params;
    protected transient Map<String, INDArray> gradientViews;
    protected NeuralNetConfiguration conf;
    protected ConvexOptimizer optimizer;
    protected Gradient gradient;
    protected INDArray maskArray;
    protected Solver solver;
    protected int[] encoderLayerSizes;
    protected int[] decoderLayerSizes;
    protected ReconstructionDistribution reconstructionDistribution;
    protected IActivation pzxActivationFn;
    protected int numSamples;
    protected double score = 0.0d;
    protected Collection<IterationListener> iterationListeners = new ArrayList();
    protected Collection<TrainingListener> trainingListeners = null;
    protected int index = 0;
    protected boolean zeroedPretrainParamGradients = false;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/nn/layers/variational/VariationalAutoencoder$VAEFwdHelper.class */
    public static class VAEFwdHelper {
        private INDArray[] encoderPreOuts;
        private INDArray pzxMeanPreOut;
        private INDArray[] encoderActivations;

        @ConstructorProperties({"encoderPreOuts", "pzxMeanPreOut", "encoderActivations"})
        public VAEFwdHelper(INDArray[] iNDArrayArr, INDArray iNDArray, INDArray[] iNDArrayArr2) {
            this.encoderPreOuts = iNDArrayArr;
            this.pzxMeanPreOut = iNDArray;
            this.encoderActivations = iNDArrayArr2;
        }

        public INDArray[] getEncoderPreOuts() {
            return this.encoderPreOuts;
        }

        public INDArray getPzxMeanPreOut() {
            return this.pzxMeanPreOut;
        }

        public INDArray[] getEncoderActivations() {
            return this.encoderActivations;
        }

        public void setEncoderPreOuts(INDArray[] iNDArrayArr) {
            this.encoderPreOuts = iNDArrayArr;
        }

        public void setPzxMeanPreOut(INDArray iNDArray) {
            this.pzxMeanPreOut = iNDArray;
        }

        public void setEncoderActivations(INDArray[] iNDArrayArr) {
            this.encoderActivations = iNDArrayArr;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof VAEFwdHelper)) {
                return false;
            }
            VAEFwdHelper vAEFwdHelper = (VAEFwdHelper) obj;
            if (!vAEFwdHelper.canEqual(this) || !Arrays.deepEquals(getEncoderPreOuts(), vAEFwdHelper.getEncoderPreOuts())) {
                return false;
            }
            INDArray pzxMeanPreOut = getPzxMeanPreOut();
            INDArray pzxMeanPreOut2 = vAEFwdHelper.getPzxMeanPreOut();
            if (pzxMeanPreOut == null) {
                if (pzxMeanPreOut2 != null) {
                    return false;
                }
            } else if (!pzxMeanPreOut.equals(pzxMeanPreOut2)) {
                return false;
            }
            return Arrays.deepEquals(getEncoderActivations(), vAEFwdHelper.getEncoderActivations());
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof VAEFwdHelper;
        }

        public int hashCode() {
            int deepHashCode = (1 * 59) + Arrays.deepHashCode(getEncoderPreOuts());
            INDArray pzxMeanPreOut = getPzxMeanPreOut();
            return (((deepHashCode * 59) + (pzxMeanPreOut == null ? 43 : pzxMeanPreOut.hashCode())) * 59) + Arrays.deepHashCode(getEncoderActivations());
        }

        public String toString() {
            return "VariationalAutoencoder.VAEFwdHelper(encoderPreOuts=" + Arrays.deepToString(getEncoderPreOuts()) + ", pzxMeanPreOut=" + getPzxMeanPreOut() + ", encoderActivations=" + Arrays.deepToString(getEncoderActivations()) + ")";
        }
    }

    public VariationalAutoencoder(NeuralNetConfiguration neuralNetConfiguration) {
        this.conf = neuralNetConfiguration;
        this.encoderLayerSizes = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) neuralNetConfiguration.getLayer()).getEncoderLayerSizes();
        this.decoderLayerSizes = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) neuralNetConfiguration.getLayer()).getDecoderLayerSizes();
        this.reconstructionDistribution = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) neuralNetConfiguration.getLayer()).getOutputDistribution();
        this.pzxActivationFn = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) neuralNetConfiguration.getLayer()).getPzxActivationFn();
        this.numSamples = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) neuralNetConfiguration.getLayer()).getNumSamples();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(Gradient gradient) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(INDArray iNDArray, String str) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public double score() {
        return this.score;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore() {
        INDArray iNDArray;
        VAEFwdHelper doForward = doForward(true, true);
        IActivation activationFn = conf().getLayer().getActivationFn();
        INDArray addiRowVector = doForward.encoderActivations[doForward.encoderActivations.length - 1].mmul(this.params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W)).addiRowVector(this.params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B));
        INDArray dup = doForward.pzxMeanPreOut.dup();
        INDArray dup2 = addiRowVector.dup();
        this.pzxActivationFn.getActivation(dup, true);
        this.pzxActivationFn.getActivation(dup2, true);
        INDArray exp = Transforms.exp(dup2, true);
        INDArray sqrt = Transforms.sqrt(exp, true);
        int size = this.input.size(0);
        int size2 = doForward.pzxMeanPreOut.size(1);
        HashMap hashMap = new HashMap();
        double d = 1.0d / this.numSamples;
        Level1 level1 = Nd4j.getBlasWrapper().level1();
        INDArray[] iNDArrayArr = this.numSamples > 1 ? new INDArray[this.encoderLayerSizes.length] : null;
        int i = 0;
        while (i < this.numSamples) {
            double d2 = i == 0 ? 0.0d : 1.0d;
            INDArray randn = Nd4j.randn(size, size2);
            INDArray addi = sqrt.mul(randn).addi(dup);
            int length = this.decoderLayerSizes.length;
            INDArray iNDArray2 = addi;
            INDArray[] iNDArrayArr2 = new INDArray[length];
            INDArray[] iNDArrayArr3 = new INDArray[length];
            for (int i2 = 0; i2 < length; i2++) {
                iNDArray2 = iNDArray2.mmul(this.params.get("d" + i2 + "W")).addiRowVector(this.params.get("d" + i2 + "b"));
                iNDArrayArr2[i2] = iNDArray2.dup();
                activationFn.getActivation(iNDArray2, true);
                iNDArrayArr3[i2] = iNDArray2;
            }
            INDArray iNDArray3 = this.params.get(VariationalAutoencoderParamInitializer.PXZ_W);
            INDArray iNDArray4 = this.params.get(VariationalAutoencoderParamInitializer.PXZ_B);
            if (i == 0) {
                INDArray negi = dup.mul(dup).addi(exp).negi();
                negi.addi(dup2).addi(Double.valueOf(1.0d));
                this.score = (((-0.5d) / size) * negi.sumNumber().doubleValue()) + ((calcL1(false) + calcL2(false)) / size);
            }
            INDArray addiRowVector2 = iNDArray2.mmul(iNDArray3).addiRowVector(iNDArray4);
            this.score += this.reconstructionDistribution.negLogProbability(this.input, addiRowVector2, true) / this.numSamples;
            if (this.trainingListeners != null && this.trainingListeners.size() > 0 && i == 0) {
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                for (int i3 = 0; i3 < doForward.encoderActivations.length; i3++) {
                    linkedHashMap.put("e" + i3, doForward.encoderActivations[i3]);
                }
                linkedHashMap.put(VariationalAutoencoderParamInitializer.PZX_PREFIX, addi);
                for (int i4 = 0; i4 < iNDArrayArr3.length; i4++) {
                    linkedHashMap.put("d" + i4, iNDArrayArr3[i4]);
                }
                linkedHashMap.put(VariationalAutoencoderParamInitializer.PXZ_PREFIX, this.reconstructionDistribution.generateAtMean(addiRowVector2));
                Iterator<TrainingListener> it = this.trainingListeners.iterator();
                while (it.hasNext()) {
                    it.next().onForwardPass(this, linkedHashMap);
                }
            }
            INDArray gradient = this.reconstructionDistribution.gradient(this.input, addiRowVector2);
            INDArray iNDArray5 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PXZ_W);
            INDArray iNDArray6 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PXZ_B);
            Nd4j.gemm(iNDArrayArr3[iNDArrayArr3.length - 1], gradient, iNDArray5, true, false, d, d2);
            if (i == 0) {
                iNDArray6.assign(gradient.sum(new int[]{0}));
                if (this.numSamples > 1) {
                    iNDArray6.muli(Double.valueOf(d));
                }
            } else {
                level1.axpy(iNDArray6.length(), d, gradient.sum(new int[]{0}), iNDArray6);
            }
            hashMap.put(VariationalAutoencoderParamInitializer.PXZ_W, iNDArray5);
            hashMap.put(VariationalAutoencoderParamInitializer.PXZ_B, iNDArray6);
            INDArray transpose = iNDArray3.mmul(gradient.transpose()).transpose();
            int i5 = length - 1;
            while (i5 >= 0) {
                String str = "d" + i5 + "W";
                String str2 = "d" + i5 + "b";
                INDArray iNDArray7 = (INDArray) activationFn.backprop(iNDArrayArr2[i5], transpose).getFirst();
                INDArray iNDArray8 = this.params.get(str);
                INDArray iNDArray9 = this.gradientViews.get(str);
                INDArray iNDArray10 = this.gradientViews.get(str2);
                Nd4j.gemm(i5 == 0 ? addi : iNDArrayArr3[i5 - 1], iNDArray7, iNDArray9, true, false, d, d2);
                if (i == 0) {
                    iNDArray10.assign(iNDArray7.sum(new int[]{0}));
                    if (this.numSamples > 1) {
                        iNDArray10.muli(Double.valueOf(d));
                    }
                } else {
                    level1.axpy(iNDArray10.length(), d, iNDArray7.sum(new int[]{0}), iNDArray10);
                }
                hashMap.put(str, iNDArray9);
                hashMap.put(str2, iNDArray10);
                transpose = iNDArray8.mmul(iNDArray7.transpose()).transpose();
                i5--;
            }
            INDArray iNDArray11 = this.params.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W);
            INDArray iNDArray12 = this.params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W);
            INDArray iNDArray13 = transpose;
            INDArray add = iNDArray13.add(dup);
            INDArray muli = iNDArray13.mul(randn).muli(sqrt).addi(exp).subi(1).muli(Double.valueOf(0.5d));
            INDArray iNDArray14 = (INDArray) this.pzxActivationFn.backprop(doForward.getPzxMeanPreOut().dup(), add).getFirst();
            INDArray iNDArray15 = (INDArray) this.pzxActivationFn.backprop(addiRowVector.dup(), muli).getFirst();
            INDArray iNDArray16 = doForward.encoderActivations[doForward.encoderActivations.length - 1];
            INDArray iNDArray17 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W);
            INDArray iNDArray18 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W);
            Nd4j.gemm(iNDArray16, iNDArray14, iNDArray17, true, false, d, d2);
            Nd4j.gemm(iNDArray16, iNDArray15, iNDArray18, true, false, d, d2);
            INDArray iNDArray19 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PZX_MEAN_B);
            INDArray iNDArray20 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B);
            if (i == 0) {
                iNDArray19.assign(((INDArray) this.pzxActivationFn.backprop(doForward.getPzxMeanPreOut().dup(), iNDArray13.add(dup)).getFirst()).sum(new int[]{0}));
                iNDArray20.assign(iNDArray15.sum(new int[]{0}));
                if (this.numSamples > 1) {
                    iNDArray19.muli(Double.valueOf(d));
                    iNDArray20.muli(Double.valueOf(d));
                }
            } else {
                level1.axpy(iNDArray19.length(), d, ((INDArray) this.pzxActivationFn.backprop(doForward.getPzxMeanPreOut().dup(), iNDArray13.add(dup)).getFirst()).sum(new int[]{0}), iNDArray19);
                level1.axpy(iNDArray20.length(), d, iNDArray15.sum(new int[]{0}), iNDArray20);
            }
            hashMap.put(VariationalAutoencoderParamInitializer.PZX_MEAN_W, iNDArray17);
            hashMap.put(VariationalAutoencoderParamInitializer.PZX_MEAN_B, iNDArray19);
            hashMap.put(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W, iNDArray18);
            hashMap.put(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B, iNDArray20);
            INDArray gemm = Nd4j.gemm(iNDArray14, iNDArray11, false, true);
            Nd4j.gemm(iNDArray15, iNDArray12, gemm, false, true, 1.0d, 1.0d);
            int length2 = this.encoderLayerSizes.length - 1;
            while (length2 >= 0) {
                String str3 = "e" + length2 + "W";
                String str4 = "e" + length2 + "b";
                INDArray iNDArray21 = this.params.get(str3);
                INDArray iNDArray22 = this.gradientViews.get(str3);
                INDArray iNDArray23 = this.gradientViews.get(str4);
                INDArray iNDArray24 = doForward.encoderPreOuts[length2];
                if (this.numSamples > 1) {
                    if (i == 0) {
                        iNDArrayArr[length2] = (INDArray) activationFn.backprop(doForward.encoderPreOuts[length2], Nd4j.ones(doForward.encoderPreOuts[length2].shape())).getFirst();
                    }
                    iNDArray = gemm.muli(iNDArrayArr[length2]);
                } else {
                    iNDArray = (INDArray) activationFn.backprop(iNDArray24, gemm).getFirst();
                }
                Nd4j.gemm(length2 == 0 ? this.input : doForward.encoderActivations[length2 - 1], iNDArray, iNDArray22, true, false, d, d2);
                if (i == 0) {
                    iNDArray23.assign(iNDArray.sum(new int[]{0}));
                    if (this.numSamples > 1) {
                        iNDArray23.muli(Double.valueOf(d));
                    }
                } else {
                    level1.axpy(iNDArray23.length(), d, iNDArray.sum(new int[]{0}), iNDArray23);
                }
                hashMap.put(str3, iNDArray22);
                hashMap.put(str4, iNDArray23);
                gemm = iNDArray21.mmul(iNDArray.transpose()).transpose();
                length2--;
            }
            i++;
        }
        DefaultGradient defaultGradient = new DefaultGradient(this.gradientsFlattened);
        Map<String, INDArray> gradientForVariable = defaultGradient.gradientForVariable();
        for (int i6 = 0; i6 < this.encoderLayerSizes.length; i6++) {
            String str5 = "e" + i6 + "W";
            gradientForVariable.put(str5, hashMap.get(str5));
            String str6 = "e" + i6 + "b";
            gradientForVariable.put(str6, hashMap.get(str6));
        }
        gradientForVariable.put(VariationalAutoencoderParamInitializer.PZX_MEAN_W, hashMap.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W));
        gradientForVariable.put(VariationalAutoencoderParamInitializer.PZX_MEAN_B, hashMap.get(VariationalAutoencoderParamInitializer.PZX_MEAN_B));
        gradientForVariable.put(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W, hashMap.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W));
        gradientForVariable.put(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B, hashMap.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B));
        for (int i7 = 0; i7 < this.decoderLayerSizes.length; i7++) {
            String str7 = "d" + i7 + "W";
            gradientForVariable.put(str7, hashMap.get(str7));
            String str8 = "d" + i7 + "b";
            gradientForVariable.put(str8, hashMap.get(str8));
        }
        gradientForVariable.put(VariationalAutoencoderParamInitializer.PXZ_W, hashMap.get(VariationalAutoencoderParamInitializer.PXZ_W));
        gradientForVariable.put(VariationalAutoencoderParamInitializer.PXZ_B, hashMap.get(VariationalAutoencoderParamInitializer.PXZ_B));
        this.gradient = defaultGradient;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void accumulateScore(double d) {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray params() {
        return this.paramsFlattened;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int numParams() {
        return numParams(false);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int numParams(boolean z) {
        int i = 0;
        for (Map.Entry<String, INDArray> entry : this.params.entrySet()) {
            if (!z || !isPretrainParam(entry.getKey())) {
                i += entry.getValue().length();
            }
        }
        return i;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        if (iNDArray.length() != this.paramsFlattened.length()) {
            throw new IllegalArgumentException("Cannot set parameters: expected parameters vector of length " + this.paramsFlattened.length() + " but got parameters array of length " + iNDArray.length());
        }
        this.paramsFlattened.assign(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamsViewArray(INDArray iNDArray) {
        if (this.params != null && iNDArray.length() != numParams()) {
            throw new IllegalArgumentException("Invalid input: expect params of length " + numParams() + ", got params of length " + iNDArray.length());
        }
        this.paramsFlattened = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        if (this.params != null && iNDArray.length() != numParams()) {
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + numParams() + ", got gradient array of length of length " + iNDArray.length());
        }
        this.gradientsFlattened = iNDArray;
        this.gradientViews = this.conf.getLayer().initializer().getGradientsFromFlattened(this.conf, iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void applyLearningRateScoreDecay() {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
        setInput(iNDArray);
        fit();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void iterate(INDArray iNDArray) {
        fit(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        return this.gradient;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), Double.valueOf(score()));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int batchSize() {
        return this.input.size(0);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public NeuralNetConfiguration conf() {
        return this.conf;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setConf(NeuralNetConfiguration neuralNetConfiguration) {
        this.conf = neuralNetConfiguration;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray input() {
        return this.input;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void validateInput() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public ConvexOptimizer getOptimizer() {
        return this.optimizer;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray getParam(String str) {
        return this.params.get(str);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void initParams() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable() {
        return new LinkedHashMap(this.params);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable(boolean z) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<String, INDArray> entry : this.params.entrySet()) {
            if (!z || !isPretrainParam(entry.getKey())) {
                linkedHashMap.put(entry.getKey(), entry.getValue());
            }
        }
        return linkedHashMap;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamTable(Map<String, INDArray> map) {
        this.params = map;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParam(String str, INDArray iNDArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void clear() {
        this.input = null;
        this.maskArray = null;
    }

    public boolean isPretrainParam(String str) {
        return (str.startsWith("e") || str.startsWith(VariationalAutoencoderParamInitializer.PZX_MEAN_PREFIX)) ? false : true;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public double calcL2(boolean z) {
        if (!this.conf.isUseRegularization()) {
            return 0.0d;
        }
        double d = 0.0d;
        for (Map.Entry<String, INDArray> entry : paramTable().entrySet()) {
            double l2ByParam = conf().getL2ByParam(entry.getKey());
            if (l2ByParam > 0.0d && (!z || !isPretrainParam(entry.getKey()))) {
                double doubleValue = entry.getValue().norm2Number().doubleValue();
                d += 0.5d * l2ByParam * doubleValue * doubleValue;
            }
        }
        return d;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public double calcL1(boolean z) {
        if (!this.conf.isUseRegularization()) {
            return 0.0d;
        }
        double d = 0.0d;
        for (Map.Entry<String, INDArray> entry : paramTable().entrySet()) {
            double l1ByParam = conf().getL1ByParam(entry.getKey());
            if (l1ByParam > 0.0d && (!z || !isPretrainParam(entry.getKey()))) {
                d += l1ByParam * entry.getValue().norm1Number().doubleValue();
            }
        }
        return d;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.FEED_FORWARD;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Gradient error(INDArray iNDArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray derivativeActivation(INDArray iNDArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Gradient calcGradient(Gradient gradient, INDArray iNDArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        if (!this.zeroedPretrainParamGradients) {
            for (Map.Entry<String, INDArray> entry : this.gradientViews.entrySet()) {
                if (isPretrainParam(entry.getKey())) {
                    entry.getValue().assign(0);
                }
            }
            this.zeroedPretrainParamGradients = true;
        }
        DefaultGradient defaultGradient = new DefaultGradient();
        VAEFwdHelper doForward = doForward(true, true);
        INDArray iNDArray2 = (INDArray) this.pzxActivationFn.backprop(doForward.pzxMeanPreOut, iNDArray).getFirst();
        INDArray iNDArray3 = this.params.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W);
        INDArray iNDArray4 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W);
        Nd4j.gemm(doForward.encoderActivations[doForward.encoderActivations.length - 1], iNDArray2, iNDArray4, true, false, 1.0d, 0.0d);
        INDArray iNDArray5 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PZX_MEAN_B);
        iNDArray5.assign(iNDArray2.sum(new int[]{0}));
        defaultGradient.gradientForVariable().put(VariationalAutoencoderParamInitializer.PZX_MEAN_W, iNDArray4);
        defaultGradient.gradientForVariable().put(VariationalAutoencoderParamInitializer.PZX_MEAN_B, iNDArray5);
        INDArray transpose = iNDArray3.mmul(iNDArray2.transpose()).transpose();
        int length = this.encoderLayerSizes.length;
        IActivation activationFn = conf().getLayer().getActivationFn();
        int i = length - 1;
        while (i >= 0) {
            String str = "e" + i + "W";
            String str2 = "e" + i + "b";
            INDArray iNDArray6 = this.params.get(str);
            INDArray iNDArray7 = this.gradientViews.get(str);
            INDArray iNDArray8 = this.gradientViews.get(str2);
            INDArray iNDArray9 = (INDArray) activationFn.backprop(doForward.encoderPreOuts[i], transpose).getFirst();
            Nd4j.gemm(i == 0 ? this.input : doForward.encoderActivations[i - 1], iNDArray9, iNDArray7, true, false, 1.0d, 0.0d);
            iNDArray8.assign(iNDArray9.sum(new int[]{0}));
            defaultGradient.gradientForVariable().put(str, iNDArray7);
            defaultGradient.gradientForVariable().put(str2, iNDArray8);
            transpose = iNDArray6.mmul(iNDArray9.transpose()).transpose();
            i--;
        }
        return new Pair<>(defaultGradient, transpose);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void merge(Layer layer, int i) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activationMean() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray) {
        return preOutput(iNDArray, Layer.TrainingMode.TEST);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, Layer.TrainingMode trainingMode) {
        return preOutput(iNDArray, trainingMode == Layer.TrainingMode.TRAIN);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, boolean z) {
        setInput(iNDArray);
        return preOutput(z);
    }

    public INDArray preOutput(boolean z) {
        return doForward(z, false).pzxMeanPreOut;
    }

    private VAEFwdHelper doForward(boolean z, boolean z2) {
        if (this.input == null) {
            throw new IllegalStateException("Cannot do forward pass with null input");
        }
        int length = this.encoderLayerSizes.length;
        INDArray[] iNDArrayArr = new INDArray[this.encoderLayerSizes.length];
        INDArray[] iNDArrayArr2 = new INDArray[this.encoderLayerSizes.length];
        INDArray iNDArray = this.input;
        for (int i = 0; i < length; i++) {
            INDArray iNDArray2 = this.params.get("e" + i + "W");
            iNDArray = iNDArray.mmul(iNDArray2).addiRowVector(this.params.get("e" + i + "b"));
            if (z2) {
                iNDArrayArr[i] = iNDArray.dup();
            }
            this.conf.getLayer().getActivationFn().getActivation(iNDArray, z);
            iNDArrayArr2[i] = iNDArray;
        }
        return new VAEFwdHelper(iNDArrayArr, iNDArray.mmul(this.params.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W)).addiRowVector(this.params.get(VariationalAutoencoderParamInitializer.PZX_MEAN_B)), iNDArrayArr2);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(Layer.TrainingMode trainingMode) {
        return activate(trainingMode == Layer.TrainingMode.TRAIN);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, Layer.TrainingMode trainingMode) {
        return null;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        INDArray preOutput = preOutput(z);
        this.pzxActivationFn.getActivation(preOutput, z);
        return preOutput;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z) {
        setInput(iNDArray);
        return activate(z);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate() {
        return activate(false);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray) {
        setInput(iNDArray);
        return activate();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Layer m95clone() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Collection<IterationListener> getListeners() {
        if (this.iterationListeners == null) {
            return null;
        }
        return new ArrayList(this.iterationListeners);
    }

    @Override // org.deeplearning4j.nn.api.Layer, org.deeplearning4j.nn.api.Model
    public void setListeners(IterationListener... iterationListenerArr) {
        setListeners(Arrays.asList(iterationListenerArr));
    }

    @Override // org.deeplearning4j.nn.api.Layer, org.deeplearning4j.nn.api.Model
    public void setListeners(Collection<IterationListener> collection) {
        if (this.iterationListeners == null) {
            this.iterationListeners = new ArrayList();
        } else {
            this.iterationListeners.clear();
        }
        if (this.trainingListeners == null) {
            this.trainingListeners = new ArrayList();
        } else {
            this.trainingListeners.clear();
        }
        if (collection == null || collection.size() <= 0) {
            return;
        }
        this.iterationListeners.addAll(collection);
        for (IterationListener iterationListener : collection) {
            if (iterationListener instanceof TrainingListener) {
                this.trainingListeners.add((TrainingListener) iterationListener);
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setIndex(int i) {
        this.index = i;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getIndex() {
        return this.index;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setInput(INDArray iNDArray) {
        this.input = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setInputMiniBatchSize(int i) {
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getInputMiniBatchSize() {
        return this.input.size(0);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setMaskArray(INDArray iNDArray) {
        this.maskArray = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray getMaskArray() {
        return this.maskArray;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return true;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit() {
        if (this.input == null) {
            throw new IllegalStateException("Cannot fit layer: layer input is null (not set)");
        }
        if (this.solver == null) {
            this.solver = new Solver.Builder().model(this).configure(conf()).listeners(getListeners()).build();
            Updater updater = this.solver.getOptimizer().getUpdater();
            int stateSizeForLayer = updater.stateSizeForLayer(this);
            if (stateSizeForLayer > 0) {
                updater.setStateViewArray(this, Nd4j.createUninitialized(new int[]{1, stateSizeForLayer}, Nd4j.order().charValue()), true);
            }
        }
        this.optimizer = this.solver.getOptimizer();
        this.solver.optimize();
    }

    public INDArray reconstructionProbability(INDArray iNDArray, int i) {
        return Transforms.exp(reconstructionLogProbability(iNDArray, i), false);
    }

    public INDArray reconstructionLogProbability(INDArray iNDArray, int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid input: numSamples must be > 0. Got: " + i);
        }
        if (this.reconstructionDistribution instanceof LossFunctionWrapper) {
            throw new UnsupportedOperationException("Cannot calculate reconstruction log probability when using a LossFunction (via LossFunctionWrapper) instead of a ReconstructionDistribution: ILossFunction instances are not in general probabilistic, hence it is not possible to calculate reconstruction probability");
        }
        setInput(iNDArray);
        VAEFwdHelper doForward = doForward(true, true);
        IActivation activationFn = conf().getLayer().getActivationFn();
        INDArray iNDArray2 = this.params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W);
        INDArray iNDArray3 = this.params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B);
        INDArray iNDArray4 = doForward.pzxMeanPreOut;
        INDArray addiRowVector = doForward.encoderActivations[doForward.encoderActivations.length - 1].mmul(iNDArray2).addiRowVector(iNDArray3);
        this.pzxActivationFn.getActivation(iNDArray4, false);
        this.pzxActivationFn.getActivation(addiRowVector, false);
        INDArray exp = Transforms.exp(addiRowVector, false);
        Transforms.sqrt(exp, false);
        int size = this.input.size(0);
        int size2 = doForward.pzxMeanPreOut.size(1);
        INDArray iNDArray5 = this.params.get(VariationalAutoencoderParamInitializer.PXZ_W);
        INDArray iNDArray6 = this.params.get(VariationalAutoencoderParamInitializer.PXZ_B);
        INDArray[] iNDArrayArr = new INDArray[this.decoderLayerSizes.length];
        INDArray[] iNDArrayArr2 = new INDArray[this.decoderLayerSizes.length];
        for (int i2 = 0; i2 < this.decoderLayerSizes.length; i2++) {
            iNDArrayArr[i2] = this.params.get("d" + i2 + "W");
            iNDArrayArr2[i2] = this.params.get("d" + i2 + "b");
        }
        INDArray iNDArray7 = null;
        for (int i3 = 0; i3 < i; i3++) {
            INDArray addi = Nd4j.randn(size, size2).muli(exp).addi(iNDArray4);
            int length = this.decoderLayerSizes.length;
            INDArray iNDArray8 = addi;
            for (int i4 = 0; i4 < length; i4++) {
                iNDArray8 = iNDArray8.mmul(iNDArrayArr[i4]).addiRowVector(iNDArrayArr2[i4]);
                activationFn.getActivation(iNDArray8, false);
            }
            INDArray addiRowVector2 = iNDArray8.mmul(iNDArray5).addiRowVector(iNDArray6);
            if (i3 == 0) {
                iNDArray7 = this.reconstructionDistribution.exampleNegLogProbability(iNDArray, addiRowVector2);
            } else {
                iNDArray7.addi(this.reconstructionDistribution.exampleNegLogProbability(iNDArray, addiRowVector2));
            }
        }
        setInput(null);
        return iNDArray7.divi(Integer.valueOf(-i));
    }

    public INDArray generateAtMeanGivenZ(INDArray iNDArray) {
        return this.reconstructionDistribution.generateAtMean(decodeGivenLatentSpaceValues(iNDArray));
    }

    public INDArray generateRandomGivenZ(INDArray iNDArray) {
        decodeGivenLatentSpaceValues(iNDArray);
        return this.reconstructionDistribution.generateRandom(iNDArray);
    }

    private INDArray decodeGivenLatentSpaceValues(INDArray iNDArray) {
        if (iNDArray.size(1) != this.params.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W).size(1)) {
            throw new IllegalArgumentException("Invalid latent space values: expected size " + this.params.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W).size(1) + ", got size (dimension 1) = " + iNDArray.size(1));
        }
        int length = this.decoderLayerSizes.length;
        INDArray iNDArray2 = iNDArray;
        IActivation activationFn = conf().getLayer().getActivationFn();
        for (int i = 0; i < length; i++) {
            INDArray iNDArray3 = this.params.get("d" + i + "W");
            iNDArray2 = iNDArray2.mmul(iNDArray3).addiRowVector(this.params.get("d" + i + "b"));
            activationFn.getActivation(iNDArray2, false);
        }
        return iNDArray2.mmul(this.params.get(VariationalAutoencoderParamInitializer.PXZ_W)).addiRowVector(this.params.get(VariationalAutoencoderParamInitializer.PXZ_B));
    }

    public boolean hasLossFunction() {
        return this.reconstructionDistribution.hasLossFunction();
    }

    public INDArray reconstructionError(INDArray iNDArray) {
        if (!hasLossFunction()) {
            throw new IllegalStateException("Cannot use reconstructionError method unless the variational autoencoder is configured with a standard loss function (via LossFunctionWrapper). For VAEs utilizing a reconstruction distribution, use the reconstructionProbability or reconstructionLogProbability methods");
        }
        INDArray generateAtMeanGivenZ = generateAtMeanGivenZ(activate(iNDArray, false));
        return this.reconstructionDistribution instanceof CompositeReconstructionDistribution ? ((CompositeReconstructionDistribution) this.reconstructionDistribution).computeLossFunctionScoreArray(iNDArray, generateAtMeanGivenZ) : ((LossFunctionWrapper) this.reconstructionDistribution).getLossFunction().computeScoreArray(iNDArray, generateAtMeanGivenZ, new ActivationIdentity(), (INDArray) null);
    }

    public Map<String, INDArray> getGradientViews() {
        return this.gradientViews;
    }
}
