package smile.feature.imputation;

import java.lang.reflect.Array;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import smile.data.AbstractTuple;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.measure.NominalScale;
import smile.data.transform.Transform;
import smile.data.type.DataType;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.data.vector.BooleanVector;
import smile.data.vector.ByteVector;
import smile.data.vector.CharVector;
import smile.data.vector.DoubleVector;
import smile.data.vector.FloatVector;
import smile.data.vector.IntVector;
import smile.data.vector.LongVector;
import smile.data.vector.ShortVector;
import smile.data.vector.Vector;
import smile.math.MathEx;
import smile.sort.IQAgent;

/* loaded from: input_file:smile/feature/imputation/SimpleImputer.class */
public class SimpleImputer implements Transform {
    private final Map<String, Object> values;

    public SimpleImputer(Map<String, Object> map) {
        this.values = map;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static boolean isMissing(Object obj) {
        if (obj == null) {
            return true;
        }
        if (obj instanceof Number) {
            return Double.isNaN(((Number) obj).doubleValue());
        }
        return false;
    }

    public static boolean hasMissing(Tuple tuple) {
        int length = tuple.length();
        for (int i = 0; i < length; i++) {
            if (isMissing(tuple.get(i))) {
                return true;
            }
        }
        return false;
    }

    public Tuple apply(final Tuple tuple) {
        final StructType schema = tuple.schema();
        return new AbstractTuple() { // from class: smile.feature.imputation.SimpleImputer.1
            public Object get(int i) {
                Object obj = tuple.get(i);
                return SimpleImputer.isMissing(obj) ? SimpleImputer.this.values.get(schema.field(i).name) : obj;
            }

            public StructType schema() {
                return schema;
            }
        };
    }

    public DataFrame apply(DataFrame dataFrame) {
        int nrow = dataFrame.nrow();
        StructType schema = dataFrame.schema();
        BaseVector[] baseVectorArr = new BaseVector[schema.length()];
        IntStream.range(0, schema.length()).parallel().forEach(i -> {
            StructField field = schema.field(i);
            Object obj = this.values.get(field.name);
            if (obj == null) {
                baseVectorArr[i] = dataFrame.column(i);
                return;
            }
            if (field.type.id() == DataType.ID.Double) {
                double doubleValue = ((Number) obj).doubleValue();
                double[] array = dataFrame.doubleVector(i).array();
                double[] dArr = new double[nrow];
                for (int i = 0; i < nrow; i++) {
                    dArr[i] = Double.isNaN(array[i]) ? doubleValue : array[i];
                }
                baseVectorArr[i] = DoubleVector.of(field, dArr);
                return;
            }
            if (field.type.id() == DataType.ID.Float) {
                float floatValue = ((Number) obj).floatValue();
                float[] array2 = dataFrame.floatVector(i).array();
                float[] fArr = new float[nrow];
                for (int i2 = 0; i2 < nrow; i2++) {
                    fArr[i2] = Float.isNaN(array2[i2]) ? floatValue : array2[i2];
                }
                baseVectorArr[i] = FloatVector.of(field, fArr);
                return;
            }
            if (field.type.isPrimitive()) {
                throw new IllegalArgumentException("Impute non-floating primitive types");
            }
            if (field.type == DataTypes.BooleanObjectType) {
                boolean booleanValue = ((Boolean) obj).booleanValue();
                boolean[] zArr = new boolean[nrow];
                for (int i3 = 0; i3 < nrow; i3++) {
                    Boolean bool = (Boolean) dataFrame.get(i3, i);
                    zArr[i3] = bool == null ? booleanValue : bool.booleanValue();
                }
                baseVectorArr[i] = BooleanVector.of(field, zArr);
                return;
            }
            if (field.type == DataTypes.ByteObjectType) {
                byte byteValue = ((Number) obj).byteValue();
                byte[] bArr = new byte[nrow];
                for (int i4 = 0; i4 < nrow; i4++) {
                    Byte b = (Byte) dataFrame.get(i4, i);
                    bArr[i4] = b == null ? byteValue : b.byteValue();
                }
                baseVectorArr[i] = ByteVector.of(field, bArr);
                return;
            }
            if (field.type == DataTypes.CharObjectType) {
                char charValue = ((Character) obj).charValue();
                char[] cArr = new char[nrow];
                for (int i5 = 0; i5 < nrow; i5++) {
                    Character ch = (Character) dataFrame.get(i5, i);
                    cArr[i5] = ch == null ? charValue : ch.charValue();
                }
                baseVectorArr[i] = CharVector.of(field, cArr);
                return;
            }
            if (field.type == DataTypes.DoubleObjectType) {
                double doubleValue2 = ((Number) obj).doubleValue();
                double[] dArr2 = new double[nrow];
                for (int i6 = 0; i6 < nrow; i6++) {
                    Double d = (Double) dataFrame.get(i6, i);
                    dArr2[i6] = (d == null || d.isNaN()) ? doubleValue2 : d.doubleValue();
                }
                baseVectorArr[i] = DoubleVector.of(field, dArr2);
                return;
            }
            if (field.type == DataTypes.FloatObjectType) {
                float floatValue2 = ((Number) obj).floatValue();
                float[] fArr2 = new float[nrow];
                for (int i7 = 0; i7 < nrow; i7++) {
                    Float f = (Float) dataFrame.get(i7, i);
                    fArr2[i7] = (f == null || f.isNaN()) ? floatValue2 : f.floatValue();
                }
                baseVectorArr[i] = FloatVector.of(field, fArr2);
                return;
            }
            if (field.type == DataTypes.IntegerObjectType) {
                int intValue = ((Number) obj).intValue();
                int[] iArr = new int[nrow];
                for (int i8 = 0; i8 < nrow; i8++) {
                    Integer num = (Integer) dataFrame.get(i8, i);
                    iArr[i8] = num == null ? intValue : num.intValue();
                }
                baseVectorArr[i] = IntVector.of(field, iArr);
                return;
            }
            if (field.type == DataTypes.LongObjectType) {
                long longValue = ((Number) obj).longValue();
                long[] jArr = new long[nrow];
                for (int i9 = 0; i9 < nrow; i9++) {
                    Long l = (Long) dataFrame.get(i9, i);
                    jArr[i9] = l == null ? longValue : l.longValue();
                }
                baseVectorArr[i] = LongVector.of(field, jArr);
                return;
            }
            if (field.type != DataTypes.ShortObjectType) {
                Object[] objArr = (Object[]) Array.newInstance(obj.getClass(), nrow);
                for (int i10 = 0; i10 < nrow; i10++) {
                    Object obj2 = dataFrame.get(i10, i);
                    objArr[i10] = obj2 == null ? obj : obj2;
                }
                baseVectorArr[i] = Vector.of(field, objArr);
                return;
            }
            short shortValue = ((Number) obj).shortValue();
            short[] sArr = new short[nrow];
            for (int i11 = 0; i11 < nrow; i11++) {
                Short sh = (Short) dataFrame.get(i11, i);
                sArr[i11] = sh == null ? shortValue : sh.shortValue();
            }
            baseVectorArr[i] = ShortVector.of(field, sArr);
        });
        return DataFrame.of(baseVectorArr);
    }

    public String toString() {
        return (String) this.values.keySet().stream().map(str -> {
            return str + " -> " + this.values.get(str);
        }).collect(Collectors.joining(",\n  ", "SimpleImputer(\n  ", "\n)"));
    }

    public static SimpleImputer fit(DataFrame dataFrame, String... strArr) {
        return fit(dataFrame, 0.5d, 0.5d, strArr);
    }

    public static SimpleImputer fit(DataFrame dataFrame, double d, double d2, String... strArr) {
        if (dataFrame.isEmpty()) {
            throw new IllegalArgumentException("Empty data frame");
        }
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid lower: " + d);
        }
        if (d2 > 1.0d) {
            throw new IllegalArgumentException("Invalid upper: " + d2);
        }
        if (d > d2) {
            throw new IllegalArgumentException(String.format("Invalid lower=%f > upper=%f", Double.valueOf(d), Double.valueOf(d2)));
        }
        StructType schema = dataFrame.schema();
        if (strArr.length == 0) {
            strArr = dataFrame.names();
        }
        HashMap hashMap = new HashMap();
        for (String str : strArr) {
            StructField field = schema.field(str);
            if (field.type.isString()) {
                hashMap.put(field.name, "");
            } else if (field.type.isBoolean()) {
                hashMap.put(field.name, Boolean.valueOf(MathEx.mode(MathEx.omit(dataFrame.column(str).toIntArray(), Integer.MIN_VALUE)) != 0));
            } else if (field.type.isChar()) {
                hashMap.put(field.name, Character.valueOf((char) MathEx.mode(MathEx.omit(dataFrame.column(str).toIntArray(), Integer.MIN_VALUE))));
            } else if (field.measure instanceof NominalScale) {
                hashMap.put(field.name, Integer.valueOf(MathEx.mode(MathEx.omit(dataFrame.column(str).toIntArray(), Integer.MIN_VALUE))));
            } else if (field.type.isNumeric()) {
                double[] omitNaN = MathEx.omitNaN(dataFrame.column(str).toDoubleArray());
                IQAgent iQAgent = new IQAgent();
                for (double d3 : omitNaN) {
                    iQAgent.add(d3);
                }
                if (d == d2) {
                    hashMap.put(field.name, Double.valueOf(iQAgent.quantile(d)));
                } else {
                    double quantile = iQAgent.quantile(d);
                    double quantile2 = iQAgent.quantile(d2);
                    int i = 0;
                    double d4 = 0.0d;
                    for (double d5 : omitNaN) {
                        if (d5 >= quantile && d5 <= quantile2) {
                            i++;
                            d4 += d5;
                        }
                    }
                    hashMap.put(field.name, Double.valueOf(d4 / i));
                }
            }
        }
        return new SimpleImputer(hashMap);
    }

    public static double[][] impute(double[][] dArr) {
        int length = dArr[0].length;
        int[] iArr = new int[length];
        for (int i = 0; i < dArr.length; i++) {
            int i2 = 0;
            for (int i3 = 0; i3 < length; i3++) {
                if (Double.isNaN(dArr[i][i3])) {
                    i2++;
                    int i4 = i3;
                    iArr[i4] = iArr[i4] + 1;
                }
            }
            if (i2 == length) {
                throw new IllegalArgumentException("The whole row " + i + " is missing");
            }
        }
        for (int i5 = 0; i5 < length; i5++) {
            if (iArr[i5] == dArr.length) {
                throw new IllegalArgumentException("The whole column " + i5 + " is missing");
            }
        }
        double[] dArr2 = new double[length];
        int[] iArr2 = new int[length];
        for (double[] dArr3 : dArr) {
            for (int i6 = 0; i6 < length; i6++) {
                if (!Double.isNaN(dArr3[i6])) {
                    int i7 = i6;
                    iArr2[i7] = iArr2[i7] + 1;
                    int i8 = i6;
                    dArr2[i8] = dArr2[i8] + dArr3[i6];
                }
            }
        }
        for (int i9 = 0; i9 < length; i9++) {
            if (iArr2[i9] != 0) {
                int i10 = i9;
                dArr2[i10] = dArr2[i10] / iArr2[i9];
            }
        }
        double[][] clone = MathEx.clone(dArr);
        for (double[] dArr4 : clone) {
            for (int i11 = 0; i11 < length; i11++) {
                if (Double.isNaN(dArr4[i11])) {
                    dArr4[i11] = dArr2[i11];
                }
            }
        }
        return clone;
    }
}
