package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.optimizations.SymbolMapper;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.TableWriterNode;
import io.prestosql.sql.planner.plan.UnionNode;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/PushTableWriteThroughUnion.class */
public class PushTableWriteThroughUnion implements Rule<TableWriterNode> {
    private static final Capture<UnionNode> CHILD = Capture.newCapture();
    private static final Pattern<TableWriterNode> PATTERN = Patterns.tableWriterNode().matching(tableWriterNode -> {
        return tableWriterNode.getPartitioningScheme().isEmpty();
    }).with(Patterns.source().matching(Patterns.union().capturedAs(CHILD)));

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Pattern<TableWriterNode> getPattern() {
        return PATTERN;
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isPushTableWriteThroughUnion(session);
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Rule.Result apply(TableWriterNode tableWriterNode, Captures captures, Rule.Context context) {
        UnionNode unionNode = (UnionNode) captures.get(CHILD);
        ImmutableList.Builder builder = ImmutableList.builder();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < unionNode.getSources().size(); i++) {
            builder.add(rewriteSource(tableWriterNode, unionNode, i, arrayList, context));
        }
        ImmutableListMultimap.Builder builder2 = ImmutableListMultimap.builder();
        arrayList.forEach(map -> {
            Objects.requireNonNull(builder2);
            map.forEach((v1, v2) -> {
                r1.put(v1, v2);
            });
        });
        return Rule.Result.ofPlanNode(new UnionNode(context.getIdAllocator().getNextId(), builder.build(), builder2.build(), ImmutableList.copyOf(builder2.build().keySet())));
    }

    private static TableWriterNode rewriteSource(TableWriterNode tableWriterNode, UnionNode unionNode, int i, List<Map<Symbol, Symbol>> list, Rule.Context context) {
        Map<Symbol, Symbol> inputSymbolMapping = getInputSymbolMapping(unionNode, i);
        ImmutableMap.Builder builder = ImmutableMap.builder();
        builder.putAll(inputSymbolMapping);
        ImmutableMap.Builder builder2 = ImmutableMap.builder();
        for (Symbol symbol : tableWriterNode.getOutputSymbols()) {
            if (inputSymbolMapping.containsKey(symbol)) {
                builder2.put(symbol, inputSymbolMapping.get(symbol));
            } else {
                Symbol newSymbol = context.getSymbolAllocator().newSymbol(symbol);
                builder2.put(symbol, newSymbol);
                builder.put(symbol, newSymbol);
            }
        }
        list.add(builder2.build());
        return new SymbolMapper(builder.build()).map(tableWriterNode, unionNode.getSources().get(i), context.getIdAllocator().getNextId());
    }

    private static Map<Symbol, Symbol> getInputSymbolMapping(UnionNode unionNode, int i) {
        return (Map) unionNode.getSymbolMapping().keySet().stream().collect(ImmutableMap.toImmutableMap(symbol -> {
            return symbol;
        }, symbol2 -> {
            return (Symbol) unionNode.getSymbolMapping().get(symbol2).get(i);
        }));
    }
}
