package io.github.jbellis.jvector.quantization;

import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.quantization.NVQuantization;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.types.VectorFloat;

/* loaded from: input_file:io/github/jbellis/jvector/quantization/NVQScorer.class */
public class NVQScorer {
    final NVQuantization nvq;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.github.jbellis.jvector.quantization.NVQScorer$1, reason: invalid class name */
    /* loaded from: input_file:io/github/jbellis/jvector/quantization/NVQScorer$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction;
        static final /* synthetic */ int[] $SwitchMap$io$github$jbellis$jvector$quantization$NVQuantization$BitsPerDimension = new int[NVQuantization.BitsPerDimension.values().length];

        static {
            try {
                $SwitchMap$io$github$jbellis$jvector$quantization$NVQuantization$BitsPerDimension[NVQuantization.BitsPerDimension.EIGHT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction = new int[VectorSimilarityFunction.values().length];
            try {
                $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[VectorSimilarityFunction.DOT_PRODUCT.ordinal()] = 1;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[VectorSimilarityFunction.EUCLIDEAN.ordinal()] = 2;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[VectorSimilarityFunction.COSINE.ordinal()] = 3;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* loaded from: input_file:io/github/jbellis/jvector/quantization/NVQScorer$NVQScoreFunction.class */
    public interface NVQScoreFunction {
        float similarityTo(NVQuantization.QuantizedVector quantizedVector);
    }

    public NVQScorer(NVQuantization nVQuantization) {
        this.nvq = nVQuantization;
    }

    public NVQScoreFunction scoreFunctionFor(VectorFloat<?> vectorFloat, VectorSimilarityFunction vectorSimilarityFunction) {
        switch (AnonymousClass1.$SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[vectorSimilarityFunction.ordinal()]) {
            case RamUsageEstimator.MAX_DEPTH /* 1 */:
                return dotProductScoreFunctionFor(vectorFloat);
            case 2:
                return euclideanScoreFunctionFor(vectorFloat);
            case OnDiskGraphIndex.CURRENT_VERSION /* 3 */:
                return cosineScoreFunctionFor(vectorFloat);
            default:
                throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf(vectorSimilarityFunction));
        }
    }

    private NVQScoreFunction dotProductScoreFunctionFor(VectorFloat<?> vectorFloat) {
        float dotProduct = VectorUtil.dotProduct(vectorFloat, this.nvq.globalMean);
        VectorFloat<?>[] subVectors = this.nvq.getSubVectors(vectorFloat);
        switch (AnonymousClass1.$SwitchMap$io$github$jbellis$jvector$quantization$NVQuantization$BitsPerDimension[this.nvq.bitsPerDimension.ordinal()]) {
            case RamUsageEstimator.MAX_DEPTH /* 1 */:
                for (VectorFloat<?> vectorFloat2 : subVectors) {
                    VectorUtil.nvqShuffleQueryInPlace8bit(vectorFloat2);
                }
                return quantizedVector -> {
                    float f = 0.0f;
                    for (int i = 0; i < subVectors.length; i++) {
                        NVQuantization.QuantizedSubVector quantizedSubVector = quantizedVector.subVectors[i];
                        f += VectorUtil.nvqDotProduct8bit(subVectors[i], quantizedSubVector.bytes, quantizedSubVector.growthRate, quantizedSubVector.midpoint, quantizedSubVector.minValue, quantizedSubVector.maxValue);
                    }
                    return ((1.0f + f) + dotProduct) / 2.0f;
                };
            default:
                throw new IllegalArgumentException("Unsupported bits per dimension " + String.valueOf(this.nvq.bitsPerDimension));
        }
    }

    private NVQScoreFunction euclideanScoreFunctionFor(VectorFloat<?> vectorFloat) {
        VectorFloat<?>[] subVectors = this.nvq.getSubVectors(VectorUtil.sub(vectorFloat, this.nvq.globalMean));
        switch (AnonymousClass1.$SwitchMap$io$github$jbellis$jvector$quantization$NVQuantization$BitsPerDimension[this.nvq.bitsPerDimension.ordinal()]) {
            case RamUsageEstimator.MAX_DEPTH /* 1 */:
                for (VectorFloat<?> vectorFloat2 : subVectors) {
                    VectorUtil.nvqShuffleQueryInPlace8bit(vectorFloat2);
                }
                return quantizedVector -> {
                    float f = 0.0f;
                    for (int i = 0; i < subVectors.length; i++) {
                        NVQuantization.QuantizedSubVector quantizedSubVector = quantizedVector.subVectors[i];
                        f += VectorUtil.nvqSquareL2Distance8bit(subVectors[i], quantizedSubVector.bytes, quantizedSubVector.growthRate, quantizedSubVector.midpoint, quantizedSubVector.minValue, quantizedSubVector.maxValue);
                    }
                    return 1.0f / (1.0f + f);
                };
            default:
                throw new IllegalArgumentException("Unsupported bits per dimension " + String.valueOf(this.nvq.bitsPerDimension));
        }
    }

    private NVQScoreFunction cosineScoreFunctionFor(VectorFloat<?> vectorFloat) {
        float sqrt = (float) Math.sqrt(VectorUtil.dotProduct(vectorFloat, vectorFloat));
        VectorFloat<?>[] subVectors = this.nvq.getSubVectors(vectorFloat);
        VectorFloat<?>[] subVectors2 = this.nvq.getSubVectors(this.nvq.globalMean);
        switch (AnonymousClass1.$SwitchMap$io$github$jbellis$jvector$quantization$NVQuantization$BitsPerDimension[this.nvq.bitsPerDimension.ordinal()]) {
            case RamUsageEstimator.MAX_DEPTH /* 1 */:
                for (int i = 0; i < subVectors.length; i++) {
                    VectorUtil.nvqShuffleQueryInPlace8bit(subVectors[i]);
                    VectorUtil.nvqShuffleQueryInPlace8bit(subVectors2[i]);
                }
                return quantizedVector -> {
                    float f = 0.0f;
                    float f2 = 0.0f;
                    for (int i2 = 0; i2 < subVectors.length; i2++) {
                        NVQuantization.QuantizedSubVector quantizedSubVector = quantizedVector.subVectors[i2];
                        float[] nvqCosine8bit = VectorUtil.nvqCosine8bit(subVectors[i2], quantizedSubVector.bytes, quantizedSubVector.growthRate, quantizedSubVector.midpoint, quantizedSubVector.minValue, quantizedSubVector.maxValue, subVectors2[i2]);
                        f += nvqCosine8bit[0];
                        f2 += nvqCosine8bit[1];
                    }
                    return (1.0f + ((f / sqrt) / ((float) Math.sqrt(f2)))) / 2.0f;
                };
            default:
                throw new IllegalArgumentException("Unsupported bits per dimension " + String.valueOf(this.nvq.bitsPerDimension));
        }
    }
}
