package io.prestosql.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.metadata.Metadata;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.optimizations.ScalarAggregationToJoinRewriter;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.CorrelatedJoinNode;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.BooleanLiteral;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Stream;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.class */
public class TransformCorrelatedScalarAggregationToJoin {
    private final Metadata metadata;

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin$TransformCorrelatedScalarAggregationWithProjection.class */
    public static final class TransformCorrelatedScalarAggregationWithProjection implements Rule<CorrelatedJoinNode> {
        private static final Capture<ProjectNode> PROJECTION = Capture.newCapture();
        private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
        private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo(BooleanLiteral.TRUE_LITERAL)).with(Patterns.CorrelatedJoin.subquery().matching(Patterns.project().capturedAs(PROJECTION).with(Patterns.source().matching(Patterns.aggregation().with(Pattern.empty(Patterns.Aggregation.groupingColumns())).capturedAs(AGGREGATION)))));
        private final Metadata metadata;

        @VisibleForTesting
        TransformCorrelatedScalarAggregationWithProjection(Metadata metadata) {
            this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        }

        @Override // io.prestosql.sql.planner.iterative.Rule
        public Pattern<CorrelatedJoinNode> getPattern() {
            return PATTERN;
        }

        @Override // io.prestosql.sql.planner.iterative.Rule
        public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
            PlanNode rewriteScalarAggregation = new ScalarAggregationToJoinRewriter(this.metadata, context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup()).rewriteScalarAggregation(correlatedJoinNode, (AggregationNode) captures.get(AGGREGATION));
            if (rewriteScalarAggregation instanceof CorrelatedJoinNode) {
                return Rule.Result.empty();
            }
            HashSet hashSet = new HashSet(correlatedJoinNode.getOutputSymbols());
            Stream<Symbol> stream = rewriteScalarAggregation.getOutputSymbols().stream();
            Objects.requireNonNull(hashSet);
            return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), rewriteScalarAggregation, Assignments.builder().putIdentities((List) stream.filter((v1) -> {
                return r1.contains(v1);
            }).collect(ImmutableList.toImmutableList())).putAll(((ProjectNode) captures.get(PROJECTION)).getAssignments()).build()));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin$TransformCorrelatedScalarAggregationWithoutProjection.class */
    public static final class TransformCorrelatedScalarAggregationWithoutProjection implements Rule<CorrelatedJoinNode> {
        private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
        private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo(BooleanLiteral.TRUE_LITERAL)).with(Patterns.CorrelatedJoin.subquery().matching(Patterns.aggregation().with(Pattern.empty(Patterns.Aggregation.groupingColumns())).capturedAs(AGGREGATION)));
        private final Metadata metadata;

        @VisibleForTesting
        TransformCorrelatedScalarAggregationWithoutProjection(Metadata metadata) {
            this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        }

        @Override // io.prestosql.sql.planner.iterative.Rule
        public Pattern<CorrelatedJoinNode> getPattern() {
            return PATTERN;
        }

        @Override // io.prestosql.sql.planner.iterative.Rule
        public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
            PlanNode rewriteScalarAggregation = new ScalarAggregationToJoinRewriter(this.metadata, context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup()).rewriteScalarAggregation(correlatedJoinNode, (AggregationNode) captures.get(AGGREGATION));
            if (rewriteScalarAggregation instanceof CorrelatedJoinNode) {
                return Rule.Result.empty();
            }
            HashSet hashSet = new HashSet(correlatedJoinNode.getOutputSymbols());
            Stream<Symbol> stream = rewriteScalarAggregation.getOutputSymbols().stream();
            Objects.requireNonNull(hashSet);
            return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), rewriteScalarAggregation, Assignments.identity((List) stream.filter((v1) -> {
                return r1.contains(v1);
            }).collect(ImmutableList.toImmutableList()))));
        }
    }

    public TransformCorrelatedScalarAggregationToJoin(Metadata metadata) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    public Set<Rule<?>> rules() {
        return ImmutableSet.of(new TransformCorrelatedScalarAggregationWithProjection(this.metadata), new TransformCorrelatedScalarAggregationWithoutProjection(this.metadata));
    }
}
