package io.github.jbellis.jvector.quantization;

import io.github.jbellis.jvector.graph.disk.FusedADC;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.util.Arrays;

/* loaded from: input_file:io/github/jbellis/jvector/quantization/FusedADCPQDecoder.class */
public abstract class FusedADCPQDecoder implements ScoreFunction.ApproximateScoreFunction {
    private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
    protected final ProductQuantization pq;
    protected final VectorFloat<?> query;
    protected final ScoreFunction.ExactScoreFunction esf;
    protected final ByteSequence<?> partialQuantizedSums;
    protected final FusedADC.PackedNeighbors neighbors;
    protected final VectorFloat<?> results;
    protected final VectorFloat<?> partialSums;
    protected final VectorFloat<?> partialBestDistances;
    protected final int invocationThreshold;
    protected float bestDistance;
    protected float worstDistance;
    protected float delta;
    protected final VectorSimilarityFunction vsf;
    protected int invocations = 0;
    protected boolean supportsQuantizedSimilarity = false;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.github.jbellis.jvector.quantization.FusedADCPQDecoder$1, reason: invalid class name */
    /* loaded from: input_file:io/github/jbellis/jvector/quantization/FusedADCPQDecoder$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction = new int[VectorSimilarityFunction.values().length];

        static {
            try {
                $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[VectorSimilarityFunction.DOT_PRODUCT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[VectorSimilarityFunction.EUCLIDEAN.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[VectorSimilarityFunction.COSINE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/github/jbellis/jvector/quantization/FusedADCPQDecoder$CosineDecoder.class */
    public static class CosineDecoder extends FusedADCPQDecoder {
        private final float queryMagnitudeSquared;
        private final VectorFloat<?> partialSquaredMagnitudes;
        private final ByteSequence<?> partialQuantizedSquaredMagnitudes;
        private final float[] resultSumAggregates;
        private final float[] resultMagnitudeAggregates;
        private float minSquaredMagnitude;
        private float squaredMagnitudeDelta;

        protected CosineDecoder(FusedADC.PackedNeighbors packedNeighbors, ProductQuantization productQuantization, VectorFloat<?> vectorFloat, VectorFloat<?> vectorFloat2, ScoreFunction.ExactScoreFunction exactScoreFunction) {
            super(productQuantization, vectorFloat, packedNeighbors.maxDegree(), packedNeighbors, vectorFloat2, exactScoreFunction, VectorSimilarityFunction.COSINE);
            this.worstDistance = Float.MAX_VALUE;
            this.partialSquaredMagnitudes = productQuantization.partialSquaredMagnitudes().updateAndGet(vectorFloat3 -> {
                if (vectorFloat3 != null) {
                    this.squaredMagnitudeDelta = productQuantization.squaredMagnitudeDelta;
                    this.minSquaredMagnitude = productQuantization.minSquaredMagnitude;
                    return vectorFloat3;
                }
                float f = 0.0f;
                VectorFloat<?> createFloatVector = FusedADCPQDecoder.vts.createFloatVector(productQuantization.getSubspaceCount());
                VectorFloat<?> createFloatVector2 = FusedADCPQDecoder.vts.createFloatVector(productQuantization.getSubspaceCount() * productQuantization.getClusterCount());
                for (int i = 0; i < productQuantization.getSubspaceCount(); i++) {
                    int i2 = productQuantization.subvectorSizesAndOffsets[i][0];
                    VectorFloat<?> vectorFloat3 = productQuantization.codebooks[i];
                    float f2 = Float.POSITIVE_INFINITY;
                    float f3 = 0.0f;
                    for (int i3 = 0; i3 < productQuantization.getClusterCount(); i3++) {
                        float dotProduct = VectorUtil.dotProduct(vectorFloat3, i3 * i2, vectorFloat3, i3 * i2, i2);
                        f2 = Math.min(f2, dotProduct);
                        f3 = Math.max(f3, dotProduct);
                        createFloatVector2.set((i * productQuantization.getClusterCount()) + i3, dotProduct);
                    }
                    createFloatVector.set(i, f2);
                    f += f3;
                    this.minSquaredMagnitude += f2;
                }
                this.squaredMagnitudeDelta = (f - this.minSquaredMagnitude) / 65535.0f;
                ByteSequence<?> createByteSequence = FusedADCPQDecoder.vts.createByteSequence(productQuantization.getSubspaceCount() * productQuantization.getClusterCount() * 2);
                VectorUtil.quantizePartials(this.squaredMagnitudeDelta, createFloatVector2, createFloatVector, createByteSequence);
                productQuantization.squaredMagnitudeDelta = this.squaredMagnitudeDelta;
                productQuantization.minSquaredMagnitude = this.minSquaredMagnitude;
                productQuantization.partialQuantizedSquaredMagnitudes().set(createByteSequence);
                return createFloatVector2;
            });
            this.partialQuantizedSquaredMagnitudes = productQuantization.partialQuantizedSquaredMagnitudes().get();
            VectorFloat<?> vectorFloat4 = productQuantization.globalCentroid;
            float f = 0.0f;
            VectorFloat<?> sub = vectorFloat4 == null ? vectorFloat : VectorUtil.sub(vectorFloat, vectorFloat4);
            for (int i = 0; i < productQuantization.getSubspaceCount(); i++) {
                int i2 = productQuantization.subvectorSizesAndOffsets[i][1];
                int i3 = productQuantization.subvectorSizesAndOffsets[i][0];
                VectorUtil.calculatePartialSums(productQuantization.codebooks[i], i, i3, productQuantization.getClusterCount(), sub, i2, VectorSimilarityFunction.DOT_PRODUCT, this.partialSums, this.partialBestDistances);
                f += VectorUtil.dotProduct(sub, i2, sub, i2, i3);
            }
            this.queryMagnitudeSquared = f;
            this.bestDistance = VectorUtil.sum(this.partialBestDistances);
            this.resultSumAggregates = new float[vectorFloat2.length()];
            this.resultMagnitudeAggregates = new float[vectorFloat2.length()];
        }

        @Override // io.github.jbellis.jvector.quantization.FusedADCPQDecoder, io.github.jbellis.jvector.graph.similarity.ScoreFunction
        public VectorFloat<?> edgeLoadingSimilarityTo(int i) {
            ByteSequence<?> packedNeighbors = this.neighbors.getPackedNeighbors(i);
            if (this.supportsQuantizedSimilarity) {
                this.results.zero();
                VectorUtil.bulkShuffleQuantizedSimilarityCosine(packedNeighbors, this.pq.compressedVectorSize(), this.partialQuantizedSums, this.delta, this.bestDistance, this.partialQuantizedSquaredMagnitudes, this.squaredMagnitudeDelta, this.minSquaredMagnitude, this.queryMagnitudeSquared, this.results);
                return this.results;
            }
            int length = this.results.length();
            Arrays.fill(this.resultSumAggregates, 0.0f);
            Arrays.fill(this.resultMagnitudeAggregates, 0.0f);
            for (int i2 = 0; i2 < this.pq.getSubspaceCount(); i2++) {
                for (int i3 = 0; i3 < length; i3++) {
                    float[] fArr = this.resultSumAggregates;
                    int i4 = i3;
                    fArr[i4] = fArr[i4] + this.partialSums.get((i2 * this.pq.getClusterCount()) + Byte.toUnsignedInt(packedNeighbors.get((i2 * length) + i3)));
                    float[] fArr2 = this.resultMagnitudeAggregates;
                    int i5 = i3;
                    fArr2[i5] = fArr2[i5] + this.partialSquaredMagnitudes.get((i2 * this.pq.getClusterCount()) + Byte.toUnsignedInt(packedNeighbors.get((i2 * length) + i3)));
                }
            }
            for (int i6 = 0; i6 < length; i6++) {
                updateWorstDistance(this.resultSumAggregates[i6]);
                float sqrt = this.resultSumAggregates[i6] / ((float) Math.sqrt(this.resultMagnitudeAggregates[i6] * this.queryMagnitudeSquared));
                this.invocations++;
                this.results.set(i6, distanceToScore(sqrt));
            }
            if (this.invocations >= this.invocationThreshold) {
                this.delta = (this.worstDistance - this.bestDistance) / 65535.0f;
                VectorUtil.quantizePartials(this.delta, this.partialSums, this.partialBestDistances, this.partialQuantizedSums);
                this.supportsQuantizedSimilarity = true;
            }
            return this.results;
        }

        @Override // io.github.jbellis.jvector.quantization.FusedADCPQDecoder
        protected float distanceToScore(float f) {
            return (1.0f + f) / 2.0f;
        }

        @Override // io.github.jbellis.jvector.quantization.FusedADCPQDecoder
        protected void updateWorstDistance(float f) {
            this.worstDistance = Math.min(this.worstDistance, f);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/github/jbellis/jvector/quantization/FusedADCPQDecoder$DotProductDecoder.class */
    public static class DotProductDecoder extends FusedADCPQDecoder {
        public DotProductDecoder(FusedADC.PackedNeighbors packedNeighbors, ProductQuantization productQuantization, VectorFloat<?> vectorFloat, VectorFloat<?> vectorFloat2, ScoreFunction.ExactScoreFunction exactScoreFunction) {
            super(productQuantization, vectorFloat, packedNeighbors.maxDegree(), packedNeighbors, vectorFloat2, exactScoreFunction, VectorSimilarityFunction.DOT_PRODUCT);
            this.worstDistance = Float.MAX_VALUE;
        }

        @Override // io.github.jbellis.jvector.quantization.FusedADCPQDecoder
        protected float distanceToScore(float f) {
            return (f + 1.0f) / 2.0f;
        }

        @Override // io.github.jbellis.jvector.quantization.FusedADCPQDecoder
        protected void updateWorstDistance(float f) {
            this.worstDistance = Math.min(this.worstDistance, f);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/github/jbellis/jvector/quantization/FusedADCPQDecoder$EuclideanDecoder.class */
    public static class EuclideanDecoder extends FusedADCPQDecoder {
        public EuclideanDecoder(FusedADC.PackedNeighbors packedNeighbors, ProductQuantization productQuantization, VectorFloat<?> vectorFloat, VectorFloat<?> vectorFloat2, ScoreFunction.ExactScoreFunction exactScoreFunction) {
            super(productQuantization, vectorFloat, packedNeighbors.maxDegree(), packedNeighbors, vectorFloat2, exactScoreFunction, VectorSimilarityFunction.EUCLIDEAN);
            this.worstDistance = 0.0f;
        }

        @Override // io.github.jbellis.jvector.quantization.FusedADCPQDecoder
        protected float distanceToScore(float f) {
            return 1.0f / (1.0f + f);
        }

        @Override // io.github.jbellis.jvector.quantization.FusedADCPQDecoder
        protected void updateWorstDistance(float f) {
            this.worstDistance = Math.max(this.worstDistance, f);
        }
    }

    protected FusedADCPQDecoder(ProductQuantization productQuantization, VectorFloat<?> vectorFloat, int i, FusedADC.PackedNeighbors packedNeighbors, VectorFloat<?> vectorFloat2, ScoreFunction.ExactScoreFunction exactScoreFunction, VectorSimilarityFunction vectorSimilarityFunction) {
        this.pq = productQuantization;
        this.query = vectorFloat;
        this.esf = exactScoreFunction;
        this.invocationThreshold = i;
        this.neighbors = packedNeighbors;
        this.results = vectorFloat2;
        this.vsf = vectorSimilarityFunction;
        this.partialSums = productQuantization.reusablePartialSums();
        this.partialBestDistances = productQuantization.reusablePartialBestDistances();
        if (vectorSimilarityFunction != VectorSimilarityFunction.COSINE) {
            VectorFloat<?> vectorFloat3 = productQuantization.globalCentroid;
            VectorFloat<?> sub = vectorFloat3 == null ? vectorFloat : VectorUtil.sub(vectorFloat, vectorFloat3);
            for (int i2 = 0; i2 < productQuantization.getSubspaceCount(); i2++) {
                VectorUtil.calculatePartialSums(productQuantization.codebooks[i2], i2, productQuantization.subvectorSizesAndOffsets[i2][0], productQuantization.getClusterCount(), sub, productQuantization.subvectorSizesAndOffsets[i2][1], vectorSimilarityFunction, this.partialSums, this.partialBestDistances);
            }
            this.bestDistance = VectorUtil.sum(this.partialBestDistances);
        }
        this.partialQuantizedSums = productQuantization.reusablePartialQuantizedSums();
    }

    @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
    public VectorFloat<?> edgeLoadingSimilarityTo(int i) {
        ByteSequence<?> packedNeighbors = this.neighbors.getPackedNeighbors(i);
        this.results.zero();
        if (this.supportsQuantizedSimilarity) {
            VectorUtil.bulkShuffleQuantizedSimilarity(packedNeighbors, this.pq.compressedVectorSize(), this.partialQuantizedSums, this.delta, this.bestDistance, this.results, this.vsf);
            return this.results;
        }
        int length = this.results.length();
        for (int i2 = 0; i2 < this.pq.getSubspaceCount(); i2++) {
            for (int i3 = 0; i3 < length; i3++) {
                this.results.set(i3, this.results.get(i3) + this.partialSums.get((i2 * this.pq.getClusterCount()) + Byte.toUnsignedInt(packedNeighbors.get((i2 * length) + i3))));
            }
        }
        for (int i4 = 0; i4 < length; i4++) {
            float f = this.results.get(i4);
            this.invocations++;
            updateWorstDistance(f);
            this.results.set(i4, distanceToScore(f));
        }
        if (this.invocations >= this.invocationThreshold) {
            this.delta = (this.worstDistance - this.bestDistance) / 65535.0f;
            VectorUtil.quantizePartials(this.delta, this.partialSums, this.partialBestDistances, this.partialQuantizedSums);
            this.supportsQuantizedSimilarity = true;
        }
        return this.results;
    }

    @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
    public boolean supportsEdgeLoadingSimilarity() {
        return true;
    }

    @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
    public float similarityTo(int i) {
        return this.esf.similarityTo(i);
    }

    protected abstract float distanceToScore(float f);

    protected abstract void updateWorstDistance(float f);

    public static FusedADCPQDecoder newDecoder(FusedADC.PackedNeighbors packedNeighbors, ProductQuantization productQuantization, VectorFloat<?> vectorFloat, VectorFloat<?> vectorFloat2, VectorSimilarityFunction vectorSimilarityFunction, ScoreFunction.ExactScoreFunction exactScoreFunction) {
        switch (AnonymousClass1.$SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[vectorSimilarityFunction.ordinal()]) {
            case RamUsageEstimator.MAX_DEPTH /* 1 */:
                return new DotProductDecoder(packedNeighbors, productQuantization, vectorFloat, vectorFloat2, exactScoreFunction);
            case 2:
                return new EuclideanDecoder(packedNeighbors, productQuantization, vectorFloat, vectorFloat2, exactScoreFunction);
            case OnDiskGraphIndex.CURRENT_VERSION /* 3 */:
                return new CosineDecoder(packedNeighbors, productQuantization, vectorFloat, vectorFloat2, exactScoreFunction);
            default:
                throw new IllegalArgumentException("Unsupported similarity function: " + String.valueOf(vectorSimilarityFunction));
        }
    }
}
