package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.Iterables;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.SetOperationNode;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/SetOperationMerge.class */
class SetOperationMerge {
    private final Rule.Context context;
    private final SetOperationNode node;
    private List<PlanNode> newSources = new ArrayList();
    private final SetOperationNodeInstantiator instantiator;

    public SetOperationMerge(SetOperationNode setOperationNode, Rule.Context context, SetOperationNodeInstantiator setOperationNodeInstantiator) {
        this.node = setOperationNode;
        this.context = context;
        this.instantiator = setOperationNodeInstantiator;
    }

    public Optional<SetOperationNode> mergeFirstSource() {
        Lookup lookup = this.context.getLookup();
        Stream<PlanNode> stream = this.node.getSources().stream();
        Objects.requireNonNull(lookup);
        List list = (List) stream.flatMap(lookup::resolveGroup).collect(Collectors.toList());
        if (!((PlanNode) list.get(0)).getClass().equals(this.node.getClass())) {
            return Optional.empty();
        }
        ImmutableListMultimap.Builder<Symbol, Symbol> builder = ImmutableListMultimap.builder();
        addMergedMappings((SetOperationNode) list.get(0), 0, builder);
        for (int i = 1; i < list.size(); i++) {
            addOriginalMappings((PlanNode) list.get(i), i, builder);
        }
        return Optional.of(this.instantiator.create(this.node.getId(), this.newSources, builder.build(), this.node.getOutputSymbols()));
    }

    public Optional<SetOperationNode> merge() {
        Lookup lookup = this.context.getLookup();
        Stream<PlanNode> stream = this.node.getSources().stream();
        Objects.requireNonNull(lookup);
        List list = (List) stream.flatMap(lookup::resolveGroup).collect(Collectors.toList());
        Stream stream2 = list.stream();
        Class<?> cls = this.node.getClass();
        Objects.requireNonNull(cls);
        if (stream2.noneMatch((v1) -> {
            return r1.isInstance(v1);
        })) {
            return Optional.empty();
        }
        ImmutableListMultimap.Builder<Symbol, Symbol> builder = ImmutableListMultimap.builder();
        for (int i = 0; i < list.size(); i++) {
            PlanNode planNode = (PlanNode) list.get(i);
            if (this.node.getClass().equals(planNode.getClass())) {
                addMergedMappings((SetOperationNode) planNode, i, builder);
            } else {
                addOriginalMappings(planNode, i, builder);
            }
        }
        return Optional.of(this.instantiator.create(this.node.getId(), this.newSources, builder.build(), this.node.getOutputSymbols()));
    }

    private void addMergedMappings(SetOperationNode setOperationNode, int i, ImmutableListMultimap.Builder<Symbol, Symbol> builder) {
        this.newSources.addAll(setOperationNode.getSources());
        for (Map.Entry entry : this.node.getSymbolMapping().asMap().entrySet()) {
            builder.putAll((Symbol) entry.getKey(), setOperationNode.getSymbolMapping().get((Symbol) Iterables.get((Iterable) entry.getValue(), i)));
        }
    }

    private void addOriginalMappings(PlanNode planNode, int i, ImmutableListMultimap.Builder<Symbol, Symbol> builder) {
        this.newSources.add(planNode);
        for (Map.Entry entry : this.node.getSymbolMapping().asMap().entrySet()) {
            builder.put((Symbol) entry.getKey(), (Symbol) Iterables.get((Iterable) entry.getValue(), i));
        }
    }
}
