package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.graph.ConcurrentNeighborSet;
import io.github.jbellis.jvector.graph.GraphSearcher;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.vector.VectorEncoding;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.stream.IntStream;

/* loaded from: input_file:io/github/jbellis/jvector/graph/GraphIndexBuilder.class */
public class GraphIndexBuilder<T> {
    private final int beamWidth;
    private final ThreadLocal<NeighborArray> naturalScratch;
    private final ThreadLocal<NeighborArray> concurrentScratch;
    private final VectorSimilarityFunction similarityFunction;
    private final float neighborOverflow;
    private final VectorEncoding vectorEncoding;
    private final ThreadLocal<GraphSearcher<?>> graphSearcher;
    final OnHeapGraphIndex<T> graph;
    private final ConcurrentSkipListSet<Integer> insertionsInProgress = new ConcurrentSkipListSet<>();
    private final ThreadLocal<RandomAccessVectorValues<T>> vectors;
    private final ThreadLocal<RandomAccessVectorValues<T>> vectorsCopy;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.github.jbellis.jvector.graph.GraphIndexBuilder$1, reason: invalid class name */
    /* loaded from: input_file:io/github/jbellis/jvector/graph/GraphIndexBuilder$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$io$github$jbellis$jvector$vector$VectorEncoding = new int[VectorEncoding.values().length];

        static {
            try {
                $SwitchMap$io$github$jbellis$jvector$vector$VectorEncoding[VectorEncoding.BYTE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$io$github$jbellis$jvector$vector$VectorEncoding[VectorEncoding.FLOAT32.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/github/jbellis/jvector/graph/GraphIndexBuilder$ExcludingBits.class */
    public static class ExcludingBits implements Bits {
        private final int excluded;

        public ExcludingBits(int i) {
            this.excluded = i;
        }

        @Override // io.github.jbellis.jvector.util.Bits
        public boolean get(int i) {
            return i != this.excluded;
        }

        @Override // io.github.jbellis.jvector.util.Bits
        public int length() {
            throw new UnsupportedOperationException();
        }
    }

    public GraphIndexBuilder(RandomAccessVectorValues<T> randomAccessVectorValues, VectorEncoding vectorEncoding, VectorSimilarityFunction vectorSimilarityFunction, int i, int i2, float f, float f2) {
        Objects.requireNonNull(randomAccessVectorValues);
        this.vectors = ThreadLocal.withInitial(randomAccessVectorValues::copy2);
        Objects.requireNonNull(randomAccessVectorValues);
        this.vectorsCopy = ThreadLocal.withInitial(randomAccessVectorValues::copy2);
        this.vectorEncoding = (VectorEncoding) Objects.requireNonNull(vectorEncoding);
        this.similarityFunction = (VectorSimilarityFunction) Objects.requireNonNull(vectorSimilarityFunction);
        this.neighborOverflow = f;
        if (i <= 0) {
            throw new IllegalArgumentException("maxConn must be positive");
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.beamWidth = i2;
        NeighborSimilarity neighborSimilarity = i3 -> {
            T vectorValue = this.vectors.get().vectorValue(i3);
            return i3 -> {
                return scoreBetween(vectorValue, this.vectorsCopy.get().vectorValue(i3));
            };
        };
        this.graph = new OnHeapGraphIndex<>(i, (num, num2) -> {
            return new ConcurrentNeighborSet(num.intValue(), num2.intValue(), neighborSimilarity, f2);
        });
        this.graphSearcher = ThreadLocal.withInitial(() -> {
            return new GraphSearcher.Builder(this.graph.getView()).withConcurrentUpdates().build();
        });
        this.naturalScratch = ThreadLocal.withInitial(() -> {
            return new NeighborArray(Math.max(i2, i + 1));
        });
        this.concurrentScratch = ThreadLocal.withInitial(() -> {
            return new NeighborArray(Math.max(i2, i + 1));
        });
    }

    public OnHeapGraphIndex<T> build() {
        IntStream.range(0, this.vectors.get().size()).parallel().forEach(i -> {
            addGraphNode(i, (RandomAccessVectorValues) this.vectors.get());
        });
        complete();
        return this.graph;
    }

    public void complete() {
        this.graph.validateEntryNode();
        IntStream.range(0, this.graph.size()).parallel().forEach(i -> {
            this.graph.getNeighbors(i).cleanup();
        });
        this.graph.updateEntryNode(approximateMedioid());
        this.graph.validateEntryNode();
    }

    public long addGraphNode(int i, RandomAccessVectorValues<T> randomAccessVectorValues) {
        return addGraphNode(i, (int) randomAccessVectorValues.vectorValue(i));
    }

    public OnHeapGraphIndex<T> getGraph() {
        return this.graph;
    }

    public int insertsInProgress() {
        return this.insertionsInProgress.size();
    }

    public long addGraphNode(int i, T t) {
        this.graph.addNode(i);
        this.insertionsInProgress.add(Integer.valueOf(i));
        ConcurrentSkipListSet<Integer> clone = this.insertionsInProgress.clone();
        try {
            int entry = this.graph.entry();
            updateNeighbors(i, getNaturalCandidates(this.graphSearcher.get().searchInternal(i2 -> {
                return scoreBetween(this.vectorsCopy.get().vectorValue(i2), t);
            }, null, this.beamWidth, entry, new ExcludingBits(i)).getNodes()), getConcurrentCandidates(i, clone));
            this.graph.markComplete(i);
            this.insertionsInProgress.remove(Integer.valueOf(i));
            return this.graph.ramBytesUsedOneNode(0);
        } catch (Throwable th) {
            this.insertionsInProgress.remove(Integer.valueOf(i));
            throw th;
        }
    }

    private int approximateMedioid() {
        RandomAccessVectorValues<T> randomAccessVectorValues = this.vectors.get();
        RandomAccessVectorValues<T> randomAccessVectorValues2 = this.vectorsCopy.get();
        int entryNode = this.graph.getView().entryNode();
        while (true) {
            int i = entryNode;
            ConcurrentNeighborSet.ConcurrentNeighborArray current = this.graph.getNeighbors(i).getCurrent();
            int intValue = ((Integer) IntStream.concat(IntStream.of(i), Arrays.stream(current.node(), 0, current.size)).mapToObj(i2 -> {
                ConcurrentNeighborSet.ConcurrentNeighborArray current2 = this.graph.getNeighbors(i2).getCurrent();
                return new AbstractMap.SimpleEntry(Integer.valueOf(i2), Double.valueOf(Arrays.stream(current2.node(), 0, current2.size).mapToDouble(i2 -> {
                    return scoreBetween(randomAccessVectorValues.vectorValue(i2), randomAccessVectorValues2.vectorValue(i2));
                }).sum() / randomAccessVectorValues2.size()));
            }).min(Comparator.comparingDouble((v0) -> {
                return v0.getValue();
            })).map((v0) -> {
                return v0.getKey();
            }).get()).intValue();
            if (i == intValue) {
                return intValue;
            }
            entryNode = intValue;
        }
    }

    private void updateNeighbors(int i, NeighborArray neighborArray, NeighborArray neighborArray2) {
        ConcurrentNeighborSet neighbors = this.graph.getNeighbors(i);
        neighbors.insertDiverse(neighborArray, neighborArray2);
        OnHeapGraphIndex<T> onHeapGraphIndex = this.graph;
        Objects.requireNonNull(onHeapGraphIndex);
        neighbors.backlink((v1) -> {
            return r1.getNeighbors(v1);
        }, this.neighborOverflow);
    }

    private NeighborArray getNaturalCandidates(SearchResult.NodeScore[] nodeScoreArr) {
        NeighborArray neighborArray = this.naturalScratch.get();
        neighborArray.clear();
        for (SearchResult.NodeScore nodeScore : nodeScoreArr) {
            neighborArray.addInOrder(nodeScore.node, nodeScore.score);
        }
        return neighborArray;
    }

    private NeighborArray getConcurrentCandidates(int i, Set<Integer> set) {
        NeighborArray neighborArray = this.concurrentScratch.get();
        neighborArray.clear();
        for (Integer num : set) {
            if (num.intValue() != i) {
                neighborArray.insertSorted(num.intValue(), scoreBetween(this.vectors.get().vectorValue(i), this.vectorsCopy.get().vectorValue(num.intValue())));
            }
        }
        return neighborArray;
    }

    protected float scoreBetween(T t, T t2) {
        return scoreBetween(this.vectorEncoding, this.similarityFunction, t, t2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    static <T> float scoreBetween(VectorEncoding vectorEncoding, VectorSimilarityFunction vectorSimilarityFunction, T t, T t2) {
        switch (AnonymousClass1.$SwitchMap$io$github$jbellis$jvector$vector$VectorEncoding[vectorEncoding.ordinal()]) {
            case RamUsageEstimator.MAX_DEPTH /* 1 */:
                return vectorSimilarityFunction.compare((byte[]) t, (byte[]) t2);
            case 2:
                return vectorSimilarityFunction.compare((float[]) t, (float[]) t2);
            default:
                throw new IllegalArgumentException();
        }
    }
}
