package org.nd4j.linalg.learning;

import java.io.Serializable;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/learning/Adam.class */
public class Adam implements Serializable, GradientUpdater {
    public static final double DEFAULT_ADAM_EPSILON = 1.0E-8d;
    public static final double DEFAULT_ADAM_BETA1_MEAN_DECAY = 0.9d;
    public static final double DEFAULT_ADAM_BETA2_VAR_DECAY = 0.999d;
    private double learningRate;
    private double beta1;
    private double beta2;
    private double epsilon;
    private INDArray m;
    private INDArray v;

    /* loaded from: input_file:org/nd4j/linalg/learning/Adam$AdamAggregator.class */
    public static class AdamAggregator implements GradientUpdaterAggregator {
        private INDArray mSum;
        private INDArray vSum;
        private double lrSum;
        private double beta1Sum;
        private double beta2Sum;
        private double epsilonSum;
        private int count = 0;

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public GradientUpdater getUpdater() {
            Adam adam = new Adam(this.lrSum / this.count, this.beta1Sum / this.count, this.beta2Sum / this.count, this.epsilonSum / this.count);
            adam.setM(this.mSum.div(Integer.valueOf(this.count)));
            adam.setV(this.vSum.div(Integer.valueOf(this.count)));
            return adam;
        }

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public void aggregate(GradientUpdater gradientUpdater) {
            if (!(gradientUpdater instanceof Adam)) {
                throw new UnsupportedOperationException("Cannot aggregate Adam with updater: " + gradientUpdater);
            }
            Adam adam = (Adam) gradientUpdater;
            if (this.mSum == null) {
                this.mSum = adam.m.dup();
                this.vSum = adam.v.dup();
                this.lrSum = adam.learningRate;
                this.beta1Sum = adam.beta1;
                this.beta2Sum = adam.beta2;
                this.epsilonSum = adam.epsilon;
            } else {
                this.mSum.addi(adam.m);
                this.vSum.addi(adam.v);
                this.lrSum += adam.learningRate;
                this.beta1Sum += adam.beta1;
                this.beta2Sum += adam.beta2;
                this.epsilonSum += adam.epsilon;
            }
            this.count++;
        }

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public GradientUpdaterAggregator combine(GradientUpdaterAggregator gradientUpdaterAggregator) {
            if (!(gradientUpdaterAggregator instanceof AdamAggregator)) {
                throw new IllegalArgumentException("Cannot combine AdamAggregator with aggregator: " + gradientUpdaterAggregator);
            }
            AdamAggregator adamAggregator = (AdamAggregator) gradientUpdaterAggregator;
            this.mSum.addi(adamAggregator.mSum);
            this.vSum.addi(adamAggregator.vSum);
            this.lrSum += adamAggregator.lrSum;
            this.beta1Sum += adamAggregator.beta1Sum;
            this.beta2Sum += adamAggregator.beta2Sum;
            this.epsilonSum += adamAggregator.epsilonSum;
            this.count += adamAggregator.count;
            return this;
        }
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public int stateSizeForInputSize(int i) {
        return 2 * i;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void setStateViewArray(INDArray iNDArray, int[] iArr, char c, boolean z) {
        if (!iNDArray.isRowVector()) {
            throw new IllegalArgumentException("Invalid input: expect row vector input");
        }
        if (z) {
            iNDArray.assign((Number) 0);
        }
        int length = iNDArray.length();
        this.m = iNDArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
        this.v = iNDArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));
        this.m = Shape.newShapeNoCopy(this.m, iArr, c == 'f');
        this.v = Shape.newShapeNoCopy(this.v, iArr, c == 'f');
        if (this.m == null || this.v == null) {
            throw new IllegalStateException("Could not correctly reshape gradient view arrays");
        }
    }

    public Adam(double d, double d2, double d3, double d4) {
        this.learningRate = 0.001d;
        this.beta1 = 0.9d;
        this.beta2 = 0.999d;
        this.epsilon = 1.0E-8d;
        this.learningRate = d;
        this.beta1 = d2;
        this.beta2 = d3;
        this.epsilon = d4;
    }

    public Adam(double d, double d2, double d3) {
        this.learningRate = 0.001d;
        this.beta1 = 0.9d;
        this.beta2 = 0.999d;
        this.epsilon = 1.0E-8d;
        this.learningRate = d;
        this.beta1 = d2;
        this.beta2 = d3;
    }

    public Adam(double d) {
        this.learningRate = 0.001d;
        this.beta1 = 0.9d;
        this.beta2 = 0.999d;
        this.epsilon = 1.0E-8d;
        this.learningRate = d;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void update(Object... objArr) {
        if (objArr.length > 0) {
            this.learningRate = ((Double) objArr[0]).doubleValue();
        }
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public INDArray getGradient(INDArray iNDArray, int i) {
        if (this.m == null || this.v == null) {
            throw new IllegalStateException("Updater has not been initialized with view state");
        }
        this.m.muli(Double.valueOf(this.beta1)).addi(iNDArray.mul(Double.valueOf(1.0d - this.beta1)));
        this.v.muli(Double.valueOf(this.beta2)).addi(iNDArray.mul(iNDArray).muli(Double.valueOf(1.0d - this.beta2)));
        double pow = FastMath.pow(this.beta1, i + 1);
        double sqrt = (this.learningRate * FastMath.sqrt(1.0d - FastMath.pow(this.beta2, i + 1))) / (1.0d - pow);
        if (Double.isNaN(sqrt) || sqrt == 0.0d) {
            sqrt = this.epsilon;
        }
        iNDArray.assign(this.m.mul(Double.valueOf(sqrt)).divi(Transforms.sqrt(this.v, true).addi(Double.valueOf(this.epsilon))));
        return iNDArray;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public GradientUpdaterAggregator getAggregator(boolean z) {
        AdamAggregator adamAggregator = new AdamAggregator();
        if (z) {
            adamAggregator.aggregate(this);
        }
        return adamAggregator;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public double getBeta1() {
        return this.beta1;
    }

    public double getBeta2() {
        return this.beta2;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public INDArray getM() {
        return this.m;
    }

    public INDArray getV() {
        return this.v;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public void setBeta1(double d) {
        this.beta1 = d;
    }

    public void setBeta2(double d) {
        this.beta2 = d;
    }

    public void setEpsilon(double d) {
        this.epsilon = d;
    }

    public void setM(INDArray iNDArray) {
        this.m = iNDArray;
    }

    public void setV(INDArray iNDArray) {
        this.v = iNDArray;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof Adam)) {
            return false;
        }
        Adam adam = (Adam) obj;
        if (!adam.canEqual(this) || Double.compare(getLearningRate(), adam.getLearningRate()) != 0 || Double.compare(getBeta1(), adam.getBeta1()) != 0 || Double.compare(getBeta2(), adam.getBeta2()) != 0 || Double.compare(getEpsilon(), adam.getEpsilon()) != 0) {
            return false;
        }
        INDArray m = getM();
        INDArray m2 = adam.getM();
        if (m == null) {
            if (m2 != null) {
                return false;
            }
        } else if (!m.equals(m2)) {
            return false;
        }
        INDArray v = getV();
        INDArray v2 = adam.getV();
        return v == null ? v2 == null : v.equals(v2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof Adam;
    }

    public int hashCode() {
        long doubleToLongBits = Double.doubleToLongBits(getLearningRate());
        int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        long doubleToLongBits2 = Double.doubleToLongBits(getBeta1());
        int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
        long doubleToLongBits3 = Double.doubleToLongBits(getBeta2());
        int i3 = (i2 * 59) + ((int) ((doubleToLongBits3 >>> 32) ^ doubleToLongBits3));
        long doubleToLongBits4 = Double.doubleToLongBits(getEpsilon());
        int i4 = (i3 * 59) + ((int) ((doubleToLongBits4 >>> 32) ^ doubleToLongBits4));
        INDArray m = getM();
        int hashCode = (i4 * 59) + (m == null ? 43 : m.hashCode());
        INDArray v = getV();
        return (hashCode * 59) + (v == null ? 43 : v.hashCode());
    }

    public String toString() {
        return "Adam(learningRate=" + getLearningRate() + ", beta1=" + getBeta1() + ", beta2=" + getBeta2() + ", epsilon=" + getEpsilon() + ", m=" + getM() + ", v=" + getV() + ")";
    }

    public Adam() {
        this.learningRate = 0.001d;
        this.beta1 = 0.9d;
        this.beta2 = 0.999d;
        this.epsilon = 1.0E-8d;
    }
}
