package smile.base.mlp;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Properties;
import java.util.stream.Collectors;
import smile.math.MathEx;
import smile.math.TimeFunction;

/* loaded from: input_file:smile/base/mlp/MultilayerPerceptron.class */
public abstract class MultilayerPerceptron implements Serializable {
    private static final long serialVersionUID = 2;
    protected int p;
    protected OutputLayer output;
    protected Layer[] net;
    protected transient ThreadLocal<double[]> target;
    protected TimeFunction learningRate = TimeFunction.constant(0.01d);
    protected TimeFunction momentum = null;
    protected double rho = 0.0d;
    protected double epsilon = 1.0E-7d;
    protected double lambda = 0.0d;
    protected double clipValue = 0.0d;
    protected double clipNorm = 0.0d;
    protected int t = 0;

    public MultilayerPerceptron(Layer... layerArr) {
        if (layerArr.length <= 2) {
            throw new IllegalArgumentException("Too few layers: " + layerArr.length);
        }
        if (!(layerArr[0] instanceof InputLayer)) {
            throw new IllegalArgumentException("The first layer is not an InputLayer: " + layerArr[0]);
        }
        if (!(layerArr[layerArr.length - 1] instanceof OutputLayer)) {
            throw new IllegalArgumentException("The last layer is not an OutputLayer: " + layerArr[layerArr.length - 1]);
        }
        Layer layer = layerArr[0];
        for (int i = 1; i < layerArr.length; i++) {
            Layer layer2 = layerArr[i];
            if (layer2.getInputSize() != layer.getOutputSize()) {
                throw new IllegalArgumentException(String.format("Invalid network architecture. Layer %d has %d neurons while layer %d takes %d inputs", Integer.valueOf(i - 1), Integer.valueOf(layer.getOutputSize()), Integer.valueOf(i), Integer.valueOf(layer2.getInputSize())));
            }
            layer = layer2;
        }
        this.output = (OutputLayer) layerArr[layerArr.length - 1];
        this.net = (Layer[]) Arrays.copyOf(layerArr, layerArr.length - 1);
        this.p = layerArr[0].getInputSize();
        init();
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        init();
    }

    private void init() {
        this.target = ThreadLocal.withInitial(() -> {
            return new double[this.output.getOutputSize()];
        });
    }

    public String toString() {
        String format = String.format("%s -> %s(learning rate = %s", Arrays.stream(this.net).map((v0) -> {
            return v0.toString();
        }).collect(Collectors.joining(" -> ")), this.output, this.learningRate);
        if (this.momentum != null) {
            format = String.format("%s, momentum = %s", format, this.momentum);
        }
        if (this.lambda != 0.0d) {
            format = String.format("%s, weight decay = %f", format, Double.valueOf(this.lambda));
        }
        if (this.rho != 0.0d) {
            format = String.format("%s, RMSProp = %f", format, Double.valueOf(this.rho));
        }
        return format + ")";
    }

    public void setLearningRate(TimeFunction timeFunction) {
        this.learningRate = timeFunction;
    }

    public void setMomentum(TimeFunction timeFunction) {
        this.momentum = timeFunction;
    }

    public void setRMSProp(double d, double d2) {
        if (d < 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("Invalid rho = " + d);
        }
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Invalid epsilon = " + d2);
        }
        this.rho = d;
        this.epsilon = d2;
    }

    public void setWeightDecay(double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid weight decay factor: " + d);
        }
        this.lambda = d;
    }

    public void setClipValue(double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid gradient clipping value: " + d);
        }
        this.clipValue = d;
    }

    public void setClipNorm(double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid gradient clipping norm: " + d);
        }
        this.clipNorm = d;
    }

    public double getLearningRate() {
        return this.learningRate.apply(this.t);
    }

    public double getMomentum() {
        if (this.momentum == null) {
            return 0.0d;
        }
        return this.momentum.apply(this.t);
    }

    public double getWeightDecay() {
        return this.lambda;
    }

    public double getClipValue() {
        return this.clipValue;
    }

    public double getClipNorm() {
        return this.clipNorm;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void propagate(double[] dArr, boolean z) {
        double[] dArr2 = dArr;
        for (Layer layer : this.net) {
            layer.propagate(dArr2);
            if (z) {
                layer.propagateDropout();
            }
            dArr2 = layer.output();
        }
        this.output.propagate(dArr2);
    }

    private void clipGradient(double[] dArr) {
        if (this.clipNorm > 0.0d) {
            double norm = MathEx.norm(dArr);
            if (norm > this.clipNorm) {
                double d = this.clipNorm / norm;
                for (int i = 0; i < dArr.length; i++) {
                    int i2 = i;
                    dArr[i2] = dArr[i2] * d;
                }
                return;
            }
            return;
        }
        if (this.clipValue > 0.0d) {
            for (int i3 = 0; i3 < dArr.length; i3++) {
                if (dArr[i3] > this.clipValue) {
                    dArr[i3] = this.clipValue;
                } else if (dArr[i3] < (-this.clipValue)) {
                    dArr[i3] = -this.clipValue;
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v58, types: [smile.base.mlp.Layer[]] */
    /* JADX WARN: Type inference failed for: r0v59 */
    public void backpropagate(boolean z) {
        this.output.computeOutputGradient(this.target.get(), 1.0d);
        clipGradient(this.output.gradient());
        OutputLayer outputLayer = this.output;
        int length = this.net.length;
        while (true) {
            length--;
            if (length <= 0) {
                break;
            }
            outputLayer.backpropagate(this.net[length].gradient());
            outputLayer = this.net[length];
            outputLayer.backpopagateDropout();
            clipGradient(outputLayer.gradient());
        }
        outputLayer.backpropagate(null);
        if (!z) {
            double[] output = this.net[0].output();
            for (int i = 1; i < this.net.length; i++) {
                Layer layer = this.net[i];
                layer.computeGradient(output);
                output = layer.output();
            }
            this.output.computeGradient(output);
            return;
        }
        double learningRate = getLearningRate();
        if (learningRate <= 0.0d) {
            throw new IllegalArgumentException("Invalid learning rate: " + learningRate);
        }
        double momentum = getMomentum();
        if (momentum < 0.0d || momentum >= 1.0d) {
            throw new IllegalArgumentException("Invalid momentum factor: " + momentum);
        }
        double d = 1.0d - ((2.0d * learningRate) * this.lambda);
        if (d < 0.9d) {
            throw new IllegalStateException(String.format("Invalid learning rate (eta = %.2f) and/or L2 regularization (lambda = %.2f) such that weight decay = %.2f", Double.valueOf(learningRate), Double.valueOf(this.lambda), Double.valueOf(d)));
        }
        double[] output2 = this.net[0].output();
        for (int i2 = 1; i2 < this.net.length; i2++) {
            Layer layer2 = this.net[i2];
            layer2.computeGradientUpdate(output2, learningRate, momentum, d);
            output2 = layer2.output();
        }
        this.output.computeGradientUpdate(output2, learningRate, momentum, d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void update(int i) {
        double learningRate = getLearningRate();
        if (learningRate <= 0.0d) {
            throw new IllegalArgumentException("Invalid learning rate: " + learningRate);
        }
        double momentum = getMomentum();
        if (momentum < 0.0d || momentum >= 1.0d) {
            throw new IllegalArgumentException("Invalid momentum factor: " + momentum);
        }
        double d = 1.0d - ((2.0d * learningRate) * this.lambda);
        if (d < 0.9d) {
            throw new IllegalStateException(String.format("Invalid learning rate (eta = %.2f) and/or decay (lambda = %.2f)", Double.valueOf(learningRate), Double.valueOf(this.lambda)));
        }
        for (int i2 = 1; i2 < this.net.length; i2++) {
            this.net[i2].update(i, learningRate, momentum, d, this.rho, this.epsilon);
        }
        this.output.update(i, learningRate, momentum, d, this.rho, this.epsilon);
    }

    public void setParameters(Properties properties) {
        String property = properties.getProperty("smile.mlp.learning_rate");
        if (property != null) {
            setLearningRate(TimeFunction.of(property));
        }
        String property2 = properties.getProperty("smile.mlp.weight_decay");
        if (property2 != null) {
            setWeightDecay(Double.parseDouble(property2));
        }
        String property3 = properties.getProperty("smile.mlp.momentum");
        if (property3 != null) {
            setMomentum(TimeFunction.of(property3));
        }
        String property4 = properties.getProperty("smile.mlp.clip_value");
        if (property4 != null) {
            setClipValue(Double.parseDouble(property4));
        }
        String property5 = properties.getProperty("smile.mlp.clip_norm");
        if (property5 != null) {
            setClipNorm(Double.parseDouble(property5));
        }
        String property6 = properties.getProperty("smile.mlp.RMSProp.rho");
        if (property6 != null) {
            setRMSProp(Double.parseDouble(property6), Double.parseDouble(properties.getProperty("smile.mlp.RMSProp.epsilon", "1E-7")));
        }
    }
}
