package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.Metadata;
import io.trino.metadata.TableHandle;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.AggregationApplicationResult;
import io.trino.spi.connector.Assignment;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.expression.Variable;
import io.trino.spi.predicate.TupleDomain;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.ConnectorExpressionTranslator;
import io.trino.sql.planner.LiteralEncoder;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SymbolReference;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.class */
public class PushAggregationIntoTableScan implements Rule<AggregationNode> {
    private static final Capture<TableScanNode> TABLE_SCAN = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().with(Patterns.Aggregation.step().equalTo(AggregationNode.Step.SINGLE)).matching(PushAggregationIntoTableScan::allArgumentsAreSimpleReferences).matching(aggregationNode -> {
        return aggregationNode.getGroupingSets().getGroupingSetCount() <= 1;
    }).matching(PushAggregationIntoTableScan::hasNoMasks).with(Patterns.source().matching(Patterns.tableScan().capturedAs(TABLE_SCAN)));
    private final PlannerContext plannerContext;

    public PushAggregationIntoTableScan(PlannerContext plannerContext) {
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isAllowPushdownIntoConnectors(session);
    }

    private static boolean allArgumentsAreSimpleReferences(AggregationNode aggregationNode) {
        Stream<R> flatMap = aggregationNode.getAggregations().values().stream().flatMap(aggregation -> {
            return aggregation.getArguments().stream();
        });
        Class<SymbolReference> cls = SymbolReference.class;
        Objects.requireNonNull(SymbolReference.class);
        return flatMap.allMatch((v1) -> {
            return r1.isInstance(v1);
        });
    }

    private static boolean hasNoMasks(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().allMatch(aggregation -> {
            return aggregation.getMask().isEmpty();
        });
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        return (Rule.Result) pushAggregationIntoTableScan(this.plannerContext, context, aggregationNode, (TableScanNode) captures.get(TABLE_SCAN), aggregationNode.getAggregations(), aggregationNode.getGroupingSets().getGroupingKeys()).map(Rule.Result::ofPlanNode).orElseGet(Rule.Result::empty);
    }

    public static Optional<PlanNode> pushAggregationIntoTableScan(PlannerContext plannerContext, Rule.Context context, PlanNode planNode, TableScanNode tableScanNode, Map<Symbol, AggregationNode.Aggregation> map, List<Symbol> list) {
        Map<String, ColumnHandle> map2 = (Map) tableScanNode.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> {
            return ((Symbol) entry.getKey()).getName();
        }, (v0) -> {
            return v0.getValue();
        }));
        List list2 = (List) map.entrySet().stream().collect(ImmutableList.toImmutableList());
        List<AggregateFunction> list3 = (List) list2.stream().map((v0) -> {
            return v0.getValue();
        }).map(aggregation -> {
            return toAggregateFunction(plannerContext.getMetadata(), context, aggregation);
        }).collect(ImmutableList.toImmutableList());
        List list4 = (List) list2.stream().map((v0) -> {
            return v0.getKey();
        }).collect(ImmutableList.toImmutableList());
        Optional<AggregationApplicationResult<TableHandle>> applyAggregation = plannerContext.getMetadata().applyAggregation(context.getSession(), tableScanNode.getTable(), list3, map2, ImmutableList.of((List) list.stream().map(symbol -> {
            return (ColumnHandle) map2.get(symbol.getName());
        }).collect(ImmutableList.toImmutableList())));
        if (applyAggregation.isEmpty()) {
            return Optional.empty();
        }
        AggregationApplicationResult<TableHandle> aggregationApplicationResult = applyAggregation.get();
        ImmutableList.Builder builder = new ImmutableList.Builder();
        builder.addAll(tableScanNode.getOutputSymbols());
        ImmutableBiMap.Builder builder2 = new ImmutableBiMap.Builder();
        builder2.putAll(tableScanNode.getAssignments());
        HashMap hashMap = new HashMap();
        for (Assignment assignment : aggregationApplicationResult.getAssignments()) {
            Symbol newSymbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());
            builder.add(newSymbol);
            builder2.put(newSymbol, assignment.getColumn());
            hashMap.put(assignment.getVariable(), newSymbol);
        }
        List list5 = (List) aggregationApplicationResult.getProjections().stream().map(connectorExpression -> {
            return ConnectorExpressionTranslator.translate(context.getSession(), connectorExpression, (Map<String, Symbol>) hashMap, new LiteralEncoder(plannerContext));
        }).collect(ImmutableList.toImmutableList());
        Verify.verify(list4.size() == list5.size());
        Assignments.Builder builder3 = Assignments.builder();
        IntStream.range(0, list4.size()).forEach(i -> {
            builder3.put((Symbol) list4.get(i), (Expression) list5.get(i));
        });
        ImmutableBiMap build = builder2.build();
        ImmutableBiMap inverse = build.inverse();
        list.forEach(symbol2 -> {
            ColumnHandle columnHandle = (ColumnHandle) map2.get(symbol2.getName());
            builder3.put(symbol2, ((Symbol) inverse.get((ColumnHandle) aggregationApplicationResult.getGroupingColumnMapping().getOrDefault(columnHandle, columnHandle))).toSymbolReference());
        });
        return Optional.of(new ProjectNode(context.getIdAllocator().getNextId(), new TableScanNode(context.getIdAllocator().getNextId(), (TableHandle) aggregationApplicationResult.getHandle(), builder.build(), build, TupleDomain.all(), Rules.deriveTableStatisticsForPushdown(context.getStatsProvider(), context.getSession(), aggregationApplicationResult.isPrecalculateStatistics(), planNode), tableScanNode.isUpdateTarget(), Optional.empty()), builder3.build()));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static AggregateFunction toAggregateFunction(Metadata metadata, Rule.Context context, AggregationNode.Aggregation aggregation) {
        String canonicalName = metadata.getFunctionMetadata(aggregation.getResolvedFunction()).getCanonicalName();
        BoundSignature signature = aggregation.getResolvedFunction().getSignature();
        ImmutableList.Builder builder = new ImmutableList.Builder();
        for (int i = 0; i < aggregation.getArguments().size(); i++) {
            builder.add(new Variable(aggregation.getArguments().get(i).getName(), signature.getArgumentTypes().get(i)));
        }
        return new AggregateFunction(canonicalName, signature.getReturnType(), builder.build(), (List) aggregation.getOrderingScheme().map((v0) -> {
            return v0.toSortItems();
        }).orElse(ImmutableList.of()), aggregation.isDistinct(), aggregation.getFilter().map(symbol -> {
            return new Variable(symbol.getName(), context.getSymbolAllocator().getTypes().get(symbol));
        }));
    }
}
