package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.annotations.VisibleForTesting;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.ConcurrentNeighborMap;
import io.github.jbellis.jvector.graph.OnHeapGraphIndex;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
import io.github.jbellis.jvector.util.AtomicFixedBitSet;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.ExceptionUtils;
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
import io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Objects;
import java.util.OptionalDouble;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.agrona.collections.IntArrayList;
import org.agrona.collections.IntArrayQueue;

/* loaded from: input_file:io/github/jbellis/jvector/graph/GraphIndexBuilder.class */
public class GraphIndexBuilder implements Closeable {
    private final int beamWidth;
    private final ExplicitThreadLocal<NodeArray> naturalScratch;
    private final ExplicitThreadLocal<NodeArray> concurrentScratch;
    private final int dimension;
    private final float neighborOverflow;
    private final float alpha;

    @VisibleForTesting
    final OnHeapGraphIndex graph;
    private double averageShortEdges;
    private final ConcurrentSkipListSet<Integer> insertionsInProgress;
    private final BuildScoreProvider scoreProvider;
    private final ForkJoinPool simdExecutor;
    private final ForkJoinPool parallelExecutor;
    private final ExplicitThreadLocal<GraphSearcher> searchers;
    private final AtomicInteger updateEntryNodeIn;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/github/jbellis/jvector/graph/GraphIndexBuilder$ExcludingBits.class */
    public static class ExcludingBits implements Bits {
        private final int excluded;

        public ExcludingBits(int i) {
            this.excluded = i;
        }

        @Override // io.github.jbellis.jvector.util.Bits
        public boolean get(int i) {
            return i != this.excluded;
        }
    }

    public GraphIndexBuilder(RandomAccessVectorValues randomAccessVectorValues, VectorSimilarityFunction vectorSimilarityFunction, int i, int i2, float f, float f2) {
        this(BuildScoreProvider.randomAccessScoreProvider(randomAccessVectorValues, vectorSimilarityFunction), randomAccessVectorValues.dimension(), i, i2, f, f2);
    }

    public GraphIndexBuilder(BuildScoreProvider buildScoreProvider, int i, int i2, int i3, float f, float f2) {
        this(buildScoreProvider, i, i2, i3, f, f2, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool());
    }

    public GraphIndexBuilder(BuildScoreProvider buildScoreProvider, int i, int i2, int i3, float f, float f2, ForkJoinPool forkJoinPool, ForkJoinPool forkJoinPool2) {
        this.averageShortEdges = Double.NaN;
        this.insertionsInProgress = new ConcurrentSkipListSet<>();
        this.updateEntryNodeIn = new AtomicInteger(10000);
        this.scoreProvider = buildScoreProvider;
        this.dimension = i;
        this.neighborOverflow = f;
        this.alpha = f2;
        if (i2 <= 0) {
            throw new IllegalArgumentException("maxConn must be positive");
        }
        if (i3 <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.beamWidth = i3;
        this.simdExecutor = forkJoinPool;
        this.parallelExecutor = forkJoinPool2;
        this.graph = new OnHeapGraphIndex(i2, (int) (i2 * f), buildScoreProvider, f2);
        this.searchers = ExplicitThreadLocal.withInitial(() -> {
            return new GraphSearcher(this.graph);
        });
        this.naturalScratch = ExplicitThreadLocal.withInitial(() -> {
            return new NodeArray(Math.max(i3, i2 + 1));
        });
        this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> {
            return new NodeArray(Math.max(i3, i2 + 1));
        });
    }

    public OnHeapGraphIndex build(RandomAccessVectorValues randomAccessVectorValues) {
        Supplier<RandomAccessVectorValues> threadLocalSupplier = randomAccessVectorValues.threadLocalSupplier();
        int size = randomAccessVectorValues.size();
        this.simdExecutor.submit(() -> {
            IntStream.range(0, size).parallel().forEach(i -> {
                addGraphNode(i, ((RandomAccessVectorValues) threadLocalSupplier.get()).getVector(i));
            });
        }).join();
        cleanup();
        return this.graph;
    }

    public void cleanup() {
        if (this.graph.size() == 0) {
            return;
        }
        this.graph.validateEntryNode();
        removeDeletedNodes();
        if (this.graph.size() == 0) {
            return;
        }
        this.averageShortEdges = ((OptionalDouble) this.parallelExecutor.submit(() -> {
            IntStream parallel = IntStream.range(0, this.graph.getIdUpperBound()).parallel();
            ConcurrentNeighborMap concurrentNeighborMap = this.graph.nodes;
            Objects.requireNonNull(concurrentNeighborMap);
            return parallel.mapToDouble(concurrentNeighborMap::enforceDegree).filter(Double::isFinite).average();
        }).join()).orElse(Double.NaN);
        updateEntryPoint();
        reconnectOrphanedNodes();
    }

    private void reconnectOrphanedNodes() {
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        ConcurrentHashMap.KeySetView newKeySet = ConcurrentHashMap.newKeySet();
        newKeySet.add(Integer.valueOf(this.graph.entry()));
        for (int i = 0; i < 3; i++) {
            AtomicFixedBitSet atomicFixedBitSet = new AtomicFixedBitSet(this.graph.getIdUpperBound());
            atomicFixedBitSet.set(this.graph.entry());
            ConcurrentNeighborMap.Neighbors neighbors = this.graph.getNeighbors(this.graph.entry());
            this.parallelExecutor.submit(() -> {
                IntStream.range(0, neighbors.size()).parallel().forEach(i2 -> {
                    findConnected(atomicFixedBitSet, neighbors.getNode(i2));
                });
            }).join();
            AtomicInteger atomicInteger = new AtomicInteger();
            this.simdExecutor.submit(() -> {
                IntStream.range(0, this.graph.getIdUpperBound()).parallel().forEach(i2 -> {
                    if (atomicFixedBitSet.get(i2) || !this.graph.containsNode(i2)) {
                        return;
                    }
                    atomicInteger.incrementAndGet();
                    if (connectToClosestNeighbor(i2, this.graph.getNeighbors(i2), newKeySet)) {
                        return;
                    }
                    NodeArray nodeArray = (NodeArray) concurrentHashMap.get(Integer.valueOf(i2));
                    if (nodeArray == null || isSubset(nodeArray, newKeySet)) {
                        try {
                            GraphSearcher graphSearcher = this.searchers.get();
                            try {
                                Bits createExcludeBits = createExcludeBits(i2, newKeySet);
                                SearchResult searchInternal = graphSearcher.searchInternal(this.scoreProvider.searchProviderFor(i2), this.beamWidth, this.beamWidth, 0.0f, 0.0f, this.graph.entry(), createExcludeBits);
                                if (graphSearcher != null) {
                                    graphSearcher.close();
                                }
                                nodeArray = new NodeArray(searchInternal.getNodes().length);
                                toScratchCandidates(searchInternal.getNodes(), nodeArray);
                                concurrentHashMap.put(Integer.valueOf(i2), nodeArray);
                            } finally {
                            }
                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    }
                    connectToClosestNeighbor(i2, nodeArray, newKeySet);
                });
            }).join();
            if (atomicInteger.get() == 0) {
                return;
            }
        }
    }

    private boolean isSubset(NodeArray nodeArray, Set<Integer> set) {
        for (int i = 0; i < nodeArray.size(); i++) {
            if (!set.contains(Integer.valueOf(nodeArray.getNode(i)))) {
                return false;
            }
        }
        return true;
    }

    private boolean connectToClosestNeighbor(int i, NodeArray nodeArray, Set<Integer> set) {
        for (int i2 = 0; i2 < nodeArray.size(); i2++) {
            int node = nodeArray.getNode(i2);
            float score = nodeArray.getScore(i2);
            if (set.add(Integer.valueOf(node))) {
                this.graph.nodes.insertNotDiverse(node, i, score);
                return true;
            }
        }
        return false;
    }

    private void findConnected(AtomicFixedBitSet atomicFixedBitSet, int i) {
        IntArrayQueue intArrayQueue = new IntArrayQueue();
        intArrayQueue.add(Integer.valueOf(i));
        try {
            OnHeapGraphIndex.ConcurrentGraphIndexView view = this.graph.getView();
            while (!intArrayQueue.isEmpty()) {
                try {
                    int pollInt = intArrayQueue.pollInt();
                    if (!atomicFixedBitSet.getAndSet(pollInt)) {
                        NodesIterator neighborsIterator = view.getNeighborsIterator(pollInt);
                        while (neighborsIterator.hasNext()) {
                            intArrayQueue.addInt(neighborsIterator.nextInt());
                        }
                    }
                } finally {
                }
            }
            if (view != null) {
                view.close();
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public OnHeapGraphIndex getGraph() {
        return this.graph;
    }

    public int insertsInProgress() {
        return this.insertionsInProgress.size();
    }

    @Deprecated
    public long addGraphNode(int i, RandomAccessVectorValues randomAccessVectorValues) {
        return addGraphNode(i, randomAccessVectorValues.getVector(i));
    }

    public long addGraphNode(int i, VectorFloat<?> vectorFloat) {
        this.graph.addNode(i);
        this.insertionsInProgress.add(Integer.valueOf(i));
        ConcurrentSkipListSet<Integer> clone = this.insertionsInProgress.clone();
        try {
            try {
                GraphSearcher graphSearcher = this.searchers.get();
                try {
                    NodeArray nodeArray = this.naturalScratch.get();
                    NodeArray nodeArray2 = this.concurrentScratch.get();
                    int entry = this.graph.entry();
                    ExcludingBits excludingBits = new ExcludingBits(i);
                    SearchScoreProvider searchProviderFor = this.scoreProvider.searchProviderFor(vectorFloat);
                    updateNeighbors(i, toScratchCandidates(graphSearcher.searchInternal(searchProviderFor, this.beamWidth, this.beamWidth, 0.0f, 0.0f, entry, excludingBits).getNodes(), nodeArray), getConcurrentCandidates(i, clone, nodeArray2, searchProviderFor.scoreFunction()));
                    maybeUpdateEntryPoint(i);
                    maybeImproveOlderNode();
                    if (graphSearcher != null) {
                        graphSearcher.close();
                    }
                    return this.graph.ramBytesUsedOneNode();
                } catch (Throwable th) {
                    if (graphSearcher != null) {
                        try {
                            graphSearcher.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } finally {
            this.insertionsInProgress.remove(Integer.valueOf(i));
        }
    }

    private void maybeImproveOlderNode() {
        if (this.dimension > 3 || this.graph.size() <= 20000) {
            return;
        }
        for (int i = 0; i < 3; i++) {
            int nextInt = ThreadLocalRandom.current().nextInt(this.graph.size());
            if (this.graph.containsNode(nextInt) && !this.graph.getDeletedNodes().get(nextInt)) {
                improveConnections(nextInt);
                return;
            }
        }
    }

    private void maybeUpdateEntryPoint(int i) {
        this.graph.maybeSetInitialEntryNode(i);
        if (this.updateEntryNodeIn.decrementAndGet() == 0) {
            updateEntryPoint();
        }
    }

    private void updateEntryPoint() {
        int approximateMedioid = approximateMedioid();
        this.graph.updateEntryNode(approximateMedioid);
        if (approximateMedioid < 0) {
            this.updateEntryNodeIn.addAndGet(10000);
        } else {
            improveConnections(approximateMedioid);
            this.updateEntryNodeIn.addAndGet(this.graph.size());
        }
    }

    private void improveConnections(int i) {
        try {
            GraphSearcher graphSearcher = this.searchers.get();
            try {
                NodeArray nodeArray = this.naturalScratch.get();
                SearchResult searchInternal = graphSearcher.searchInternal(this.scoreProvider.searchProviderFor(i), this.beamWidth, this.beamWidth, 0.0f, 0.0f, this.graph.entry(), new ExcludingBits(i));
                if (graphSearcher != null) {
                    graphSearcher.close();
                }
                this.graph.nodes.backlink(this.graph.nodes.insertDiverse(i, toScratchCandidates(searchInternal.getNodes(), nodeArray)), i, 1.0f);
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void markNodeDeleted(int i) {
        this.graph.markDeleted(i);
    }

    public synchronized long removeDeletedNodes() {
        ThreadSafeGrowableBitSet copy = this.graph.getDeletedNodes().copy();
        int cardinality = copy.cardinality();
        if (cardinality == 0) {
            return 0L;
        }
        IntArrayList intArrayList = new IntArrayList();
        for (int i = 0; i < this.graph.getIdUpperBound(); i++) {
            if (this.graph.containsNode(i) && !copy.get(i)) {
                intArrayList.add(Integer.valueOf(i));
            }
        }
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        this.parallelExecutor.submit(() -> {
            IntStream.range(0, this.graph.getIdUpperBound()).parallel().forEach(i2 -> {
                ConcurrentNeighborMap.Neighbors neighbors = this.graph.getNeighbors(i2);
                if (neighbors == null || copy.get(i2)) {
                    return;
                }
                NodesIterator it = neighbors.iterator();
                while (it.hasNext()) {
                    int nextInt = it.nextInt();
                    if (copy.get(nextInt)) {
                        Set set = (Set) concurrentHashMap.computeIfAbsent(Integer.valueOf(i2), num -> {
                            return ConcurrentHashMap.newKeySet();
                        });
                        NodesIterator it2 = this.graph.getNeighbors(nextInt).iterator();
                        while (it2.hasNext()) {
                            int nextInt2 = it2.nextInt();
                            if (i2 != nextInt2 && !copy.get(nextInt2)) {
                                set.add(Integer.valueOf(nextInt2));
                            }
                        }
                    }
                }
            });
        }).join();
        this.simdExecutor.submit(() -> {
            ((Stream) concurrentHashMap.entrySet().stream().parallel()).forEach(entry -> {
                int intValue = ((Integer) entry.getKey()).intValue();
                ScoreFunction scoreFunction = this.scoreProvider.searchProviderFor(intValue).scoreFunction();
                NodeArray nodeArray = new NodeArray(this.graph.maxDegree);
                for (Integer num : (Set) entry.getValue()) {
                    nodeArray.insertSorted(num.intValue(), scoreFunction.similarityTo(num.intValue()));
                }
                if (nodeArray.size() == 0) {
                    ThreadLocalRandom current = ThreadLocalRandom.current();
                    for (int i2 = 0; i2 < 2 * this.graph.maxDegree(); i2++) {
                        int intValue2 = intArrayList.get(current.nextInt(intArrayList.size())).intValue();
                        if (intValue2 != intValue && !nodeArray.contains(intValue2)) {
                            nodeArray.insertSorted(intValue2, scoreFunction.similarityTo(intValue2));
                        }
                        if (nodeArray.size() == this.graph.maxDegree) {
                            break;
                        }
                    }
                }
                this.graph.nodes.replaceDeletedNeighbors(intValue, copy, nodeArray);
            });
        }).join();
        if (copy.get(this.graph.entry())) {
            updateEntryPoint();
        }
        if (!$assertionsDisabled && copy.cardinality() != cardinality) {
            throw new AssertionError("cardinality changed");
        }
        int nextSetBit = copy.nextSetBit(0);
        while (true) {
            int i2 = nextSetBit;
            if (i2 == Integer.MAX_VALUE) {
                return cardinality * this.graph.ramBytesUsedOneNode();
            }
            this.graph.removeNode(i2);
            nextSetBit = copy.nextSetBit(i2 + 1);
        }
    }

    private static Bits createExcludeBits(int i, Set<Integer> set) {
        return i2 -> {
            return (i2 == i || set.contains(Integer.valueOf(i2))) ? false : true;
        };
    }

    private int approximateMedioid() {
        if (this.graph.size() == 0) {
            return -1;
        }
        VectorFloat<?> approximateCentroid = this.scoreProvider.approximateCentroid();
        if (VectorUtil.dotProduct(approximateCentroid, approximateCentroid) < 1.0E-6d) {
            return randomLiveNode();
        }
        int entry = this.graph.entry();
        SearchScoreProvider searchProviderFor = this.scoreProvider.searchProviderFor(approximateCentroid);
        try {
            GraphSearcher graphSearcher = this.searchers.get();
            try {
                SearchResult searchInternal = graphSearcher.searchInternal(searchProviderFor, this.beamWidth, this.beamWidth, 0.0f, 0.0f, entry, Bits.ALL);
                if (searchInternal.getNodes().length == 0) {
                    if (graphSearcher != null) {
                        graphSearcher.close();
                    }
                    return -1;
                }
                int i = searchInternal.getNodes()[0].node;
                if (graphSearcher != null) {
                    graphSearcher.close();
                }
                return i;
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private void updateNeighbors(int i, NodeArray nodeArray, NodeArray nodeArray2) {
        this.graph.nodes.backlink(this.graph.nodes.insertDiverse(i, nodeArray2.size() == 0 ? nodeArray : nodeArray.size() == 0 ? nodeArray2 : NodeArray.merge(nodeArray, nodeArray2)), i, this.neighborOverflow);
    }

    private static NodeArray toScratchCandidates(SearchResult.NodeScore[] nodeScoreArr, NodeArray nodeArray) {
        nodeArray.clear();
        for (SearchResult.NodeScore nodeScore : nodeScoreArr) {
            nodeArray.addInOrder(nodeScore.node, nodeScore.score);
        }
        return nodeArray;
    }

    private NodeArray getConcurrentCandidates(int i, Set<Integer> set, NodeArray nodeArray, ScoreFunction scoreFunction) {
        nodeArray.clear();
        for (Integer num : set) {
            if (num.intValue() != i) {
                nodeArray.insertSorted(num.intValue(), scoreFunction.similarityTo(num.intValue()));
            }
        }
        return nodeArray;
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        try {
            this.searchers.close();
        } catch (Exception e) {
            ExceptionUtils.throwIoException(e);
        }
    }

    @VisibleForTesting
    int randomLiveNode() {
        ThreadLocalRandom current = ThreadLocalRandom.current();
        for (int i = 0; i < 3; i++) {
            int idUpperBound = this.graph.getIdUpperBound();
            if (idUpperBound == 0) {
                return -1;
            }
            int nextInt = current.nextInt(idUpperBound);
            if (this.graph.containsNode(nextInt) && !this.graph.getDeletedNodes().get(nextInt)) {
                return nextInt;
            }
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < this.graph.getIdUpperBound(); i2++) {
            if (this.graph.containsNode(i2) && !this.graph.getDeletedNodes().get(i2)) {
                arrayList.add(Integer.valueOf(i2));
            }
        }
        if (arrayList.isEmpty()) {
            return -1;
        }
        return ((Integer) arrayList.get(current.nextInt(arrayList.size()))).intValue();
    }

    @VisibleForTesting
    void validateAllNodesLive() {
        if (!$assertionsDisabled && this.graph.getDeletedNodes().cardinality() != 0) {
            throw new AssertionError();
        }
        for (int i = 0; i < this.graph.getIdUpperBound(); i++) {
            if (this.graph.containsNode(i)) {
                NodesIterator it = this.graph.getNeighbors(i).iterator();
                while (it.hasNext()) {
                    int nextInt = it.nextInt();
                    if (!$assertionsDisabled && !this.graph.containsNode(nextInt)) {
                        throw new AssertionError(String.format("Edge %d -> %d is invalid", Integer.valueOf(i), Integer.valueOf(nextInt)));
                    }
                }
            }
        }
    }

    public double getAverageShortEdges() {
        return this.averageShortEdges;
    }

    public void load(RandomAccessReader randomAccessReader) throws IOException {
        if (this.graph.size() != 0) {
            throw new IllegalStateException("Cannot load into a non-empty graph");
        }
        int readInt = randomAccessReader.readInt();
        int readInt2 = randomAccessReader.readInt();
        randomAccessReader.readInt();
        for (int i = 0; i < readInt; i++) {
            int readInt3 = randomAccessReader.readInt();
            int readInt4 = randomAccessReader.readInt();
            ScoreFunction.ExactScoreFunction exactScoreFunction = this.scoreProvider.searchProviderFor(readInt3).exactScoreFunction();
            NodeArray nodeArray = new NodeArray(readInt4);
            for (int i2 = 0; i2 < readInt4; i2++) {
                int readInt5 = randomAccessReader.readInt();
                nodeArray.addInOrder(readInt5, exactScoreFunction.similarityTo(readInt5));
            }
            this.graph.addNode(readInt3, nodeArray);
        }
        this.graph.updateEntryNode(readInt2);
    }

    static {
        $assertionsDisabled = !GraphIndexBuilder.class.desiredAssertionStatus();
    }
}
