package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.ExpressionNodeInliner;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SubscriptExpression;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.class */
public class PushDownDereferenceThroughJoin implements Rule<ProjectNode> {
    private static final Capture<JoinNode> CHILD = Capture.newCapture();
    private final TypeAnalyzer typeAnalyzer;

    public PushDownDereferenceThroughJoin(TypeAnalyzer typeAnalyzer) {
        this.typeAnalyzer = (TypeAnalyzer) Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<ProjectNode> getPattern() {
        return Patterns.project().with(Patterns.source().matching(Patterns.join().capturedAs(CHILD)));
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(ProjectNode projectNode, Captures captures, Rule.Context context) {
        JoinNode joinNode = (JoinNode) captures.get(CHILD);
        ImmutableList.Builder builder = ImmutableList.builder();
        builder.addAll(projectNode.getAssignments().getExpressions());
        Optional<Expression> filter = joinNode.getFilter();
        Objects.requireNonNull(builder);
        filter.ifPresent((v1) -> {
            r1.add(v1);
        });
        Set<SubscriptExpression> extractRowSubscripts = DereferencePushdown.extractRowSubscripts(builder.build(), false, context.getSession(), this.typeAnalyzer, context.getSymbolAllocator().getTypes());
        ImmutableSet.Builder builder2 = ImmutableSet.builder();
        joinNode.getCriteria().forEach(equiJoinClause -> {
            builder2.add(equiJoinClause.getLeft());
            builder2.add(equiJoinClause.getRight());
        });
        ImmutableSet build = builder2.build();
        Set set = (Set) extractRowSubscripts.stream().filter(subscriptExpression -> {
            return !build.contains(DereferencePushdown.getBase(subscriptExpression));
        }).collect(ImmutableSet.toImmutableSet());
        if (set.isEmpty()) {
            return Rule.Result.empty();
        }
        Assignments of = Assignments.of(set, context.getSession(), context.getSymbolAllocator(), this.typeAnalyzer);
        Map map = (Map) HashBiMap.create(of.getMap()).inverse().entrySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return ((Symbol) entry.getValue()).toSymbolReference();
        }));
        Assignments rewrite = projectNode.getAssignments().rewrite(expression -> {
            return ExpressionNodeInliner.replaceExpression(expression, map);
        });
        Assignments.Builder builder3 = Assignments.builder();
        Assignments.Builder builder4 = Assignments.builder();
        of.entrySet().stream().forEach(entry2 -> {
            Symbol symbol = (Symbol) Iterables.getOnlyElement(SymbolsExtractor.extractAll((Expression) entry2.getValue()));
            if (joinNode.getLeft().getOutputSymbols().contains(symbol)) {
                builder3.put((Symbol) entry2.getKey(), (Expression) entry2.getValue());
            } else {
                if (!joinNode.getRight().getOutputSymbols().contains(symbol)) {
                    throw new IllegalArgumentException(String.format("Unexpected symbol %s in projectNode", symbol));
                }
                builder4.put((Symbol) entry2.getKey(), (Expression) entry2.getValue());
            }
        });
        Assignments build2 = builder3.build();
        Assignments build3 = builder4.build();
        PlanNode createProjectNodeIfRequired = createProjectNodeIfRequired(joinNode.getLeft(), build2, context.getIdAllocator());
        PlanNode createProjectNodeIfRequired2 = createProjectNodeIfRequired(joinNode.getRight(), build3, context.getIdAllocator());
        List list = (List) rewrite.getExpressions().stream().flatMap(expression2 -> {
            return SymbolsExtractor.extractAll(expression2).stream();
        }).collect(Collectors.toList());
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new JoinNode(context.getIdAllocator().getNextId(), joinNode.getType(), createProjectNodeIfRequired, createProjectNodeIfRequired2, joinNode.getCriteria(), (List) list.stream().filter(symbol -> {
            return createProjectNodeIfRequired.getOutputSymbols().contains(symbol);
        }).collect(Collectors.toList()), (List) list.stream().filter(symbol2 -> {
            return createProjectNodeIfRequired2.getOutputSymbols().contains(symbol2);
        }).collect(Collectors.toList()), joinNode.isMaySkipOutputDuplicates(), joinNode.getFilter().map(expression3 -> {
            return ExpressionNodeInliner.replaceExpression(expression3, map);
        }), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost()), rewrite));
    }

    private static PlanNode createProjectNodeIfRequired(PlanNode planNode, Assignments assignments, PlanNodeIdAllocator planNodeIdAllocator) {
        return assignments.isEmpty() ? planNode : new ProjectNode(planNodeIdAllocator.getNextId(), planNode, Assignments.builder().putIdentities(planNode.getOutputSymbols()).putAll(assignments).build());
    }
}
