package org.deeplearning4j.nn.layers.normalization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/normalization/BatchNormalization.class */
public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.layers.BatchNormalization> {
    protected static final Logger log = LoggerFactory.getLogger(BatchNormalization.class);
    BatchNormalizationHelper helper;
    protected int index;
    protected List<IterationListener> listeners;
    protected INDArray std;
    protected INDArray xMu;
    protected INDArray xHat;

    public BatchNormalization(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
        this.helper = null;
        this.index = 0;
        this.listeners = new ArrayList();
        initializeHelper();
    }

    void initializeHelper() {
        try {
            this.helper = (BatchNormalizationHelper) Class.forName("org.deeplearning4j.nn.layers.normalization.CudnnBatchNormalizationHelper").asSubclass(BatchNormalizationHelper.class).newInstance();
            log.debug("CudnnBatchNormalizationHelper successfully loaded");
        } catch (Throwable th) {
            if (th instanceof ClassNotFoundException) {
                return;
            }
            log.warn("Could not load CudnnBatchNormalizationHelper", th);
        }
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2(boolean z) {
        return 0.0d;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1(boolean z) {
        return 0.0d;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.NORMALIZATION;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Gradient error(INDArray iNDArray) {
        return null;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Gradient calcGradient(Gradient gradient, INDArray iNDArray) {
        return null;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        INDArray iNDArray2;
        INDArray iNDArray3;
        INDArray iNDArray4;
        int[] shape = getShape(iNDArray);
        int size = iNDArray.size(0);
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = layerConf();
        INDArray iNDArray5 = null;
        INDArray iNDArray6 = this.gradientViews.get(BatchNormalizationParamInitializer.GLOBAL_MEAN);
        INDArray iNDArray7 = this.gradientViews.get(BatchNormalizationParamInitializer.GLOBAL_VAR);
        if (layerConf.isLockGammaBeta()) {
            int[] iArr = {1, shape[1]};
            iNDArray2 = Nd4j.createUninitialized(iArr, 'c');
            iNDArray3 = Nd4j.createUninitialized(iArr, 'c');
        } else {
            iNDArray5 = getParam(BatchNormalizationParamInitializer.GAMMA);
            iNDArray2 = this.gradientViews.get(BatchNormalizationParamInitializer.GAMMA);
            iNDArray3 = this.gradientViews.get(BatchNormalizationParamInitializer.BETA);
        }
        DefaultGradient defaultGradient = new DefaultGradient();
        if (this.helper != null && iNDArray.rank() == 4) {
            if (layerConf.isLockGammaBeta()) {
                iNDArray5 = Nd4j.valueArrayOf(new int[]{1, shape[1]}, layerConf.getGamma());
            }
            Pair<Gradient, INDArray> backpropGradient = this.helper.backpropGradient(this.input, iNDArray, shape, iNDArray5, iNDArray2, iNDArray3, layerConf.getEps());
            if (backpropGradient != null) {
                return backpropGradient;
            }
        }
        if (iNDArray.rank() == 2) {
            INDArray sum = iNDArray.mul(this.xHat).sum(new int[]{0});
            INDArray sum2 = iNDArray.sum(new int[]{0});
            INDArray mul = layerConf.isLockGammaBeta() ? iNDArray.mul(Double.valueOf(layerConf.getGamma())) : iNDArray.mulRowVector(iNDArray5);
            INDArray muli = mul.mul(this.xMu).sum(new int[]{0}).muli(Double.valueOf(-0.5d)).muli(Transforms.pow(this.std, Double.valueOf(-3.0d), true));
            INDArray addiRowVector = mul.diviRowVector(this.std).addi(this.xMu.muliRowVector(muli.muli(Double.valueOf(2.0d / size)))).addiRowVector(mul.sum(new int[]{0}).divi(this.std).negi().addi(this.xMu.sum(new int[]{0}).muli(Double.valueOf((-2.0d) / size)).muli(muli)).muli(Double.valueOf(1.0d / size)));
            iNDArray2.assign(sum);
            iNDArray3.assign(sum2);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, iNDArray2);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, iNDArray3);
            iNDArray6.assign(0);
            iNDArray7.assign(0);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_MEAN, iNDArray6);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_VAR, iNDArray7);
            iNDArray4 = addiRowVector;
        } else {
            if (iNDArray.rank() != 4) {
                throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported.");
            }
            if (!Shape.strideDescendingCAscendingF(iNDArray)) {
                iNDArray = iNDArray.dup();
            }
            INDArray sum3 = iNDArray.mul(this.xHat).sum(new int[]{0, 2, 3});
            INDArray sum4 = iNDArray.sum(new int[]{0, 2, 3});
            INDArray mul2 = layerConf.isLockGammaBeta() ? iNDArray.mul(Double.valueOf(layerConf.getGamma())) : Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(iNDArray, iNDArray5, Nd4j.createUninitialized(iNDArray.shape(), iNDArray.ordering()), new int[]{1}));
            INDArray muli2 = mul2.mul(this.xMu).sum(new int[]{0, 2, 3}).muli(Double.valueOf(-0.5d)).muli(Transforms.pow(this.std, Double.valueOf(-3.0d), true));
            int size2 = this.input.size(0) * this.input.size(2) * this.input.size(3);
            INDArray addi = mul2.sum(new int[]{0, 2, 3}).divi(this.std).negi().addi(this.xMu.sum(new int[]{0, 2, 3}).muli(Double.valueOf((-2.0d) / size2)).muli(muli2));
            INDArray addi2 = Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(mul2, this.std, mul2, new int[]{1})).addi(Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(this.xMu, muli2.muli(Double.valueOf(2.0d / size2)), this.xMu, new int[]{1})));
            Nd4j.getExecutioner().execAndReturn(new BroadcastAddOp(addi2, addi.muli(Double.valueOf(1.0d / size2)), addi2, new int[]{1}));
            iNDArray2.assign(sum3);
            iNDArray3.assign(sum4);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, iNDArray2);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, iNDArray3);
            iNDArray6.assign(0);
            iNDArray7.assign(0);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_MEAN, iNDArray6);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_VAR, iNDArray7);
            iNDArray4 = addi2;
        }
        return new Pair<>(defaultGradient, iNDArray4);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public void merge(Layer layer, int i) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        return preOutput(this.input, z ? Layer.TrainingMode.TRAIN : Layer.TrainingMode.TEST);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        return this.gradient;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray) {
        return preOutput(iNDArray, Layer.TrainingMode.TRAIN);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, Layer.TrainingMode trainingMode) {
        INDArray param;
        INDArray param2;
        INDArray execAndReturn;
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = layerConf();
        int[] shape = getShape(iNDArray);
        if (trainingMode == Layer.TrainingMode.TRAIN) {
            switch (iNDArray.rank()) {
                case 2:
                    param = iNDArray.mean(new int[]{0});
                    param2 = iNDArray.var(false, new int[]{0});
                    break;
                case 4:
                    param = iNDArray.mean(new int[]{0, 2, 3});
                    param2 = iNDArray.var(false, new int[]{0, 2, 3});
                    break;
                default:
                    throw new IllegalStateException("Batch normalization on activations of rank " + iNDArray.rank() + " not supported");
            }
            param2.addi(Double.valueOf(layerConf.getEps()));
        } else {
            param = getParam(BatchNormalizationParamInitializer.GLOBAL_MEAN);
            param2 = getParam(BatchNormalizationParamInitializer.GLOBAL_VAR);
        }
        this.std = Transforms.sqrt(param2, true);
        INDArray iNDArray2 = null;
        INDArray iNDArray3 = null;
        INDArray param3 = getParam(BatchNormalizationParamInitializer.GLOBAL_MEAN);
        INDArray param4 = getParam(BatchNormalizationParamInitializer.GLOBAL_VAR);
        if (!layerConf.isLockGammaBeta()) {
            iNDArray2 = getParam(BatchNormalizationParamInitializer.GAMMA);
            iNDArray3 = getParam(BatchNormalizationParamInitializer.BETA);
        } else if (this.helper != null && this.input.rank() == 4) {
            int[] iArr = {1, layerConf().getNOut()};
            iNDArray2 = Nd4j.valueArrayOf(iArr, layerConf().getGamma());
            iNDArray3 = Nd4j.valueArrayOf(iArr, layerConf().getBeta());
        }
        if (this.helper != null && this.input.rank() != 4) {
            INDArray preOutput = this.helper.preOutput(iNDArray, trainingMode == Layer.TrainingMode.TRAIN, shape, iNDArray2, iNDArray3, param3, param4, layerConf.getDecay(), layerConf.getEps());
            if (preOutput != null) {
                return preOutput;
            }
        }
        if (iNDArray.rank() == 2) {
            this.xMu = iNDArray.subRowVector(param);
            this.xHat = this.xMu.divRowVector(this.std);
            if (layerConf.isLockGammaBeta()) {
                double gamma = layerConf.getGamma();
                double beta = layerConf.getBeta();
                execAndReturn = (gamma == 1.0d || beta == 0.0d) ? this.xHat : this.xHat.mul(Double.valueOf(gamma)).addi(Double.valueOf(beta));
            } else {
                execAndReturn = this.xHat.mulRowVector(iNDArray2).addiRowVector(iNDArray3);
            }
        } else {
            if (iNDArray.rank() != 4) {
                throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported.");
            }
            if (!Shape.strideDescendingCAscendingF(iNDArray)) {
                iNDArray = iNDArray.dup();
            }
            this.xMu = Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(iNDArray, param, Nd4j.createUninitialized(iNDArray.shape(), iNDArray.ordering()), new int[]{1}));
            this.xHat = Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(this.xMu, this.std, Nd4j.createUninitialized(iNDArray.shape(), iNDArray.ordering()), new int[]{1}));
            if (layerConf.isLockGammaBeta()) {
                double gamma2 = layerConf.getGamma();
                double beta2 = layerConf.getBeta();
                execAndReturn = (gamma2 == 1.0d || beta2 == 0.0d) ? this.xHat : this.xHat.mul(Double.valueOf(gamma2)).addi(Double.valueOf(beta2));
            } else {
                INDArray execAndReturn2 = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(this.xHat, iNDArray2, Nd4j.createUninitialized(iNDArray.shape(), iNDArray.ordering()), new int[]{1}));
                execAndReturn = Nd4j.getExecutioner().execAndReturn(new BroadcastAddOp(execAndReturn2, iNDArray3, execAndReturn2, new int[]{1}));
            }
        }
        if (trainingMode == Layer.TrainingMode.TRAIN) {
            if (layerConf.isMinibatch()) {
                double decay = layerConf.getDecay();
                param3.muli(Double.valueOf(decay)).addi(param.muli(Double.valueOf(1.0d - decay)));
                param4.muli(Double.valueOf(decay)).addi(param2.muli(Double.valueOf(1.0d - decay)));
            } else {
                param3.assign(param);
                param4.assign(param2);
            }
        }
        return execAndReturn;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(Layer.TrainingMode trainingMode) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, Layer.TrainingMode trainingMode) {
        return preOutput(iNDArray, trainingMode);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, boolean z) {
        return preOutput(iNDArray, z ? Layer.TrainingMode.TRAIN : Layer.TrainingMode.TEST);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Layer mo83clone() {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Collection<IterationListener> getListeners() {
        return this.listeners;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer, org.deeplearning4j.nn.api.Model
    public void setListeners(IterationListener... iterationListenerArr) {
        this.listeners = new ArrayList(Arrays.asList(iterationListenerArr));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public void setIndex(int i) {
        this.index = i;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public int getIndex() {
        return this.index;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    public int[] getShape(INDArray iNDArray) {
        if (iNDArray.rank() == 2 || iNDArray.rank() == 4) {
            return new int[]{1, iNDArray.size(1)};
        }
        if (iNDArray.rank() != 3) {
            throw new IllegalStateException("Unable to process input of rank " + iNDArray.rank());
        }
        int size = iNDArray.size(1);
        int size2 = iNDArray.size(2);
        if (iNDArray.size(0) <= 1 || size * size2 != iNDArray.length()) {
            return new int[]{1, size * size2};
        }
        throw new IllegalArgumentException("Illegal input for batch size");
    }
}
