package org.nd4j.linalg.activations.impl;

import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.transforms.ELU;
import org.nd4j.linalg.api.ops.impl.transforms.ELUDerivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;

/* loaded from: input_file:org/nd4j/linalg/activations/impl/ActivationELU.class */
public class ActivationELU extends BaseActivationFunction {
    public static final double DEFAULT_ALPHA = 1.0d;
    private double alpha;

    public ActivationELU() {
        this(1.0d);
    }

    public ActivationELU(double d) {
        this.alpha = 1.0d;
        this.alpha = d;
    }

    @Override // org.nd4j.linalg.activations.IActivation
    public INDArray getActivation(INDArray iNDArray, boolean z) {
        if (this.alpha != 1.0d) {
            INDArray execAndReturn = Nd4j.getExecutioner().execAndReturn((TransformOp) new ELU(iNDArray.dup()));
            execAndReturn.muli(Double.valueOf(this.alpha));
            BooleanIndexing.replaceWhere(iNDArray, execAndReturn, Conditions.lessThan((Number) 0));
        } else {
            Nd4j.getExecutioner().execAndReturn((TransformOp) new ELU(iNDArray));
        }
        return iNDArray;
    }

    @Override // org.nd4j.linalg.activations.IActivation
    public Pair<INDArray, INDArray> backprop(INDArray iNDArray, INDArray iNDArray2) {
        if (this.alpha == 1.0d) {
            INDArray execAndReturn = Nd4j.getExecutioner().execAndReturn(new ELU(iNDArray).derivative());
            execAndReturn.muli(iNDArray2);
            return new Pair<>(execAndReturn, (Object) null);
        }
        INDArray execAndReturn2 = Nd4j.getExecutioner().execAndReturn((TransformOp) new ELUDerivative(iNDArray.dup()));
        execAndReturn2.muli(Double.valueOf(this.alpha));
        BooleanIndexing.replaceWhere(execAndReturn2, (Number) 1, Conditions.equals((Number) Double.valueOf(this.alpha)));
        execAndReturn2.muli(iNDArray2);
        return new Pair<>(execAndReturn2, (Object) null);
    }

    public String toString() {
        return "elu(alpha=" + this.alpha + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ActivationELU)) {
            return false;
        }
        ActivationELU activationELU = (ActivationELU) obj;
        return activationELU.canEqual(this) && Double.compare(getAlpha(), activationELU.getAlpha()) == 0;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ActivationELU;
    }

    public int hashCode() {
        long doubleToLongBits = Double.doubleToLongBits(getAlpha());
        return (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
    }

    public double getAlpha() {
        return this.alpha;
    }
}
