package smile.base.cart;

import java.util.Arrays;
import java.util.stream.IntStream;
import smile.math.MathEx;
import smile.sort.QuickSelect;

/* loaded from: input_file:smile/base/cart/Loss.class */
public interface Loss {

    /* loaded from: input_file:smile/base/cart/Loss$Type.class */
    public enum Type {
        LeastSquares,
        Quantile,
        LeastAbsoluteDeviation,
        Huber
    }

    double output(int[] iArr, int[] iArr2);

    double intercept(double[] dArr);

    double[] response();

    double[] residual();

    static Loss ls() {
        return new Loss() { // from class: smile.base.cart.Loss.1
            double[] residual;

            @Override // smile.base.cart.Loss
            public double output(int[] iArr, int[] iArr2) {
                int i = 0;
                double d = 0.0d;
                for (int i2 : iArr) {
                    i += iArr2[i2];
                    d += this.residual[i2] * iArr2[i2];
                }
                return d / i;
            }

            @Override // smile.base.cart.Loss
            public double intercept(double[] dArr) {
                int length = dArr.length;
                this.residual = new double[length];
                double mean = MathEx.mean(dArr);
                for (int i = 0; i < length; i++) {
                    this.residual[i] = dArr[i] - mean;
                }
                return mean;
            }

            @Override // smile.base.cart.Loss
            public double[] response() {
                return this.residual;
            }

            @Override // smile.base.cart.Loss
            public double[] residual() {
                return this.residual;
            }

            public String toString() {
                return "LeastSquares";
            }
        };
    }

    static Loss ls(final double[] dArr) {
        return new Loss() { // from class: smile.base.cart.Loss.2
            final double[] residual;

            {
                this.residual = dArr;
            }

            @Override // smile.base.cart.Loss
            public double output(int[] iArr, int[] iArr2) {
                int i = 0;
                double d = 0.0d;
                for (int i2 : iArr) {
                    i += iArr2[i2];
                    d += this.residual[i2] * iArr2[i2];
                }
                return d / i;
            }

            @Override // smile.base.cart.Loss
            public double intercept(double[] dArr2) {
                throw new IllegalStateException("This method should not be called.");
            }

            @Override // smile.base.cart.Loss
            public double[] response() {
                return this.residual;
            }

            @Override // smile.base.cart.Loss
            public double[] residual() {
                throw new IllegalStateException("This method should not be called.");
            }

            public String toString() {
                return "LeastSquares";
            }
        };
    }

    static Loss quantile(final double d) {
        if (d <= 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("Invalid percentile: " + d);
        }
        return new Loss() { // from class: smile.base.cart.Loss.3
            double[] response;
            double[] residual;

            @Override // smile.base.cart.Loss
            public double output(int[] iArr, int[] iArr2) {
                return QuickSelect.select(Arrays.stream(iArr).mapToDouble(i -> {
                    return this.residual[i];
                }).toArray(), (int) (r0.length * d));
            }

            @Override // smile.base.cart.Loss
            public double intercept(double[] dArr) {
                int length = dArr.length;
                this.response = new double[length];
                this.residual = new double[length];
                System.arraycopy(dArr, 0, this.response, 0, length);
                double select = QuickSelect.select(this.response, (int) (length * d));
                for (int i = 0; i < length; i++) {
                    this.residual[i] = dArr[i] - select;
                }
                return select;
            }

            @Override // smile.base.cart.Loss
            public double[] response() {
                for (int i = 0; i < this.residual.length; i++) {
                    this.response[i] = Math.signum(this.residual[i]);
                }
                return this.response;
            }

            @Override // smile.base.cart.Loss
            public double[] residual() {
                return this.residual;
            }

            public String toString() {
                return String.format("Quantile(%3.1f%%)", Double.valueOf(100.0d * d));
            }
        };
    }

    static Loss lad() {
        return new Loss() { // from class: smile.base.cart.Loss.4
            double[] response;
            double[] residual;

            @Override // smile.base.cart.Loss
            public double output(int[] iArr, int[] iArr2) {
                return QuickSelect.median(Arrays.stream(iArr).mapToDouble(i -> {
                    return this.residual[i];
                }).toArray());
            }

            @Override // smile.base.cart.Loss
            public double intercept(double[] dArr) {
                int length = dArr.length;
                this.response = new double[length];
                this.residual = new double[length];
                System.arraycopy(dArr, 0, this.response, 0, length);
                double median = QuickSelect.median(this.response);
                for (int i = 0; i < length; i++) {
                    this.residual[i] = dArr[i] - median;
                }
                return median;
            }

            @Override // smile.base.cart.Loss
            public double[] response() {
                for (int i = 0; i < this.residual.length; i++) {
                    this.response[i] = Math.signum(this.residual[i]);
                }
                return this.response;
            }

            @Override // smile.base.cart.Loss
            public double[] residual() {
                return this.residual;
            }

            public String toString() {
                return "LeastAbsoluteDeviation";
            }
        };
    }

    static Loss huber(final double d) {
        if (d <= 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("Invalid percentile: " + d);
        }
        return new Loss() { // from class: smile.base.cart.Loss.5
            double[] response;
            double[] residual;
            private double delta;

            @Override // smile.base.cart.Loss
            public double output(int[] iArr, int[] iArr2) {
                double median = QuickSelect.median(Arrays.stream(iArr).mapToDouble(i -> {
                    return this.residual[i];
                }).toArray());
                double d2 = 0.0d;
                for (int i2 : iArr) {
                    double d3 = this.residual[i2] - median;
                    d2 += Math.signum(d3) * Math.min(this.delta, Math.abs(d3));
                }
                return median + (d2 / iArr.length);
            }

            @Override // smile.base.cart.Loss
            public double intercept(double[] dArr) {
                int length = dArr.length;
                this.response = new double[length];
                this.residual = new double[length];
                System.arraycopy(dArr, 0, this.response, 0, length);
                double median = QuickSelect.median(this.response);
                for (int i = 0; i < length; i++) {
                    this.residual[i] = dArr[i] - median;
                }
                return median;
            }

            @Override // smile.base.cart.Loss
            public double[] response() {
                int length = this.residual.length;
                for (int i = 0; i < length; i++) {
                    this.response[i] = Math.abs(this.residual[i]);
                }
                this.delta = QuickSelect.select(this.response, (int) (length * d));
                for (int i2 = 0; i2 < length; i2++) {
                    if (Math.abs(this.residual[i2]) <= this.delta) {
                        this.response[i2] = this.residual[i2];
                    } else {
                        this.response[i2] = this.delta * Math.signum(this.residual[i2]);
                    }
                }
                return this.response;
            }

            @Override // smile.base.cart.Loss
            public double[] residual() {
                return this.residual;
            }

            public String toString() {
                return String.format("Huber(%3.1f%%)", Double.valueOf(100.0d * d));
            }
        };
    }

    static Loss logistic(final int[] iArr) {
        final int length = iArr.length;
        return new Loss() { // from class: smile.base.cart.Loss.6
            final int[] y;
            final double[] response;
            final double[] residual;

            {
                this.y = Arrays.stream(iArr).map(i -> {
                    return (2 * i) - 1;
                }).toArray();
                this.response = new double[length];
                this.residual = new double[length];
            }

            @Override // smile.base.cart.Loss
            public double output(int[] iArr2, int[] iArr3) {
                double d = 0.0d;
                double d2 = 0.0d;
                for (int i : iArr2) {
                    double abs = Math.abs(this.response[i]);
                    d += this.response[i];
                    d2 += abs * (2.0d - abs);
                }
                return d / d2;
            }

            @Override // smile.base.cart.Loss
            public double intercept(double[] dArr) {
                double mean = MathEx.mean(this.y);
                double log = 0.5d * Math.log((1.0d + mean) / (1.0d - mean));
                Arrays.fill(this.residual, log);
                return log;
            }

            @Override // smile.base.cart.Loss
            public double[] response() {
                for (int i = 0; i < length; i++) {
                    this.response[i] = (2.0d * this.y[i]) / (1.0d + Math.exp((2 * this.y[i]) * this.residual[i]));
                }
                return this.response;
            }

            @Override // smile.base.cart.Loss
            public double[] residual() {
                return this.residual;
            }

            public String toString() {
                return "Logistic";
            }
        };
    }

    static Loss logistic(final int i, final int i2, final int[] iArr, final double[][] dArr) {
        final int length = iArr.length;
        return new Loss() { // from class: smile.base.cart.Loss.7
            final int[] y;
            final double[] response;
            final double[] residual;

            {
                IntStream stream = Arrays.stream(iArr);
                int i3 = i;
                this.y = stream.map(i4 -> {
                    return i4 == i3 ? 1 : 0;
                }).toArray();
                this.response = new double[length];
                this.residual = new double[length];
            }

            @Override // smile.base.cart.Loss
            public double output(int[] iArr2, int[] iArr3) {
                double d = 0.0d;
                double d2 = 0.0d;
                for (int i3 : iArr2) {
                    double abs = Math.abs(this.response[i3]);
                    d += this.response[i3];
                    d2 += abs * (1.0d - abs);
                }
                return d2 < 1.0E-10d ? d / iArr2.length : ((i2 - 1.0d) / i2) * (d / d2);
            }

            @Override // smile.base.cart.Loss
            public double intercept(double[] dArr2) {
                throw new IllegalStateException("This method should not be called.");
            }

            @Override // smile.base.cart.Loss
            public double[] response() {
                for (int i3 = 0; i3 < length; i3++) {
                    this.response[i3] = this.y[i3] - dArr[i3][i];
                }
                return this.response;
            }

            @Override // smile.base.cart.Loss
            public double[] residual() {
                return this.residual;
            }

            public String toString() {
                return String.format("Logistic(%d)", Integer.valueOf(i2));
            }
        };
    }

    static Loss valueOf(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1197090431:
                if (str.equals("LeastAbsoluteDeviation")) {
                    z = true;
                    break;
                }
                break;
            case 833163373:
                if (str.equals("LeastSquares")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return ls();
            case true:
                return lad();
            default:
                if (str.startsWith("Quantile(") && str.endsWith(")")) {
                    return quantile(Double.parseDouble(str.substring(9, str.length() - 1)));
                }
                if (str.startsWith("Huber(") && str.endsWith(")")) {
                    return huber(Double.parseDouble(str.substring(6, str.length() - 1)));
                }
                throw new IllegalArgumentException("Unsupported loss: " + str);
        }
    }
}
