package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Properties;
import java.util.function.ToDoubleFunction;
import java.util.function.ToIntFunction;
import java.util.stream.IntStream;
import smile.data.Dataset;
import smile.data.Instance;
import smile.math.MathEx;

/* loaded from: input_file:smile/classification/Classifier.class */
public interface Classifier<T> extends ToIntFunction<T>, ToDoubleFunction<T>, Serializable {

    /* loaded from: input_file:smile/classification/Classifier$Trainer.class */
    public interface Trainer<T, M extends Classifier<T>> {
        default M fit(T[] tArr, int[] iArr) {
            return fit(tArr, iArr, new Properties());
        }

        M fit(T[] tArr, int[] iArr, Properties properties);
    }

    int numClasses();

    int[] classes();

    int predict(T t);

    default double score(T t) {
        throw new UnsupportedOperationException();
    }

    @Override // java.util.function.ToIntFunction
    default int applyAsInt(T t) {
        return predict((Classifier<T>) t);
    }

    @Override // java.util.function.ToDoubleFunction
    default double applyAsDouble(T t) {
        return score(t);
    }

    default int[] predict(T[] tArr) {
        return Arrays.stream(tArr).mapToInt(this::predict).toArray();
    }

    default int[] predict(List<T> list) {
        return list.stream().mapToInt(this::predict).toArray();
    }

    default int[] predict(Dataset<T> dataset) {
        return dataset.stream().mapToInt(this::predict).toArray();
    }

    default boolean soft() {
        try {
            predict((Classifier<T>) null, new double[numClasses()]);
            return false;
        } catch (UnsupportedOperationException e) {
            return !e.getMessage().equals("soft classification with a hard classifier");
        } catch (Exception e2) {
            return true;
        }
    }

    default int predict(T t, double[] dArr) {
        throw new UnsupportedOperationException("soft classification with a hard classifier");
    }

    default int[] predict(T[] tArr, double[][] dArr) {
        return IntStream.range(0, tArr.length).parallel().map(i -> {
            return predict((Classifier<T>) tArr[i], dArr[i]);
        }).toArray();
    }

    default int[] predict(List<T> list, List<double[]> list2) {
        int size = list.size();
        double[][] dArr = new double[size][numClasses()];
        Collections.addAll(list2, dArr);
        return IntStream.range(0, size).parallel().map(i -> {
            return predict((Classifier<T>) list.get(i), dArr[i]);
        }).toArray();
    }

    default int[] predict(Dataset<T> dataset, List<double[]> list) {
        int size = dataset.size();
        double[][] dArr = new double[size][numClasses()];
        Collections.addAll(list, dArr);
        return IntStream.range(0, size).parallel().map(i -> {
            return predict((Classifier<T>) dataset.get(i), dArr[i]);
        }).toArray();
    }

    default boolean online() {
        try {
            update((Classifier<T>) null, 0);
            return false;
        } catch (UnsupportedOperationException e) {
            return !e.getMessage().equals("update a batch learner");
        } catch (Exception e2) {
            return true;
        }
    }

    default void update(T t, int i) {
        throw new UnsupportedOperationException("update a batch learner");
    }

    default void update(T[] tArr, int[] iArr) {
        if (tArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("Input vector x of size %d not equal to length %d of y", Integer.valueOf(tArr.length), Integer.valueOf(iArr.length)));
        }
        for (int i = 0; i < tArr.length; i++) {
            update((Classifier<T>) tArr[i], iArr[i]);
        }
    }

    default void update(Dataset<Instance<T>> dataset) {
        dataset.stream().forEach(instance -> {
            update((Classifier<T>) instance.x(), instance.label());
        });
    }

    @SafeVarargs
    static <T> Classifier<T> ensemble(final Classifier<T>... classifierArr) {
        return new Classifier<T>() { // from class: smile.classification.Classifier.1
            private final boolean soft;
            private final boolean online;

            {
                this.soft = Arrays.stream(classifierArr).allMatch((v0) -> {
                    return v0.soft();
                });
                this.online = Arrays.stream(classifierArr).allMatch((v0) -> {
                    return v0.online();
                });
            }

            @Override // smile.classification.Classifier
            public boolean soft() {
                return this.soft;
            }

            @Override // smile.classification.Classifier
            public boolean online() {
                return this.online;
            }

            @Override // smile.classification.Classifier
            public int numClasses() {
                return classifierArr[0].numClasses();
            }

            @Override // smile.classification.Classifier
            public int[] classes() {
                return classifierArr[0].classes();
            }

            @Override // smile.classification.Classifier
            public int predict(T t) {
                int[] iArr = new int[classifierArr.length];
                for (int i = 0; i < classifierArr.length; i++) {
                    iArr[i] = classifierArr[i].predict((Classifier) t);
                }
                return MathEx.mode(iArr);
            }

            @Override // smile.classification.Classifier
            public int predict(T t, double[] dArr) {
                Arrays.fill(dArr, 0.0d);
                double[] dArr2 = new double[dArr.length];
                for (Classifier classifier : classifierArr) {
                    classifier.predict((Classifier) t, dArr2);
                    for (int i = 0; i < dArr2.length; i++) {
                        int i2 = i;
                        dArr[i2] = dArr[i2] + dArr2[i];
                    }
                }
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    int i4 = i3;
                    dArr[i4] = dArr[i4] / classifierArr.length;
                }
                return MathEx.whichMax(dArr);
            }

            @Override // smile.classification.Classifier
            public void update(T t, int i) {
                for (Classifier classifier : classifierArr) {
                    classifier.update((Classifier) t, i);
                }
            }
        };
    }
}
