package org.apache.cassandra.index.sai.memory;

import io.github.jbellis.jvector.util.Bits;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.NavigableSet;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.cassandra.db.Clustering;
import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.db.PartitionPosition;
import org.apache.cassandra.dht.AbstractBounds;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.index.sai.StorageAttachedIndex;
import org.apache.cassandra.index.sai.VectorQueryContext;
import org.apache.cassandra.index.sai.disk.format.IndexDescriptor;
import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata;
import org.apache.cassandra.index.sai.disk.v1.vector.OnHeapGraph;
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
import org.apache.cassandra.index.sai.iterators.KeyRangeListIterator;
import org.apache.cassandra.index.sai.plan.Expression;
import org.apache.cassandra.index.sai.utils.IndexIdentifier;
import org.apache.cassandra.index.sai.utils.PrimaryKey;
import org.apache.cassandra.index.sai.utils.PrimaryKeys;
import org.apache.cassandra.index.sai.utils.RangeUtil;
import org.apache.cassandra.tracing.Tracing;
import org.apache.cassandra.utils.Pair;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;

/* loaded from: input_file:org/apache/cassandra/index/sai/memory/VectorMemoryIndex.class */
public class VectorMemoryIndex extends MemoryIndex {
    private final OnHeapGraph<PrimaryKey> graph;
    private final LongAdder writeCount;
    private PrimaryKey minimumKey;
    private PrimaryKey maximumKey;
    private final NavigableSet<PrimaryKey> primaryKeys;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/apache/cassandra/index/sai/memory/VectorMemoryIndex$KeyFilteringBits.class */
    private class KeyFilteringBits implements Bits {
        private final List<PrimaryKey> results;

        public KeyFilteringBits(List<PrimaryKey> list) {
            this.results = list;
        }

        public boolean get(int i) {
            Collection<PrimaryKey> keysFromOrdinal = VectorMemoryIndex.this.graph.keysFromOrdinal(i);
            Stream<PrimaryKey> stream = this.results.stream();
            Objects.requireNonNull(keysFromOrdinal);
            return stream.anyMatch((v1) -> {
                return r1.contains(v1);
            });
        }

        public int length() {
            return this.results.size();
        }
    }

    /* loaded from: input_file:org/apache/cassandra/index/sai/memory/VectorMemoryIndex$KeyRangeFilteringBits.class */
    private class KeyRangeFilteringBits implements Bits {
        private final AbstractBounds<PartitionPosition> keyRange;

        @Nullable
        private final Bits bits;

        public KeyRangeFilteringBits(AbstractBounds<PartitionPosition> abstractBounds, @Nullable Bits bits) {
            this.keyRange = abstractBounds;
            this.bits = bits;
        }

        public boolean get(int i) {
            if (this.bits == null || this.bits.get(i)) {
                return VectorMemoryIndex.this.graph.keysFromOrdinal(i).stream().anyMatch(primaryKey -> {
                    return this.keyRange.contains(primaryKey.partitionKey());
                });
            }
            return false;
        }

        public int length() {
            return VectorMemoryIndex.this.graph.size();
        }
    }

    /* loaded from: input_file:org/apache/cassandra/index/sai/memory/VectorMemoryIndex$ReorderingRangeIterator.class */
    private class ReorderingRangeIterator extends KeyRangeIterator {
        private final PriorityQueue<PrimaryKey> keyQueue;

        ReorderingRangeIterator(PriorityQueue<PrimaryKey> priorityQueue) {
            super(VectorMemoryIndex.this.minimumKey, VectorMemoryIndex.this.maximumKey, priorityQueue.size());
            this.keyQueue = priorityQueue;
        }

        @Override // org.apache.cassandra.index.sai.iterators.KeyRangeIterator
        protected void performSkipTo(PrimaryKey primaryKey) {
            while (!this.keyQueue.isEmpty() && this.keyQueue.peek().compareTo(primaryKey) < 0) {
                this.keyQueue.poll();
            }
        }

        @Override // java.io.Closeable, java.lang.AutoCloseable
        public void close() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.apache.cassandra.utils.AbstractGuavaIterator
        public PrimaryKey computeNext() {
            return this.keyQueue.isEmpty() ? endOfData() : this.keyQueue.poll();
        }
    }

    public VectorMemoryIndex(StorageAttachedIndex storageAttachedIndex) {
        super(storageAttachedIndex);
        this.writeCount = new LongAdder();
        this.primaryKeys = new ConcurrentSkipListSet();
        this.graph = new OnHeapGraph<>(storageAttachedIndex.termType().indexType(), storageAttachedIndex.indexWriterConfig());
    }

    @Override // org.apache.cassandra.index.sai.memory.MemoryIndex
    public synchronized long add(DecoratedKey decoratedKey, Clustering<?> clustering, ByteBuffer byteBuffer) {
        if (byteBuffer == null || byteBuffer.remaining() == 0 || !this.index.validateMaxTermSize(decoratedKey, byteBuffer, false)) {
            return 0L;
        }
        return index(this.index.hasClustering() ? this.index.keyFactory().create(decoratedKey, clustering) : this.index.keyFactory().create(decoratedKey), byteBuffer);
    }

    private long index(PrimaryKey primaryKey, ByteBuffer byteBuffer) {
        updateKeyBounds(primaryKey);
        this.writeCount.increment();
        this.primaryKeys.add(primaryKey);
        return this.graph.add(byteBuffer, primaryKey, OnHeapGraph.InvalidVectorBehavior.FAIL);
    }

    @Override // org.apache.cassandra.index.sai.memory.MemoryIndex
    public long update(DecoratedKey decoratedKey, Clustering<?> clustering, ByteBuffer byteBuffer, ByteBuffer byteBuffer2) {
        boolean anyMatch;
        int remaining = byteBuffer == null ? 0 : byteBuffer.remaining();
        int remaining2 = byteBuffer2 == null ? 0 : byteBuffer2.remaining();
        if (remaining == 0 && remaining2 == 0) {
            return 0L;
        }
        if (remaining == remaining2) {
            anyMatch = IntStream.range(0, remaining).anyMatch(i -> {
                return byteBuffer.get(i) != byteBuffer2.get(i);
            });
        } else {
            if (!$assertionsDisabled && remaining != 0 && remaining2 != 0) {
                throw new AssertionError();
            }
            anyMatch = true;
        }
        long j = 0;
        if (anyMatch) {
            PrimaryKey create = this.index.hasClustering() ? this.index.keyFactory().create(decoratedKey, clustering) : this.index.keyFactory().create(decoratedKey);
            updateKeyBounds(create);
            if (remaining2 > 0) {
                j = 0 + this.graph.add(byteBuffer2, create, OnHeapGraph.InvalidVectorBehavior.FAIL);
            }
            if (remaining > 0) {
                j -= this.graph.remove(byteBuffer, create);
            }
            if (remaining2 <= 0 && remaining > 0) {
                this.primaryKeys.remove(create);
            }
        }
        return j;
    }

    private void updateKeyBounds(PrimaryKey primaryKey) {
        if (this.minimumKey == null) {
            this.minimumKey = primaryKey;
        } else if (primaryKey.compareTo(this.minimumKey) < 0) {
            this.minimumKey = primaryKey;
        }
        if (this.maximumKey == null) {
            this.maximumKey = primaryKey;
        } else if (primaryKey.compareTo(this.maximumKey) > 0) {
            this.maximumKey = primaryKey;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v72, types: [java.util.Set] */
    @Override // org.apache.cassandra.index.sai.memory.MemoryIndex
    public KeyRangeIterator search(QueryContext queryContext, Expression expression, AbstractBounds<PartitionPosition> abstractBounds) {
        Bits bitsetForShadowedPrimaryKeys;
        if (!$assertionsDisabled && expression.getIndexOperator() != Expression.IndexOperator.ANN) {
            throw new AssertionError("Only ANN is supported for vector search, received " + expression.getIndexOperator());
        }
        VectorQueryContext vectorContext = queryContext.vectorContext();
        float[] decomposeVector = this.index.termType().decomposeVector(expression.lower().value.raw);
        if (RangeUtil.coversFullRing(abstractBounds)) {
            bitsetForShadowedPrimaryKeys = queryContext.vectorContext().bitsetForShadowedPrimaryKeys(this.graph);
        } else {
            boolean z = abstractBounds.left.kind() != PartitionPosition.Kind.MAX_BOUND;
            boolean z2 = abstractBounds.right.kind() != PartitionPosition.Kind.MIN_BOUND;
            boolean isMinimum = abstractBounds.right.getToken().isMinimum();
            PrimaryKey create = this.index.keyFactory().create(abstractBounds.left.getToken());
            NavigableSet<PrimaryKey> tailSet = isMinimum ? this.primaryKeys.tailSet(create, z) : this.primaryKeys.subSet(create, z, isMinimum ? null : this.index.keyFactory().create(abstractBounds.right.getToken()), z2);
            if (!vectorContext.getShadowedPrimaryKeys().isEmpty()) {
                tailSet = (Set) tailSet.stream().filter(primaryKey -> {
                    return !vectorContext.containsShadowedPrimaryKey(primaryKey);
                }).collect(Collectors.toSet());
            }
            if (tailSet.isEmpty()) {
                return KeyRangeIterator.empty();
            }
            int maxBruteForceRows = maxBruteForceRows(vectorContext.limit(), tailSet.size(), this.graph.size());
            Tracing.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, LIMIT {}", Integer.valueOf(tailSet.size()), Integer.valueOf(maxBruteForceRows), Integer.valueOf(this.graph.size()), Integer.valueOf(vectorContext.limit()));
            if (tailSet.size() < Math.max(vectorContext.limit(), maxBruteForceRows)) {
                return new ReorderingRangeIterator(new PriorityQueue((Collection) tailSet));
            }
            bitsetForShadowedPrimaryKeys = new KeyRangeFilteringBits(abstractBounds, vectorContext.bitsetForShadowedPrimaryKeys(this.graph));
        }
        PriorityQueue<PrimaryKey> search = this.graph.search(decomposeVector, queryContext.vectorContext().limit(), bitsetForShadowedPrimaryKeys);
        return search.isEmpty() ? KeyRangeIterator.empty() : new ReorderingRangeIterator(search);
    }

    @Override // org.apache.cassandra.index.sai.memory.MemtableOrdering
    public KeyRangeIterator limitToTopResults(List<PrimaryKey> list, Expression expression, int i) {
        if (this.minimumKey == null) {
            return KeyRangeIterator.empty();
        }
        List list2 = (List) list.stream().dropWhile(primaryKey -> {
            return primaryKey.compareTo(this.minimumKey) < 0;
        }).takeWhile(primaryKey2 -> {
            return primaryKey2.compareTo(this.maximumKey) <= 0;
        }).collect(Collectors.toList());
        int maxBruteForceRows = maxBruteForceRows(i, list2.size(), this.graph.size());
        Tracing.trace("SAI materialized {} rows; max brute force rows is {} for memtable index with {} nodes, LIMIT {}", Integer.valueOf(list2.size()), Integer.valueOf(maxBruteForceRows), Integer.valueOf(this.graph.size()), Integer.valueOf(i));
        if (list2.size() <= maxBruteForceRows) {
            return list2.isEmpty() ? KeyRangeIterator.empty() : new KeyRangeListIterator(this.minimumKey, this.maximumKey, list2);
        }
        PriorityQueue<PrimaryKey> search = this.graph.search(this.index.termType().decomposeVector(expression.lower().value.raw), i, new KeyFilteringBits(list2));
        return search.isEmpty() ? KeyRangeIterator.empty() : new ReorderingRangeIterator(search);
    }

    private int maxBruteForceRows(int i, int i2, int i3) {
        return (int) Math.max(i, 0.25d * this.index.indexWriterConfig().getMaximumNodeConnections() * expectedNodesVisited(i, i2, i3));
    }

    public static int expectedNodesVisited(int i, int i2, int i3) {
        int min = Math.min(i2, i3);
        return Math.min(Math.max((int) (((((0.7d * Math.pow(Math.log(i3), 2.0d)) * Math.pow(i3, 0.33d)) * Math.pow(Math.log(i), 2.0d)) * Math.pow(Math.log(i3 / min), 2.0d)) / Math.pow(min, 0.13d)), Math.min(i, i3)), i3);
    }

    @Override // org.apache.cassandra.index.sai.memory.MemoryIndex
    public Iterator<Pair<ByteComparable, PrimaryKeys>> iterator() {
        throw new UnsupportedOperationException();
    }

    @Override // org.apache.cassandra.index.sai.memory.MemoryIndex
    public SegmentMetadata.ComponentMetadataMap writeDirect(IndexDescriptor indexDescriptor, IndexIdentifier indexIdentifier, Function<PrimaryKey, Integer> function) throws IOException {
        return this.graph.writeData(indexDescriptor, indexIdentifier, function);
    }

    @Override // org.apache.cassandra.index.sai.memory.MemoryIndex
    public boolean isEmpty() {
        return this.graph.isEmpty();
    }

    @Override // org.apache.cassandra.index.sai.memory.MemoryIndex
    @Nullable
    public ByteBuffer getMinTerm() {
        return null;
    }

    @Override // org.apache.cassandra.index.sai.memory.MemoryIndex
    @Nullable
    public ByteBuffer getMaxTerm() {
        return null;
    }

    static {
        $assertionsDisabled = !VectorMemoryIndex.class.desiredAssertionStatus();
    }
}
