package org.apache.flink.runtime.state.heap;

import java.util.HashMap;
import java.util.Set;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.flink.runtime.state.KeyExtractorFunction;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
import org.apache.flink.runtime.state.PriorityComparator;
import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/runtime/state/heap/HeapPriorityQueueSet.class */
public class HeapPriorityQueueSet<T extends HeapPriorityQueueElement> extends HeapPriorityQueue<T> implements KeyGroupedInternalPriorityQueue<T> {
    private final KeyExtractorFunction<T> keyExtractor;
    private final HashMap<T, T>[] deduplicationMapsByKeyGroup;
    private final KeyGroupRange keyGroupRange;
    private final int totalNumberOfKeyGroups;

    public HeapPriorityQueueSet(@Nonnull PriorityComparator<T> priorityComparator, @Nonnull KeyExtractorFunction<T> keyExtractorFunction, @Nonnegative int i, @Nonnull KeyGroupRange keyGroupRange, @Nonnegative int i2) {
        super(priorityComparator, i);
        this.keyExtractor = keyExtractorFunction;
        this.totalNumberOfKeyGroups = i2;
        this.keyGroupRange = keyGroupRange;
        int numberOfKeyGroups = keyGroupRange.getNumberOfKeyGroups();
        int i3 = 1 + (i / numberOfKeyGroups);
        this.deduplicationMapsByKeyGroup = new HashMap[numberOfKeyGroups];
        for (int i4 = 0; i4 < numberOfKeyGroups; i4++) {
            this.deduplicationMapsByKeyGroup[i4] = new HashMap<>(i3);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.flink.runtime.state.heap.AbstractHeapPriorityQueue, org.apache.flink.runtime.state.InternalPriorityQueue
    @Nullable
    public T poll() {
        HeapPriorityQueueElement poll = super.poll();
        if (poll != null) {
            return (T) getDedupMapForElement(poll).remove(poll);
        }
        return null;
    }

    @Override // org.apache.flink.runtime.state.heap.AbstractHeapPriorityQueue, org.apache.flink.runtime.state.InternalPriorityQueue
    public boolean add(@Nonnull T t) {
        return getDedupMapForElement(t).putIfAbsent(t, t) == null && super.add((HeapPriorityQueueSet<T>) t);
    }

    @Override // org.apache.flink.runtime.state.heap.AbstractHeapPriorityQueue, org.apache.flink.runtime.state.InternalPriorityQueue
    public boolean remove(@Nonnull T t) {
        T remove = getDedupMapForElement(t).remove(t);
        return remove != null && super.remove((HeapPriorityQueueSet<T>) remove);
    }

    @Override // org.apache.flink.runtime.state.heap.AbstractHeapPriorityQueue
    public void clear() {
        super.clear();
        for (HashMap<T, T> hashMap : this.deduplicationMapsByKeyGroup) {
            hashMap.clear();
        }
    }

    private HashMap<T, T> getDedupMapForKeyGroup(@Nonnegative int i) {
        return this.deduplicationMapsByKeyGroup[globalKeyGroupToLocalIndex(i)];
    }

    private HashMap<T, T> getDedupMapForElement(T t) {
        return getDedupMapForKeyGroup(KeyGroupRangeAssignment.assignToKeyGroup(this.keyExtractor.extractKeyFromElement(t), this.totalNumberOfKeyGroups));
    }

    private int globalKeyGroupToLocalIndex(int i) {
        Preconditions.checkArgument(this.keyGroupRange.contains(i));
        return i - this.keyGroupRange.getStartKeyGroup();
    }

    @Override // org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue
    @Nonnull
    public Set<T> getSubsetForKeyGroup(int i) {
        return getDedupMapForKeyGroup(i).keySet();
    }
}
