package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.Plugin;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.assertions.RowNumberSymbolMatcher;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.Assignments;
import java.util.List;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoRowNumber.class */
public class TestPushPredicateThroughProjectIntoRowNumber extends BaseRuleTest {
    public TestPushPredicateThroughProjectIntoRowNumber() {
        super(new Plugin[0]);
    }

    @Test
    public void testRowNumberSymbolPruned() {
        tester().assertThat(new PushPredicateThroughProjectIntoRowNumber(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            return planBuilder.filter(PlanBuilder.expression("a = 1"), planBuilder.project(Assignments.identity(new Symbol[]{symbol}), planBuilder.rowNumber(ImmutableList.of(), Optional.empty(), planBuilder.symbol("row_number"), planBuilder.values(symbol))));
        }).doesNotFire();
    }

    @Test
    public void testNoUpperBoundForRowNumberSymbol() {
        tester().assertThat(new PushPredicateThroughProjectIntoRowNumber(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("row_number");
            return planBuilder.filter(PlanBuilder.expression("a = 1"), planBuilder.project(Assignments.identity(new Symbol[]{symbol, symbol2}), planBuilder.rowNumber(ImmutableList.of(), Optional.empty(), symbol2, planBuilder.values(symbol))));
        }).doesNotFire();
    }

    @Test
    public void testNonPositiveUpperBoundForRowNumberSymbol() {
        tester().assertThat(new PushPredicateThroughProjectIntoRowNumber(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("row_number");
            return planBuilder.filter(PlanBuilder.expression("a = 1 AND row_number < -10"), planBuilder.project(Assignments.identity(new Symbol[]{symbol, symbol2}), planBuilder.rowNumber(ImmutableList.of(), Optional.empty(), symbol2, planBuilder.values(symbol))));
        }).matches(PlanMatchPattern.values("a", "row_number"));
    }

    @Test
    public void testPredicateNotSatisfied() {
        tester().assertThat(new PushPredicateThroughProjectIntoRowNumber(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("row_number");
            return planBuilder.filter(PlanBuilder.expression("row_number > 2 AND row_number < 5"), planBuilder.project(Assignments.identity(new Symbol[]{symbol2}), planBuilder.rowNumber(ImmutableList.of(), Optional.empty(), symbol2, planBuilder.values(symbol))));
        }).matches(PlanMatchPattern.filter("row_number > 2 AND row_number < 5", PlanMatchPattern.project(ImmutableMap.of("row_number", PlanMatchPattern.expression("row_number")), PlanMatchPattern.rowNumber(builder -> {
            builder.maxRowCountPerPartition(Optional.of(4));
        }, PlanMatchPattern.values((List<String>) ImmutableList.of("a"))).withAlias("row_number", new RowNumberSymbolMatcher()))));
    }

    @Test
    public void testPredicateNotSatisfiedAndMaxRowCountNotUpdated() {
        tester().assertThat(new PushPredicateThroughProjectIntoRowNumber(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("row_number");
            return planBuilder.filter(PlanBuilder.expression("row_number > 2 AND row_number < 5"), planBuilder.project(Assignments.identity(new Symbol[]{symbol2}), planBuilder.rowNumber(ImmutableList.of(), Optional.of(3), symbol2, planBuilder.values(symbol))));
        }).doesNotFire();
    }

    @Test
    public void testPredicateSatisfied() {
        tester().assertThat(new PushPredicateThroughProjectIntoRowNumber(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("row_number");
            return planBuilder.filter(PlanBuilder.expression("row_number < 5"), planBuilder.project(Assignments.identity(new Symbol[]{symbol2}), planBuilder.rowNumber(ImmutableList.of(), Optional.of(3), symbol2, planBuilder.values(symbol))));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("row_number", PlanMatchPattern.expression("row_number")), PlanMatchPattern.rowNumber(builder -> {
            builder.maxRowCountPerPartition(Optional.of(3));
        }, PlanMatchPattern.values((List<String>) ImmutableList.of("a"))).withAlias("row_number", new RowNumberSymbolMatcher())));
        tester().assertThat(new PushPredicateThroughProjectIntoRowNumber(tester().getPlannerContext())).on(planBuilder2 -> {
            Symbol symbol = planBuilder2.symbol("a");
            Symbol symbol2 = planBuilder2.symbol("row_number");
            return planBuilder2.filter(PlanBuilder.expression("row_number < 3"), planBuilder2.project(Assignments.identity(new Symbol[]{symbol2}), planBuilder2.rowNumber(ImmutableList.of(), Optional.of(5), symbol2, planBuilder2.values(symbol))));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("row_number", PlanMatchPattern.expression("row_number")), PlanMatchPattern.rowNumber(builder2 -> {
            builder2.maxRowCountPerPartition(Optional.of(2));
        }, PlanMatchPattern.values((List<String>) ImmutableList.of("a"))).withAlias("row_number", new RowNumberSymbolMatcher())));
    }

    @Test
    public void testPredicatePartiallySatisfied() {
        tester().assertThat(new PushPredicateThroughProjectIntoRowNumber(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("row_number");
            return planBuilder.filter(PlanBuilder.expression("row_number < 5 AND a > 0"), planBuilder.project(Assignments.identity(new Symbol[]{symbol2, symbol}), planBuilder.rowNumber(ImmutableList.of(), Optional.of(3), symbol2, planBuilder.values(symbol))));
        }).matches(PlanMatchPattern.filter("a > 0", PlanMatchPattern.project(ImmutableMap.of("row_number", PlanMatchPattern.expression("row_number"), "a", PlanMatchPattern.expression("a")), PlanMatchPattern.rowNumber(builder -> {
            builder.maxRowCountPerPartition(Optional.of(3));
        }, PlanMatchPattern.values((List<String>) ImmutableList.of("a"))).withAlias("row_number", new RowNumberSymbolMatcher()))));
        tester().assertThat(new PushPredicateThroughProjectIntoRowNumber(tester().getPlannerContext())).on(planBuilder2 -> {
            Symbol symbol = planBuilder2.symbol("a");
            Symbol symbol2 = planBuilder2.symbol("row_number");
            return planBuilder2.filter(PlanBuilder.expression("row_number < 5 AND row_number % 2 = 0"), planBuilder2.project(Assignments.identity(new Symbol[]{symbol2}), planBuilder2.rowNumber(ImmutableList.of(), Optional.of(3), symbol2, planBuilder2.values(symbol))));
        }).matches(PlanMatchPattern.filter("row_number % 2 = 0", PlanMatchPattern.project(ImmutableMap.of("row_number", PlanMatchPattern.expression("row_number")), PlanMatchPattern.rowNumber(builder2 -> {
            builder2.maxRowCountPerPartition(Optional.of(3));
        }, PlanMatchPattern.values((List<String>) ImmutableList.of("a"))).withAlias("row_number", new RowNumberSymbolMatcher()))));
    }
}
