package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.graph.disk.FusedADCNeighbors;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;

/* loaded from: input_file:io/github/jbellis/jvector/pq/QuickADCPQDecoder.class */
public abstract class QuickADCPQDecoder implements ScoreFunction.ApproximateScoreFunction {
    protected final ProductQuantization pq;
    protected final VectorFloat<?> query;
    protected final ScoreFunction.ExactScoreFunction esf;

    /* loaded from: input_file:io/github/jbellis/jvector/pq/QuickADCPQDecoder$CachingDecoder.class */
    protected static abstract class CachingDecoder extends QuickADCPQDecoder {
        protected final VectorFloat<?> partialSums;

        protected CachingDecoder(ProductQuantization productQuantization, VectorFloat<?> vectorFloat, VectorSimilarityFunction vectorSimilarityFunction, ScoreFunction.ExactScoreFunction exactScoreFunction) {
            super(productQuantization, vectorFloat, exactScoreFunction);
            this.partialSums = productQuantization.reusablePartialSums();
            VectorFloat<?> vectorFloat2 = productQuantization.globalCentroid;
            VectorFloat<?> sub = vectorFloat2 == null ? vectorFloat : VectorUtil.sub(vectorFloat, vectorFloat2);
            for (int i = 0; i < productQuantization.getSubspaceCount(); i++) {
                VectorUtil.calculatePartialSums(productQuantization.codebooks[i], i * productQuantization.getClusterCount(), productQuantization.subvectorSizesAndOffsets[i][0], productQuantization.getClusterCount(), sub, productQuantization.subvectorSizesAndOffsets[i][1], vectorSimilarityFunction, this.partialSums);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/github/jbellis/jvector/pq/QuickADCPQDecoder$DotProductDecoder.class */
    public static class DotProductDecoder extends CachingDecoder {
        private final VectorFloat<?> results;
        private final FusedADCNeighbors neighbors;

        public DotProductDecoder(FusedADCNeighbors fusedADCNeighbors, ProductQuantization productQuantization, VectorFloat<?> vectorFloat, VectorFloat<?> vectorFloat2, ScoreFunction.ExactScoreFunction exactScoreFunction) {
            super(productQuantization, vectorFloat, VectorSimilarityFunction.DOT_PRODUCT, exactScoreFunction);
            this.neighbors = fusedADCNeighbors;
            this.results = vectorFloat2;
        }

        @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
        public float similarityTo(int i) {
            return this.esf.similarityTo(i);
        }

        @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
        public VectorFloat<?> edgeLoadingSimilarityTo(int i) {
            ByteSequence<?> packedNeighbors = this.neighbors.getPackedNeighbors(i);
            this.results.zero();
            VectorUtil.bulkShuffleSimilarity(packedNeighbors, this.pq.compressedVectorSize(), this.partialSums, this.results, VectorSimilarityFunction.DOT_PRODUCT);
            return this.results;
        }

        @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
        public boolean supportsEdgeLoadingSimilarity() {
            return true;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/github/jbellis/jvector/pq/QuickADCPQDecoder$EuclideanDecoder.class */
    public static class EuclideanDecoder extends CachingDecoder {
        private final FusedADCNeighbors neighbors;
        private final VectorFloat<?> results;

        public EuclideanDecoder(FusedADCNeighbors fusedADCNeighbors, ProductQuantization productQuantization, VectorFloat<?> vectorFloat, VectorFloat<?> vectorFloat2, ScoreFunction.ExactScoreFunction exactScoreFunction) {
            super(productQuantization, vectorFloat, VectorSimilarityFunction.EUCLIDEAN, exactScoreFunction);
            this.neighbors = fusedADCNeighbors;
            this.results = vectorFloat2;
        }

        @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
        public float similarityTo(int i) {
            return this.esf.similarityTo(i);
        }

        @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
        public VectorFloat<?> edgeLoadingSimilarityTo(int i) {
            ByteSequence<?> packedNeighbors = this.neighbors.getPackedNeighbors(i);
            this.results.zero();
            VectorUtil.bulkShuffleSimilarity(packedNeighbors, this.pq.compressedVectorSize(), this.partialSums, this.results, VectorSimilarityFunction.EUCLIDEAN);
            return this.results;
        }

        @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
        public boolean supportsEdgeLoadingSimilarity() {
            return true;
        }
    }

    protected QuickADCPQDecoder(ProductQuantization productQuantization, VectorFloat<?> vectorFloat, ScoreFunction.ExactScoreFunction exactScoreFunction) {
        this.pq = productQuantization;
        this.query = vectorFloat;
        this.esf = exactScoreFunction;
    }

    public static QuickADCPQDecoder newDecoder(FusedADCNeighbors fusedADCNeighbors, ProductQuantization productQuantization, VectorFloat<?> vectorFloat, VectorFloat<?> vectorFloat2, VectorSimilarityFunction vectorSimilarityFunction, ScoreFunction.ExactScoreFunction exactScoreFunction) {
        switch (vectorSimilarityFunction) {
            case DOT_PRODUCT:
                return new DotProductDecoder(fusedADCNeighbors, productQuantization, vectorFloat, vectorFloat2, exactScoreFunction);
            case EUCLIDEAN:
                return new EuclideanDecoder(fusedADCNeighbors, productQuantization, vectorFloat, vectorFloat2, exactScoreFunction);
            default:
                throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf(vectorSimilarityFunction));
        }
    }
}
