package io.prestosql.plugin.geospatial;

import com.google.common.collect.ImmutableMap;
import io.prestosql.spi.Plugin;
import io.prestosql.spi.type.VarcharType;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.assertions.PlanMatchPattern;
import io.prestosql.sql.planner.iterative.rule.ExtractSpatialJoins;
import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest;
import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder;
import io.prestosql.sql.planner.iterative.rule.test.RuleAssert;
import io.prestosql.sql.planner.iterative.rule.test.RuleTester;
import io.prestosql.sql.planner.plan.JoinNode;
import org.testng.annotations.Test;

/* loaded from: input_file:io/prestosql/plugin/geospatial/TestExtractSpatialInnerJoin.class */
public class TestExtractSpatialInnerJoin extends BaseRuleTest {
    public TestExtractSpatialInnerJoin() {
        super(new Plugin[]{new GeoPlugin()});
    }

    @Test
    public void testDoesNotFire() {
        assertRuleApplication().on(planBuilder -> {
            return planBuilder.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText('POLYGON ...'), b)"), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(new Symbol[0]), planBuilder.values(new Symbol[]{planBuilder.symbol("b")}), new JoinNode.EquiJoinClause[0]));
        }).doesNotFire();
        assertRuleApplication().on(planBuilder2 -> {
            return planBuilder2.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText(wkt), point) OR name_1 != name_2"), planBuilder2.join(JoinNode.Type.INNER, planBuilder2.values(new Symbol[]{planBuilder2.symbol("wkt", VarcharType.VARCHAR), planBuilder2.symbol("name_1")}), planBuilder2.values(new Symbol[]{planBuilder2.symbol("point", GeometryType.GEOMETRY), planBuilder2.symbol("name_2")}), new JoinNode.EquiJoinClause[0]));
        }).doesNotFire();
        assertRuleApplication().on(planBuilder3 -> {
            return planBuilder3.filter(PlanBuilder.expression("NOT ST_Contains(ST_GeometryFromText(wkt), point)"), planBuilder3.join(JoinNode.Type.INNER, planBuilder3.values(new Symbol[]{planBuilder3.symbol("wkt", VarcharType.VARCHAR), planBuilder3.symbol("name_1")}), planBuilder3.values(new Symbol[]{planBuilder3.symbol("point", GeometryType.GEOMETRY), planBuilder3.symbol("name_2")}), new JoinNode.EquiJoinClause[0]));
        }).doesNotFire();
        assertRuleApplication().on(planBuilder4 -> {
            return planBuilder4.filter(PlanBuilder.expression("ST_Distance(a, b) > 5"), planBuilder4.join(JoinNode.Type.INNER, planBuilder4.values(new Symbol[]{planBuilder4.symbol("a", GeometryType.GEOMETRY)}), planBuilder4.values(new Symbol[]{planBuilder4.symbol("b", GeometryType.GEOMETRY)}), new JoinNode.EquiJoinClause[0]));
        }).doesNotFire();
        assertRuleApplication().on(planBuilder5 -> {
            return planBuilder5.filter(PlanBuilder.expression("ST_Distance(a, b) < 5"), planBuilder5.join(JoinNode.Type.INNER, planBuilder5.values(new Symbol[]{planBuilder5.symbol("a", SphericalGeographyType.SPHERICAL_GEOGRAPHY)}), planBuilder5.values(new Symbol[]{planBuilder5.symbol("b", SphericalGeographyType.SPHERICAL_GEOGRAPHY)}), new JoinNode.EquiJoinClause[0]));
        }).doesNotFire();
        assertRuleApplication().on(planBuilder6 -> {
            return planBuilder6.filter(PlanBuilder.expression("ST_Contains(polygon, point)"), planBuilder6.join(JoinNode.Type.INNER, planBuilder6.values(new Symbol[]{planBuilder6.symbol("polygon", SphericalGeographyType.SPHERICAL_GEOGRAPHY)}), planBuilder6.values(new Symbol[]{planBuilder6.symbol("point", SphericalGeographyType.SPHERICAL_GEOGRAPHY)}), new JoinNode.EquiJoinClause[0]));
        }).doesNotFire();
        assertRuleApplication().on(planBuilder7 -> {
            return planBuilder7.filter(PlanBuilder.expression("ST_Distance(to_spherical_geography(ST_GeometryFromText(wkt)), point) < 5"), planBuilder7.join(JoinNode.Type.INNER, planBuilder7.values(new Symbol[]{planBuilder7.symbol("wkt", VarcharType.VARCHAR)}), planBuilder7.values(new Symbol[]{planBuilder7.symbol("point", SphericalGeographyType.SPHERICAL_GEOGRAPHY)}), new JoinNode.EquiJoinClause[0]));
        }).doesNotFire();
        assertRuleApplication().on(planBuilder8 -> {
            return planBuilder8.filter(PlanBuilder.expression("ST_Contains(to_spherical_geography(ST_GeometryFromText(wkt)), point)"), planBuilder8.join(JoinNode.Type.INNER, planBuilder8.values(new Symbol[]{planBuilder8.symbol("wkt", VarcharType.VARCHAR)}), planBuilder8.values(new Symbol[]{planBuilder8.symbol("point", SphericalGeographyType.SPHERICAL_GEOGRAPHY)}), new JoinNode.EquiJoinClause[0]));
        }).doesNotFire();
    }

    @Test
    public void testDistanceQueries() {
        testSimpleDistanceQuery("ST_Distance(a, b) <= r", "ST_Distance(a, b) <= r");
        testSimpleDistanceQuery("ST_Distance(b, a) <= r", "ST_Distance(b, a) <= r");
        testSimpleDistanceQuery("r >= ST_Distance(a, b)", "ST_Distance(a, b) <= r");
        testSimpleDistanceQuery("r >= ST_Distance(b, a)", "ST_Distance(b, a) <= r");
        testSimpleDistanceQuery("ST_Distance(a, b) < r", "ST_Distance(a, b) < r");
        testSimpleDistanceQuery("ST_Distance(b, a) < r", "ST_Distance(b, a) < r");
        testSimpleDistanceQuery("r > ST_Distance(a, b)", "ST_Distance(a, b) < r");
        testSimpleDistanceQuery("r > ST_Distance(b, a)", "ST_Distance(b, a) < r");
        testSimpleDistanceQuery("ST_Distance(a, b) <= r AND name_a != name_b", "ST_Distance(a, b) <= r AND name_a != name_b");
        testSimpleDistanceQuery("r > ST_Distance(a, b) AND name_a != name_b", "ST_Distance(a, b) < r AND name_a != name_b");
        testRadiusExpressionInDistanceQuery("ST_Distance(a, b) <= decimal '1.2'", "ST_Distance(a, b) <= radius", "decimal '1.2'");
        testRadiusExpressionInDistanceQuery("ST_Distance(b, a) <= decimal '1.2'", "ST_Distance(b, a) <= radius", "decimal '1.2'");
        testRadiusExpressionInDistanceQuery("decimal '1.2' >= ST_Distance(a, b)", "ST_Distance(a, b) <= radius", "decimal '1.2'");
        testRadiusExpressionInDistanceQuery("decimal '1.2' >= ST_Distance(b, a)", "ST_Distance(b, a) <= radius", "decimal '1.2'");
        testRadiusExpressionInDistanceQuery("ST_Distance(a, b) < decimal '1.2'", "ST_Distance(a, b) < radius", "decimal '1.2'");
        testRadiusExpressionInDistanceQuery("ST_Distance(b, a) < decimal '1.2'", "ST_Distance(b, a) < radius", "decimal '1.2'");
        testRadiusExpressionInDistanceQuery("decimal '1.2' > ST_Distance(a, b)", "ST_Distance(a, b) < radius", "decimal '1.2'");
        testRadiusExpressionInDistanceQuery("decimal '1.2' > ST_Distance(b, a)", "ST_Distance(b, a) < radius", "decimal '1.2'");
        testRadiusExpressionInDistanceQuery("ST_Distance(a, b) <= decimal '1.2' AND name_a != name_b", "ST_Distance(a, b) <= radius AND name_a != name_b", "decimal '1.2'");
        testRadiusExpressionInDistanceQuery("decimal '1.2' > ST_Distance(a, b) AND name_a != name_b", "ST_Distance(a, b) < radius AND name_a != name_b", "decimal '1.2'");
        testRadiusExpressionInDistanceQuery("ST_Distance(a, b) <= 2 * r", "ST_Distance(a, b) <= radius", "2 * r");
        testRadiusExpressionInDistanceQuery("ST_Distance(b, a) <= 2 * r", "ST_Distance(b, a) <= radius", "2 * r");
        testRadiusExpressionInDistanceQuery("2 * r >= ST_Distance(a, b)", "ST_Distance(a, b) <= radius", "2 * r");
        testRadiusExpressionInDistanceQuery("2 * r >= ST_Distance(b, a)", "ST_Distance(b, a) <= radius", "2 * r");
        testRadiusExpressionInDistanceQuery("ST_Distance(a, b) < 2 * r", "ST_Distance(a, b) < radius", "2 * r");
        testRadiusExpressionInDistanceQuery("ST_Distance(b, a) < 2 * r", "ST_Distance(b, a) < radius", "2 * r");
        testRadiusExpressionInDistanceQuery("2 * r > ST_Distance(a, b)", "ST_Distance(a, b) < radius", "2 * r");
        testRadiusExpressionInDistanceQuery("2 * r > ST_Distance(b, a)", "ST_Distance(b, a) < radius", "2 * r");
        testRadiusExpressionInDistanceQuery("ST_Distance(a, b) <= 2 * r AND name_a != name_b", "ST_Distance(a, b) <= radius AND name_a != name_b", "2 * r");
        testRadiusExpressionInDistanceQuery("2 * r > ST_Distance(a, b) AND name_a != name_b", "ST_Distance(a, b) < radius AND name_a != name_b", "2 * r");
        testPointExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) <= 5", "ST_Distance(point_a, point_b) <= radius", "5");
        testPointExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a)) <= 5", "ST_Distance(point_b, point_a) <= radius", "5");
        testPointExpressionsInDistanceQuery("5 >= ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b))", "ST_Distance(point_a, point_b) <= radius", "5");
        testPointExpressionsInDistanceQuery("5 >= ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a))", "ST_Distance(point_b, point_a) <= radius", "5");
        testPointExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) < 5", "ST_Distance(point_a, point_b) < radius", "5");
        testPointExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a)) < 5", "ST_Distance(point_b, point_a) < radius", "5");
        testPointExpressionsInDistanceQuery("5 > ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b))", "ST_Distance(point_a, point_b) < radius", "5");
        testPointExpressionsInDistanceQuery("5 > ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a))", "ST_Distance(point_b, point_a) < radius", "5");
        testPointExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) <= 5 AND name_a != name_b", "ST_Distance(point_a, point_b) <= radius AND name_a != name_b", "5");
        testPointExpressionsInDistanceQuery("5 > ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) AND name_a != name_b", "ST_Distance(point_a, point_b) < radius AND name_a != name_b", "5");
        testPointAndRadiusExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) <= 500 / (111000 * cos(lat_b))", "ST_Distance(point_a, point_b) <= radius", "500 / (111000 * cos(lat_b))");
        testPointAndRadiusExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a)) <= 500 / (111000 * cos(lat_b))", "ST_Distance(point_b, point_a) <= radius", "500 / (111000 * cos(lat_b))");
        testPointAndRadiusExpressionsInDistanceQuery("500 / (111000 * cos(lat_b)) >= ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b))", "ST_Distance(point_a, point_b) <= radius", "500 / (111000 * cos(lat_b))");
        testPointAndRadiusExpressionsInDistanceQuery("500 / (111000 * cos(lat_b)) >= ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a))", "ST_Distance(point_b, point_a) <= radius", "500 / (111000 * cos(lat_b))");
        testPointAndRadiusExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) < 500 / (111000 * cos(lat_b))", "ST_Distance(point_a, point_b) < radius", "500 / (111000 * cos(lat_b))");
        testPointAndRadiusExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a)) < 500 / (111000 * cos(lat_b))", "ST_Distance(point_b, point_a) < radius", "500 / (111000 * cos(lat_b))");
        testPointAndRadiusExpressionsInDistanceQuery("500 / (111000 * cos(lat_b)) > ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b))", "ST_Distance(point_a, point_b) < radius", "500 / (111000 * cos(lat_b))");
        testPointAndRadiusExpressionsInDistanceQuery("500 / (111000 * cos(lat_b)) > ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a))", "ST_Distance(point_b, point_a) < radius", "500 / (111000 * cos(lat_b))");
        testPointAndRadiusExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) <= 500 / (111000 * cos(lat_b)) AND name_a != name_b", "ST_Distance(point_a, point_b) <= radius AND name_a != name_b", "500 / (111000 * cos(lat_b))");
        testPointAndRadiusExpressionsInDistanceQuery("500 / (111000 * cos(lat_b)) > ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) AND name_a != name_b", "ST_Distance(point_a, point_b) < radius AND name_a != name_b", "500 / (111000 * cos(lat_b))");
    }

    private void testSimpleDistanceQuery(String str, String str2) {
        assertRuleApplication().on(planBuilder -> {
            return planBuilder.filter(PlanBuilder.expression(str), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(new Symbol[]{planBuilder.symbol("a", GeometryType.GEOMETRY), planBuilder.symbol("name_a")}), planBuilder.values(new Symbol[]{planBuilder.symbol("b", GeometryType.GEOMETRY), planBuilder.symbol("name_b"), planBuilder.symbol("r")}), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.spatialJoin(str2, PlanMatchPattern.values(ImmutableMap.of("a", 0, "name_a", 1)), PlanMatchPattern.values(ImmutableMap.of("b", 0, "name_b", 1, "r", 2))));
    }

    private void testRadiusExpressionInDistanceQuery(String str, String str2, String str3) {
        assertRuleApplication().on(planBuilder -> {
            return planBuilder.filter(PlanBuilder.expression(str), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(new Symbol[]{planBuilder.symbol("a", GeometryType.GEOMETRY), planBuilder.symbol("name_a")}), planBuilder.values(new Symbol[]{planBuilder.symbol("b", GeometryType.GEOMETRY), planBuilder.symbol("name_b"), planBuilder.symbol("r")}), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.spatialJoin(str2, PlanMatchPattern.values(ImmutableMap.of("a", 0, "name_a", 1)), PlanMatchPattern.project(ImmutableMap.of("radius", PlanMatchPattern.expression(str3)), PlanMatchPattern.values(ImmutableMap.of("b", 0, "name_b", 1, "r", 2)))));
    }

    private void testPointExpressionsInDistanceQuery(String str, String str2, String str3) {
        assertRuleApplication().on(planBuilder -> {
            return planBuilder.filter(PlanBuilder.expression(str), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(new Symbol[]{planBuilder.symbol("lat_a"), planBuilder.symbol("lng_a"), planBuilder.symbol("name_a")}), planBuilder.values(new Symbol[]{planBuilder.symbol("lat_b"), planBuilder.symbol("lng_b"), planBuilder.symbol("name_b")}), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.spatialJoin(str2, PlanMatchPattern.project(ImmutableMap.of("point_a", PlanMatchPattern.expression("ST_Point(lng_a, lat_a)")), PlanMatchPattern.values(ImmutableMap.of("lat_a", 0, "lng_a", 1, "name_a", 2))), PlanMatchPattern.project(ImmutableMap.of("point_b", PlanMatchPattern.expression("ST_Point(lng_b, lat_b)")), PlanMatchPattern.project(ImmutableMap.of("radius", PlanMatchPattern.expression(str3)), PlanMatchPattern.values(ImmutableMap.of("lat_b", 0, "lng_b", 1, "name_b", 2))))));
    }

    private void testPointAndRadiusExpressionsInDistanceQuery(String str, String str2, String str3) {
        assertRuleApplication().on(planBuilder -> {
            return planBuilder.filter(PlanBuilder.expression(str), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(new Symbol[]{planBuilder.symbol("lat_a"), planBuilder.symbol("lng_a"), planBuilder.symbol("name_a")}), planBuilder.values(new Symbol[]{planBuilder.symbol("lat_b"), planBuilder.symbol("lng_b"), planBuilder.symbol("name_b")}), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.spatialJoin(str2, PlanMatchPattern.project(ImmutableMap.of("point_a", PlanMatchPattern.expression("ST_Point(lng_a, lat_a)")), PlanMatchPattern.values(ImmutableMap.of("lat_a", 0, "lng_a", 1, "name_a", 2))), PlanMatchPattern.project(ImmutableMap.of("point_b", PlanMatchPattern.expression("ST_Point(lng_b, lat_b)")), PlanMatchPattern.project(ImmutableMap.of("radius", PlanMatchPattern.expression(str3)), PlanMatchPattern.values(ImmutableMap.of("lat_b", 0, "lng_b", 1, "name_b", 2))))));
    }

    @Test
    public void testConvertToSpatialJoin() {
        assertRuleApplication().on(planBuilder -> {
            return planBuilder.filter(PlanBuilder.expression("ST_Contains(a, b)"), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(new Symbol[]{planBuilder.symbol("a")}), planBuilder.values(new Symbol[]{planBuilder.symbol("b")}), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.spatialJoin("ST_Contains(a, b)", PlanMatchPattern.values(ImmutableMap.of("a", 0)), PlanMatchPattern.values(ImmutableMap.of("b", 0))));
        assertRuleApplication().on(planBuilder2 -> {
            return planBuilder2.filter(PlanBuilder.expression("name_1 != name_2 AND ST_Contains(a, b)"), planBuilder2.join(JoinNode.Type.INNER, planBuilder2.values(new Symbol[]{planBuilder2.symbol("a"), planBuilder2.symbol("name_1")}), planBuilder2.values(new Symbol[]{planBuilder2.symbol("b"), planBuilder2.symbol("name_2")}), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.spatialJoin("name_1 != name_2 AND ST_Contains(a, b)", PlanMatchPattern.values(ImmutableMap.of("a", 0, "name_1", 1)), PlanMatchPattern.values(ImmutableMap.of("b", 0, "name_2", 1))));
        assertRuleApplication().on(planBuilder3 -> {
            return planBuilder3.filter(PlanBuilder.expression("ST_Contains(a1, b1) AND ST_Contains(a2, b2)"), planBuilder3.join(JoinNode.Type.INNER, planBuilder3.values(new Symbol[]{planBuilder3.symbol("a1"), planBuilder3.symbol("a2")}), planBuilder3.values(new Symbol[]{planBuilder3.symbol("b1"), planBuilder3.symbol("b2")}), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.spatialJoin("ST_Contains(a1, b1) AND ST_Contains(a2, b2)", PlanMatchPattern.values(ImmutableMap.of("a1", 0, "a2", 1)), PlanMatchPattern.values(ImmutableMap.of("b1", 0, "b2", 1))));
    }

    @Test
    public void testPushDownFirstArgument() {
        assertRuleApplication().on(planBuilder -> {
            return planBuilder.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText(wkt), point)"), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(new Symbol[]{planBuilder.symbol("wkt", VarcharType.VARCHAR)}), planBuilder.values(new Symbol[]{planBuilder.symbol("point", GeometryType.GEOMETRY)}), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.spatialJoin("ST_Contains(st_geometryfromtext, point)", PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression("ST_GeometryFromText(wkt)")), PlanMatchPattern.values(ImmutableMap.of("wkt", 0))), PlanMatchPattern.values(ImmutableMap.of("point", 0))));
        assertRuleApplication().on(planBuilder2 -> {
            return planBuilder2.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText(wkt), ST_Point(0, 0))"), planBuilder2.join(JoinNode.Type.INNER, planBuilder2.values(new Symbol[]{planBuilder2.symbol("wkt", VarcharType.VARCHAR)}), planBuilder2.values(new Symbol[0]), new JoinNode.EquiJoinClause[0]));
        }).doesNotFire();
    }

    @Test
    public void testPushDownSecondArgument() {
        assertRuleApplication().on(planBuilder -> {
            return planBuilder.filter(PlanBuilder.expression("ST_Contains(polygon, ST_Point(lng, lat))"), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(new Symbol[]{planBuilder.symbol("polygon", GeometryType.GEOMETRY)}), planBuilder.values(new Symbol[]{planBuilder.symbol("lat"), planBuilder.symbol("lng")}), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.spatialJoin("ST_Contains(polygon, st_point)", PlanMatchPattern.values(ImmutableMap.of("polygon", 0)), PlanMatchPattern.project(ImmutableMap.of("st_point", PlanMatchPattern.expression("ST_Point(lng, lat)")), PlanMatchPattern.values(ImmutableMap.of("lat", 0, "lng", 1)))));
        assertRuleApplication().on(planBuilder2 -> {
            return planBuilder2.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText('POLYGON ...'), ST_Point(lng, lat))"), planBuilder2.join(JoinNode.Type.INNER, planBuilder2.values(new Symbol[0]), planBuilder2.values(new Symbol[]{planBuilder2.symbol("lat"), planBuilder2.symbol("lng")}), new JoinNode.EquiJoinClause[0]));
        }).doesNotFire();
    }

    @Test
    public void testPushDownBothArguments() {
        assertRuleApplication().on(planBuilder -> {
            return planBuilder.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))"), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(new Symbol[]{planBuilder.symbol("wkt", VarcharType.VARCHAR)}), planBuilder.values(new Symbol[]{planBuilder.symbol("lat"), planBuilder.symbol("lng")}), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.spatialJoin("ST_Contains(st_geometryfromtext, st_point)", PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression("ST_GeometryFromText(wkt)")), PlanMatchPattern.values(ImmutableMap.of("wkt", 0))), PlanMatchPattern.project(ImmutableMap.of("st_point", PlanMatchPattern.expression("ST_Point(lng, lat)")), PlanMatchPattern.values(ImmutableMap.of("lat", 0, "lng", 1)))));
    }

    @Test
    public void testPushDownOppositeOrder() {
        assertRuleApplication().on(planBuilder -> {
            return planBuilder.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))"), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(new Symbol[]{planBuilder.symbol("lat"), planBuilder.symbol("lng")}), planBuilder.values(new Symbol[]{planBuilder.symbol("wkt", VarcharType.VARCHAR)}), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.spatialJoin("ST_Contains(st_geometryfromtext, st_point)", PlanMatchPattern.project(ImmutableMap.of("st_point", PlanMatchPattern.expression("ST_Point(lng, lat)")), PlanMatchPattern.values(ImmutableMap.of("lat", 0, "lng", 1))), PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression("ST_GeometryFromText(wkt)")), PlanMatchPattern.values(ImmutableMap.of("wkt", 0)))));
    }

    @Test
    public void testPushDownAnd() {
        assertRuleApplication().on(planBuilder -> {
            return planBuilder.filter(PlanBuilder.expression("name_1 != name_2 AND ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))"), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(new Symbol[]{planBuilder.symbol("wkt", VarcharType.VARCHAR), planBuilder.symbol("name_1")}), planBuilder.values(new Symbol[]{planBuilder.symbol("lat"), planBuilder.symbol("lng"), planBuilder.symbol("name_2")}), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.spatialJoin("name_1 != name_2 AND ST_Contains(st_geometryfromtext, st_point)", PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression("ST_GeometryFromText(wkt)")), PlanMatchPattern.values(ImmutableMap.of("wkt", 0, "name_1", 1))), PlanMatchPattern.project(ImmutableMap.of("st_point", PlanMatchPattern.expression("ST_Point(lng, lat)")), PlanMatchPattern.values(ImmutableMap.of("lat", 0, "lng", 1, "name_2", 2)))));
        assertRuleApplication().on(planBuilder2 -> {
            return planBuilder2.filter(PlanBuilder.expression("ST_Contains(ST_GeometryFromText(wkt1), geometry1) AND ST_Contains(ST_GeometryFromText(wkt2), geometry2)"), planBuilder2.join(JoinNode.Type.INNER, planBuilder2.values(new Symbol[]{planBuilder2.symbol("wkt1", VarcharType.VARCHAR), planBuilder2.symbol("wkt2", VarcharType.VARCHAR)}), planBuilder2.values(new Symbol[]{planBuilder2.symbol("geometry1"), planBuilder2.symbol("geometry2")}), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.spatialJoin("ST_Contains(st_geometryfromtext, geometry1) AND ST_Contains(ST_GeometryFromText(wkt2), geometry2)", PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression("ST_GeometryFromText(wkt1)")), PlanMatchPattern.values(ImmutableMap.of("wkt1", 0, "wkt2", 1))), PlanMatchPattern.values(ImmutableMap.of("geometry1", 0, "geometry2", 1))));
    }

    private RuleAssert assertRuleApplication() {
        RuleTester tester = tester();
        return tester.assertThat(new ExtractSpatialJoins.ExtractSpatialInnerJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager(), tester.getTypeAnalyzer()));
    }
}
