package smile.validation;

import smile.classification.Classifier;
import smile.classification.ClassifierTrainer;
import smile.math.Math;
import smile.regression.Regression;
import smile.regression.RegressionTrainer;

/* loaded from: input_file:smile/validation/Validation.class */
public class Validation {
    public static <T> double test(Classifier<T> classifier, T[] tArr, int[] iArr) {
        int length = tArr.length;
        int[] iArr2 = new int[length];
        for (int i = 0; i < length; i++) {
            iArr2[i] = classifier.predict(tArr[i]);
        }
        return new Accuracy().measure(iArr, iArr2);
    }

    public static <T> double test(Regression<T> regression, T[] tArr, double[] dArr) {
        int length = tArr.length;
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            dArr2[i] = regression.predict(tArr[i]);
        }
        return new RMSE().measure(dArr, dArr2);
    }

    public static <T> double test(Classifier<T> classifier, T[] tArr, int[] iArr, ClassificationMeasure classificationMeasure) {
        int length = tArr.length;
        int[] iArr2 = new int[length];
        for (int i = 0; i < length; i++) {
            iArr2[i] = classifier.predict(tArr[i]);
        }
        return classificationMeasure.measure(iArr, iArr2);
    }

    public static <T> double[] test(Classifier<T> classifier, T[] tArr, int[] iArr, ClassificationMeasure[] classificationMeasureArr) {
        int length = tArr.length;
        int[] iArr2 = new int[length];
        for (int i = 0; i < length; i++) {
            iArr2[i] = classifier.predict(tArr[i]);
        }
        int length2 = classificationMeasureArr.length;
        double[] dArr = new double[length2];
        for (int i2 = 0; i2 < length2; i2++) {
            dArr[i2] = classificationMeasureArr[i2].measure(iArr, iArr2);
        }
        return dArr;
    }

    public static <T> double test(Regression<T> regression, T[] tArr, double[] dArr, RegressionMeasure regressionMeasure) {
        int length = tArr.length;
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            dArr2[i] = regression.predict(tArr[i]);
        }
        return regressionMeasure.measure(dArr, dArr2);
    }

    public static <T> double[] test(Regression<T> regression, T[] tArr, double[] dArr, RegressionMeasure[] regressionMeasureArr) {
        int length = tArr.length;
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            dArr2[i] = regression.predict(tArr[i]);
        }
        int length2 = regressionMeasureArr.length;
        double[] dArr3 = new double[length2];
        for (int i2 = 0; i2 < length2; i2++) {
            dArr3[i2] = regressionMeasureArr[i2].measure(dArr, dArr2);
        }
        return dArr3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double loocv(ClassifierTrainer<T> classifierTrainer, T[] tArr, int[] iArr) {
        int i = 0;
        int length = tArr.length;
        LOOCV loocv = new LOOCV(length);
        for (int i2 = 0; i2 < length; i2++) {
            if (classifierTrainer.train(Math.slice(tArr, loocv.train[i2]), Math.slice(iArr, loocv.train[i2])).predict(tArr[loocv.test[i2]]) == iArr[loocv.test[i2]]) {
                i++;
            }
        }
        return i / length;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double loocv(RegressionTrainer<T> regressionTrainer, T[] tArr, double[] dArr) {
        double d = 0.0d;
        int length = tArr.length;
        LOOCV loocv = new LOOCV(length);
        for (int i = 0; i < length; i++) {
            d += Math.sqr(regressionTrainer.train(Math.slice(tArr, loocv.train[i]), Math.slice(dArr, loocv.train[i])).predict(tArr[loocv.test[i]]) - dArr[loocv.test[i]]);
        }
        return Math.sqrt(d / length);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double loocv(ClassifierTrainer<T> classifierTrainer, T[] tArr, int[] iArr, ClassificationMeasure classificationMeasure) {
        int length = tArr.length;
        int[] iArr2 = new int[length];
        LOOCV loocv = new LOOCV(length);
        for (int i = 0; i < length; i++) {
            iArr2[loocv.test[i]] = classifierTrainer.train(Math.slice(tArr, loocv.train[i]), Math.slice(iArr, loocv.train[i])).predict(tArr[loocv.test[i]]);
        }
        return classificationMeasure.measure(iArr, iArr2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double[] loocv(ClassifierTrainer<T> classifierTrainer, T[] tArr, int[] iArr, ClassificationMeasure[] classificationMeasureArr) {
        int length = tArr.length;
        int[] iArr2 = new int[length];
        LOOCV loocv = new LOOCV(length);
        for (int i = 0; i < length; i++) {
            iArr2[loocv.test[i]] = classifierTrainer.train(Math.slice(tArr, loocv.train[i]), Math.slice(iArr, loocv.train[i])).predict(tArr[loocv.test[i]]);
        }
        int length2 = classificationMeasureArr.length;
        double[] dArr = new double[length2];
        for (int i2 = 0; i2 < length2; i2++) {
            dArr[i2] = classificationMeasureArr[i2].measure(iArr, iArr2);
        }
        return dArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double loocv(RegressionTrainer<T> regressionTrainer, T[] tArr, double[] dArr, RegressionMeasure regressionMeasure) {
        int length = tArr.length;
        double[] dArr2 = new double[length];
        LOOCV loocv = new LOOCV(length);
        for (int i = 0; i < length; i++) {
            dArr2[loocv.test[i]] = regressionTrainer.train(Math.slice(tArr, loocv.train[i]), Math.slice(dArr, loocv.train[i])).predict(tArr[loocv.test[i]]);
        }
        return regressionMeasure.measure(dArr, dArr2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double[] loocv(RegressionTrainer<T> regressionTrainer, T[] tArr, double[] dArr, RegressionMeasure[] regressionMeasureArr) {
        int length = tArr.length;
        double[] dArr2 = new double[length];
        LOOCV loocv = new LOOCV(length);
        for (int i = 0; i < length; i++) {
            dArr2[loocv.test[i]] = regressionTrainer.train(Math.slice(tArr, loocv.train[i]), Math.slice(dArr, loocv.train[i])).predict(tArr[loocv.test[i]]);
        }
        int length2 = regressionMeasureArr.length;
        double[] dArr3 = new double[length2];
        for (int i2 = 0; i2 < length2; i2++) {
            dArr3[i2] = regressionMeasureArr[i2].measure(dArr, dArr2);
        }
        return dArr3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double cv(int i, ClassifierTrainer<T> classifierTrainer, T[] tArr, int[] iArr) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid k for k-fold cross validation: " + i);
        }
        int length = tArr.length;
        int[] iArr2 = new int[length];
        CrossValidation crossValidation = new CrossValidation(length, i);
        for (int i2 = 0; i2 < i; i2++) {
            Classifier train = classifierTrainer.train(Math.slice(tArr, crossValidation.train[i2]), Math.slice(iArr, crossValidation.train[i2]));
            for (int i3 : crossValidation.test[i2]) {
                iArr2[i3] = train.predict(tArr[i3]);
            }
        }
        return new Accuracy().measure(iArr, iArr2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double cv(int i, RegressionTrainer<T> regressionTrainer, T[] tArr, double[] dArr) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid k for k-fold cross validation: " + i);
        }
        int length = tArr.length;
        double[] dArr2 = new double[length];
        CrossValidation crossValidation = new CrossValidation(length, i);
        for (int i2 = 0; i2 < i; i2++) {
            Regression train = regressionTrainer.train(Math.slice(tArr, crossValidation.train[i2]), Math.slice(dArr, crossValidation.train[i2]));
            for (int i3 : crossValidation.test[i2]) {
                dArr2[i3] = train.predict(tArr[i3]);
            }
        }
        return new RMSE().measure(dArr, dArr2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double cv(int i, ClassifierTrainer<T> classifierTrainer, T[] tArr, int[] iArr, ClassificationMeasure classificationMeasure) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid k for k-fold cross validation: " + i);
        }
        int length = tArr.length;
        int[] iArr2 = new int[length];
        CrossValidation crossValidation = new CrossValidation(length, i);
        for (int i2 = 0; i2 < i; i2++) {
            Classifier train = classifierTrainer.train(Math.slice(tArr, crossValidation.train[i2]), Math.slice(iArr, crossValidation.train[i2]));
            for (int i3 : crossValidation.test[i2]) {
                iArr2[i3] = train.predict(tArr[i3]);
            }
        }
        return classificationMeasure.measure(iArr, iArr2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double[] cv(int i, ClassifierTrainer<T> classifierTrainer, T[] tArr, int[] iArr, ClassificationMeasure[] classificationMeasureArr) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid k for k-fold cross validation: " + i);
        }
        int length = tArr.length;
        int[] iArr2 = new int[length];
        CrossValidation crossValidation = new CrossValidation(length, i);
        for (int i2 = 0; i2 < i; i2++) {
            Classifier train = classifierTrainer.train(Math.slice(tArr, crossValidation.train[i2]), Math.slice(iArr, crossValidation.train[i2]));
            for (int i3 : crossValidation.test[i2]) {
                iArr2[i3] = train.predict(tArr[i3]);
            }
        }
        int length2 = classificationMeasureArr.length;
        double[] dArr = new double[length2];
        for (int i4 = 0; i4 < length2; i4++) {
            dArr[i4] = classificationMeasureArr[i4].measure(iArr, iArr2);
        }
        return dArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double cv(int i, RegressionTrainer<T> regressionTrainer, T[] tArr, double[] dArr, RegressionMeasure regressionMeasure) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid k for k-fold cross validation: " + i);
        }
        int length = tArr.length;
        double[] dArr2 = new double[length];
        CrossValidation crossValidation = new CrossValidation(length, i);
        for (int i2 = 0; i2 < i; i2++) {
            Regression train = regressionTrainer.train(Math.slice(tArr, crossValidation.train[i2]), Math.slice(dArr, crossValidation.train[i2]));
            for (int i3 : crossValidation.test[i2]) {
                dArr2[i3] = train.predict(tArr[i3]);
            }
        }
        return regressionMeasure.measure(dArr, dArr2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double[] cv(int i, RegressionTrainer<T> regressionTrainer, T[] tArr, double[] dArr, RegressionMeasure[] regressionMeasureArr) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid k for k-fold cross validation: " + i);
        }
        int length = tArr.length;
        double[] dArr2 = new double[length];
        CrossValidation crossValidation = new CrossValidation(length, i);
        for (int i2 = 0; i2 < i; i2++) {
            Regression train = regressionTrainer.train(Math.slice(tArr, crossValidation.train[i2]), Math.slice(dArr, crossValidation.train[i2]));
            for (int i3 : crossValidation.test[i2]) {
                dArr2[i3] = train.predict(tArr[i3]);
            }
        }
        int length2 = regressionMeasureArr.length;
        double[] dArr3 = new double[length2];
        for (int i4 = 0; i4 < length2; i4++) {
            dArr3[i4] = regressionMeasureArr[i4].measure(dArr, dArr2);
        }
        return dArr3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double[] bootstrap(int i, ClassifierTrainer<T> classifierTrainer, T[] tArr, int[] iArr) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid k for k-fold bootstrap: " + i);
        }
        int length = tArr.length;
        double[] dArr = new double[i];
        Accuracy accuracy = new Accuracy();
        Bootstrap bootstrap = new Bootstrap(length, i);
        for (int i2 = 0; i2 < i; i2++) {
            Classifier train = classifierTrainer.train(Math.slice(tArr, bootstrap.train[i2]), Math.slice(iArr, bootstrap.train[i2]));
            int length2 = bootstrap.test[i2].length;
            int[] iArr2 = new int[length2];
            int[] iArr3 = new int[length2];
            for (int i3 = 0; i3 < length2; i3++) {
                int i4 = bootstrap.test[i2][i3];
                iArr2[i3] = iArr[i4];
                iArr3[i3] = train.predict(tArr[i4]);
            }
            dArr[i2] = accuracy.measure(iArr2, iArr3);
        }
        return dArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double[] bootstrap(int i, RegressionTrainer<T> regressionTrainer, T[] tArr, double[] dArr) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid k for k-fold bootstrap: " + i);
        }
        int length = tArr.length;
        double[] dArr2 = new double[i];
        RMSE rmse = new RMSE();
        Bootstrap bootstrap = new Bootstrap(length, i);
        for (int i2 = 0; i2 < i; i2++) {
            Regression train = regressionTrainer.train(Math.slice(tArr, bootstrap.train[i2]), Math.slice(dArr, bootstrap.train[i2]));
            int length2 = bootstrap.test[i2].length;
            double[] dArr3 = new double[length2];
            double[] dArr4 = new double[length2];
            for (int i3 = 0; i3 < length2; i3++) {
                int i4 = bootstrap.test[i2][i3];
                dArr3[i3] = dArr[i4];
                dArr4[i3] = train.predict(tArr[i4]);
            }
            dArr2[i2] = rmse.measure(dArr3, dArr4);
        }
        return dArr2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double[] bootstrap(int i, ClassifierTrainer<T> classifierTrainer, T[] tArr, int[] iArr, ClassificationMeasure classificationMeasure) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid k for k-fold bootstrap: " + i);
        }
        double[] dArr = new double[i];
        Bootstrap bootstrap = new Bootstrap(tArr.length, i);
        for (int i2 = 0; i2 < i; i2++) {
            Classifier train = classifierTrainer.train(Math.slice(tArr, bootstrap.train[i2]), Math.slice(iArr, bootstrap.train[i2]));
            int length = bootstrap.test[i2].length;
            int[] iArr2 = new int[length];
            int[] iArr3 = new int[length];
            for (int i3 = 0; i3 < length; i3++) {
                int i4 = bootstrap.test[i2][i3];
                iArr2[i3] = iArr[i4];
                iArr3[i3] = train.predict(tArr[i4]);
            }
            dArr[i2] = classificationMeasure.measure(iArr2, iArr3);
        }
        return dArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double[][] bootstrap(int i, ClassifierTrainer<T> classifierTrainer, T[] tArr, int[] iArr, ClassificationMeasure[] classificationMeasureArr) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid k for k-fold bootstrap: " + i);
        }
        int length = tArr.length;
        int length2 = classificationMeasureArr.length;
        double[][] dArr = new double[i][length2];
        Bootstrap bootstrap = new Bootstrap(length, i);
        for (int i2 = 0; i2 < i; i2++) {
            Classifier train = classifierTrainer.train(Math.slice(tArr, bootstrap.train[i2]), Math.slice(iArr, bootstrap.train[i2]));
            int length3 = bootstrap.test[i2].length;
            int[] iArr2 = new int[length3];
            int[] iArr3 = new int[length3];
            for (int i3 = 0; i3 < length3; i3++) {
                int i4 = bootstrap.test[i2][i3];
                iArr2[i3] = iArr[i4];
                iArr3[i3] = train.predict(tArr[i4]);
            }
            for (int i5 = 0; i5 < length2; i5++) {
                dArr[i2][i5] = classificationMeasureArr[i5].measure(iArr2, iArr3);
            }
        }
        return dArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double[] bootstrap(int i, RegressionTrainer<T> regressionTrainer, T[] tArr, double[] dArr, RegressionMeasure regressionMeasure) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid k for k-fold bootstrap: " + i);
        }
        double[] dArr2 = new double[i];
        Bootstrap bootstrap = new Bootstrap(tArr.length, i);
        for (int i2 = 0; i2 < i; i2++) {
            Regression train = regressionTrainer.train(Math.slice(tArr, bootstrap.train[i2]), Math.slice(dArr, bootstrap.train[i2]));
            int length = bootstrap.test[i2].length;
            double[] dArr3 = new double[length];
            double[] dArr4 = new double[length];
            for (int i3 = 0; i3 < length; i3++) {
                int i4 = bootstrap.test[i2][i3];
                dArr3[i3] = dArr[i4];
                dArr4[i3] = train.predict(tArr[i4]);
            }
            dArr2[i2] = regressionMeasure.measure(dArr3, dArr4);
        }
        return dArr2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> double[][] bootstrap(int i, RegressionTrainer<T> regressionTrainer, T[] tArr, double[] dArr, RegressionMeasure[] regressionMeasureArr) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid k for k-fold bootstrap: " + i);
        }
        int length = tArr.length;
        int length2 = regressionMeasureArr.length;
        double[][] dArr2 = new double[i][length2];
        Bootstrap bootstrap = new Bootstrap(length, i);
        for (int i2 = 0; i2 < i; i2++) {
            Regression train = regressionTrainer.train(Math.slice(tArr, bootstrap.train[i2]), Math.slice(dArr, bootstrap.train[i2]));
            int length3 = bootstrap.test[i2].length;
            double[] dArr3 = new double[length3];
            double[] dArr4 = new double[length3];
            for (int i3 = 0; i3 < length3; i3++) {
                int i4 = bootstrap.test[i2][i3];
                dArr3[i3] = dArr[i4];
                dArr4[i3] = train.predict(tArr[i4]);
            }
            for (int i5 = 0; i5 < length2; i5++) {
                dArr2[i2][i5] = regressionMeasureArr[i5].measure(dArr3, dArr4);
            }
        }
        return dArr2;
    }
}
