package io.github.jbellis.jvector.quantization;

import io.github.jbellis.jvector.annotations.VisibleForTesting;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.quantization.PQDecoder;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Objects;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.stream.IntStream;

/* loaded from: input_file:io/github/jbellis/jvector/quantization/PQVectors.class */
public abstract class PQVectors implements CompressedVectors {
    private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
    static final int MAX_CHUNK_SIZE = 2147483631;
    final ProductQuantization pq;
    protected ByteSequence<?>[] compressedDataChunks;
    protected int vectorsPerChunk;

    /* renamed from: io.github.jbellis.jvector.quantization.PQVectors$1, reason: invalid class name */
    /* loaded from: input_file:io/github/jbellis/jvector/quantization/PQVectors$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction = new int[VectorSimilarityFunction.values().length];

        static {
            try {
                $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[VectorSimilarityFunction.DOT_PRODUCT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[VectorSimilarityFunction.EUCLIDEAN.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[VectorSimilarityFunction.COSINE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public PQVectors(ProductQuantization productQuantization) {
        this.pq = productQuantization;
    }

    public static ImmutablePQVectors load(RandomAccessReader randomAccessReader) throws IOException {
        ProductQuantization load = ProductQuantization.load(randomAccessReader);
        int readInt = randomAccessReader.readInt();
        int readInt2 = randomAccessReader.readInt();
        int[] calculateChunkParameters = calculateChunkParameters(readInt, readInt2);
        int i = calculateChunkParameters[0];
        int i2 = calculateChunkParameters[1];
        int i3 = calculateChunkParameters[2];
        int i4 = calculateChunkParameters[3];
        ByteSequence[] byteSequenceArr = new ByteSequence[i2];
        int i5 = i * readInt2;
        for (int i6 = 0; i6 < i3; i6++) {
            byteSequenceArr[i6] = vectorTypeSupport.readByteSequence(randomAccessReader, i5);
        }
        if (i2 > i3) {
            byteSequenceArr[i3] = vectorTypeSupport.readByteSequence(randomAccessReader, i4 * readInt2);
        }
        return new ImmutablePQVectors(load, byteSequenceArr, readInt, i);
    }

    @VisibleForTesting
    static int[] calculateChunkParameters(int i, int i2) {
        if (i < 0) {
            throw new IllegalArgumentException("Invalid vector count " + i);
        }
        if (i2 < 0) {
            throw new IllegalArgumentException("Invalid compressed dimension " + i2);
        }
        int i3 = ((long) i) * ((long) i2) <= 2147483631 ? i : MAX_CHUNK_SIZE / i2;
        if (i3 == 0) {
            throw new IllegalArgumentException("Compressed dimension " + i2 + " too large for chunking");
        }
        int i4 = i / i3;
        return new int[]{i3, i % i3 == 0 ? i4 : i4 + 1, i4, i % i3};
    }

    public static PQVectors load(RandomAccessReader randomAccessReader, long j) throws IOException {
        randomAccessReader.seek(j);
        return load(randomAccessReader);
    }

    public static ImmutablePQVectors encodeAndBuild(ProductQuantization productQuantization, int i, RandomAccessVectorValues randomAccessVectorValues, ForkJoinPool forkJoinPool) {
        int compressedVectorSize = productQuantization.compressedVectorSize();
        int i2 = ((long) i) * ((long) compressedVectorSize) <= 2147483631 ? i : MAX_CHUNK_SIZE / compressedVectorSize;
        int i3 = i / i2;
        ByteSequence[] byteSequenceArr = new ByteSequence[i3];
        int i4 = i2 * compressedVectorSize;
        for (int i5 = 0; i5 < i3 - 1; i5++) {
            byteSequenceArr[i5] = vectorTypeSupport.createByteSequence(i4);
        }
        byteSequenceArr[i3 - 1] = vectorTypeSupport.createByteSequence((i - (i2 * (i3 - 1))) * compressedVectorSize);
        Supplier<RandomAccessVectorValues> threadLocalSupplier = randomAccessVectorValues.threadLocalSupplier();
        forkJoinPool.submit(() -> {
            IntStream.range(0, randomAccessVectorValues.size()).parallel().forEach(i6 -> {
                RandomAccessVectorValues randomAccessVectorValues2 = (RandomAccessVectorValues) threadLocalSupplier.get();
                ByteSequence<?> byteSequence = get(byteSequenceArr, i6, i2, productQuantization.getSubspaceCount());
                VectorFloat<?> vector = randomAccessVectorValues2.getVector(i6);
                if (vector != null) {
                    productQuantization.encodeTo2(vector, byteSequence);
                } else {
                    byteSequence.zero();
                }
            });
        }).join();
        return new ImmutablePQVectors(productQuantization, byteSequenceArr, i, i2);
    }

    @Override // io.github.jbellis.jvector.quantization.CompressedVectors
    public void write(DataOutput dataOutput, int i) throws IOException {
        this.pq.write(dataOutput, i);
        dataOutput.writeInt(count());
        dataOutput.writeInt(this.pq.getSubspaceCount());
        for (int i2 = 0; i2 < validChunkCount(); i2++) {
            vectorTypeSupport.writeByteSequence(dataOutput, this.compressedDataChunks[i2]);
        }
    }

    protected abstract int validChunkCount();

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        PQVectors pQVectors = (PQVectors) obj;
        if (!Objects.equals(this.pq, pQVectors.pq) || count() != pQVectors.count()) {
            return false;
        }
        for (int i = 0; i < count(); i++) {
            if (!get(i).equals(pQVectors.get(i))) {
                return false;
            }
        }
        return true;
    }

    public int hashCode() {
        int hashCode = (31 * ((31 * 1) + this.pq.hashCode())) + count();
        for (int i = 0; i < count(); i++) {
            hashCode = (31 * hashCode) + get(i).hashCode();
        }
        return hashCode;
    }

    @Override // io.github.jbellis.jvector.quantization.CompressedVectors
    public ScoreFunction.ApproximateScoreFunction precomputedScoreFunctionFor(VectorFloat<?> vectorFloat, VectorSimilarityFunction vectorSimilarityFunction) {
        switch (AnonymousClass1.$SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[vectorSimilarityFunction.ordinal()]) {
            case RamUsageEstimator.MAX_DEPTH /* 1 */:
                return new PQDecoder.DotProductDecoder(this, vectorFloat);
            case 2:
                return new PQDecoder.EuclideanDecoder(this, vectorFloat);
            case OnDiskGraphIndex.CURRENT_VERSION /* 3 */:
                return new PQDecoder.CosineDecoder(this, vectorFloat);
            default:
                throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf(vectorSimilarityFunction));
        }
    }

    @Override // io.github.jbellis.jvector.quantization.CompressedVectors
    public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> vectorFloat, VectorSimilarityFunction vectorSimilarityFunction) {
        VectorFloat<?> sub = this.pq.globalCentroid == null ? vectorFloat : VectorUtil.sub(vectorFloat, this.pq.globalCentroid);
        switch (AnonymousClass1.$SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[vectorSimilarityFunction.ordinal()]) {
            case RamUsageEstimator.MAX_DEPTH /* 1 */:
                return i -> {
                    ByteSequence<?> byteSequence = get(i);
                    float f = 0.0f;
                    for (int i = 0; i < this.pq.getSubspaceCount(); i++) {
                        int unsignedInt = Byte.toUnsignedInt(byteSequence.get(i));
                        int i2 = this.pq.subvectorSizesAndOffsets[i][0];
                        f += VectorUtil.dotProduct(this.pq.codebooks[i], unsignedInt * i2, sub, this.pq.subvectorSizesAndOffsets[i][1], i2);
                    }
                    return (1.0f + f) / 2.0f;
                };
            case 2:
                return i2 -> {
                    ByteSequence<?> byteSequence = get(i2);
                    float f = 0.0f;
                    for (int i2 = 0; i2 < this.pq.getSubspaceCount(); i2++) {
                        int unsignedInt = Byte.toUnsignedInt(byteSequence.get(i2));
                        int i3 = this.pq.subvectorSizesAndOffsets[i2][0];
                        f += VectorUtil.squareL2Distance(this.pq.codebooks[i2], unsignedInt * i3, sub, this.pq.subvectorSizesAndOffsets[i2][1], i3);
                    }
                    return 1.0f / (1.0f + f);
                };
            case OnDiskGraphIndex.CURRENT_VERSION /* 3 */:
                float dotProduct = VectorUtil.dotProduct(sub, sub);
                return i3 -> {
                    ByteSequence<?> byteSequence = get(i3);
                    float f = 0.0f;
                    float f2 = 0.0f;
                    for (int i3 = 0; i3 < this.pq.getSubspaceCount(); i3++) {
                        int unsignedInt = Byte.toUnsignedInt(byteSequence.get(i3));
                        int i4 = this.pq.subvectorSizesAndOffsets[i3][0];
                        int i5 = unsignedInt * i4;
                        f += VectorUtil.dotProduct(this.pq.codebooks[i3], i5, sub, this.pq.subvectorSizesAndOffsets[i3][1], i4);
                        f2 += VectorUtil.dotProduct(this.pq.codebooks[i3], i5, this.pq.codebooks[i3], i5, i4);
                    }
                    return (1.0f + (f / ((float) Math.sqrt(dotProduct * f2)))) / 2.0f;
                };
            default:
                throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf(vectorSimilarityFunction));
        }
    }

    public ByteSequence<?> get(int i) {
        if (i < 0 || i >= count()) {
            throw new IndexOutOfBoundsException("Ordinal " + i + " out of bounds for vector count " + count());
        }
        return get(this.compressedDataChunks, i, this.vectorsPerChunk, this.pq.getSubspaceCount());
    }

    static ByteSequence<?> get(ByteSequence<?>[] byteSequenceArr, int i, int i2, int i3) {
        return byteSequenceArr[i / i2].slice((i % i2) * i3, i3);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public VectorFloat<?> reusablePartialSums() {
        return this.pq.reusablePartialSums();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public AtomicReference<VectorFloat<?>> partialSquaredMagnitudes() {
        return this.pq.partialSquaredMagnitudes();
    }

    @Override // io.github.jbellis.jvector.quantization.CompressedVectors
    public int getOriginalSize() {
        return this.pq.originalDimension * 4;
    }

    @Override // io.github.jbellis.jvector.quantization.CompressedVectors
    public int getCompressedSize() {
        return this.pq.compressedVectorSize();
    }

    @Override // io.github.jbellis.jvector.quantization.CompressedVectors
    public ProductQuantization getCompressor() {
        return this.pq;
    }

    @Override // io.github.jbellis.jvector.util.Accountable
    public long ramBytesUsed() {
        int i = RamUsageEstimator.NUM_BYTES_OBJECT_REF;
        int i2 = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER;
        int i3 = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;
        long ramBytesUsed = this.pq.ramBytesUsed();
        long validChunkCount = i2 + i3 + (validChunkCount() * i);
        long j = 0;
        for (int i4 = 0; i4 < validChunkCount(); i4++) {
            j += this.compressedDataChunks[i4].ramBytesUsed();
        }
        return ramBytesUsed + validChunkCount + j;
    }

    public String toString() {
        return "PQVectors{pq=" + String.valueOf(this.pq) + ", count=" + count() + "}";
    }
}
