package smile.glm;

import java.io.Serializable;
import java.util.Properties;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.CategoricalEncoder;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.glm.model.Model;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.math.special.Erf;
import smile.stat.Hypothesis;
import smile.validation.ModelSelection;

/* loaded from: input_file:smile/glm/GLM.class */
public class GLM implements Serializable {
    private static final long serialVersionUID = 2;
    private static final Logger logger = LoggerFactory.getLogger(GLM.class);
    protected Formula formula;
    String[] predictors;
    protected Model model;
    protected double[] beta;
    protected double[][] ztest;
    protected double[] mu;
    protected double nullDeviance;
    protected double deviance;
    protected double[] devianceResiduals;
    protected int df;
    protected double logLikelihood;

    public GLM(Formula formula, String[] strArr, Model model, double[] dArr, double d, double d2, double d3, double[] dArr2, double[] dArr3, double[][] dArr4) {
        this.formula = formula;
        this.model = model;
        this.predictors = strArr;
        this.beta = dArr;
        this.logLikelihood = d;
        this.deviance = d2;
        this.nullDeviance = d3;
        this.mu = dArr2;
        this.devianceResiduals = dArr3;
        this.ztest = dArr4;
        this.df = dArr2.length - dArr.length;
    }

    public double[] coefficients() {
        return this.beta;
    }

    public double[][] ztest() {
        return this.ztest;
    }

    public double[] devianceResiduals() {
        return this.devianceResiduals;
    }

    public double[] fittedValues() {
        return this.mu;
    }

    public double deviance() {
        return this.deviance;
    }

    public double logLikelihood() {
        return this.logLikelihood;
    }

    public double AIC() {
        return ModelSelection.AIC(this.logLikelihood, this.beta.length);
    }

    public double BIC() {
        return ModelSelection.BIC(this.logLikelihood, this.beta.length, this.mu.length);
    }

    public double predict(Tuple tuple) {
        double[] array = this.formula.x(tuple).toArray(true, CategoricalEncoder.DUMMY, new String[0]);
        int length = this.beta.length;
        double d = 0.0d;
        for (int i = 0; i < length; i++) {
            d += array[i] * this.beta[i];
        }
        return this.model.invlink(d);
    }

    public double[] predict(DataFrame dataFrame) {
        double[] mv = this.formula.matrix(dataFrame, true).mv(this.beta);
        int length = mv.length;
        for (int i = 0; i < length; i++) {
            mv[i] = this.model.invlink(mv[i]);
        }
        return mv;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("Generalized Linear Model - %s:\n", this.model));
        double[] dArr = (double[]) this.devianceResiduals.clone();
        sb.append("\nDeviance Residuals:\n");
        sb.append("       Min          1Q      Median          3Q         Max\n");
        sb.append(String.format("%10.4f  %10.4f  %10.4f  %10.4f  %10.4f%n", Double.valueOf(MathEx.min(dArr)), Double.valueOf(MathEx.q1(dArr)), Double.valueOf(MathEx.median(dArr)), Double.valueOf(MathEx.q3(dArr)), Double.valueOf(MathEx.max(dArr))));
        int length = this.beta.length - 1;
        sb.append("\nCoefficients:\n");
        if (this.ztest != null) {
            sb.append("                  Estimate Std. Error    z value   Pr(>|z|)\n");
            for (int i = 0; i < length; i++) {
                sb.append(String.format("%-15s %10.3e %10.3e %10.4f %10.5f %s%n", this.predictors[i], Double.valueOf(this.ztest[i][0]), Double.valueOf(this.ztest[i][1]), Double.valueOf(this.ztest[i][2]), Double.valueOf(this.ztest[i][3]), Hypothesis.significance(this.ztest[i][3])));
            }
            sb.append("---------------------------------------------------------------------\n");
            sb.append("Significance codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n");
        } else {
            sb.append(String.format("Intercept       %10.4f%n", Double.valueOf(this.beta[length])));
            for (int i2 = 0; i2 < length; i2++) {
                sb.append(String.format("%-15s %10.4f%n", this.predictors[i2], Double.valueOf(this.beta[i2])));
            }
        }
        sb.append(String.format("%n    Null deviance: %.1f on %d degrees of freedom", Double.valueOf(this.nullDeviance), Integer.valueOf(this.df + length)));
        sb.append(String.format("%nResidual deviance: %.1f on %d degrees of freedom", Double.valueOf(this.deviance), Integer.valueOf(this.df)));
        sb.append(String.format("%nAIC: %.4f     BIC: %.4f%n", Double.valueOf(AIC()), Double.valueOf(BIC())));
        return sb.toString();
    }

    public static GLM fit(Formula formula, DataFrame dataFrame, Model model) {
        return fit(formula, dataFrame, model, new Properties());
    }

    public static GLM fit(Formula formula, DataFrame dataFrame, Model model, Properties properties) {
        return fit(formula, dataFrame, model, Double.parseDouble(properties.getProperty("smile.glm.tolerance", "1E-5")), Integer.parseInt(properties.getProperty("smile.glm.iterations", "50")));
    }

    public static GLM fit(Formula formula, DataFrame dataFrame, Model model, double d, int i) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid tolerance: " + d);
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
        }
        Matrix matrix = formula.matrix(dataFrame, true);
        Matrix matrix2 = new Matrix(matrix.nrow(), matrix.ncol());
        double[] doubleArray = formula.y(dataFrame).toDoubleArray();
        int nrow = matrix.nrow();
        int ncol = matrix.ncol();
        if (nrow <= ncol) {
            throw new IllegalArgumentException(String.format("The input matrix is not over determined: %d rows, %d columns", Integer.valueOf(nrow), Integer.valueOf(ncol)));
        }
        double[] dArr = new double[nrow];
        double[] dArr2 = new double[nrow];
        double[] dArr3 = new double[nrow];
        double[] dArr4 = new double[nrow];
        double[] dArr5 = new double[nrow];
        IntStream.range(0, nrow).parallel().forEach(i2 -> {
            dArr2[i2] = model.mustart(doubleArray[i2]);
            dArr[i2] = model.link(dArr2[i2]);
            double dlink = model.dlink(dArr2[i2]);
            dArr4[i2] = dArr[i2] + ((doubleArray[i2] - dArr2[i2]) * dlink);
            dArr3[i2] = 1.0d / (dlink * Math.sqrt(model.variance(dArr2[i2])));
            dArr4[i2] = dArr4[i2] * dArr3[i2];
        });
        for (int i3 = 0; i3 < ncol; i3++) {
            for (int i4 = 0; i4 < nrow; i4++) {
                matrix2.set(i4, i3, matrix.get(i4, i3) * dArr3[i4]);
            }
        }
        Matrix.QR qr = matrix2.qr(true);
        double[] solve = qr.solve(dArr4);
        double d2 = Double.POSITIVE_INFINITY;
        for (int i5 = 0; i5 < i; i5++) {
            matrix.mv(solve, dArr);
            IntStream.range(0, nrow).parallel().forEach(i6 -> {
                dArr2[i6] = model.invlink(dArr[i6]);
                double dlink = model.dlink(dArr2[i6]);
                dArr4[i6] = dArr[i6] + ((doubleArray[i6] - dArr2[i6]) * dlink);
                dArr3[i6] = 1.0d / (dlink * Math.sqrt(model.variance(dArr2[i6])));
                dArr4[i6] = dArr4[i6] * dArr3[i6];
            });
            double deviance = model.deviance(doubleArray, dArr2, dArr5);
            if (i5 > 0) {
                logger.info(String.format("Deviance after %3d iterations: %.5f", Integer.valueOf(i5), Double.valueOf(d2)));
            }
            if (d2 - deviance < d) {
                break;
            }
            d2 = deviance;
            for (int i7 = 0; i7 < ncol; i7++) {
                for (int i8 = 0; i8 < nrow; i8++) {
                    matrix2.set(i8, i7, matrix.get(i8, i7) * dArr3[i8]);
                }
            }
            qr = matrix2.qr(true);
            solve = qr.solve(dArr4);
        }
        Matrix inverse = qr.CholeskyOfAtA().inverse();
        double[][] dArr6 = new double[ncol][4];
        for (int i9 = 0; i9 < ncol; i9++) {
            dArr6[i9][0] = solve[i9];
            dArr6[i9][1] = Math.sqrt(inverse.get(i9, i9));
            dArr6[i9][2] = dArr6[i9][0] / dArr6[i9][1];
            dArr6[i9][3] = 2.0d - Erf.erfc((-0.7071067811865476d) * Math.abs(dArr6[i9][2]));
        }
        return new GLM(formula, matrix.colNames(), model, solve, model.logLikelihood(doubleArray, dArr2), d2, model.nullDeviance(doubleArray, MathEx.mean(doubleArray)), dArr2, dArr5, dArr6);
    }
}
