package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.util.MathUtil;
import io.github.jbellis.jvector.vector.Matrix;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;

/* loaded from: input_file:io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.class */
public class KMeansPlusPlusClusterer {
    private static final VectorTypeSupport vectorTypeSupport;
    public static final float UNWEIGHTED = -1.0f;
    private final int k;
    private final VectorFloat<?>[] points;
    private final int[] assignments;
    private final VectorFloat<?> centroids;
    private final float anisotropicThreshold;
    private final int[] centroidDenoms;
    private final VectorFloat<?>[] centroidNums;
    static final /* synthetic */ boolean $assertionsDisabled;

    public KMeansPlusPlusClusterer(VectorFloat<?>[] vectorFloatArr, int i) {
        this(vectorFloatArr, chooseInitialCentroids(vectorFloatArr, i), -1.0f);
    }

    public KMeansPlusPlusClusterer(VectorFloat<?>[] vectorFloatArr, int i, float f) {
        this(vectorFloatArr, chooseInitialCentroids(vectorFloatArr, i), f);
    }

    public KMeansPlusPlusClusterer(VectorFloat<?>[] vectorFloatArr, VectorFloat<?> vectorFloat, float f) {
        if (Float.isNaN(f) || f < -1.0d || f >= 1.0d) {
            throw new IllegalArgumentException("Valid range for anisotropic threshold T is -1.0 <= t < 1.0");
        }
        this.points = vectorFloatArr;
        this.k = vectorFloat.length() / vectorFloatArr[0].length();
        this.centroids = vectorFloat.copy();
        this.anisotropicThreshold = f;
        this.centroidDenoms = new int[this.k];
        this.centroidNums = new VectorFloat[this.k];
        for (int i = 0; i < this.centroidNums.length; i++) {
            this.centroidNums[i] = vectorTypeSupport.createFloatVector(vectorFloatArr[0].length());
        }
        this.assignments = new int[vectorFloatArr.length];
        initializeAssignedPoints();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static float computeParallelCostMultiplier(double d, int i) {
        if (!$assertionsDisabled && !Double.isFinite(d)) {
            throw new AssertionError("threshold=" + d);
        }
        double d2 = d * d;
        return (float) Math.max(1.0d, d2 / ((1.0d - d2) / (i - 1)));
    }

    public VectorFloat<?> cluster(int i, int i2) {
        for (int i3 = 0; i3 < i && clusterOnceUnweighted() > 0.01d * this.points.length; i3++) {
        }
        for (int i4 = 0; i4 < i2 && clusterOnceAnisotropic() > 0.01d * this.points.length; i4++) {
        }
        return this.centroids;
    }

    public int clusterOnceUnweighted() {
        updateCentroidsUnweighted();
        return updateAssignedPointsUnweighted();
    }

    public int clusterOnceAnisotropic() {
        updateCentroidsAnisotropic();
        return updateAssignedPointsAnisotropic();
    }

    private static VectorFloat<?> chooseInitialCentroids(VectorFloat<?>[] vectorFloatArr, int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Number of clusters must be positive.");
        }
        if (i > vectorFloatArr.length) {
            throw new IllegalArgumentException(String.format("Number of clusters %d cannot exceed number of points %d", Integer.valueOf(i), Integer.valueOf(vectorFloatArr.length)));
        }
        ThreadLocalRandom current = ThreadLocalRandom.current();
        VectorFloat<?> createFloatVector = vectorTypeSupport.createFloatVector(i * vectorFloatArr[0].length());
        float[] fArr = new float[vectorFloatArr.length];
        Arrays.fill(fArr, Float.MAX_VALUE);
        VectorFloat<?> vectorFloat = vectorFloatArr[current.nextInt(vectorFloatArr.length)];
        createFloatVector.copyFrom(vectorFloat, 0, 0, vectorFloat.length());
        for (int i2 = 0; i2 < vectorFloatArr.length; i2++) {
            fArr[i2] = Math.min(fArr[i2], VectorUtil.squareL2Distance(vectorFloatArr[i2], vectorFloat));
        }
        for (int i3 = 1; i3 < i; i3++) {
            float f = 0.0f;
            for (float f2 : fArr) {
                f += f2;
            }
            float nextFloat = current.nextFloat() * f;
            int i4 = -1;
            int i5 = 0;
            while (true) {
                if (i5 >= fArr.length) {
                    break;
                }
                nextFloat -= fArr[i5];
                if (nextFloat < 1.0E-6d) {
                    i4 = i5;
                    break;
                }
                i5++;
            }
            if (i4 == -1) {
                i4 = current.nextInt(vectorFloatArr.length);
            }
            VectorFloat<?> vectorFloat2 = vectorFloatArr[i4];
            createFloatVector.copyFrom(vectorFloat2, 0, i3 * vectorFloat2.length(), vectorFloat2.length());
            for (int i6 = 0; i6 < vectorFloatArr.length; i6++) {
                fArr[i6] = Math.min(fArr[i6], VectorUtil.squareL2Distance(vectorFloatArr[i6], vectorFloat2));
            }
        }
        assertFinite(createFloatVector);
        return createFloatVector;
    }

    private void initializeAssignedPoints() {
        for (int i = 0; i < this.points.length; i++) {
            VectorFloat<?> vectorFloat = this.points[i];
            int nearestCluster = getNearestCluster(vectorFloat);
            this.centroidDenoms[nearestCluster] = this.centroidDenoms[nearestCluster] + 1;
            VectorUtil.addInPlace(this.centroidNums[nearestCluster], vectorFloat);
            this.assignments[i] = nearestCluster;
        }
    }

    private int updateAssignedPointsUnweighted() {
        int i = 0;
        for (int i2 = 0; i2 < this.points.length; i2++) {
            VectorFloat<?> vectorFloat = this.points[i2];
            int i3 = this.assignments[i2];
            int nearestCluster = getNearestCluster(vectorFloat);
            if (nearestCluster != i3) {
                this.centroidDenoms[i3] = this.centroidDenoms[i3] - 1;
                VectorUtil.subInPlace(this.centroidNums[i3], vectorFloat);
                this.centroidDenoms[nearestCluster] = this.centroidDenoms[nearestCluster] + 1;
                VectorUtil.addInPlace(this.centroidNums[nearestCluster], vectorFloat);
                this.assignments[i2] = nearestCluster;
                i++;
            }
        }
        return i;
    }

    private int updateAssignedPointsAnisotropic() {
        float computeParallelCostMultiplier = computeParallelCostMultiplier(this.anisotropicThreshold, this.points[0].length());
        float[] fArr = new float[this.k];
        for (int i = 0; i < this.k; i++) {
            fArr[i] = VectorUtil.dotProduct(this.centroids, i * this.points[0].length(), this.centroids, i * this.points[0].length(), this.points[0].length());
        }
        int i2 = 0;
        for (int i3 = 0; i3 < this.points.length; i3++) {
            VectorFloat<?> vectorFloat = this.points[i3];
            float dotProduct = VectorUtil.dotProduct(vectorFloat, vectorFloat);
            int i4 = this.assignments[i3];
            float f = Float.MAX_VALUE;
            for (int i5 = 0; i5 < this.k; i5++) {
                float weightedDistance = weightedDistance(vectorFloat, i5, computeParallelCostMultiplier, fArr[i5], dotProduct);
                if (weightedDistance < f) {
                    f = weightedDistance;
                    i4 = i5;
                }
            }
            if (i4 != this.assignments[i3]) {
                i2++;
                this.assignments[i3] = i4;
            }
        }
        return i2;
    }

    private float weightedDistance(VectorFloat<?> vectorFloat, int i, float f, float f2, float f3) {
        float dotProduct = VectorUtil.dotProduct(this.centroids, i * vectorFloat.length(), vectorFloat, 0, vectorFloat.length());
        float f4 = dotProduct - f3;
        float f5 = (f2 - (2.0f * dotProduct)) + f3;
        float square = MathUtil.square(f4);
        return (f * square) + (f5 - square);
    }

    private int getNearestCluster(VectorFloat<?> vectorFloat) {
        float f = Float.MAX_VALUE;
        int i = 0;
        for (int i2 = 0; i2 < this.k; i2++) {
            float squareL2Distance = VectorUtil.squareL2Distance(vectorFloat, 0, this.centroids, i2 * vectorFloat.length(), vectorFloat.length());
            if (squareL2Distance < f) {
                f = squareL2Distance;
                i = i2;
            }
        }
        return i;
    }

    private static void assertFinite(VectorFloat<?> vectorFloat) {
        boolean z = false;
        if (!$assertionsDisabled) {
            z = true;
            if (1 == 0) {
                throw new AssertionError();
            }
        }
        if (z) {
            for (int i = 0; i < vectorFloat.length(); i++) {
                if (!$assertionsDisabled && !Float.isFinite(vectorFloat.get(i))) {
                    throw new AssertionError("vector " + String.valueOf(vectorFloat) + " contains non-finite value");
                }
            }
        }
    }

    private void updateCentroidsUnweighted() {
        for (int i = 0; i < this.k; i++) {
            if (this.centroidDenoms[i] == 0) {
                initializeCentroidToRandomPoint(i);
            } else {
                VectorFloat<?> copy = this.centroidNums[i].copy();
                VectorUtil.scale(copy, 1.0f / this.centroidDenoms[i]);
                this.centroids.copyFrom(copy, 0, i * copy.length(), copy.length());
            }
        }
    }

    private void initializeCentroidToRandomPoint(int i) {
        this.centroids.copyFrom(this.points[ThreadLocalRandom.current().nextInt(this.points.length)], 0, i * this.points[0].length(), this.points[0].length());
    }

    private void updateCentroidsAnisotropic() {
        int length = this.points[0].length();
        float computeParallelCostMultiplier = 1.0f / computeParallelCostMultiplier(this.anisotropicThreshold, length);
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.assignments.length; i++) {
            ((List) hashMap.computeIfAbsent(Integer.valueOf(this.assignments[i]), num -> {
                return new ArrayList();
            })).add(Integer.valueOf(i));
        }
        for (int i2 = 0; i2 < this.k; i2++) {
            List list = (List) hashMap.getOrDefault(Integer.valueOf(i2), List.of());
            if (list.isEmpty()) {
                initializeCentroidToRandomPoint(i2);
            } else {
                VectorFloat<?> createFloatVector = vectorTypeSupport.createFloatVector(length);
                Matrix matrix = new Matrix(length, length);
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    VectorFloat<?> vectorFloat = this.points[((Integer) it.next()).intValue()];
                    VectorUtil.addInPlace(createFloatVector, vectorFloat);
                    float dotProduct = VectorUtil.dotProduct(vectorFloat, vectorFloat);
                    if (dotProduct > 0.0f) {
                        Matrix outerProduct = Matrix.outerProduct(vectorFloat, vectorFloat);
                        outerProduct.scale(1.0f / dotProduct);
                        matrix.addInPlace(outerProduct);
                    }
                }
                matrix.scale((1.0f - computeParallelCostMultiplier) / list.size());
                VectorUtil.scale(createFloatVector, 1.0f / list.size());
                for (int i3 = 0; i3 < length; i3++) {
                    matrix.addTo(i3, i3, computeParallelCostMultiplier);
                }
                this.centroids.copyFrom(matrix.invert().multiply(createFloatVector), 0, i2 * length, length);
            }
        }
    }

    public static VectorFloat<?> centroidOf(List<VectorFloat<?>> list) {
        if (list.isEmpty()) {
            throw new IllegalArgumentException("Can't compute centroid of empty points list");
        }
        VectorFloat<?> sum = VectorUtil.sum(list);
        VectorUtil.scale(sum, 1.0f / list.size());
        return sum;
    }

    public VectorFloat<?> getCentroids() {
        return this.centroids;
    }

    static {
        $assertionsDisabled = !KMeansPlusPlusClusterer.class.desiredAssertionStatus();
        vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
    }
}
