package ai.libs.jaicore.ml.core.dataset;

import ai.libs.jaicore.basic.reconstruction.ReconstructionInstruction;
import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.ml.core.dataset.schema.LabeledInstanceSchema;
import ai.libs.jaicore.ml.core.dataset.schema.attribute.IntBasedCategoricalAttribute;
import ai.libs.jaicore.ml.core.dataset.schema.attribute.NumericAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.api4.java.ai.ml.core.dataset.schema.ILabeledInstanceSchema;
import org.api4.java.ai.ml.core.dataset.schema.attribute.IAttribute;
import org.api4.java.ai.ml.core.dataset.schema.attribute.ICategoricalAttribute;
import org.api4.java.ai.ml.core.dataset.schema.attribute.INumericAttribute;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.common.reconstruction.IReconstructible;

/* loaded from: input_file:ai/libs/jaicore/ml/core/dataset/DatasetUtil.class */
public class DatasetUtil {
    public static final int EXPANSION_SQUARES = 1;
    public static final int EXPANSION_LOGARITHM = 2;
    public static final int EXPANSION_PRODUCTS = 3;

    private DatasetUtil() {
    }

    public static Map<Object, Integer> getLabelCounts(ILabeledDataset<?> iLabeledDataset) {
        HashMap hashMap = new HashMap();
        iLabeledDataset.forEach(iLabeledInstance -> {
            Object label = iLabeledInstance.getLabel();
            hashMap.put(label, Integer.valueOf(((Integer) hashMap.computeIfAbsent(label, obj -> {
                return 0;
            })).intValue() + 1));
        });
        return hashMap;
    }

    public static int getLabelCountDifference(ILabeledDataset<?> iLabeledDataset, ILabeledDataset<?> iLabeledDataset2) {
        Map<Object, Integer> labelCounts = getLabelCounts(iLabeledDataset);
        Map<Object, Integer> labelCounts2 = getLabelCounts(iLabeledDataset2);
        int i = 0;
        for (Object obj : SetUtil.union(new Collection[]{labelCounts.keySet(), labelCounts2.keySet()})) {
            i += Math.abs(labelCounts.get(obj).intValue() - labelCounts2.get(obj).intValue());
        }
        return i;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v33, types: [ai.libs.jaicore.ml.core.dataset.DenseInstance] */
    private static ILabeledDataset<?> convertTargetOfDataset(ILabeledDataset<?> iLabeledDataset, IAttribute iAttribute, Map<Object, ? extends Object> map, String str) {
        SparseInstance sparseInstance;
        Dataset dataset = new Dataset(new LabeledInstanceSchema(iLabeledDataset.getRelationName(), new ArrayList(iLabeledDataset.getInstanceSchema().getAttributeList()), iAttribute));
        int numAttributes = iLabeledDataset.getNumAttributes();
        Iterator it = iLabeledDataset.iterator();
        while (it.hasNext()) {
            ILabeledInstance iLabeledInstance = (ILabeledInstance) it.next();
            if (iLabeledInstance instanceof DenseInstance) {
                sparseInstance = new DenseInstance(iLabeledInstance.getAttributes(), map.get(iLabeledInstance.getLabel()));
            } else {
                if (!(iLabeledInstance instanceof SparseInstance)) {
                    throw new UnsupportedOperationException();
                }
                sparseInstance = new SparseInstance(numAttributes, ((SparseInstance) iLabeledInstance).getAttributeMap(), map.get(iLabeledInstance.getLabel()));
            }
            if (!dataset.getLabelAttribute().isValidValue(sparseInstance.getLabel())) {
                throw new IllegalStateException("Value " + sparseInstance.getLabel() + " is not a valid label value for label attribute " + dataset.getLabelAttribute());
            }
            dataset.add(sparseInstance);
        }
        if (iLabeledDataset instanceof IReconstructible) {
            List instructions = ((IReconstructible) iLabeledDataset).getConstructionPlan().getInstructions();
            Objects.requireNonNull(dataset);
            instructions.forEach(dataset::addInstruction);
            try {
                dataset.addInstruction(new ReconstructionInstruction(DatasetUtil.class.getMethod(str, ILabeledDataset.class), new Object[]{"this"}));
            } catch (NoSuchMethodException | SecurityException e) {
                throw new UnsupportedOperationException(e);
            }
        }
        return dataset;
    }

    public static ILabeledDataset<?> convertToClassificationDataset(ILabeledDataset<?> iLabeledDataset) {
        IAttribute labelAttribute = iLabeledDataset.getLabelAttribute();
        if (labelAttribute instanceof ICategoricalAttribute) {
            return iLabeledDataset;
        }
        IntBasedCategoricalAttribute intBasedCategoricalAttribute = new IntBasedCategoricalAttribute(labelAttribute.getName(), new ArrayList((Set) iLabeledDataset.stream().map(iLabeledInstance -> {
            return iLabeledInstance.getLabel().toString();
        }).collect(Collectors.toSet())));
        HashMap hashMap = new HashMap();
        Iterator it = iLabeledDataset.iterator();
        while (it.hasNext()) {
            ILabeledInstance iLabeledInstance2 = (ILabeledInstance) it.next();
            if (!hashMap.containsKey(iLabeledInstance2.getLabel())) {
                hashMap.put(iLabeledInstance2.getLabel(), intBasedCategoricalAttribute.m51deserializeAttributeValue(iLabeledInstance2.getLabel().toString()));
            }
        }
        return convertTargetOfDataset(iLabeledDataset, intBasedCategoricalAttribute, hashMap, "convertToClassificationDataset");
    }

    public static ILabeledDataset<?> convertToRegressionDataset(ILabeledDataset<?> iLabeledDataset) {
        IAttribute labelAttribute = iLabeledDataset.getLabelAttribute();
        if (labelAttribute instanceof INumericAttribute) {
            return iLabeledDataset;
        }
        NumericAttribute numericAttribute = new NumericAttribute(labelAttribute.getName());
        HashMap hashMap = new HashMap();
        ICategoricalAttribute labelAttribute2 = iLabeledDataset.getLabelAttribute();
        try {
            for (String str : labelAttribute2.getLabels()) {
                try {
                    hashMap.put(labelAttribute2.deserializeAttributeValue(str), Double.valueOf(Integer.valueOf(str).doubleValue()));
                } catch (NumberFormatException e) {
                    hashMap.put(labelAttribute2.deserializeAttributeValue(str), Double.valueOf(Double.parseDouble(str)));
                }
            }
        } catch (NumberFormatException e2) {
            hashMap.clear();
            List labels = labelAttribute2.getLabels();
            for (int i = 0; i < labels.size(); i++) {
                hashMap.put(labelAttribute2.deserializeAttributeValue((String) labels.get(i)), Double.valueOf(labels.indexOf(labels.get(i))));
            }
        }
        return convertTargetOfDataset(iLabeledDataset, numericAttribute, hashMap, "convertToRegressionDataset");
    }

    public static ILabeledDataset<?> getDatasetFromMapCollection(Collection<Map<String, Object>> collection, String str) {
        return getDatasetFromMapCollection(collection, str, (List) collection.iterator().next().keySet().stream().sorted().collect(Collectors.toList()));
    }

    public static ILabeledDataset<?> getDatasetFromMapCollection(Collection<Map<String, Object>> collection, String str, List<String> list) {
        HashSet hashSet = new HashSet(list);
        Iterator<Map<String, Object>> it = collection.iterator();
        while (it.hasNext()) {
            if (!hashSet.equals(it.next().keySet())) {
                throw new IllegalStateException();
            }
        }
        ArrayList arrayList = new ArrayList();
        for (String str2 : list) {
            if (!str2.equals(str)) {
                Object obj = collection.iterator().next().get(str2);
                if (obj instanceof Number) {
                    arrayList.add(new NumericAttribute(str2));
                } else {
                    if (!(obj instanceof Boolean)) {
                        throw new UnsupportedOperationException();
                    }
                    arrayList.add(new IntBasedCategoricalAttribute(str2, Arrays.asList("false", "true")));
                }
            }
        }
        LabeledInstanceSchema labeledInstanceSchema = new LabeledInstanceSchema("rel", arrayList, new NumericAttribute(str));
        Dataset dataset = new Dataset(labeledInstanceSchema);
        Iterator<Map<String, Object>> it2 = collection.iterator();
        while (it2.hasNext()) {
            dataset.add(getInstanceFromMap(labeledInstanceSchema, it2.next(), str));
        }
        return dataset;
    }

    public static ILabeledInstance getInstanceFromMap(ILabeledInstanceSchema iLabeledInstanceSchema, Map<String, Object> map, String str) {
        return getInstanceFromMap(iLabeledInstanceSchema, map, str, new HashMap());
    }

    public static ILabeledInstance getInstanceFromMap(ILabeledInstanceSchema iLabeledInstanceSchema, Map<String, Object> map, String str, Map<IAttribute, Function<ILabeledInstance, Double>> map2) {
        ArrayList arrayList = new ArrayList(iLabeledInstanceSchema.getNumAttributes());
        ArrayList arrayList2 = new ArrayList();
        int i = 0;
        for (IAttribute iAttribute : iLabeledInstanceSchema.getAttributeList()) {
            if (map.containsKey(iAttribute.getName())) {
                arrayList.add(map.get(iAttribute.getName()));
            } else {
                arrayList2.add(Integer.valueOf(i));
                arrayList.add(null);
            }
            i++;
        }
        DenseInstance denseInstance = new DenseInstance(arrayList, map.get(str));
        if (denseInstance.getNumAttributes() != iLabeledInstanceSchema.getNumAttributes()) {
            throw new IllegalStateException("Created dense instance with " + denseInstance.getNumAttributes() + " attributes where the scheme requires " + iLabeledInstanceSchema.getNumAttributes());
        }
        Iterator it = arrayList2.iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            denseInstance.setAttributeValue(intValue, map2.get(iLabeledInstanceSchema.getAttribute(intValue)).apply(denseInstance));
        }
        return denseInstance;
    }

    public static Pair<List<IAttribute>, Map<IAttribute, Function<ILabeledInstance, Double>>> getPairOfNewAttributesAndExpansionMap(ILabeledDataset<?> iLabeledDataset, int... iArr) throws InterruptedException {
        List attributeList = iLabeledDataset.getInstanceSchema().getAttributeList();
        ArrayList arrayList = new ArrayList();
        boolean z = false;
        boolean z2 = false;
        boolean z3 = false;
        for (int i : iArr) {
            switch (i) {
                case EXPANSION_SQUARES /* 1 */:
                    z = true;
                    break;
                case EXPANSION_LOGARITHM /* 2 */:
                    z3 = true;
                    break;
                case EXPANSION_PRODUCTS /* 3 */:
                    z2 = true;
                    break;
                default:
                    throw new UnsupportedOperationException("Unknown expansion " + i);
            }
        }
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < iLabeledDataset.getNumAttributes(); i2++) {
            int i3 = i2;
            IAttribute attribute = iLabeledDataset.getAttribute(i2);
            if (z && (attribute instanceof INumericAttribute)) {
                NumericAttribute numericAttribute = new NumericAttribute(attribute.getName() + "_2");
                arrayList.add(numericAttribute);
                hashMap.put(numericAttribute, iLabeledInstance -> {
                    return Double.valueOf(Math.pow(Double.parseDouble(iLabeledInstance.getAttributeValue(i3).toString()), 2.0d));
                });
            } else if (z3 && (attribute instanceof INumericAttribute)) {
                NumericAttribute numericAttribute2 = new NumericAttribute(attribute.getName() + "_log");
                arrayList.add(numericAttribute2);
                hashMap.put(numericAttribute2, iLabeledInstance2 -> {
                    return Double.valueOf(Math.log(((Double) iLabeledInstance2.getAttributeValue(i3)).doubleValue()));
                });
            }
        }
        if (z2) {
            for (Collection collection : SetUtil.powerset(attributeList)) {
                if (collection.size() <= 3 && collection.size() >= 2) {
                    StringBuilder sb = new StringBuilder("x");
                    ArrayList arrayList2 = new ArrayList();
                    for (IAttribute iAttribute : (List) collection.stream().sorted((iAttribute2, iAttribute3) -> {
                        return iAttribute2.getName().compareTo(iAttribute3.getName());
                    }).collect(Collectors.toList())) {
                        sb.append("_" + iAttribute.getName());
                        arrayList2.add(Integer.valueOf(attributeList.indexOf(iAttribute)));
                    }
                    NumericAttribute numericAttribute3 = new NumericAttribute(sb.toString());
                    if (attributeList.contains(numericAttribute3)) {
                        throw new IllegalStateException("Dataset already has attribute " + numericAttribute3.getName());
                    }
                    if (arrayList.contains(numericAttribute3)) {
                        throw new IllegalStateException("Already added attribute " + numericAttribute3.getName());
                    }
                    arrayList.add(numericAttribute3);
                    hashMap.put(numericAttribute3, iLabeledInstance3 -> {
                        double d = 1.0d;
                        Iterator it = arrayList2.iterator();
                        while (it.hasNext()) {
                            d *= Double.parseDouble(iLabeledInstance3.getAttributeValue(((Integer) it.next()).intValue()).toString());
                        }
                        return Double.valueOf(d);
                    });
                }
            }
        }
        return new Pair<>(arrayList, hashMap);
    }

    public static ILabeledDataset<?> getExpansionOfDataset(ILabeledDataset<?> iLabeledDataset, int... iArr) throws InterruptedException {
        return getExpansionOfDataset(iLabeledDataset, getPairOfNewAttributesAndExpansionMap(iLabeledDataset, iArr));
    }

    public static ILabeledDataset<?> getExpansionOfDataset(ILabeledDataset<?> iLabeledDataset, Pair<List<IAttribute>, Map<IAttribute, Function<ILabeledInstance, Double>>> pair) {
        ArrayList arrayList = new ArrayList(iLabeledDataset.getInstanceSchema().getAttributeList());
        arrayList.addAll((Collection) pair.getX());
        Dataset dataset = new Dataset(new LabeledInstanceSchema(iLabeledDataset.getRelationName() + "_expansion", arrayList, iLabeledDataset.getLabelAttribute()));
        Iterator it = iLabeledDataset.iterator();
        while (it.hasNext()) {
            dataset.add(getExpansionOfInstance((ILabeledInstance) it.next(), pair));
        }
        return dataset;
    }

    public static ILabeledInstance getExpansionOfInstance(ILabeledInstance iLabeledInstance, Pair<List<IAttribute>, Map<IAttribute, Function<ILabeledInstance, Double>>> pair) {
        ArrayList arrayList = new ArrayList(Arrays.asList(iLabeledInstance.getAttributes()));
        Iterator it = ((List) pair.getX()).iterator();
        while (it.hasNext()) {
            arrayList.add(((Function) ((Map) pair.getY()).get((IAttribute) it.next())).apply(iLabeledInstance));
        }
        return new DenseInstance(arrayList, iLabeledInstance.getLabel());
    }
}
