package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.ChainedStateHandle;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.TaskStateHandles;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/runtime/checkpoint/StateAssignmentOperation.class */
public class StateAssignmentOperation {
    private static final Logger LOG = LoggerFactory.getLogger(StateAssignmentOperation.class);
    private final Map<JobVertexID, ExecutionJobVertex> tasks;
    private final Map<OperatorID, OperatorState> operatorStates;
    private final boolean allowNonRestoredState;

    public StateAssignmentOperation(Map<JobVertexID, ExecutionJobVertex> map, Map<OperatorID, OperatorState> map2, boolean z) {
        this.tasks = (Map) Preconditions.checkNotNull(map);
        this.operatorStates = (Map) Preconditions.checkNotNull(map2);
        this.allowNonRestoredState = z;
    }

    public boolean assignStates() throws Exception {
        HashMap hashMap = new HashMap(this.operatorStates);
        Map<JobVertexID, ExecutionJobVertex> map = this.tasks;
        checkStateMappingCompleteness(this.allowNonRestoredState, this.operatorStates, this.tasks);
        for (Map.Entry<JobVertexID, ExecutionJobVertex> entry : map.entrySet()) {
            ExecutionJobVertex value = entry.getValue();
            List<OperatorID> operatorIDs = value.getOperatorIDs();
            List<OperatorID> userDefinedOperatorIDs = value.getUserDefinedOperatorIDs();
            ArrayList arrayList = new ArrayList();
            boolean z = true;
            for (int i = 0; i < operatorIDs.size(); i++) {
                OperatorID operatorID = userDefinedOperatorIDs.get(i) == null ? operatorIDs.get(i) : userDefinedOperatorIDs.get(i);
                OperatorState operatorState = (OperatorState) hashMap.remove(operatorID);
                if (operatorState == null) {
                    operatorState = new OperatorState(operatorID, value.getParallelism(), value.getMaxParallelism());
                } else {
                    z = false;
                }
                arrayList.add(operatorState);
            }
            if (!z) {
                assignAttemptState(entry.getValue(), arrayList);
            }
        }
        return true;
    }

    private void assignAttemptState(ExecutionJobVertex executionJobVertex, List<OperatorState> list) {
        List<OperatorID> operatorIDs = executionJobVertex.getOperatorIDs();
        checkParallelismPreconditions(list, executionJobVertex);
        int parallelism = executionJobVertex.getParallelism();
        List<KeyGroupRange> createKeyGroupPartitions = createKeyGroupPartitions(executionJobVertex.getMaxParallelism(), parallelism);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        reDistributePartitionableStates(list, parallelism, arrayList, arrayList2);
        for (int i = 0; i < parallelism; i++) {
            Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt();
            ArrayList arrayList3 = new ArrayList();
            Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> tuple2 = null;
            ArrayList arrayList4 = new ArrayList();
            ArrayList arrayList5 = new ArrayList();
            for (int i2 = 0; i2 < operatorIDs.size(); i2++) {
                OperatorState operatorState = list.get(i2);
                int parallelism2 = operatorState.getParallelism();
                reAssignSubNonPartitionedStates(operatorState, i, parallelism, parallelism2, arrayList3);
                reAssignSubPartitionableState(arrayList, arrayList2, i, i2, arrayList4, arrayList5);
                if (i2 == operatorIDs.size() - 1) {
                    tuple2 = reAssignSubKeyedStates(operatorState, createKeyGroupPartitions, i, parallelism, parallelism2);
                }
            }
            if (!allElementsAreNull(arrayList3) || !allElementsAreNull(arrayList4) || !allElementsAreNull(arrayList5) || tuple2 != null) {
                currentExecutionAttempt.setInitialState(new TaskStateHandles(new ChainedStateHandle(arrayList3), arrayList4, arrayList5, tuple2 != null ? (Collection) tuple2.f0 : null, tuple2 != null ? (Collection) tuple2.f1 : null));
            }
        }
    }

    public void checkParallelismPreconditions(List<OperatorState> list, ExecutionJobVertex executionJobVertex) {
        Iterator<OperatorState> it = list.iterator();
        while (it.hasNext()) {
            checkParallelismPreconditions(it.next(), executionJobVertex);
        }
    }

    private void reAssignSubPartitionableState(List<List<Collection<OperatorStateHandle>>> list, List<List<Collection<OperatorStateHandle>>> list2, int i, int i2, List<Collection<OperatorStateHandle>> list3, List<Collection<OperatorStateHandle>> list4) {
        if (list.get(i2) != null) {
            list3.add(list.get(i2).get(i));
        } else {
            list3.add(null);
        }
        if (list2.get(i2) != null) {
            list4.add(list2.get(i2).get(i));
        } else {
            list4.add(null);
        }
    }

    private Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> reAssignSubKeyedStates(OperatorState operatorState, List<KeyGroupRange> list, int i, int i2, int i3) {
        List<KeyedStateHandle> managedKeyedStateHandles;
        List<KeyedStateHandle> rawKeyedStateHandles;
        if (i2 != i3) {
            managedKeyedStateHandles = getManagedKeyedStateHandles(operatorState, list.get(i));
            rawKeyedStateHandles = getRawKeyedStateHandles(operatorState, list.get(i));
        } else if (operatorState.getState(i) != null) {
            KeyedStateHandle managedKeyedState = operatorState.getState(i).getManagedKeyedState();
            KeyedStateHandle rawKeyedState = operatorState.getState(i).getRawKeyedState();
            managedKeyedStateHandles = managedKeyedState != null ? Collections.singletonList(managedKeyedState) : null;
            rawKeyedStateHandles = rawKeyedState != null ? Collections.singletonList(rawKeyedState) : null;
        } else {
            managedKeyedStateHandles = null;
            rawKeyedStateHandles = null;
        }
        if (managedKeyedStateHandles == null && rawKeyedStateHandles == null) {
            return null;
        }
        return new Tuple2<>(managedKeyedStateHandles, rawKeyedStateHandles);
    }

    private <X> boolean allElementsAreNull(List<X> list) {
        Iterator<X> it = list.iterator();
        while (it.hasNext()) {
            if (it.next() != null) {
                return false;
            }
        }
        return true;
    }

    private void reAssignSubNonPartitionedStates(OperatorState operatorState, int i, int i2, int i3, List<StreamStateHandle> list) {
        if (i3 != i2) {
            list.add(null);
        } else if (operatorState.getState(i) != null) {
            list.add(operatorState.getState(i).getLegacyOperatorState());
        } else {
            list.add(null);
        }
    }

    private void reDistributePartitionableStates(List<OperatorState> list, int i, List<List<Collection<OperatorStateHandle>>> list2, List<List<Collection<OperatorStateHandle>>> list3) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        collectPartionableStates(list, arrayList, arrayList2);
        OperatorStateRepartitioner operatorStateRepartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
        for (int i2 = 0; i2 < list.size(); i2++) {
            int parallelism = list.get(i2).getParallelism();
            list2.add(applyRepartitioner(operatorStateRepartitioner, arrayList.get(i2), parallelism, i));
            list3.add(applyRepartitioner(operatorStateRepartitioner, arrayList2.get(i2), parallelism, i));
        }
    }

    private void collectPartionableStates(List<OperatorState> list, List<List<OperatorStateHandle>> list2, List<List<OperatorStateHandle>> list3) {
        for (OperatorState operatorState : list) {
            ArrayList arrayList = null;
            ArrayList arrayList2 = null;
            for (int i = 0; i < operatorState.getParallelism(); i++) {
                OperatorSubtaskState state = operatorState.getState(i);
                if (state != null) {
                    if (state.getManagedOperatorState() != null) {
                        if (arrayList == null) {
                            arrayList = new ArrayList();
                        }
                        arrayList.add(state.getManagedOperatorState());
                    }
                    if (state.getRawOperatorState() != null) {
                        if (arrayList2 == null) {
                            arrayList2 = new ArrayList();
                        }
                        arrayList2.add(state.getRawOperatorState());
                    }
                }
            }
            list2.add(arrayList);
            list3.add(arrayList2);
        }
    }

    public static List<KeyedStateHandle> getManagedKeyedStateHandles(OperatorState operatorState, KeyGroupRange keyGroupRange) {
        KeyedStateHandle intersection;
        ArrayList arrayList = null;
        for (int i = 0; i < operatorState.getParallelism(); i++) {
            if (operatorState.getState(i) != null && operatorState.getState(i).getManagedKeyedState() != null && (intersection = operatorState.getState(i).getManagedKeyedState().getIntersection(keyGroupRange)) != null) {
                if (arrayList == null) {
                    arrayList = new ArrayList();
                }
                arrayList.add(intersection);
            }
        }
        return arrayList;
    }

    public static List<KeyedStateHandle> getRawKeyedStateHandles(OperatorState operatorState, KeyGroupRange keyGroupRange) {
        KeyedStateHandle intersection;
        ArrayList arrayList = null;
        for (int i = 0; i < operatorState.getParallelism(); i++) {
            if (operatorState.getState(i) != null && operatorState.getState(i).getRawKeyedState() != null && (intersection = operatorState.getState(i).getRawKeyedState().getIntersection(keyGroupRange)) != null) {
                if (arrayList == null) {
                    arrayList = new ArrayList();
                }
                arrayList.add(intersection);
            }
        }
        return arrayList;
    }

    public static List<KeyGroupRange> createKeyGroupPartitions(int i, int i2) {
        Preconditions.checkArgument(i >= i2);
        ArrayList arrayList = new ArrayList(i2);
        for (int i3 = 0; i3 < i2; i3++) {
            arrayList.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(i, i2, i3));
        }
        return arrayList;
    }

    private static void checkParallelismPreconditions(OperatorState operatorState, ExecutionJobVertex executionJobVertex) {
        if (operatorState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
            if (executionJobVertex.isMaxParallelismConfigured()) {
                throw new IllegalStateException("The maximum parallelism (" + operatorState.getMaxParallelism() + ") with which the latest checkpoint of the execution job vertex " + executionJobVertex + " has been taken and the current maximum parallelism (" + executionJobVertex.getMaxParallelism() + ") changed. This is currently not supported.");
            }
            LOG.debug("Overriding maximum parallelism for JobVertex {} from {} to {}", new Object[]{executionJobVertex.getJobVertexId(), Integer.valueOf(executionJobVertex.getMaxParallelism()), Integer.valueOf(operatorState.getMaxParallelism())});
            executionJobVertex.setMaxParallelism(operatorState.getMaxParallelism());
        }
        int parallelism = operatorState.getParallelism();
        int parallelism2 = executionJobVertex.getParallelism();
        if (operatorState.hasNonPartitionedState() && parallelism != parallelism2) {
            throw new IllegalStateException("Cannot restore the latest checkpoint because the operator " + executionJobVertex.getJobVertexId() + " has non-partitioned state and its parallelism changed. The operator " + executionJobVertex.getJobVertexId() + " has parallelism " + parallelism2 + " whereas the corresponding state object has a parallelism of " + parallelism);
        }
    }

    private static void checkStateMappingCompleteness(boolean z, Map<OperatorID, OperatorState> map, Map<JobVertexID, ExecutionJobVertex> map2) {
        HashSet hashSet = new HashSet();
        Iterator<ExecutionJobVertex> it = map2.values().iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().getOperatorIDs());
        }
        for (Map.Entry<OperatorID, OperatorState> entry : map.entrySet()) {
            OperatorState value = entry.getValue();
            if (!hashSet.contains(entry.getKey())) {
                if (!z) {
                    throw new IllegalStateException("There is no operator for the state " + value.getOperatorID());
                }
                LOG.info("Skipped checkpoint state for operator {}.", value.getOperatorID());
            }
        }
    }

    public static List<Collection<OperatorStateHandle>> applyRepartitioner(OperatorStateRepartitioner operatorStateRepartitioner, List<OperatorStateHandle> list, int i, int i2) {
        if (list == null) {
            return null;
        }
        if (i2 != i) {
            return operatorStateRepartitioner.repartitionState(list, i2);
        }
        ArrayList arrayList = new ArrayList(i2);
        for (OperatorStateHandle operatorStateHandle : list) {
            Iterator<OperatorStateHandle.StateMetaInfo> it = operatorStateHandle.getStateNameToPartitionOffsets().values().iterator();
            while (it.hasNext()) {
                if (OperatorStateHandle.Mode.BROADCAST.equals(it.next().getDistributionMode())) {
                    return operatorStateRepartitioner.repartitionState(list, i2);
                }
            }
            arrayList.add(Collections.singletonList(operatorStateHandle));
        }
        return arrayList;
    }

    public static List<KeyedStateHandle> getKeyedStateHandles(Collection<? extends KeyedStateHandle> collection, KeyGroupRange keyGroupRange) {
        ArrayList arrayList = new ArrayList();
        Iterator<? extends KeyedStateHandle> it = collection.iterator();
        while (it.hasNext()) {
            KeyedStateHandle intersection = it.next().getIntersection(keyGroupRange);
            if (intersection != null) {
                arrayList.add(intersection);
            }
        }
        return arrayList;
    }
}
