package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.disk.LVQPackedVectors;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.concurrent.ForkJoinPool;
import java.util.stream.IntStream;

/* loaded from: input_file:io/github/jbellis/jvector/pq/LocallyAdaptiveVectorQuantization.class */
public class LocallyAdaptiveVectorQuantization implements VectorCompressor<QuantizedVector> {
    private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
    public final VectorFloat<?> globalMean;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.github.jbellis.jvector.pq.LocallyAdaptiveVectorQuantization$4, reason: invalid class name */
    /* loaded from: input_file:io/github/jbellis/jvector/pq/LocallyAdaptiveVectorQuantization$4.class */
    public static /* synthetic */ class AnonymousClass4 {
        static final /* synthetic */ int[] $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction = new int[VectorSimilarityFunction.values().length];

        static {
            try {
                $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[VectorSimilarityFunction.DOT_PRODUCT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[VectorSimilarityFunction.EUCLIDEAN.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[VectorSimilarityFunction.COSINE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:io/github/jbellis/jvector/pq/LocallyAdaptiveVectorQuantization$PackedVector.class */
    public static class PackedVector {
        public final ByteSequence<?> bytes;
        public final float bias;
        public final float scale;

        public PackedVector(ByteSequence<?> byteSequence, float f, float f2) {
            this.bytes = byteSequence;
            this.bias = f;
            this.scale = f2;
        }

        public int getQuantized(int i) {
            int i2 = i % 64;
            int i3 = i2 % 16;
            return Byte.toUnsignedInt(this.bytes.get(((i / 64) * 64) + (i3 * 4) + (i2 / 16)));
        }

        public float getDequantized(int i) {
            return (getQuantized(i) * this.scale) + this.bias;
        }

        public VectorFloat<?> decode() {
            VectorFloat<?> createFloatVector = LocallyAdaptiveVectorQuantization.vectorTypeSupport.createFloatVector(this.bytes.length());
            for (int i = 0; i < this.bytes.length(); i++) {
                createFloatVector.set(i, getDequantized(i));
            }
            return createFloatVector;
        }

        public PackedVector copy() {
            return new PackedVector(this.bytes.copy2(), this.bias, this.scale);
        }
    }

    /* loaded from: input_file:io/github/jbellis/jvector/pq/LocallyAdaptiveVectorQuantization$QuantizedVector.class */
    public static class QuantizedVector {
        private final ByteSequence<?> bytes;
        private final float bias;
        private final float scale;

        public QuantizedVector(ByteSequence<?> byteSequence, float f, float f2) {
            this.bytes = byteSequence;
            this.bias = f;
            this.scale = f2;
        }

        private void writeByteSafely(DataOutput dataOutput, ByteSequence<?> byteSequence, int i) throws IOException {
            if (i < byteSequence.length()) {
                dataOutput.writeByte(byteSequence.get(i));
            } else {
                dataOutput.writeByte(0);
            }
        }

        public void writePacked(DataOutput dataOutput) throws IOException {
            dataOutput.writeFloat(this.bias);
            dataOutput.writeFloat(this.scale);
            int length = this.bytes.length() / 64;
            int i = 0;
            while (i < length) {
                int i2 = i * 64;
                for (int i3 = i2; i3 < i2 + 16; i3++) {
                    dataOutput.writeByte(this.bytes.get(i3));
                    dataOutput.writeByte(this.bytes.get(i3 + 16));
                    dataOutput.writeByte(this.bytes.get(i3 + 32));
                    dataOutput.writeByte(this.bytes.get(i3 + 48));
                }
                i++;
            }
            int i4 = i * 64;
            if (i4 < this.bytes.length()) {
                int min = Math.min(i4 + 16, this.bytes.length());
                int i5 = i4;
                while (i5 < min) {
                    writeByteSafely(dataOutput, this.bytes, i5);
                    writeByteSafely(dataOutput, this.bytes, i5 + 16);
                    writeByteSafely(dataOutput, this.bytes, i5 + 32);
                    writeByteSafely(dataOutput, this.bytes, i5 + 48);
                    i5++;
                }
                while (i5 < i4 + 16) {
                    dataOutput.writeInt(0);
                    i5++;
                }
            }
        }
    }

    public LocallyAdaptiveVectorQuantization(VectorFloat<?> vectorFloat) {
        this.globalMean = vectorFloat;
    }

    public static LocallyAdaptiveVectorQuantization compute(RandomAccessVectorValues randomAccessVectorValues) {
        RandomAccessVectorValues randomAccessVectorValues2 = randomAccessVectorValues.threadLocalSupplier().get();
        ArrayList arrayList = new ArrayList(randomAccessVectorValues2.size());
        for (int i = 0; i < randomAccessVectorValues2.size(); i++) {
            arrayList.add(randomAccessVectorValues2.vectorValue(i));
        }
        return new LocallyAdaptiveVectorQuantization(KMeansPlusPlusClusterer.centroidOf(arrayList));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.github.jbellis.jvector.pq.VectorCompressor
    public QuantizedVector[] encodeAll(RandomAccessVectorValues randomAccessVectorValues, ForkJoinPool forkJoinPool) {
        return (QuantizedVector[]) forkJoinPool.submit(() -> {
            return (QuantizedVector[]) IntStream.range(0, randomAccessVectorValues.size()).parallel().mapToObj(i -> {
                return encode(randomAccessVectorValues.getVector(i));
            }).toArray(i2 -> {
                return new QuantizedVector[i2];
            });
        }).join();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.github.jbellis.jvector.pq.VectorCompressor
    public QuantizedVector encode(VectorFloat<?> vectorFloat) {
        VectorFloat<?> sub = VectorUtil.sub(vectorFloat, this.globalMean);
        float max = VectorUtil.max(sub);
        float min = VectorUtil.min(sub);
        ByteSequence<?> createByteSequence = vectorTypeSupport.createByteSequence(sub.length());
        for (int i = 0; i < sub.length(); i++) {
            createByteSequence.set(i, quantizeFloatToByte(sub.get(i), min, max));
        }
        return new QuantizedVector(createByteSequence, min, (max - min) / 255.0f);
    }

    private static byte quantizeFloatToByte(float f, float f2, float f3) {
        int round = Math.round((f - f2) / ((f3 - f2) / 255.0f));
        if (round < 0) {
            round = 0;
        }
        if (round > 255) {
            round = 255;
        }
        return (byte) round;
    }

    @Override // io.github.jbellis.jvector.pq.VectorCompressor
    public void write(DataOutput dataOutput, int i) throws IOException {
        if (i < 3) {
            throw new IllegalArgumentException("LVQ requires version 3 or greater");
        }
        dataOutput.writeInt(this.globalMean.length());
        vectorTypeSupport.writeFloatVector(dataOutput, this.globalMean);
    }

    @Override // io.github.jbellis.jvector.pq.VectorCompressor
    public int compressedVectorSize() {
        return (this.globalMean.length() % 64 == 0 ? this.globalMean.length() : ((this.globalMean.length() / 64) + 1) * 64) + 8;
    }

    @Override // io.github.jbellis.jvector.pq.VectorCompressor
    public int compressorSize() {
        return 4 + (4 * this.globalMean.length());
    }

    private ScoreFunction.Reranker dotProductScoreFunctionFrom(final VectorFloat<?> vectorFloat, final LVQPackedVectors lVQPackedVectors) {
        final float sum = VectorUtil.sum(vectorFloat);
        final float dotProduct = VectorUtil.dotProduct(vectorFloat, this.globalMean);
        return new ScoreFunction.Reranker(this) { // from class: io.github.jbellis.jvector.pq.LocallyAdaptiveVectorQuantization.1
            final /* synthetic */ LocallyAdaptiveVectorQuantization this$0;

            {
                this.this$0 = this;
            }

            @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction.Reranker
            public VectorFloat<?> similarityTo(int[] iArr) {
                VectorFloat<?> createFloatVector = vts.createFloatVector(iArr.length);
                int length = iArr.length;
                for (int i = 0; i < length; i++) {
                    createFloatVector.set(i, (1.0f + (VectorUtil.lvqDotProduct(vectorFloat, lVQPackedVectors.getPackedVector(iArr[i]), sum) + dotProduct)) / 2.0f);
                }
                return createFloatVector;
            }

            @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
            public float similarityTo(int i) {
                return (1.0f + (VectorUtil.lvqDotProduct(vectorFloat, lVQPackedVectors.getPackedVector(i), sum) + dotProduct)) / 2.0f;
            }
        };
    }

    private ScoreFunction.Reranker euclideanScoreFunctionFrom(VectorFloat<?> vectorFloat, final LVQPackedVectors lVQPackedVectors) {
        final VectorFloat<?> sub = VectorUtil.sub(vectorFloat, this.globalMean);
        return new ScoreFunction.Reranker(this) { // from class: io.github.jbellis.jvector.pq.LocallyAdaptiveVectorQuantization.2
            final /* synthetic */ LocallyAdaptiveVectorQuantization this$0;

            {
                this.this$0 = this;
            }

            @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction.Reranker
            public VectorFloat<?> similarityTo(int[] iArr) {
                VectorFloat<?> createFloatVector = vts.createFloatVector(iArr.length);
                int length = iArr.length;
                for (int i = 0; i < length; i++) {
                    createFloatVector.set(i, 1.0f / (1.0f + VectorUtil.lvqSquareL2Distance(sub, lVQPackedVectors.getPackedVector(iArr[i]))));
                }
                return createFloatVector;
            }

            @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
            public float similarityTo(int i) {
                return 1.0f / (1.0f + VectorUtil.lvqSquareL2Distance(sub, lVQPackedVectors.getPackedVector(i)));
            }
        };
    }

    private ScoreFunction.Reranker cosineScoreFunctionFrom(final VectorFloat<?> vectorFloat, final LVQPackedVectors lVQPackedVectors) {
        return new ScoreFunction.Reranker(this) { // from class: io.github.jbellis.jvector.pq.LocallyAdaptiveVectorQuantization.3
            final /* synthetic */ LocallyAdaptiveVectorQuantization this$0;

            {
                this.this$0 = this;
            }

            @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction.Reranker
            public VectorFloat<?> similarityTo(int[] iArr) {
                VectorFloat<?> createFloatVector = vts.createFloatVector(iArr.length);
                int length = iArr.length;
                for (int i = 0; i < length; i++) {
                    createFloatVector.set(i, (1.0f + VectorUtil.lvqCosine(vectorFloat, lVQPackedVectors.getPackedVector(iArr[i]), this.this$0.globalMean)) / 2.0f);
                }
                return createFloatVector;
            }

            @Override // io.github.jbellis.jvector.graph.similarity.ScoreFunction
            public float similarityTo(int i) {
                return (1.0f + VectorUtil.lvqCosine(vectorFloat, lVQPackedVectors.getPackedVector(i), this.this$0.globalMean)) / 2.0f;
            }
        };
    }

    public ScoreFunction.Reranker scoreFunctionFrom(VectorFloat<?> vectorFloat, VectorSimilarityFunction vectorSimilarityFunction, LVQPackedVectors lVQPackedVectors) {
        switch (AnonymousClass4.$SwitchMap$io$github$jbellis$jvector$vector$VectorSimilarityFunction[vectorSimilarityFunction.ordinal()]) {
            case RamUsageEstimator.MAX_DEPTH /* 1 */:
                return dotProductScoreFunctionFrom(vectorFloat, lVQPackedVectors);
            case 2:
                return euclideanScoreFunctionFrom(vectorFloat, lVQPackedVectors);
            case OnDiskGraphIndex.CURRENT_VERSION /* 3 */:
                return cosineScoreFunctionFrom(vectorFloat, lVQPackedVectors);
            default:
                throw new IllegalArgumentException("Unsupported similarity function: " + String.valueOf(vectorSimilarityFunction));
        }
    }

    @Override // io.github.jbellis.jvector.pq.VectorCompressor
    public CompressedVectors createCompressedVectors(Object[] objArr) {
        throw new UnsupportedOperationException("LVQ does not produce a compressed vectors implementation");
    }

    public static LocallyAdaptiveVectorQuantization load(RandomAccessReader randomAccessReader) throws IOException {
        return new LocallyAdaptiveVectorQuantization(vectorTypeSupport.readFloatVector(randomAccessReader, randomAccessReader.readInt()));
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        return this.globalMean.equals(((LocallyAdaptiveVectorQuantization) obj).globalMean);
    }

    @Override // io.github.jbellis.jvector.pq.VectorCompressor
    public /* bridge */ /* synthetic */ QuantizedVector encode(VectorFloat vectorFloat) {
        return encode((VectorFloat<?>) vectorFloat);
    }
}
