package ai.libs.jaicore.ml.clustering.learner;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.ml.clustering.CentroidCluster;
import org.apache.commons.math3.ml.clustering.Clusterable;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.ManhattanDistance;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;

/* loaded from: input_file:ai/libs/jaicore/ml/clustering/learner/GMeans.class */
public class GMeans<C extends Clusterable> {
    private List<double[]> center;
    private List<CentroidCluster<C>> gmeansCluster;
    private Map<double[], List<C>> currentPoints;
    private Map<double[], List<C>> intermediatePoints;
    private List<C> points;
    private DistanceMeasure distanceMeasure;
    private RandomGenerator randomGenerator;

    public GMeans(Collection<C> collection) {
        this(collection, new ManhattanDistance(), 1L);
    }

    public GMeans(Collection<C> collection, DistanceMeasure distanceMeasure, long j) {
        this.center = new ArrayList();
        this.gmeansCluster = new ArrayList();
        this.currentPoints = new HashMap();
        this.intermediatePoints = new HashMap();
        this.points = new ArrayList();
        this.distanceMeasure = new ManhattanDistance();
        this.points = new ArrayList(collection);
        this.distanceMeasure = distanceMeasure;
        this.gmeansCluster = new ArrayList();
        this.randomGenerator = new JDKRandomGenerator();
        this.randomGenerator.setSeed(j);
    }

    public List<CentroidCluster<C>> cluster() {
        HashMap hashMap = new HashMap();
        int i = 1;
        int i2 = 1;
        int i3 = 1;
        for (CentroidCluster centroidCluster : new KMeansPlusPlusClusterer(1, -1, this.distanceMeasure, this.randomGenerator).cluster(this.points)) {
            this.currentPoints.put(centroidCluster.getCenter().getPoint(), centroidCluster.getPoints());
        }
        Iterator<double[]> it = this.currentPoints.keySet().iterator();
        while (it.hasNext()) {
            this.center.add(it.next());
        }
        Iterator<double[]> it2 = this.center.iterator();
        while (it2.hasNext()) {
            hashMap.put(Integer.valueOf(i), it2.next());
            i++;
        }
        while (i3 <= i2) {
            List<C> list = this.currentPoints.get(hashMap.get(Integer.valueOf(i3)));
            KMeansPlusPlusClusterer kMeansPlusPlusClusterer = new KMeansPlusPlusClusterer(2, -1, this.distanceMeasure, this.randomGenerator);
            ArrayList arrayList = new ArrayList(2);
            if (list.size() < 2) {
                break;
            }
            for (CentroidCluster centroidCluster2 : kMeansPlusPlusClusterer.cluster(list)) {
                this.intermediatePoints.put(centroidCluster2.getCenter().getPoint(), centroidCluster2.getPoints());
                arrayList.add(centroidCluster2.getCenter().getPoint());
            }
            double[] difference = difference((double[]) arrayList.get(0), (double[]) arrayList.get(1));
            double d = 0.0d;
            for (int i4 = 0; i4 < difference.length; i4++) {
                if (!Double.isNaN(difference[i4])) {
                    d += Math.pow(difference[i4], 2.0d);
                }
            }
            if (d == 0.0d) {
                throw new IllegalStateException("All entries in v are NaN, cannot compute w!");
            }
            double[] dArr = new double[list.size()];
            for (int i5 = 0; i5 < list.size(); i5++) {
                for (int i6 = 0; i6 < list.get(i5).getPoint().length; i6++) {
                    if (!Double.isNaN(list.get(i5).getPoint()[i6]) && !Double.isNaN(difference[i6])) {
                        int i7 = i5;
                        dArr[i7] = dArr[i7] + ((difference[i6] * list.get(i5).getPoint()[i6]) / d);
                    }
                }
            }
            if (andersonDarlingTest(dArr)) {
                i3++;
            } else {
                this.currentPoints.remove(hashMap.get(Integer.valueOf(i3)));
                this.currentPoints.put((double[]) arrayList.get(0), this.intermediatePoints.get(arrayList.get(0)));
                hashMap.replace(Integer.valueOf(i3), (double[]) arrayList.get(0));
                i2++;
                this.currentPoints.put((double[]) arrayList.get(1), this.intermediatePoints.get(arrayList.get(1)));
                hashMap.put(Integer.valueOf(i2), (double[]) arrayList.get(1));
            }
        }
        mergeCluster(this.currentPoints);
        for (Map.Entry<double[], List<C>> entry : this.currentPoints.entrySet()) {
            List<C> list2 = this.currentPoints.get(entry.getKey());
            Objects.requireNonNull(entry);
            CentroidCluster<C> centroidCluster3 = new CentroidCluster<>(entry::getKey);
            Iterator<C> it3 = list2.iterator();
            while (it3.hasNext()) {
                centroidCluster3.addPoint(it3.next());
            }
            this.gmeansCluster.add(centroidCluster3);
        }
        return this.gmeansCluster;
    }

    protected void mergeCluster(Map<double[], List<C>> map) {
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<double[], List<C>> entry : map.entrySet()) {
            if (map.get(entry.getKey()).size() <= 2) {
                arrayList.add(entry.getKey());
            }
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            for (C c : map.remove((double[]) it.next())) {
                double d = Double.MAX_VALUE;
                double[] dArr = null;
                for (double[] dArr2 : map.keySet()) {
                    double compute = this.distanceMeasure.compute(c.getPoint(), dArr2);
                    if (compute <= d) {
                        dArr = dArr2;
                        d = compute;
                    }
                }
                map.get(dArr).add(c);
            }
        }
    }

    protected boolean andersonDarlingTest(double[] dArr) {
        Arrays.sort(dArr);
        double d = 0.0d;
        double d2 = 0.0d;
        int i = 0;
        for (double d3 : dArr) {
            if (!Double.isNaN(d3)) {
                i++;
                d += d3;
            }
        }
        double d4 = d / i;
        int i2 = 0;
        for (double d5 : dArr) {
            if (!Double.isNaN(d5)) {
                d2 += Math.pow(d5 - d4, 2.0d);
                i2++;
            }
        }
        double[] standraizeRandomVariable = standraizeRandomVariable(dArr, d4, d2 / (i2 - 1));
        double length = (-1.0d) * standraizeRandomVariable.length;
        double d6 = 0.0d;
        NormalDistribution normalDistribution = new NormalDistribution((RandomGenerator) null, 0.0d, 1.0d);
        for (int i3 = 1; i3 < standraizeRandomVariable.length; i3++) {
            if (!Double.isNaN(standraizeRandomVariable[i3])) {
                d6 += ((2 * i3) - 1) * (Math.log(normalDistribution.cumulativeProbability(standraizeRandomVariable[i3 - 1])) + Math.log(1.0d - normalDistribution.cumulativeProbability(standraizeRandomVariable[standraizeRandomVariable.length - i3])));
            }
        }
        double length2 = length - (d6 / standraizeRandomVariable.length);
        return standraizeRandomVariable.length <= 10 ? length2 <= 0.683d : standraizeRandomVariable.length <= 20 ? length2 <= 0.704d : standraizeRandomVariable.length <= 50 ? length2 <= 0.735d : standraizeRandomVariable.length <= 100 ? length2 <= 0.754d : length2 <= 0.787d;
    }

    private double[] standraizeRandomVariable(double[] dArr, double d, double d2) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr2.length; i++) {
            if (Double.isNaN(dArr[i])) {
                dArr2[i] = Double.NaN;
            } else {
                dArr2[i] = (dArr[i] - d) / Math.sqrt(d2);
            }
        }
        return dArr2;
    }

    protected double[] difference(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            if (Double.isNaN(dArr[i]) || Double.isNaN(dArr2[i])) {
                dArr3[i] = Double.NaN;
            } else {
                dArr3[i] = dArr[i] - dArr2[i];
            }
        }
        return dArr3;
    }

    public List<CentroidCluster<C>> getGmeansCluster() {
        return this.gmeansCluster;
    }

    protected List<double[]> getCentersModifiable() {
        return this.center;
    }

    public List<C> getPoints() {
        return this.points;
    }
}
