package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.util.AbstractLongHeap;
import io.github.jbellis.jvector.util.BoundedLongHeap;
import io.github.jbellis.jvector.util.NumericUtils;
import org.agrona.collections.Int2ObjectHashMap;
import org.agrona.collections.IntIntConsumer;

/* loaded from: input_file:io/github/jbellis/jvector/graph/NodeQueue.class */
public class NodeQueue {
    private final AbstractLongHeap heap;
    private final Order order;

    @FunctionalInterface
    /* loaded from: input_file:io/github/jbellis/jvector/graph/NodeQueue$NodeConsumer.class */
    public interface NodeConsumer {
        void accept(int i, float f);
    }

    /* loaded from: input_file:io/github/jbellis/jvector/graph/NodeQueue$Order.class */
    public enum Order {
        MIN_HEAP { // from class: io.github.jbellis.jvector.graph.NodeQueue.Order.1
            @Override // io.github.jbellis.jvector.graph.NodeQueue.Order
            long apply(long j) {
                return j;
            }
        },
        MAX_HEAP { // from class: io.github.jbellis.jvector.graph.NodeQueue.Order.2
            @Override // io.github.jbellis.jvector.graph.NodeQueue.Order
            long apply(long j) {
                return (-1) - j;
            }
        };

        abstract long apply(long j);
    }

    public NodeQueue(AbstractLongHeap abstractLongHeap, Order order) {
        this.heap = abstractLongHeap;
        this.order = order;
    }

    public int size() {
        return this.heap.size();
    }

    public boolean push(int i, float f) {
        return this.heap.push(encode(i, f));
    }

    private long encode(int i, float f) {
        return this.order.apply((NumericUtils.floatToSortableInt(f) << 32) | (4294967295L & (i ^ (-1))));
    }

    private float decodeScore(long j) {
        return NumericUtils.sortableIntToFloat((int) (this.order.apply(j) >> 32));
    }

    private int decodeNodeId(long j) {
        return (int) (this.order.apply(j) ^ (-1));
    }

    public int pop() {
        return decodeNodeId(this.heap.pop());
    }

    public int[] nodesCopy() {
        int size = size();
        int[] iArr = new int[size];
        for (int i = 0; i < size; i++) {
            iArr[i] = decodeNodeId(this.heap.get(i + 1));
        }
        return iArr;
    }

    public float rerank(int i, ScoreFunction.ExactScoreFunction exactScoreFunction, float f, NodeQueue nodeQueue, NodesUnsorted nodesUnsorted) {
        int[] iArr = new int[size()];
        float[] fArr = new float[size()];
        Int2ObjectHashMap int2ObjectHashMap = new Int2ObjectHashMap();
        float f2 = Float.NEGATIVE_INFINITY;
        int i2 = -1;
        int i3 = 0;
        for (int i4 = 0; i4 < size(); i4++) {
            long j = this.heap.get(i4 + 1);
            float decodeScore = decodeScore(j);
            int decodeNodeId = decodeNodeId(j);
            if (decodeScore > f2) {
                f2 = decodeScore;
                i2 = i4;
            }
            if (decodeScore >= f) {
                iArr[i4] = decodeNodeId;
                fArr[i4] = exactScoreFunction.similarityTo(iArr[i4]);
                int2ObjectHashMap.put(iArr[i4], Float.valueOf(decodeScore));
                i3++;
            } else {
                iArr[i4] = -1;
            }
        }
        if (i3 == 0 && i2 >= 0) {
            iArr[i2] = decodeNodeId(this.heap.get(i2 + 1));
            fArr[i2] = exactScoreFunction.similarityTo(iArr[i2]);
            int2ObjectHashMap.put(iArr[i2], Float.valueOf(f2));
        }
        for (int i5 = 0; i5 < iArr.length; i5++) {
            if (iArr[i5] == -1) {
                nodesUnsorted.add(decodeNodeId(this.heap.get(i5 + 1)), decodeScore(this.heap.get(i5 + 1)));
            } else if (nodeQueue.size() < i) {
                nodeQueue.push(iArr[i5], fArr[i5]);
            } else if (fArr[i5] > nodeQueue.topScore()) {
                int i6 = nodeQueue.topNode();
                nodesUnsorted.add(i6, ((Float) int2ObjectHashMap.get(i6)).floatValue());
                nodeQueue.push(iArr[i5], fArr[i5]);
            } else {
                nodesUnsorted.add(iArr[i5], decodeScore(this.heap.get(i5 + 1)));
            }
        }
        float f3 = Float.POSITIVE_INFINITY;
        if (nodeQueue.size() < i) {
            return Float.POSITIVE_INFINITY;
        }
        for (int i7 = 0; i7 < nodeQueue.size(); i7++) {
            f3 = Math.min(f3, ((Float) int2ObjectHashMap.get(decodeNodeId(nodeQueue.heap.get(i7 + 1)))).floatValue());
        }
        return f3;
    }

    public int topNode() {
        return decodeNodeId(this.heap.top());
    }

    public float topScore() {
        return decodeScore(this.heap.top());
    }

    public void clear() {
        this.heap.clear();
    }

    public void setMaxSize(int i) {
        ((BoundedLongHeap) this.heap).setMaxSize(i);
    }

    public String toString() {
        return "Nodes[" + this.heap.size() + "]";
    }

    public void foreach(NodeConsumer nodeConsumer) {
        for (int i = 0; i < this.heap.size(); i++) {
            long j = this.heap.get(i + 1);
            nodeConsumer.accept(decodeNodeId(j), decodeScore(j));
        }
    }

    public void forEachTop3(IntIntConsumer intIntConsumer) {
        if (this.heap.size() == 0) {
            return;
        }
        intIntConsumer.accept(0, decodeNodeId(this.heap.get(1)));
        if (this.heap.size() == 1) {
            return;
        }
        if (this.heap.size() == 2) {
            intIntConsumer.accept(1, decodeNodeId(this.heap.get(2)));
            return;
        }
        long j = this.heap.get(2);
        long j2 = this.heap.get(3);
        if (this.order.apply(j) >= this.order.apply(j2)) {
            intIntConsumer.accept(1, decodeNodeId(j));
            findThirdPlace(intIntConsumer, 2, 3, j2);
        } else {
            intIntConsumer.accept(1, decodeNodeId(j2));
            findThirdPlace(intIntConsumer, 3, 2, j);
        }
    }

    private void findThirdPlace(IntIntConsumer intIntConsumer, int i, int i2, long j) {
        int i3 = i * 2;
        int i4 = (i * 2) + 1;
        long j2 = j;
        int i5 = i2;
        if (i3 <= this.heap.size()) {
            long j3 = this.heap.get(i3);
            if (this.order.apply(j3) > this.order.apply(j2)) {
                j2 = j3;
                i5 = i3;
            }
        }
        if (i4 <= this.heap.size()) {
            if (this.order.apply(this.heap.get(i4)) > this.order.apply(j2)) {
                i5 = i4;
            }
        }
        intIntConsumer.accept(2, decodeNodeId(this.heap.get(i5)));
    }
}
