package io.prestosql.sql.planner.optimizations;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.prestosql.Session;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.TableHandle;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.SymbolAllocator;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.plan.ChildReplacer;
import io.prestosql.sql.planner.plan.DeleteNode;
import io.prestosql.sql.planner.plan.ExchangeNode;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.SemiJoinNode;
import io.prestosql.sql.planner.plan.SimplePlanRewriter;
import io.prestosql.sql.planner.plan.StatisticsWriterNode;
import io.prestosql.sql.planner.plan.TableFinishNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.planner.plan.TableWriterNode;
import io.prestosql.sql.planner.plan.UnionNode;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:io/prestosql/sql/planner/optimizations/BeginTableWrite.class */
public class BeginTableWrite implements PlanOptimizer {
    private final Metadata metadata;

    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/BeginTableWrite$Context.class */
    public static class Context {
        private Optional<TableWriterNode.WriterTarget> handle = Optional.empty();
        private Optional<TableWriterNode.WriterTarget> materializedHandle = Optional.empty();

        public void addMaterializedHandle(TableWriterNode.WriterTarget writerTarget, TableWriterNode.WriterTarget writerTarget2) {
            Preconditions.checkState(!this.handle.isPresent(), "can only have one WriterTarget in a subtree");
            this.handle = Optional.of(writerTarget);
            this.materializedHandle = Optional.of(writerTarget2);
        }

        public Optional<TableWriterNode.WriterTarget> getMaterializedHandle(TableWriterNode.WriterTarget writerTarget) {
            Preconditions.checkState(this.handle.get().equals(writerTarget), "can't find materialized handle for WriterTarget");
            return this.materializedHandle;
        }
    }

    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/BeginTableWrite$Rewriter.class */
    private class Rewriter extends SimplePlanRewriter<Context> {
        private final Session session;

        public Rewriter(Session session) {
            this.session = session;
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitTableWriter(TableWriterNode tableWriterNode, SimplePlanRewriter.RewriteContext<Context> rewriteContext) {
            return new TableWriterNode(tableWriterNode.getId(), (PlanNode) tableWriterNode.getSource().accept(this, rewriteContext), rewriteContext.get().getMaterializedHandle(tableWriterNode.getTarget()).get(), tableWriterNode.getRowCountSymbol(), tableWriterNode.getFragmentSymbol(), tableWriterNode.getColumns(), tableWriterNode.getColumnNames(), tableWriterNode.getPartitioningScheme(), tableWriterNode.getStatisticsAggregation(), tableWriterNode.getStatisticsAggregationDescriptor());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitDelete(DeleteNode deleteNode, SimplePlanRewriter.RewriteContext<Context> rewriteContext) {
            TableWriterNode.DeleteTarget deleteTarget = (TableWriterNode.DeleteTarget) rewriteContext.get().getMaterializedHandle(deleteNode.getTarget()).get();
            return new DeleteNode(deleteNode.getId(), rewriteDeleteTableScan(deleteNode.getSource(), deleteTarget.getHandle()), deleteTarget, deleteNode.getRowId(), deleteNode.getOutputSymbols());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitStatisticsWriterNode(StatisticsWriterNode statisticsWriterNode, SimplePlanRewriter.RewriteContext<Context> rewriteContext) {
            return new StatisticsWriterNode(statisticsWriterNode.getId(), (PlanNode) statisticsWriterNode.getSource().accept(this, rewriteContext), new StatisticsWriterNode.WriteStatisticsHandle(BeginTableWrite.this.metadata.beginStatisticsCollection(this.session, ((StatisticsWriterNode.WriteStatisticsReference) statisticsWriterNode.getTarget()).getHandle())), statisticsWriterNode.getRowCountSymbol(), statisticsWriterNode.isRowCountEnabled(), statisticsWriterNode.getDescriptor());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitTableFinish(TableFinishNode tableFinishNode, SimplePlanRewriter.RewriteContext<Context> rewriteContext) {
            PlanNode source = tableFinishNode.getSource();
            TableWriterNode.WriterTarget target = getTarget(source);
            TableWriterNode.WriterTarget createWriterTarget = createWriterTarget(target);
            rewriteContext.get().addMaterializedHandle(target, createWriterTarget);
            return new TableFinishNode(tableFinishNode.getId(), (PlanNode) source.accept(this, rewriteContext), createWriterTarget, tableFinishNode.getRowCountSymbol(), tableFinishNode.getStatisticsAggregation(), tableFinishNode.getStatisticsAggregationDescriptor());
        }

        public TableWriterNode.WriterTarget getTarget(PlanNode planNode) {
            if (planNode instanceof TableWriterNode) {
                return ((TableWriterNode) planNode).getTarget();
            }
            if (planNode instanceof DeleteNode) {
                return ((DeleteNode) planNode).getTarget();
            }
            if ((planNode instanceof ExchangeNode) || (planNode instanceof UnionNode)) {
                return (TableWriterNode.WriterTarget) Iterables.getOnlyElement((Set) planNode.getSources().stream().map(this::getTarget).collect(Collectors.toSet()));
            }
            throw new IllegalArgumentException("Invalid child for TableCommitNode: " + planNode.getClass().getSimpleName());
        }

        private TableWriterNode.WriterTarget createWriterTarget(TableWriterNode.WriterTarget writerTarget) {
            if (writerTarget instanceof TableWriterNode.CreateReference) {
                TableWriterNode.CreateReference createReference = (TableWriterNode.CreateReference) writerTarget;
                return new TableWriterNode.CreateTarget(BeginTableWrite.this.metadata.beginCreateTable(this.session, createReference.getCatalog(), createReference.getTableMetadata(), createReference.getLayout()), createReference.getTableMetadata().getTable());
            }
            if (writerTarget instanceof TableWriterNode.InsertReference) {
                TableWriterNode.InsertReference insertReference = (TableWriterNode.InsertReference) writerTarget;
                return new TableWriterNode.InsertTarget(BeginTableWrite.this.metadata.beginInsert(this.session, insertReference.getHandle(), insertReference.getColumns()), BeginTableWrite.this.metadata.getTableMetadata(this.session, insertReference.getHandle()).getTable());
            }
            if (!(writerTarget instanceof TableWriterNode.DeleteTarget)) {
                throw new IllegalArgumentException("Unhandled target type: " + writerTarget.getClass().getSimpleName());
            }
            TableWriterNode.DeleteTarget deleteTarget = (TableWriterNode.DeleteTarget) writerTarget;
            return new TableWriterNode.DeleteTarget(BeginTableWrite.this.metadata.beginDelete(this.session, deleteTarget.getHandle()), deleteTarget.getSchemaTableName());
        }

        private PlanNode rewriteDeleteTableScan(PlanNode planNode, TableHandle tableHandle) {
            if (planNode instanceof TableScanNode) {
                TableScanNode tableScanNode = (TableScanNode) planNode;
                return new TableScanNode(tableScanNode.getId(), tableHandle, tableScanNode.getOutputSymbols(), tableScanNode.getAssignments(), tableScanNode.getEnforcedConstraint());
            }
            if (planNode instanceof FilterNode) {
                return ChildReplacer.replaceChildren(planNode, ImmutableList.of(rewriteDeleteTableScan(((FilterNode) planNode).getSource(), tableHandle)));
            }
            if (planNode instanceof ProjectNode) {
                return ChildReplacer.replaceChildren(planNode, ImmutableList.of(rewriteDeleteTableScan(((ProjectNode) planNode).getSource(), tableHandle)));
            }
            if (planNode instanceof SemiJoinNode) {
                return ChildReplacer.replaceChildren(planNode, ImmutableList.of(rewriteDeleteTableScan(((SemiJoinNode) planNode).getSource(), tableHandle), ((SemiJoinNode) planNode).getFilteringSource()));
            }
            if (planNode instanceof JoinNode) {
                JoinNode joinNode = (JoinNode) planNode;
                if (joinNode.getType() == JoinNode.Type.INNER && QueryCardinalityUtil.isAtMostScalar(joinNode.getRight())) {
                    return ChildReplacer.replaceChildren(planNode, ImmutableList.of(rewriteDeleteTableScan(joinNode.getLeft(), tableHandle), joinNode.getRight()));
                }
            }
            throw new IllegalArgumentException("Invalid descendant for DeleteNode: " + planNode.getClass().getName());
        }
    }

    public BeginTableWrite(Metadata metadata) {
        this.metadata = metadata;
    }

    @Override // io.prestosql.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector) {
        return SimplePlanRewriter.rewriteWith(new Rewriter(session), planNode, new Context());
    }
}
