package io.prestosql.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.Session;
import io.prestosql.spi.block.SortOrder;
import io.prestosql.sql.planner.assertions.BasePlanTest;
import io.prestosql.sql.planner.assertions.PlanMatchPattern;
import io.prestosql.sql.planner.assertions.RowNumberSymbolMatcher;
import io.prestosql.sql.planner.assertions.TopNRowNumberSymbolMatcher;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.TopNRowNumberNode;
import io.prestosql.sql.planner.plan.WindowNode;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:io/prestosql/sql/planner/optimizations/TestWindowFilterPushDown.class */
public class TestWindowFilterPushDown extends BasePlanTest {
    @Test
    public void testLimitAboveWindow() {
        assertPlanWithSession("SELECT row_number() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem LIMIT 10", optimizeTopNRowNumber(true), true, PlanMatchPattern.anyTree(PlanMatchPattern.limit(10L, PlanMatchPattern.anyTree(PlanMatchPattern.node(TopNRowNumberNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("lineitem")))))));
        assertPlanWithSession("SELECT row_number() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem LIMIT 10", optimizeTopNRowNumber(false), true, PlanMatchPattern.anyTree(PlanMatchPattern.limit(10L, PlanMatchPattern.anyTree(PlanMatchPattern.node(WindowNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("lineitem")))))));
    }

    @Test
    public void testFilterAboveWindow() {
        assertPlanWithSession("SELECT * FROM (SELECT row_number() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem) WHERE partition_row_number < 10", optimizeTopNRowNumber(true), true, PlanMatchPattern.anyTree(PlanMatchPattern.anyNot(FilterNode.class, PlanMatchPattern.node(TopNRowNumberNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("lineitem"))))));
        assertPlanWithSession("SELECT * FROM (SELECT row_number() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem) WHERE partition_row_number < 10", optimizeTopNRowNumber(false), true, PlanMatchPattern.anyTree(PlanMatchPattern.node(FilterNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.node(WindowNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("lineitem")))))));
        assertPlanWithSession("SELECT * FROM (SELECT name, row_number() OVER(ORDER BY name) FROM nation) t(name, row_number) WHERE row_number < 0", optimizeTopNRowNumber(true), true, PlanMatchPattern.output(ImmutableList.of("name", "row_number"), PlanMatchPattern.values("name", "row_number")));
        assertPlanWithSession("SELECT * FROM (SELECT name, row_number() OVER(ORDER BY name) FROM nation) t(name, row_number) WHERE row_number < 2", optimizeTopNRowNumber(true), true, PlanMatchPattern.output(ImmutableList.of("name", "row_number"), PlanMatchPattern.topNRowNumber(builder -> {
            builder.maxRowCountPerPartition(1).specification(ImmutableList.of(), ImmutableList.of("name"), ImmutableMap.of("name", SortOrder.ASC_NULLS_LAST));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("NAME", "name")))).withAlias("row_number", new TopNRowNumberSymbolMatcher())));
        assertPlanWithSession("SELECT * FROM (SELECT name, row_number() OVER(ORDER BY name) FROM nation) t(name, row_number) WHERE row_number <= 1", optimizeTopNRowNumber(true), true, PlanMatchPattern.output(ImmutableList.of("name", "row_number"), PlanMatchPattern.topNRowNumber(builder2 -> {
            builder2.maxRowCountPerPartition(1).specification(ImmutableList.of(), ImmutableList.of("name"), ImmutableMap.of("name", SortOrder.ASC_NULLS_LAST));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("NAME", "name")))).withAlias("row_number", new TopNRowNumberSymbolMatcher())));
        assertPlanWithSession("SELECT * FROM (SELECT name, row_number() OVER(ORDER BY name) FROM nation) t(name, row_number) WHERE row_number <= 1 AND row_number > -10", optimizeTopNRowNumber(true), true, PlanMatchPattern.output(ImmutableList.of("name", "row_number"), PlanMatchPattern.topNRowNumber(builder3 -> {
            builder3.maxRowCountPerPartition(1).specification(ImmutableList.of(), ImmutableList.of("name"), ImmutableMap.of("name", SortOrder.ASC_NULLS_LAST));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("NAME", "name")))).withAlias("row_number", new TopNRowNumberSymbolMatcher())));
        assertPlanWithSession("SELECT * FROM (SELECT name, row_number() OVER(ORDER BY name) FROM nation) t(name, row_number) WHERE row_number > 1 AND row_number < 3", optimizeTopNRowNumber(true), true, PlanMatchPattern.output(ImmutableList.of("name", "row_number"), PlanMatchPattern.filter("(row_number > BIGINT '1') AND (row_number < BIGINT '3')", PlanMatchPattern.topNRowNumber(builder4 -> {
            builder4.maxRowCountPerPartition(2).specification(ImmutableList.of(), ImmutableList.of("name"), ImmutableMap.of("name", SortOrder.ASC_NULLS_LAST));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("NAME", "name")))).withAlias("row_number", new TopNRowNumberSymbolMatcher()))));
    }

    @Test
    public void testFilterAboveRowNumber() {
        assertPlan("SELECT * FROM (SELECT name, row_number() OVER() FROM nation) t(name, row_number) WHERE row_number < 0", PlanMatchPattern.output(ImmutableList.of("name", "row_number"), PlanMatchPattern.values("name", "row_number")));
        assertPlan("SELECT * FROM (SELECT name, row_number() OVER() FROM nation) t(name, row_number) WHERE row_number < 2", PlanMatchPattern.output(ImmutableList.of("name", "row_number"), PlanMatchPattern.rowNumber(builder -> {
            builder.maxRowCountPerPartition(Optional.of(1));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("NAME", "name")))).withAlias("row_number", new RowNumberSymbolMatcher())));
        assertPlan("SELECT * FROM (SELECT name, row_number() OVER() FROM nation) t(name, row_number) WHERE row_number <= 1", PlanMatchPattern.output(ImmutableList.of("name", "row_number"), PlanMatchPattern.rowNumber(builder2 -> {
            builder2.maxRowCountPerPartition(Optional.of(1));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("NAME", "name")))).withAlias("row_number", new RowNumberSymbolMatcher())));
        assertPlan("SELECT * FROM (SELECT name, row_number() OVER() FROM nation) t(name, row_number) WHERE row_number <= 1 AND row_number > -10", PlanMatchPattern.output(ImmutableList.of("name", "row_number"), PlanMatchPattern.rowNumber(builder3 -> {
            builder3.maxRowCountPerPartition(Optional.of(1));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("NAME", "name")))).withAlias("row_number", new RowNumberSymbolMatcher())));
        assertPlan("SELECT * FROM (SELECT name, row_number() OVER() FROM nation) t(name, row_number) WHERE row_number > 1 AND row_number < 3", PlanMatchPattern.output(ImmutableList.of("name", "row_number"), PlanMatchPattern.filter("(row_number > BIGINT '1') AND (row_number < BIGINT '3')", PlanMatchPattern.rowNumber(builder4 -> {
            builder4.maxRowCountPerPartition(Optional.of(2));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("NAME", "name")))).withAlias("row_number", new RowNumberSymbolMatcher()))));
    }

    private Session optimizeTopNRowNumber(boolean z) {
        return Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("optimize_top_n_row_number", Boolean.toString(z)).build();
    }
}
