package org.apache.cassandra.index.sai.plan;

import com.google.common.base.Preconditions;
import java.nio.ByteBuffer;
import java.util.Comparator;
import java.util.Iterator;
import java.util.PriorityQueue;
import java.util.TreeMap;
import java.util.TreeSet;
import javax.annotation.Nullable;
import org.apache.cassandra.cql3.Operator;
import org.apache.cassandra.db.ColumnFamilyStore;
import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.db.Keyspace;
import org.apache.cassandra.db.ReadCommand;
import org.apache.cassandra.db.filter.RowFilter;
import org.apache.cassandra.db.partitions.BasePartitionIterator;
import org.apache.cassandra.db.partitions.PartitionIterator;
import org.apache.cassandra.db.rows.BaseRowIterator;
import org.apache.cassandra.db.rows.Row;
import org.apache.cassandra.db.rows.Unfiltered;
import org.apache.cassandra.index.SecondaryIndexManager;
import org.apache.cassandra.index.sai.StorageAttachedIndex;
import org.apache.cassandra.index.sai.utils.InMemoryPartitionIterator;
import org.apache.cassandra.index.sai.utils.InMemoryUnfilteredPartitionIterator;
import org.apache.cassandra.index.sai.utils.IndexTermType;
import org.apache.cassandra.index.sai.utils.PartitionInfo;
import org.apache.cassandra.schema.ColumnMetadata;
import org.apache.cassandra.utils.FBUtilities;
import org.apache.cassandra.utils.Pair;
import org.apache.commons.lang3.tuple.Triple;

/* loaded from: input_file:org/apache/cassandra/index/sai/plan/VectorTopKProcessor.class */
public class VectorTopKProcessor {
    private final ReadCommand command;
    private final StorageAttachedIndex index;
    private final IndexTermType indexTermType;
    private final float[] queryVector;
    private final int limit;

    public VectorTopKProcessor(ReadCommand readCommand) {
        this.command = readCommand;
        Pair<StorageAttachedIndex, float[]> findTopKIndex = findTopKIndex();
        Preconditions.checkNotNull(findTopKIndex);
        this.index = findTopKIndex.left;
        this.indexTermType = findTopKIndex.left().termType();
        this.queryVector = findTopKIndex.right;
        this.limit = readCommand.limits().count();
    }

    public <U extends Unfiltered, R extends BaseRowIterator<U>, P extends BasePartitionIterator<R>> BasePartitionIterator<?> filter(P p) {
        PriorityQueue priorityQueue = new PriorityQueue(this.limit + 1, Comparator.comparing((v0) -> {
            return v0.getRight();
        }));
        TreeMap treeMap = new TreeMap(Comparator.comparing(partitionInfo -> {
            return partitionInfo.key;
        }));
        while (p.hasNext()) {
            BaseRowIterator baseRowIterator = (BaseRowIterator) p.next();
            try {
                DecoratedKey partitionKey = baseRowIterator.partitionKey();
                Row staticRow = baseRowIterator.staticRow();
                PartitionInfo create = PartitionInfo.create(baseRowIterator);
                float scoreForRow = getScoreForRow(partitionKey, staticRow);
                while (baseRowIterator.hasNext()) {
                    Unfiltered unfiltered = (Unfiltered) baseRowIterator.next();
                    if (unfiltered.isRow()) {
                        Row row = (Row) unfiltered;
                        priorityQueue.add(Triple.of(create, row, Float.valueOf(scoreForRow + getScoreForRow(null, row))));
                        while (priorityQueue.size() > this.limit) {
                            priorityQueue.poll();
                        }
                    } else {
                        ((TreeSet) treeMap.computeIfAbsent(create, partitionInfo2 -> {
                            return new TreeSet(this.command.metadata().comparator);
                        })).add(unfiltered);
                    }
                }
                if (baseRowIterator != null) {
                    baseRowIterator.close();
                }
            } catch (Throwable th) {
                if (baseRowIterator != null) {
                    try {
                        baseRowIterator.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
        p.close();
        Iterator it = priorityQueue.iterator();
        while (it.hasNext()) {
            Triple triple = (Triple) it.next();
            ((TreeSet) treeMap.computeIfAbsent((PartitionInfo) triple.getLeft(), partitionInfo3 -> {
                return new TreeSet(this.command.metadata().comparator);
            })).add((Unfiltered) triple.getMiddle());
        }
        return p instanceof PartitionIterator ? new InMemoryPartitionIterator(this.command, treeMap) : new InMemoryUnfilteredPartitionIterator(this.command, treeMap);
    }

    private float getScoreForRow(DecoratedKey decoratedKey, Row row) {
        ByteBuffer valueOf;
        ColumnMetadata columnMetadata = this.indexTermType.columnMetadata();
        if (columnMetadata.isPrimaryKeyColumn() && decoratedKey == null) {
            return 0.0f;
        }
        if (columnMetadata.isStatic() && !row.isStatic()) {
            return 0.0f;
        }
        if (((columnMetadata.isClusteringColumn() || columnMetadata.isRegular()) && row.isStatic()) || (valueOf = this.indexTermType.valueOf(decoratedKey, row, FBUtilities.nowInSeconds())) == null) {
            return 0.0f;
        }
        return this.index.indexWriterConfig().getSimilarityFunction().compare(this.indexTermType.decomposeVector(valueOf), this.queryVector);
    }

    private Pair<StorageAttachedIndex, float[]> findTopKIndex() {
        ColumnFamilyStore openAndGetStore = Keyspace.openAndGetStore(this.command.metadata());
        for (RowFilter.Expression expression : this.command.rowFilter().getExpressions()) {
            StorageAttachedIndex findVectorIndexFor = findVectorIndexFor(openAndGetStore.indexManager, expression);
            if (findVectorIndexFor != null) {
                return Pair.create(findVectorIndexFor, findVectorIndexFor.termType().decomposeVector(expression.getIndexValue().duplicate()));
            }
        }
        return null;
    }

    @Nullable
    private StorageAttachedIndex findVectorIndexFor(SecondaryIndexManager secondaryIndexManager, RowFilter.Expression expression) {
        if (expression.operator() != Operator.ANN) {
            return null;
        }
        return (StorageAttachedIndex) secondaryIndexManager.getBestIndexFor(expression).filter(index -> {
            return index instanceof StorageAttachedIndex;
        }).orElse(null);
    }
}
