package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:io/github/jbellis/jvector/pq/PQDecoder.class */
public abstract class PQDecoder implements ScoreFunction.ApproximateScoreFunction {
    private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
    protected final PQVectors cv;

    /* loaded from: input_file:io/github/jbellis/jvector/pq/PQDecoder$CachingDecoder.class */
    protected static abstract class CachingDecoder extends PQDecoder {
        protected final VectorFloat<?> partialSums;

        protected CachingDecoder(PQVectors pQVectors, VectorFloat<?> vectorFloat, VectorSimilarityFunction vectorSimilarityFunction) {
            super(pQVectors);
            ProductQuantization productQuantization = this.cv.pq;
            this.partialSums = pQVectors.reusablePartialSums();
            VectorFloat<?> vectorFloat2 = productQuantization.globalCentroid;
            VectorFloat<?> sub = vectorFloat2 == null ? vectorFloat : VectorUtil.sub(vectorFloat, vectorFloat2);
            for (int i = 0; i < productQuantization.getSubspaceCount(); i++) {
                VectorUtil.calculatePartialSums(productQuantization.codebooks[i], i, productQuantization.subvectorSizesAndOffsets[i][0], productQuantization.getClusterCount(), sub, productQuantization.subvectorSizesAndOffsets[i][1], vectorSimilarityFunction, this.partialSums);
            }
        }

        protected float decodedSimilarity(ByteSequence<?> byteSequence) {
            return VectorUtil.assembleAndSum(this.partialSums, this.cv.pq.getClusterCount(), byteSequence);
        }
    }

    /* loaded from: input_file:io/github/jbellis/jvector/pq/PQDecoder$CosineDecoder.class */
    static class CosineDecoder extends PQDecoder {
        protected final VectorFloat<?> partialSums;
        protected final VectorFloat<?> aMagnitude;
        protected final float bMagnitude;

        public CosineDecoder(PQVectors pQVectors, VectorFloat<?> vectorFloat) {
            super(pQVectors);
            ProductQuantization productQuantization = this.cv.pq;
            this.aMagnitude = pQVectors.partialSquaredMagnitudes().updateAndGet(vectorFloat2 -> {
                if (vectorFloat2 != null) {
                    return vectorFloat2;
                }
                VectorFloat<?> createFloatVector = PQDecoder.vts.createFloatVector(productQuantization.getSubspaceCount() * productQuantization.getClusterCount());
                for (int i = 0; i < productQuantization.getSubspaceCount(); i++) {
                    int i2 = productQuantization.subvectorSizesAndOffsets[i][0];
                    VectorFloat<?> vectorFloat2 = productQuantization.codebooks[i];
                    for (int i3 = 0; i3 < productQuantization.getClusterCount(); i3++) {
                        createFloatVector.set((i * productQuantization.getClusterCount()) + i3, VectorUtil.dotProduct(vectorFloat2, i3 * i2, vectorFloat2, i3 * i2, i2));
                    }
                }
                return createFloatVector;
            });
            this.partialSums = pQVectors.reusablePartialSums();
            VectorFloat<?> vectorFloat3 = productQuantization.globalCentroid;
            VectorFloat<?> sub = vectorFloat3 == null ? vectorFloat : VectorUtil.sub(vectorFloat, vectorFloat3);
            for (int i = 0; i < productQuantization.getSubspaceCount(); i++) {
                int i2 = productQuantization.subvectorSizesAndOffsets[i][1];
                int i3 = productQuantization.subvectorSizesAndOffsets[i][0];
                VectorFloat<?> vectorFloat4 = productQuantization.codebooks[i];
                for (int i4 = 0; i4 < productQuantization.getClusterCount(); i4++) {
                    this.partialSums.set((i * productQuantization.getClusterCount()) + i4, VectorUtil.dotProduct(vectorFloat4, i4 * i3, sub, i2, i3));
                }
            }
            this.bMagnitude = VectorUtil.dotProduct(sub, sub);
        }

        @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
        public float similarityTo(int i) {
            return (1.0f + decodedCosine(i)) / 2.0f;
        }

        protected float decodedCosine(int i) {
            float f = 0.0f;
            float f2 = 0.0f;
            ByteSequence<?> byteSequence = this.cv.get(i);
            for (int i2 = 0; i2 < byteSequence.length(); i2++) {
                int unsignedInt = Byte.toUnsignedInt(byteSequence.get(i2));
                f += this.partialSums.get((i2 * this.cv.pq.getClusterCount()) + unsignedInt);
                f2 += this.aMagnitude.get((i2 * this.cv.pq.getClusterCount()) + unsignedInt);
            }
            return (float) (f / Math.sqrt(f2 * this.bMagnitude));
        }
    }

    /* loaded from: input_file:io/github/jbellis/jvector/pq/PQDecoder$DotProductDecoder.class */
    static class DotProductDecoder extends CachingDecoder {
        public DotProductDecoder(PQVectors pQVectors, VectorFloat<?> vectorFloat) {
            super(pQVectors, vectorFloat, VectorSimilarityFunction.DOT_PRODUCT);
        }

        @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
        public float similarityTo(int i) {
            return (1.0f + decodedSimilarity(this.cv.get(i))) / 2.0f;
        }
    }

    /* loaded from: input_file:io/github/jbellis/jvector/pq/PQDecoder$EuclideanDecoder.class */
    static class EuclideanDecoder extends CachingDecoder {
        public EuclideanDecoder(PQVectors pQVectors, VectorFloat<?> vectorFloat) {
            super(pQVectors, vectorFloat, VectorSimilarityFunction.EUCLIDEAN);
        }

        @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
        public float similarityTo(int i) {
            return 1.0f / (1.0f + decodedSimilarity(this.cv.get(i)));
        }
    }

    protected PQDecoder(PQVectors pQVectors) {
        this.cv = pQVectors;
    }
}
