package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.PlanNodeDecorrelator;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.tree.BooleanLiteral;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithoutProjection.class */
public class TransformCorrelatedDistinctAggregationWithoutProjection implements Rule<CorrelatedJoinNode> {
    private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Patterns.CorrelatedJoin.type().equalTo(CorrelatedJoinNode.Type.LEFT)).with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo(BooleanLiteral.TRUE_LITERAL)).with(Patterns.CorrelatedJoin.subquery().matching(Patterns.aggregation().matching((v0) -> {
        return AggregationDecorrelation.isDistinctOperator(v0);
    }).capturedAs(AGGREGATION)));
    private final PlannerContext plannerContext;

    public TransformCorrelatedDistinctAggregationWithoutProjection(PlannerContext plannerContext) {
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelateFilters = new PlanNodeDecorrelator(this.plannerContext, context.getSymbolAllocator(), context.getLookup()).decorrelateFilters(((AggregationNode) captures.get(AGGREGATION)).getSource(), correlatedJoinNode.getCorrelation());
        if (decorrelateFilters.isEmpty()) {
            return Rule.Result.empty();
        }
        PlanNode node = decorrelateFilters.get().getNode();
        AssignUniqueId assignUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", (Type) BigintType.BIGINT));
        JoinNode joinNode = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, assignUniqueId, node, ImmutableList.of(), assignUniqueId.getOutputSymbols(), node.getOutputSymbols(), false, decorrelateFilters.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
        AggregationNode aggregationNode = (AggregationNode) captures.get(AGGREGATION);
        AggregationNode aggregationNode2 = new AggregationNode(aggregationNode.getId(), joinNode, aggregationNode.getAggregations(), AggregationNode.singleGroupingSet(ImmutableList.builder().addAll(joinNode.getLeftOutputSymbols()).addAll(aggregationNode.getGroupingKeys()).build()), ImmutableList.of(), aggregationNode.getStep(), Optional.empty(), Optional.empty());
        return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), aggregationNode2, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(aggregationNode2));
    }
}
