package org.apache.lucene.util.hnsw;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.GrowableBitSet;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.NamedThreadFactory;
import org.apache.lucene.util.ThreadInterruptedException;
import org.apache.lucene.util.hnsw.ConcurrentOnHeapHnswGraph;

/* loaded from: input_file:org/apache/lucene/util/hnsw/ConcurrentHnswGraphBuilder.class */
public class ConcurrentHnswGraphBuilder<T> {
    public static final int DEFAULT_MAX_CONN = 16;
    public static final int DEFAULT_BEAM_WIDTH = 100;
    public static final String HNSW_COMPONENT = "HNSW";
    private final int beamWidth;
    private final double ml;
    private final ExplicitThreadLocal<NeighborArray> scratchNeighbors;
    private final VectorSimilarityFunction similarityFunction;
    private final VectorEncoding vectorEncoding;
    private final RandomAccessVectorValues<T> vectors;
    private final ExplicitThreadLocal<HnswGraphSearcher<T>> graphSearcher;
    final ConcurrentOnHeapHnswGraph hnsw;
    private final ConcurrentSkipListSet<ConcurrentOnHeapHnswGraph.NodeAtLevel> insertionsInProgress = new ConcurrentSkipListSet<>();
    private InfoStream infoStream = InfoStream.getDefault();
    private final RandomAccessVectorValues<T> vectorsCopy;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/lucene/util/hnsw/ConcurrentHnswGraphBuilder$ExplicitThreadLocal.class */
    public static abstract class ExplicitThreadLocal<U> {
        private final ConcurrentHashMap<Long, U> map = new ConcurrentHashMap<>();

        private ExplicitThreadLocal() {
        }

        public U get() {
            return this.map.computeIfAbsent(Long.valueOf(Thread.currentThread().getId()), l -> {
                return initialValue();
            });
        }

        protected abstract U initialValue();

        public static <U> ExplicitThreadLocal<U> withInitial(final Supplier<U> supplier) {
            return new ExplicitThreadLocal<U>() { // from class: org.apache.lucene.util.hnsw.ConcurrentHnswGraphBuilder.ExplicitThreadLocal.1
                @Override // org.apache.lucene.util.hnsw.ConcurrentHnswGraphBuilder.ExplicitThreadLocal
                protected U initialValue() {
                    return (U) supplier.get();
                }
            };
        }
    }

    public static <T> ConcurrentHnswGraphBuilder<T> create(RandomAccessVectorValues<T> randomAccessVectorValues, VectorEncoding vectorEncoding, VectorSimilarityFunction vectorSimilarityFunction, int i, int i2) throws IOException {
        return new ConcurrentHnswGraphBuilder<>(randomAccessVectorValues, vectorEncoding, vectorSimilarityFunction, i, i2);
    }

    public ConcurrentHnswGraphBuilder(RandomAccessVectorValues<T> randomAccessVectorValues, VectorEncoding vectorEncoding, VectorSimilarityFunction vectorSimilarityFunction, int i, int i2) throws IOException {
        this.vectors = randomAccessVectorValues;
        this.vectorsCopy = randomAccessVectorValues.copy();
        this.vectorEncoding = (VectorEncoding) Objects.requireNonNull(vectorEncoding);
        this.similarityFunction = (VectorSimilarityFunction) Objects.requireNonNull(vectorSimilarityFunction);
        if (i <= 0) {
            throw new IllegalArgumentException("maxConn must be positive");
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.beamWidth = i2;
        this.ml = i == 1 ? 1.0d : 1.0d / Math.log(1.0d * i);
        this.hnsw = new ConcurrentOnHeapHnswGraph(i);
        this.graphSearcher = ExplicitThreadLocal.withInitial(() -> {
            return new HnswGraphSearcher(vectorEncoding, vectorSimilarityFunction, new NeighborQueue(i2, true), new GrowableBitSet(this.vectors.size()));
        });
        this.scratchNeighbors = ExplicitThreadLocal.withInitial(() -> {
            return new NeighborArray(Math.max(i2, i + 1), false);
        });
    }

    /* JADX WARN: Finally extract failed */
    public ConcurrentOnHeapHnswGraph build(RandomAccessVectorValues<T> randomAccessVectorValues, boolean z) throws IOException {
        int i;
        ExecutorService newSingleThreadExecutor;
        if (z) {
            i = Runtime.getRuntime().availableProcessors();
            newSingleThreadExecutor = Executors.newFixedThreadPool(i, new NamedThreadFactory("Concurrent HNSW builder"));
        } else {
            i = 1;
            newSingleThreadExecutor = Executors.newSingleThreadExecutor(new NamedThreadFactory("Concurrent HNSW builder"));
        }
        try {
            try {
                ConcurrentOnHeapHnswGraph concurrentOnHeapHnswGraph = buildAsync(randomAccessVectorValues, newSingleThreadExecutor, i).get();
                newSingleThreadExecutor.shutdown();
                return concurrentOnHeapHnswGraph;
            } catch (InterruptedException e) {
                throw new ThreadInterruptedException(e);
            } catch (ExecutionException e2) {
                throw new IOException(e2);
            }
        } catch (Throwable th) {
            newSingleThreadExecutor.shutdown();
            throw th;
        }
    }

    public ConcurrentOnHeapHnswGraph build(RandomAccessVectorValues<T> randomAccessVectorValues) throws IOException {
        return build(randomAccessVectorValues, true);
    }

    public Future<ConcurrentOnHeapHnswGraph> buildAsync(RandomAccessVectorValues<T> randomAccessVectorValues, ExecutorService executorService, int i) {
        if (randomAccessVectorValues == this.vectors) {
            throw new IllegalArgumentException("Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
        }
        if (this.infoStream.isEnabled("HNSW")) {
            this.infoStream.message("HNSW", "build graph from " + randomAccessVectorValues.size() + " vectors");
        }
        return addVectors(randomAccessVectorValues, executorService, i);
    }

    private Future<ConcurrentOnHeapHnswGraph> addVectors(RandomAccessVectorValues<T> randomAccessVectorValues, ExecutorService executorService, int i) {
        Semaphore semaphore = new Semaphore(i);
        ConcurrentHashMap.KeySetView newKeySet = ConcurrentHashMap.newKeySet();
        AtomicReference atomicReference = new AtomicReference(null);
        for (int i2 = 0; i2 < randomAccessVectorValues.size(); i2++) {
            int i3 = i2;
            try {
                semaphore.acquire();
                newKeySet.add(Integer.valueOf(i3));
                executorService.submit(() -> {
                    try {
                        try {
                            addGraphNode(i3, randomAccessVectorValues);
                            semaphore.release();
                            newKeySet.remove(Integer.valueOf(i3));
                        } catch (Throwable th) {
                            atomicReference.set(th);
                            semaphore.release();
                            newKeySet.remove(Integer.valueOf(i3));
                        }
                    } catch (Throwable th2) {
                        semaphore.release();
                        newKeySet.remove(Integer.valueOf(i3));
                        throw th2;
                    }
                });
            } catch (InterruptedException e) {
                throw new ThreadInterruptedException(e);
            }
        }
        return CompletableFuture.supplyAsync(() -> {
            while (!newKeySet.isEmpty()) {
                try {
                    TimeUnit.MILLISECONDS.sleep(10L);
                } catch (InterruptedException e2) {
                    throw new ThreadInterruptedException(e2);
                }
            }
            if (atomicReference.get() != null) {
                throw new CompletionException((Throwable) atomicReference.get());
            }
            this.hnsw.validateEntryNode();
            return this.hnsw;
        });
    }

    public long addGraphNode(int i, RandomAccessVectorValues<T> randomAccessVectorValues) throws IOException {
        return addGraphNode(i, (int) randomAccessVectorValues.vectorValue(i));
    }

    public void setInfoStream(InfoStream infoStream) {
        this.infoStream = infoStream;
    }

    public ConcurrentOnHeapHnswGraph getGraph() {
        return this.hnsw;
    }

    public long addGraphNode(int i, T t) throws IOException {
        int randomGraphLevel = getRandomGraphLevel(this.ml);
        for (int i2 = randomGraphLevel; i2 >= 0; i2--) {
            this.hnsw.addNode(i2, i);
        }
        HnswGraph view = this.hnsw.getView();
        ConcurrentOnHeapHnswGraph.NodeAtLevel nodeAtLevel = new ConcurrentOnHeapHnswGraph.NodeAtLevel(randomGraphLevel, i);
        this.insertionsInProgress.add(nodeAtLevel);
        ConcurrentSkipListSet<ConcurrentOnHeapHnswGraph.NodeAtLevel> clone = this.insertionsInProgress.clone();
        try {
            ConcurrentOnHeapHnswGraph.NodeAtLevel entry = this.hnsw.entry();
            int i3 = entry.node;
            int[] iArr = i3 >= 0 ? new int[]{i3} : new int[0];
            HnswGraphSearcher<T> hnswGraphSearcher = this.graphSearcher.get();
            NeighborQueue neighborQueue = new NeighborQueue(1, false);
            for (int i4 = entry.level; i4 > randomGraphLevel; i4--) {
                neighborQueue.clear();
                hnswGraphSearcher.searchLevel(neighborQueue, t, 1, i4, iArr, this.vectors, view, null, Integer.MAX_VALUE);
                iArr = new int[]{neighborQueue.pop()};
            }
            NeighborQueue neighborQueue2 = new NeighborQueue(this.beamWidth, false);
            for (int min = Math.min(randomGraphLevel, entry.level); min >= 0; min--) {
                neighborQueue2.clear();
                hnswGraphSearcher.searchLevel(neighborQueue2, t, this.beamWidth, min, iArr, this.vectors, view, null, Integer.MAX_VALUE);
                iArr = neighborQueue2.nodes();
                addForwardLinks(min, i, neighborQueue2);
                addForwardLinks(min, i, clone, nodeAtLevel);
                addBackLinks(min, i);
            }
            for (int i5 = entry.level + 1; i5 <= randomGraphLevel; i5++) {
                addForwardLinks(i5, i, clone, nodeAtLevel);
                addBackLinks(i5, i);
            }
            this.hnsw.markComplete(randomGraphLevel, i);
            this.insertionsInProgress.remove(nodeAtLevel);
            return this.hnsw.ramBytesUsedOneNode(randomGraphLevel);
        } catch (Throwable th) {
            this.insertionsInProgress.remove(nodeAtLevel);
            throw th;
        }
    }

    private void addForwardLinks(int i, int i2, NeighborQueue neighborQueue) throws IOException {
        this.hnsw.getNeighbors(i, i2).insertDiverse(popToScratch(neighborQueue), this::scoreBetween);
    }

    private void addForwardLinks(int i, int i2, Set<ConcurrentOnHeapHnswGraph.NodeAtLevel> set, ConcurrentOnHeapHnswGraph.NodeAtLevel nodeAtLevel) throws IOException {
        NeighborQueue neighborQueue = new NeighborQueue(set.size(), false);
        for (ConcurrentOnHeapHnswGraph.NodeAtLevel nodeAtLevel2 : set) {
            if (nodeAtLevel2.level >= i && nodeAtLevel2 != nodeAtLevel) {
                neighborQueue.add(nodeAtLevel2.node, scoreBetween(nodeAtLevel2.node, i2));
            }
        }
        this.hnsw.getNeighbors(i, i2).insertDiverse(popToScratch(neighborQueue), this::scoreBetween);
    }

    private void addBackLinks(int i, int i2) throws IOException {
        this.hnsw.getNeighbors(i, i2).backlink(num -> {
            return this.hnsw.getNeighbors(i, num.intValue());
        }, this::scoreBetween);
    }

    private float scoreBetween(int i, int i2) {
        try {
            return scoreBetween(this.vectorsCopy.vectorValue(i), this.vectorsCopy.vectorValue(i2));
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected float scoreBetween(T t, T t2) {
        switch (this.vectorEncoding) {
            case BYTE:
                return this.similarityFunction.compare((byte[]) t, (byte[]) t2);
            case FLOAT32:
                return this.similarityFunction.compare((float[]) t, (float[]) t2);
            default:
                throw new IllegalArgumentException();
        }
    }

    private NeighborArray popToScratch(NeighborQueue neighborQueue) {
        NeighborArray neighborArray = this.scratchNeighbors.get();
        neighborArray.clear();
        int size = neighborQueue.size();
        for (int i = 0; i < size; i++) {
            neighborArray.add(neighborQueue.pop(), neighborQueue.topScore());
        }
        return neighborArray;
    }

    int getRandomGraphLevel(double d) {
        double nextDouble;
        do {
            nextDouble = ThreadLocalRandom.current().nextDouble();
        } while (nextDouble == 0.0d);
        return (int) ((-Math.log(nextDouble)) * d);
    }
}
