package org.deeplearning4j.nn.layers.convolution;

import java.util.Arrays;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNConvHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.util.OneTimeLogger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.class */
public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.ConvolutionLayer> {
    protected static final Logger log = LoggerFactory.getLogger(ConvolutionLayer.class);
    protected INDArray i2d;
    protected ConvolutionHelper helper;
    protected int helperCountFail;
    protected ConvolutionMode convolutionMode;
    protected transient INDArray dummyBias;
    protected transient INDArray dummyBiasGrad;

    public ConvolutionLayer(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
        this.helper = null;
        this.helperCountFail = 0;
        initializeHelper();
        this.convolutionMode = ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf().getLayer()).getConvolutionMode();
    }

    void initializeHelper() {
        String property = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
        if ("CUDA".equalsIgnoreCase(property)) {
            try {
                this.helper = (ConvolutionHelper) Class.forName("org.deeplearning4j.nn.layers.convolution.CudnnConvolutionHelper").asSubclass(ConvolutionHelper.class).getConstructor(DataType.class).newInstance(this.dataType);
                log.debug("CudnnConvolutionHelper successfully initialized");
                if (!this.helper.checkSupported()) {
                    this.helper = null;
                }
            } catch (Throwable th) {
                if (th instanceof ClassNotFoundException) {
                    OneTimeLogger.info(log, "cuDNN not found: use cuDNN for better GPU performance by including the deeplearning4j-cuda module. For more information, please refer to: https://deeplearning4j.org/docs/latest/deeplearning4j-config-cudnn", new Object[]{th});
                } else {
                    log.warn("Could not initialize CudnnConvolutionHelper", th);
                }
            }
        } else if ("CPU".equalsIgnoreCase(property)) {
            this.helper = new MKLDNNConvHelper(this.dataType);
            log.debug("Created MKLDNNConvHelper, layer {}", layerConf().getLayerName());
        }
        if (this.helper == null || this.helper.checkSupported()) {
            return;
        }
        log.debug("Removed helper {} as not supported", this.helper.getClass());
        this.helper = null;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.CONVOLUTIONAL;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        int[] padding;
        int[] outputSize;
        assertInputSet(true);
        INDArray paramWithNoise = getParamWithNoise("W", true, layerWorkspaceMgr);
        INDArray paramWithNoise2 = getParamWithNoise("b", true, layerWorkspaceMgr);
        INDArray castTo = this.input.castTo(this.dataType);
        int size = (int) castTo.size(0);
        int size2 = (int) castTo.size(2);
        int size3 = (int) castTo.size(3);
        int size4 = (int) paramWithNoise.size(0);
        int size5 = (int) paramWithNoise.size(1);
        int size6 = (int) paramWithNoise.size(2);
        int size7 = (int) paramWithNoise.size(3);
        int[] dilation = layerConf().getDilation();
        int[] kernelSize = layerConf().getKernelSize();
        int[] stride = layerConf().getStride();
        if (this.convolutionMode == ConvolutionMode.Same) {
            outputSize = ConvolutionUtils.getOutputSize(castTo, kernelSize, stride, null, this.convolutionMode, dilation);
            padding = ConvolutionUtils.getSameModeTopLeftPadding(outputSize, new int[]{size2, size3}, kernelSize, stride, dilation);
        } else {
            padding = layerConf().getPadding();
            outputSize = ConvolutionUtils.getOutputSize(castTo, kernelSize, stride, padding, this.convolutionMode, dilation);
        }
        int i = outputSize[0];
        int i2 = outputSize[1];
        INDArray iNDArray2 = this.gradientViews.get("b");
        INDArray iNDArray3 = this.gradientViews.get("W");
        INDArray transpose = Shape.newShapeNoCopy(iNDArray3, new int[]{size4, size5 * size6 * size7}, false).transpose();
        IActivation activationFn = layerConf().getActivationFn();
        Pair<INDArray, INDArray> preOutput4d = preOutput4d(true, true, layerWorkspaceMgr);
        INDArray iNDArray4 = (INDArray) activationFn.backprop((INDArray) preOutput4d.getFirst(), iNDArray).getFirst();
        if (this.helper != null && (this.helperCountFail == 0 || !layerConf().isCudnnAllowFallback())) {
            if (!hasBias() && !(this.helper instanceof MKLDNNConvHelper)) {
                if (this.dummyBiasGrad == null) {
                    MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                    Throwable th = null;
                    try {
                        try {
                            this.dummyBiasGrad = Nd4j.create(new long[]{1, layerConf().getNOut()});
                            if (scopeOutOfWorkspaces != null) {
                                if (0 != 0) {
                                    try {
                                        scopeOutOfWorkspaces.close();
                                    } catch (Throwable th2) {
                                        th.addSuppressed(th2);
                                    }
                                } else {
                                    scopeOutOfWorkspaces.close();
                                }
                            }
                        } finally {
                        }
                    } catch (Throwable th3) {
                        if (scopeOutOfWorkspaces != null) {
                            if (th != null) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                        throw th3;
                    }
                }
                iNDArray2 = this.dummyBiasGrad;
            }
            Pair<Gradient, INDArray> pair = null;
            try {
                pair = this.helper.backpropGradient(castTo, paramWithNoise, paramWithNoise2, iNDArray4, kernelSize, stride, padding, iNDArray2, iNDArray3, activationFn, layerConf().getCudnnAlgoMode(), layerConf().getCudnnBwdFilterAlgo(), layerConf().getCudnnBwdDataAlgo(), this.convolutionMode, dilation, layerWorkspaceMgr);
            } catch (Exception e) {
                if (e.getMessage().contains("Failed to allocate")) {
                    throw e;
                }
                if (!layerConf().isCudnnAllowFallback()) {
                    throw new RuntimeException("Error during ConvolutionLayer MKL/CuDNN helper backprop - isCudnnAllowFallback() is set to false", e);
                }
                this.helperCountFail++;
                if (this.helper instanceof MKLDNNConvHelper) {
                    log.warn("MKL-DNN execution failed - falling back on built-in implementation", e);
                } else {
                    log.warn("CuDNN execution failed - falling back on built-in implementation", e);
                }
            }
            if (pair != null) {
                pair.setSecond(backpropDropOutIfPresent((INDArray) pair.getRight()));
                return pair;
            }
        }
        INDArray reshape = iNDArray4.permute(new int[]{1, 0, 2, 3}).reshape('c', new int[]{size4, size * i * i2});
        INDArray iNDArray5 = (INDArray) preOutput4d.getSecond();
        if (iNDArray5 == null) {
            INDArray createUninitialized = Nd4j.createUninitialized(this.dataType, new long[]{size, i, i2, size5, size6, size7}, 'c');
            Convolution.im2col(castTo, size6, size7, stride[0], stride[1], padding[0], padding[1], dilation[0], dilation[1], this.convolutionMode == ConvolutionMode.Same, createUninitialized.permute(new int[]{0, 3, 4, 5, 1, 2}));
            iNDArray5 = createUninitialized.reshape('c', size * i * i2, size5 * size6 * size7);
        }
        Nd4j.gemm(iNDArray5, reshape, transpose, true, true, 1.0d, EvaluationBinary.DEFAULT_EDGE_VALUE);
        INDArray permute = Shape.newShapeNoCopy(paramWithNoise.permute(new int[]{3, 2, 1, 0}).reshape('f', size5 * size6 * size7, size4).mmul(reshape), new int[]{size7, size6, size5, i2, i, size}, true).permute(new int[]{5, 2, 1, 0, 4, 3});
        INDArray permute2 = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, permute.dataType(), new long[]{size5, size, size2, size3}, 'c').permute(new int[]{1, 0, 2, 3});
        Convolution.col2im(permute, permute2, stride[0], stride[1], padding[0], padding[1], size2, size3, dilation[0], dilation[1]);
        DefaultGradient defaultGradient = new DefaultGradient();
        if (layerConf().hasBias()) {
            reshape.sum(iNDArray2, new int[]{1});
            defaultGradient.setGradientFor("b", iNDArray2);
        }
        defaultGradient.setGradientFor("W", iNDArray3, 'c');
        this.weightNoiseParams.clear();
        return new Pair<>(defaultGradient, backpropDropOutIfPresent(permute2));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Pair<INDArray, INDArray> preOutput4d(boolean z, boolean z2, LayerWorkspaceMgr layerWorkspaceMgr) {
        return preOutput(z, z2, layerWorkspaceMgr);
    }

    protected void validateInputRank() {
        if (this.input.rank() != 4) {
            String layerName = this.conf.getLayer().getLayerName();
            if (layerName == null) {
                layerName = "(not named)";
            }
            throw new DL4JInvalidInputException("Got rank " + this.input.rank() + " array as input to ConvolutionLayer (layer name = " + layerName + ", layer index = " + this.index + ") with shape " + Arrays.toString(this.input.shape()) + ". Expected rank 4 array with shape [minibatchSize, layerInputDepth, inputHeight, inputWidth]." + (this.input.rank() == 2 ? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)" : "") + " " + layerId());
        }
    }

    protected void validateInputDepth(int i) {
        if (this.input.size(1) != i) {
            String layerName = this.conf.getLayer().getLayerName();
            if (layerName == null) {
                layerName = "(not named)";
            }
            throw new DL4JInvalidInputException("Cannot do forward pass in Convolution layer (layer name = " + layerName + ", layer index = " + this.index + "): input array channels does not match CNN layer configuration (data input channels = " + this.input.size(1) + ", [minibatch,inputDepth,height,width]=" + Arrays.toString(this.input.shape()) + "; expected input channels = " + i + ") " + layerId());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Pair<INDArray, INDArray> preOutput(boolean z, boolean z2, LayerWorkspaceMgr layerWorkspaceMgr) {
        int[] padding;
        int[] outputSize;
        assertInputSet(false);
        INDArray paramWithNoise = getParamWithNoise("b", z, layerWorkspaceMgr);
        INDArray paramWithNoise2 = getParamWithNoise("W", z, layerWorkspaceMgr);
        validateInputRank();
        INDArray castTo = this.input.castTo(this.dataType);
        int size = (int) castTo.size(0);
        int size2 = (int) paramWithNoise2.size(0);
        int size3 = (int) paramWithNoise2.size(1);
        validateInputDepth(size3);
        int size4 = (int) paramWithNoise2.size(2);
        int size5 = (int) paramWithNoise2.size(3);
        int[] dilation = layerConf().getDilation();
        int[] kernelSize = layerConf().getKernelSize();
        int[] stride = layerConf().getStride();
        if (this.convolutionMode == ConvolutionMode.Same) {
            outputSize = ConvolutionUtils.getOutputSize(castTo, kernelSize, stride, null, this.convolutionMode, dilation);
            padding = ConvolutionUtils.getSameModeTopLeftPadding(outputSize, new int[]{(int) castTo.size(2), (int) castTo.size(3)}, kernelSize, stride, dilation);
        } else {
            padding = layerConf().getPadding();
            outputSize = ConvolutionUtils.getOutputSize(castTo, kernelSize, stride, padding, this.convolutionMode, dilation);
        }
        int i = outputSize[0];
        int i2 = outputSize[1];
        if (this.helper != null && (this.helperCountFail == 0 || !layerConf().isCudnnAllowFallback())) {
            if (this.preOutput != null && z2) {
                return new Pair<>(this.preOutput, (Object) null);
            }
            if (!hasBias()) {
                if (this.dummyBias == null) {
                    MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                    Throwable th = null;
                    try {
                        try {
                            this.dummyBias = Nd4j.create(new long[]{1, layerConf().getNOut()});
                            if (scopeOutOfWorkspaces != null) {
                                if (0 != 0) {
                                    try {
                                        scopeOutOfWorkspaces.close();
                                    } catch (Throwable th2) {
                                        th.addSuppressed(th2);
                                    }
                                } else {
                                    scopeOutOfWorkspaces.close();
                                }
                            }
                        } finally {
                        }
                    } catch (Throwable th3) {
                        if (scopeOutOfWorkspaces != null) {
                            if (th != null) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                        throw th3;
                    }
                }
                paramWithNoise = this.dummyBias;
            }
            INDArray iNDArray = null;
            try {
                iNDArray = this.helper.preOutput(castTo, paramWithNoise2, paramWithNoise, kernelSize, stride, padding, layerConf().getCudnnAlgoMode(), layerConf().getCudnnFwdAlgo(), this.convolutionMode, dilation, layerWorkspaceMgr);
            } catch (Exception e) {
                if (e.getMessage() != null && e.getMessage().contains("Failed to allocate")) {
                    throw e;
                }
                if (!layerConf().isCudnnAllowFallback()) {
                    throw new RuntimeException(e);
                }
                this.helperCountFail++;
                if (this.helper instanceof MKLDNNConvHelper) {
                    log.warn("MKL-DNN execution failed - falling back on built-in implementation", e);
                } else {
                    log.warn("CuDNN execution failed - falling back on built-in implementation", e);
                }
            }
            if (iNDArray != null) {
                return new Pair<>(iNDArray, (Object) null);
            }
        }
        if (this.preOutput != null && this.i2d != null && z2) {
            return new Pair<>(this.preOutput, this.i2d);
        }
        INDArray createUninitialized = Nd4j.createUninitialized(paramWithNoise2.dataType(), new long[]{size, i, i2, size3, size4, size5}, 'c');
        INDArray permute = createUninitialized.permute(new int[]{0, 3, 4, 5, 1, 2});
        Convolution.im2col(castTo.castTo(permute.dataType()), size4, size5, stride[0], stride[1], padding[0], padding[1], dilation[0], dilation[1], this.convolutionMode == ConvolutionMode.Same, permute);
        INDArray newShapeNoCopy = Shape.newShapeNoCopy(createUninitialized, new int[]{size * i * i2, size3 * size4 * size5}, false);
        INDArray reshape = paramWithNoise2.permute(new int[]{3, 2, 1, 0}).reshape('f', size5 * size4 * size3, size2);
        INDArray createUninitialized2 = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, paramWithNoise2.dataType(), new long[]{newShapeNoCopy.size(0), reshape.size(1)}, 'f');
        newShapeNoCopy.mmuli(reshape, createUninitialized2);
        if (layerConf().hasBias()) {
            createUninitialized2.addiRowVector(paramWithNoise);
        }
        INDArray permute2 = Shape.newShapeNoCopy(createUninitialized2, new int[]{i2, i, size, size2}, true).permute(new int[]{2, 3, 1, 0});
        if (z && this.cacheMode != CacheMode.NONE && layerWorkspaceMgr.hasConfiguration(ArrayType.FF_CACHE) && layerWorkspaceMgr.isWorkspaceOpen(ArrayType.FF_CACHE)) {
            MemoryWorkspace notifyScopeBorrowed = layerWorkspaceMgr.notifyScopeBorrowed(ArrayType.FF_CACHE);
            Throwable th5 = null;
            try {
                try {
                    this.i2d = newShapeNoCopy.unsafeDuplication();
                    if (notifyScopeBorrowed != null) {
                        if (0 != 0) {
                            try {
                                notifyScopeBorrowed.close();
                            } catch (Throwable th6) {
                                th5.addSuppressed(th6);
                            }
                        } else {
                            notifyScopeBorrowed.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th7) {
                if (notifyScopeBorrowed != null) {
                    if (th5 != null) {
                        try {
                            notifyScopeBorrowed.close();
                        } catch (Throwable th8) {
                            th5.addSuppressed(th8);
                        }
                    } else {
                        notifyScopeBorrowed.close();
                    }
                }
                throw th7;
            }
        }
        return new Pair<>(permute2, z2 ? newShapeNoCopy : null);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray activate;
        if (this.input == null) {
            throw new IllegalArgumentException("Cannot perform forward pass with null input " + layerId());
        }
        if (this.cacheMode == null) {
            this.cacheMode = CacheMode.NONE;
        }
        applyDropOutIfNecessary(z, layerWorkspaceMgr);
        INDArray iNDArray = (INDArray) preOutput(z, false, layerWorkspaceMgr).getFirst();
        if (z && this.cacheMode != CacheMode.NONE && layerWorkspaceMgr.hasConfiguration(ArrayType.FF_CACHE) && layerWorkspaceMgr.isWorkspaceOpen(ArrayType.FF_CACHE)) {
            MemoryWorkspace notifyScopeBorrowed = layerWorkspaceMgr.notifyScopeBorrowed(ArrayType.FF_CACHE);
            Throwable th = null;
            try {
                this.preOutput = iNDArray.unsafeDuplication();
                if (notifyScopeBorrowed != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeBorrowed.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        notifyScopeBorrowed.close();
                    }
                }
            } catch (Throwable th3) {
                if (notifyScopeBorrowed != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeBorrowed.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        notifyScopeBorrowed.close();
                    }
                }
                throw th3;
            }
        }
        return (this.helper == null || !Shape.strideDescendingCAscendingF(iNDArray) || (activate = this.helper.activate(iNDArray, layerConf().getActivationFn(), z)) == null) ? layerConf().getActivationFn().getActivation(iNDArray, z) : activate;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer
    public boolean hasBias() {
        return layerConf().hasBias();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public LayerHelper getHelper() {
        return this.helper;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        setParams(iNDArray, 'c');
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        return iNDArray == null ? new Pair<>(iNDArray, maskState) : new Pair<>(ConvolutionUtils.cnn2dMaskReduction(iNDArray, layerConf().getKernelSize(), layerConf().getStride(), layerConf().getPadding(), layerConf().getDilation(), layerConf().getConvolutionMode()), maskState);
    }
}
