package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.pq.PQDecoder;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

/* loaded from: input_file:io/github/jbellis/jvector/pq/PQVectors.class */
public class PQVectors implements CompressedVectors {
    private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
    final ProductQuantization pq;
    private final List<ByteSequence<?>> compressedVectors;

    public PQVectors(ProductQuantization productQuantization, List<ByteSequence<?>> list) {
        this.pq = productQuantization;
        this.compressedVectors = list;
    }

    public PQVectors(ProductQuantization productQuantization, ByteSequence<?>[] byteSequenceArr) {
        this(productQuantization, (List<ByteSequence<?>>) List.of((Object[]) byteSequenceArr));
    }

    @Override // io.github.jbellis.jvector.pq.CompressedVectors
    public int count() {
        return this.compressedVectors.size();
    }

    @Override // io.github.jbellis.jvector.pq.CompressedVectors
    public void write(DataOutput dataOutput) throws IOException {
        this.pq.write(dataOutput);
        dataOutput.writeInt(this.compressedVectors.size());
        dataOutput.writeInt(this.pq.getSubspaceCount());
        Iterator<ByteSequence<?>> it = this.compressedVectors.iterator();
        while (it.hasNext()) {
            vectorTypeSupport.writeByteSequence(dataOutput, it.next());
        }
    }

    public static PQVectors load(RandomAccessReader randomAccessReader) throws IOException {
        ProductQuantization load = ProductQuantization.load(randomAccessReader);
        int readInt = randomAccessReader.readInt();
        if (readInt < 0) {
            throw new IOException("Invalid compressed vector count " + readInt);
        }
        ArrayList arrayList = new ArrayList(readInt);
        int readInt2 = randomAccessReader.readInt();
        if (readInt2 < 0) {
            throw new IOException("Invalid compressed vector dimension " + readInt2);
        }
        for (int i = 0; i < readInt; i++) {
            arrayList.add(vectorTypeSupport.readByteSequence(randomAccessReader, readInt2));
        }
        return new PQVectors(load, arrayList);
    }

    public static PQVectors load(RandomAccessReader randomAccessReader, long j) throws IOException {
        randomAccessReader.seek(j);
        return load(randomAccessReader);
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        PQVectors pQVectors = (PQVectors) obj;
        if (Objects.equals(this.pq, pQVectors.pq)) {
            return Objects.equals(this.compressedVectors, pQVectors.compressedVectors);
        }
        return false;
    }

    public int hashCode() {
        return Objects.hash(this.pq, this.compressedVectors);
    }

    @Override // io.github.jbellis.jvector.pq.CompressedVectors
    public ScoreFunction.ApproximateScoreFunction precomputedScoreFunctionFor(VectorFloat<?> vectorFloat, VectorSimilarityFunction vectorSimilarityFunction) {
        switch (vectorSimilarityFunction) {
            case DOT_PRODUCT:
                return new PQDecoder.DotProductDecoder(this, vectorFloat);
            case EUCLIDEAN:
                return new PQDecoder.EuclideanDecoder(this, vectorFloat);
            case COSINE:
                return new PQDecoder.CosineDecoder(this, vectorFloat);
            default:
                throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf(vectorSimilarityFunction));
        }
    }

    @Override // io.github.jbellis.jvector.pq.CompressedVectors
    public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> vectorFloat, VectorSimilarityFunction vectorSimilarityFunction) {
        switch (vectorSimilarityFunction) {
            case DOT_PRODUCT:
                return i -> {
                    ByteSequence<?> byteSequence = get(i);
                    float f = 0.0f;
                    for (int i = 0; i < this.pq.getSubspaceCount(); i++) {
                        int unsignedInt = Byte.toUnsignedInt(byteSequence.get(i));
                        int i2 = this.pq.subvectorSizesAndOffsets[i][0];
                        f += VectorUtil.dotProduct(this.pq.codebooks[i], unsignedInt * i2, vectorFloat, this.pq.subvectorSizesAndOffsets[i][1], i2);
                    }
                    return (1.0f + f) / 2.0f;
                };
            case EUCLIDEAN:
                return i2 -> {
                    ByteSequence<?> byteSequence = get(i2);
                    float f = 0.0f;
                    for (int i2 = 0; i2 < this.pq.getSubspaceCount(); i2++) {
                        int unsignedInt = Byte.toUnsignedInt(byteSequence.get(i2));
                        int i3 = this.pq.subvectorSizesAndOffsets[i2][0];
                        f += VectorUtil.squareL2Distance(this.pq.codebooks[i2], unsignedInt * i3, vectorFloat, this.pq.subvectorSizesAndOffsets[i2][1], i3);
                    }
                    return 1.0f / (1.0f + f);
                };
            case COSINE:
                float dotProduct = VectorUtil.dotProduct(vectorFloat, vectorFloat);
                return i3 -> {
                    ByteSequence<?> byteSequence = get(i3);
                    float f = 0.0f;
                    float f2 = 0.0f;
                    for (int i3 = 0; i3 < this.pq.getSubspaceCount(); i3++) {
                        int unsignedInt = Byte.toUnsignedInt(byteSequence.get(i3));
                        int i4 = this.pq.subvectorSizesAndOffsets[i3][0];
                        int i5 = unsignedInt * i4;
                        f += VectorUtil.dotProduct(this.pq.codebooks[i3], i5, vectorFloat, this.pq.subvectorSizesAndOffsets[i3][1], i4);
                        f2 += VectorUtil.dotProduct(this.pq.codebooks[i3], i5, this.pq.codebooks[i3], i5, i4);
                    }
                    return (1.0f + (f / ((float) Math.sqrt(dotProduct * f2)))) / 2.0f;
                };
            default:
                throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf(vectorSimilarityFunction));
        }
    }

    public ByteSequence<?> get(int i) {
        return this.compressedVectors.get(i);
    }

    public ProductQuantization getProductQuantization() {
        return this.pq;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public VectorFloat<?> reusablePartialSums() {
        return this.pq.reusablePartialSums();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public AtomicReference<VectorFloat<?>> partialMagnitudes() {
        return this.pq.partialMagnitudes();
    }

    @Override // io.github.jbellis.jvector.pq.CompressedVectors
    public int getOriginalSize() {
        return this.pq.originalDimension * 4;
    }

    @Override // io.github.jbellis.jvector.pq.CompressedVectors
    public int getCompressedSize() {
        return this.pq.compressedVectorSize();
    }

    @Override // io.github.jbellis.jvector.pq.CompressedVectors
    public ProductQuantization getCompressor() {
        return this.pq;
    }

    @Override // io.github.jbellis.jvector.util.Accountable
    public long ramBytesUsed() {
        long memorySize = this.pq.memorySize();
        return this.compressedVectors.isEmpty() ? memorySize : memorySize + (RamUsageEstimator.sizeOf(this.compressedVectors.get(0)) * this.compressedVectors.size());
    }

    public String toString() {
        return "PQVectors{pq=" + String.valueOf(this.pq) + ", count=" + this.compressedVectors.size() + "}";
    }
}
