package org.apache.lucene.util.hnsw;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.GrowableBitSet;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.ThreadInterruptedException;
import org.apache.lucene.util.hnsw.ConcurrentOnHeapHnswGraph;
import org.apache.lucene.util.hnsw.NeighborSimilarity;

/* loaded from: input_file:org/apache/lucene/util/hnsw/ConcurrentHnswGraphBuilder.class */
public class ConcurrentHnswGraphBuilder<T> {
    public static final int DEFAULT_MAX_CONN = 16;
    public static final int DEFAULT_BEAM_WIDTH = 100;
    public static final String HNSW_COMPONENT = "HNSW";
    private final int beamWidth;
    private final ExplicitThreadLocal<NeighborArray> naturalScratch;
    private final ExplicitThreadLocal<NeighborArray> concurrentScratch;
    private final VectorSimilarityFunction similarityFunction;
    private final float neighborOverflow;
    private final VectorEncoding vectorEncoding;
    private final ExplicitThreadLocal<RandomAccessVectorValues<T>> vectors;
    private final ExplicitThreadLocal<HnswGraphSearcher<T>> graphSearcher;
    private final ExplicitThreadLocal<NeighborQueue> beamCandidates;
    final ConcurrentOnHeapHnswGraph hnsw;
    private final ConcurrentSkipListSet<ConcurrentOnHeapHnswGraph.NodeAtLevel> insertionsInProgress;
    private InfoStream infoStream;
    private final ExplicitThreadLocal<RandomAccessVectorValues<T>> vectorsCopy;
    private final Supplier<Integer> levelSupplier;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/lucene/util/hnsw/ConcurrentHnswGraphBuilder$ExplicitThreadLocal.class */
    public static abstract class ExplicitThreadLocal<U> {
        private final ConcurrentHashMap<Long, U> map = new ConcurrentHashMap<>();
        private final Function<Long, U> initialSupplier = l -> {
            return initialValue();
        };

        private ExplicitThreadLocal() {
        }

        public U get() {
            return this.map.computeIfAbsent(Long.valueOf(Thread.currentThread().getId()), this.initialSupplier);
        }

        protected abstract U initialValue();

        public static <U> ExplicitThreadLocal<U> withInitial(final Supplier<U> supplier) {
            return new ExplicitThreadLocal<U>() { // from class: org.apache.lucene.util.hnsw.ConcurrentHnswGraphBuilder.ExplicitThreadLocal.1
                @Override // org.apache.lucene.util.hnsw.ConcurrentHnswGraphBuilder.ExplicitThreadLocal
                protected U initialValue() {
                    return (U) supplier.get();
                }
            };
        }
    }

    public ConcurrentHnswGraphBuilder(RandomAccessVectorValues<T> randomAccessVectorValues, VectorEncoding vectorEncoding, VectorSimilarityFunction vectorSimilarityFunction, int i, int i2, float f, float f2) {
        this.insertionsInProgress = new ConcurrentSkipListSet<>();
        this.infoStream = InfoStream.getDefault();
        this.vectors = createThreadSafeVectors(randomAccessVectorValues);
        this.vectorsCopy = createThreadSafeVectors(randomAccessVectorValues);
        this.vectorEncoding = (VectorEncoding) Objects.requireNonNull(vectorEncoding);
        this.similarityFunction = (VectorSimilarityFunction) Objects.requireNonNull(vectorSimilarityFunction);
        this.neighborOverflow = f;
        if (i <= 0) {
            throw new IllegalArgumentException("maxConn must be positive");
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.beamWidth = i2;
        NeighborSimilarity neighborSimilarity = new NeighborSimilarity() { // from class: org.apache.lucene.util.hnsw.ConcurrentHnswGraphBuilder.1
            @Override // org.apache.lucene.util.hnsw.NeighborSimilarity
            public float score(int i3, int i4) {
                try {
                    return ConcurrentHnswGraphBuilder.this.scoreBetween(ConcurrentHnswGraphBuilder.this.vectors.get().vectorValue(i3), ConcurrentHnswGraphBuilder.this.vectorsCopy.get().vectorValue(i4));
                } catch (IOException e) {
                    throw new UncheckedIOException(e);
                }
            }

            @Override // org.apache.lucene.util.hnsw.NeighborSimilarity
            public NeighborSimilarity.ScoreFunction scoreProvider(int i3) {
                try {
                    T vectorValue = ConcurrentHnswGraphBuilder.this.vectors.get().vectorValue(i3);
                    return i4 -> {
                        try {
                            return ConcurrentHnswGraphBuilder.this.scoreBetween(vectorValue, ConcurrentHnswGraphBuilder.this.vectorsCopy.get().vectorValue(i4));
                        } catch (IOException e) {
                            throw new UncheckedIOException(e);
                        }
                    };
                } catch (IOException e) {
                    throw new UncheckedIOException(e);
                }
            }
        };
        this.hnsw = new ConcurrentOnHeapHnswGraph(i, (num, num2) -> {
            return new ConcurrentNeighborSet(num.intValue(), num2.intValue(), neighborSimilarity, f2);
        });
        if (f2 > 1.0f) {
            this.levelSupplier = () -> {
                return 0;
            };
        } else {
            double log = i == 1 ? 1.0d : 1.0d / Math.log(1.0d * i);
            this.levelSupplier = () -> {
                double nextDouble;
                do {
                    nextDouble = ThreadLocalRandom.current().nextDouble();
                } while (nextDouble == 0.0d);
                return Integer.valueOf((int) ((-Math.log(nextDouble)) * log));
            };
        }
        this.graphSearcher = ExplicitThreadLocal.withInitial(() -> {
            return new HnswGraphSearcher(vectorEncoding, vectorSimilarityFunction, new NeighborQueue(i2, true), new GrowableBitSet(this.vectors.get().size()));
        });
        this.naturalScratch = ExplicitThreadLocal.withInitial(() -> {
            return new NeighborArray(Math.max(i2, i + 1), true);
        });
        this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> {
            return new NeighborArray(Math.max(i2, i + 1), true);
        });
        this.beamCandidates = ExplicitThreadLocal.withInitial(() -> {
            return new NeighborQueue(i2, false);
        });
    }

    public ConcurrentHnswGraphBuilder(RandomAccessVectorValues<T> randomAccessVectorValues, VectorEncoding vectorEncoding, VectorSimilarityFunction vectorSimilarityFunction, int i, int i2) {
        this(randomAccessVectorValues, vectorEncoding, vectorSimilarityFunction, i, i2, 1.0f, 1.0f);
    }

    public Future<ConcurrentOnHeapHnswGraph> buildAsync(RandomAccessVectorValues<T> randomAccessVectorValues, ExecutorService executorService, int i) {
        if (randomAccessVectorValues == this.vectors) {
            throw new IllegalArgumentException("Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
        }
        if (this.infoStream.isEnabled("HNSW")) {
            this.infoStream.message("HNSW", "build graph from " + randomAccessVectorValues.size() + " vectors");
        }
        return addVectors(randomAccessVectorValues, executorService, i);
    }

    private Future<ConcurrentOnHeapHnswGraph> addVectors(RandomAccessVectorValues<T> randomAccessVectorValues, ExecutorService executorService, int i) {
        Semaphore semaphore = new Semaphore(i);
        ConcurrentHashMap.KeySetView newKeySet = ConcurrentHashMap.newKeySet();
        AtomicReference atomicReference = new AtomicReference(null);
        ExplicitThreadLocal createThreadSafeVectors = createThreadSafeVectors(randomAccessVectorValues);
        return CompletableFuture.supplyAsync(() -> {
            for (int i2 = 0; i2 < randomAccessVectorValues.size() && atomicReference.get() == null; i2++) {
                int i3 = i2;
                try {
                    semaphore.acquire();
                    newKeySet.add(Integer.valueOf(i3));
                    executorService.submit(() -> {
                        try {
                            try {
                                addGraphNode(i3, (RandomAccessVectorValues) createThreadSafeVectors.get());
                                semaphore.release();
                                newKeySet.remove(Integer.valueOf(i3));
                            } catch (Throwable th) {
                                atomicReference.set(th);
                                semaphore.release();
                                newKeySet.remove(Integer.valueOf(i3));
                            }
                        } catch (Throwable th2) {
                            semaphore.release();
                            newKeySet.remove(Integer.valueOf(i3));
                            throw th2;
                        }
                    });
                } catch (InterruptedException e) {
                    throw new ThreadInterruptedException(e);
                }
            }
            while (!newKeySet.isEmpty()) {
                try {
                    TimeUnit.MILLISECONDS.sleep(10L);
                } catch (InterruptedException e2) {
                    throw new ThreadInterruptedException(e2);
                }
            }
            for (int i4 = 0; i4 < randomAccessVectorValues.size() && atomicReference.get() == null; i4++) {
                int i5 = i4;
                try {
                    semaphore.acquire();
                    newKeySet.add(Integer.valueOf(i5));
                    executorService.submit(() -> {
                        for (int i6 = 0; i6 < this.hnsw.numLevels(); i6++) {
                            try {
                                try {
                                    ConcurrentNeighborSet neighbors = this.hnsw.getNeighbors(i6, i5);
                                    if (neighbors != null) {
                                        neighbors.cleanup();
                                    }
                                } catch (Throwable th) {
                                    atomicReference.set(th);
                                    semaphore.release();
                                    newKeySet.remove(Integer.valueOf(i5));
                                    return;
                                }
                            } finally {
                                semaphore.release();
                                newKeySet.remove(Integer.valueOf(i5));
                            }
                        }
                    });
                } catch (InterruptedException e3) {
                    throw new ThreadInterruptedException(e3);
                }
            }
            while (!newKeySet.isEmpty()) {
                try {
                    TimeUnit.MILLISECONDS.sleep(10L);
                } catch (InterruptedException e4) {
                    throw new ThreadInterruptedException(e4);
                }
            }
            if (atomicReference.get() != null) {
                throw new CompletionException((Throwable) atomicReference.get());
            }
            this.hnsw.validateEntryNode();
            return this.hnsw;
        });
    }

    private static <T> ExplicitThreadLocal<RandomAccessVectorValues<T>> createThreadSafeVectors(RandomAccessVectorValues<T> randomAccessVectorValues) {
        return ExplicitThreadLocal.withInitial(() -> {
            try {
                return randomAccessVectorValues.copy();
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        });
    }

    public long addGraphNode(int i, RandomAccessVectorValues<T> randomAccessVectorValues) throws IOException {
        return addGraphNode(i, (int) randomAccessVectorValues.vectorValue(i));
    }

    public void setInfoStream(InfoStream infoStream) {
        this.infoStream = infoStream;
    }

    public ConcurrentOnHeapHnswGraph getGraph() {
        return this.hnsw;
    }

    public int insertsInProgress() {
        return this.insertionsInProgress.size();
    }

    public long addGraphNode(int i, T t) throws IOException {
        int intValue = this.levelSupplier.get().intValue();
        for (int i2 = intValue; i2 >= 0; i2--) {
            this.hnsw.addNode(i2, i);
        }
        HnswGraph view = this.hnsw.getView();
        ConcurrentOnHeapHnswGraph.NodeAtLevel nodeAtLevel = new ConcurrentOnHeapHnswGraph.NodeAtLevel(intValue, i);
        this.insertionsInProgress.add(nodeAtLevel);
        ConcurrentSkipListSet<ConcurrentOnHeapHnswGraph.NodeAtLevel> clone = this.insertionsInProgress.clone();
        try {
            ConcurrentOnHeapHnswGraph.NodeAtLevel entry = this.hnsw.entry();
            int i3 = entry.node;
            int[] iArr = i3 >= 0 ? new int[]{i3} : new int[0];
            HnswGraphSearcher<T> hnswGraphSearcher = this.graphSearcher.get();
            NeighborQueue neighborQueue = new NeighborQueue(1, false);
            for (int i4 = entry.level; i4 > intValue; i4--) {
                neighborQueue.clear();
                hnswGraphSearcher.searchLevel(neighborQueue, t, 1, i4, iArr, this.vectors.get(), view, null, Integer.MAX_VALUE);
                iArr = new int[]{neighborQueue.pop()};
            }
            NeighborQueue neighborQueue2 = this.beamCandidates.get();
            for (int min = Math.min(intValue, entry.level); min >= 0; min--) {
                neighborQueue2.clear();
                hnswGraphSearcher.searchLevel(neighborQueue2, t, this.beamWidth, min, iArr, this.vectors.get(), view, null, Integer.MAX_VALUE);
                iArr = neighborQueue2.nodes();
                updateNeighbors(i, min, getNaturalCandidates(neighborQueue2), getConcurrentCandidates(min, i, clone, nodeAtLevel));
            }
            for (int i5 = entry.level + 1; i5 <= intValue; i5++) {
                NeighborArray neighborArray = this.naturalScratch.get();
                neighborArray.clear();
                updateNeighbors(i, i5, neighborArray, getConcurrentCandidates(i5, i, clone, nodeAtLevel));
            }
            this.hnsw.markComplete(intValue, i);
            this.insertionsInProgress.remove(nodeAtLevel);
            return this.hnsw.ramBytesUsedOneNode(intValue);
        } catch (Throwable th) {
            this.insertionsInProgress.remove(nodeAtLevel);
            throw th;
        }
    }

    private void updateNeighbors(int i, int i2, NeighborArray neighborArray, NeighborArray neighborArray2) throws IOException {
        ConcurrentNeighborSet neighbors = this.hnsw.getNeighbors(i2, i);
        neighbors.insertDiverse(neighborArray, neighborArray2);
        neighbors.backlink(num -> {
            return this.hnsw.getNeighbors(i2, num.intValue());
        }, this.neighborOverflow);
    }

    private NeighborArray getNaturalCandidates(NeighborQueue neighborQueue) {
        NeighborArray neighborArray = this.naturalScratch.get();
        neighborArray.clear();
        int size = neighborQueue.size();
        for (int i = size - 1; i >= 0; i--) {
            float f = neighborQueue.topScore();
            neighborArray.node()[i] = neighborQueue.pop();
            neighborArray.score()[i] = f;
            neighborArray.size = size;
        }
        return neighborArray;
    }

    private NeighborArray getConcurrentCandidates(int i, int i2, Set<ConcurrentOnHeapHnswGraph.NodeAtLevel> set, ConcurrentOnHeapHnswGraph.NodeAtLevel nodeAtLevel) throws IOException {
        NeighborArray neighborArray = this.concurrentScratch.get();
        neighborArray.clear();
        for (ConcurrentOnHeapHnswGraph.NodeAtLevel nodeAtLevel2 : set) {
            if (nodeAtLevel2.level >= i && nodeAtLevel2 != nodeAtLevel) {
                neighborArray.insertSorted(nodeAtLevel2.node, scoreBetween(this.vectors.get().vectorValue(i2), this.vectorsCopy.get().vectorValue(nodeAtLevel2.node)));
            }
        }
        return neighborArray;
    }

    protected float scoreBetween(T t, T t2) {
        return scoreBetween(this.vectorEncoding, this.similarityFunction, t, t2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    static <T> float scoreBetween(VectorEncoding vectorEncoding, VectorSimilarityFunction vectorSimilarityFunction, T t, T t2) {
        switch (vectorEncoding) {
            case BYTE:
                return vectorSimilarityFunction.compare((byte[]) t, (byte[]) t2);
            case FLOAT32:
                return vectorSimilarityFunction.compare((float[]) t, (float[]) t2);
            default:
                throw new IllegalArgumentException();
        }
    }
}
