package io.trino.plugin.opa;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.common.net.MediaType;
import io.opentelemetry.api.OpenTelemetry;
import io.opentelemetry.api.trace.Tracer;
import io.trino.execution.QueryIdGenerator;
import io.trino.plugin.opa.HttpClientUtils;
import io.trino.plugin.opa.OpaQueryException;
import io.trino.spi.security.AccessDeniedException;
import io.trino.spi.security.Identity;
import io.trino.spi.security.SystemAccessControlFactory;
import io.trino.spi.security.SystemSecurityContext;
import java.net.URI;
import java.time.Instant;
import java.util.Arrays;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Stream;
import org.junit.jupiter.api.Named;
import org.junit.jupiter.params.provider.Arguments;

/* loaded from: input_file:io/trino/plugin/opa/TestHelpers.class */
public final class TestHelpers {
    public static final HttpClientUtils.MockResponse OK_RESPONSE = new HttpClientUtils.MockResponse("{\n    \"decision_id\": \"\",\n    \"result\": true\n}\n", 200);
    public static final HttpClientUtils.MockResponse NO_ACCESS_RESPONSE = new HttpClientUtils.MockResponse("{\n    \"decision_id\": \"\",\n    \"result\": false\n}\n", 200);
    public static final HttpClientUtils.MockResponse MALFORMED_RESPONSE = new HttpClientUtils.MockResponse("{ \"this\"\": is broken_json; }\n", 200);
    public static final HttpClientUtils.MockResponse UNDEFINED_RESPONSE = new HttpClientUtils.MockResponse("{}", 404);
    public static final HttpClientUtils.MockResponse BAD_REQUEST_RESPONSE = new HttpClientUtils.MockResponse("{}", 400);
    public static final HttpClientUtils.MockResponse SERVER_ERROR_RESPONSE = new HttpClientUtils.MockResponse("", 500);
    public static final SystemAccessControlFactory.SystemAccessControlContext SYSTEM_ACCESS_CONTROL_CONTEXT = new TestingSystemAccessControlContext("TEST_VERSION");

    /* loaded from: input_file:io/trino/plugin/opa/TestHelpers$MethodWrapper.class */
    public static abstract class MethodWrapper {
        public abstract boolean isAccessAllowed(OpaAccessControl opaAccessControl);
    }

    /* loaded from: input_file:io/trino/plugin/opa/TestHelpers$ReturningMethodWrapper.class */
    public static class ReturningMethodWrapper extends MethodWrapper {
        private final Function<OpaAccessControl, Boolean> callable;

        public ReturningMethodWrapper(Function<OpaAccessControl, Boolean> function) {
            this.callable = function;
        }

        @Override // io.trino.plugin.opa.TestHelpers.MethodWrapper
        public boolean isAccessAllowed(OpaAccessControl opaAccessControl) {
            return this.callable.apply(opaAccessControl).booleanValue();
        }
    }

    /* loaded from: input_file:io/trino/plugin/opa/TestHelpers$TestingSystemAccessControlContext.class */
    static final class TestingSystemAccessControlContext implements SystemAccessControlFactory.SystemAccessControlContext {
        private final String trinoVersion;

        public TestingSystemAccessControlContext(String str) {
            this.trinoVersion = str;
        }

        public String getVersion() {
            return this.trinoVersion;
        }

        public OpenTelemetry getOpenTelemetry() {
            return null;
        }

        public Tracer getTracer() {
            return null;
        }
    }

    /* loaded from: input_file:io/trino/plugin/opa/TestHelpers$ThrowingMethodWrapper.class */
    public static class ThrowingMethodWrapper extends MethodWrapper {
        private final Consumer<OpaAccessControl> callable;

        public ThrowingMethodWrapper(Consumer<OpaAccessControl> consumer) {
            this.callable = consumer;
        }

        @Override // io.trino.plugin.opa.TestHelpers.MethodWrapper
        public boolean isAccessAllowed(OpaAccessControl opaAccessControl) {
            try {
                this.callable.accept(opaAccessControl);
                return true;
            } catch (AccessDeniedException e) {
                if (e.getMessage().contains("Access Denied")) {
                    return false;
                }
                throw new AssertionError("Expected AccessDenied exception to contain 'Access Denied' in the message");
            }
        }
    }

    private TestHelpers() {
    }

    public static Stream<Arguments> createFailingTestCases(Stream<Arguments> stream) {
        return Sets.cartesianProduct(new Set[]{(Set) stream.collect(ImmutableSet.toImmutableSet()), (Set) allErrorCasesArgumentProvider().collect(ImmutableSet.toImmutableSet())}).stream().map(list -> {
            return Arguments.of(list.stream().flatMap(arguments -> {
                return Arrays.stream(arguments.get());
            }).toArray());
        });
    }

    public static Stream<Arguments> createIllegalResponseTestCases(Stream<Arguments> stream) {
        return Sets.cartesianProduct(new Set[]{(Set) stream.collect(ImmutableSet.toImmutableSet()), (Set) illegalResponseArgumentProvider().collect(ImmutableSet.toImmutableSet())}).stream().map(list -> {
            return Arguments.of(list.stream().flatMap(arguments -> {
                return Arrays.stream(arguments.get());
            }).toArray());
        });
    }

    public static Stream<Arguments> illegalResponseArgumentProvider() {
        return Stream.of((Object[]) new Arguments[]{Arguments.of(new Object[]{Named.of("Undefined policy response", UNDEFINED_RESPONSE), OpaQueryException.PolicyNotFound.class, "did not return a value"}), Arguments.of(new Object[]{Named.of("Bad request response", BAD_REQUEST_RESPONSE), OpaQueryException.OpaServerError.class, "returned status 400"}), Arguments.of(new Object[]{Named.of("Server error response", SERVER_ERROR_RESPONSE), OpaQueryException.OpaServerError.class, "returned status 500"}), Arguments.of(new Object[]{Named.of("Malformed JSON response", MALFORMED_RESPONSE), OpaQueryException.class, "Failed to deserialize"})});
    }

    public static Stream<Arguments> allErrorCasesArgumentProvider() {
        return Stream.concat(illegalResponseArgumentProvider(), Stream.of(Arguments.of(new Object[]{Named.of("No access response", NO_ACCESS_RESPONSE), AccessDeniedException.class, "Access Denied"})));
    }

    public static SystemSecurityContext systemSecurityContextFromIdentity(Identity identity) {
        return new SystemSecurityContext(identity, new QueryIdGenerator().createNextQueryId(), Instant.now());
    }

    public static HttpClientUtils.InstrumentedHttpClient createMockHttpClient(URI uri, Function<JsonNode, HttpClientUtils.MockResponse> function) {
        return new HttpClientUtils.InstrumentedHttpClient(uri, "POST", MediaType.JSON_UTF_8.toString(), function);
    }

    public static OpaAccessControl createOpaAuthorizer(URI uri, HttpClientUtils.InstrumentedHttpClient instrumentedHttpClient) {
        return OpaAccessControlFactory.create(ImmutableMap.of("opa.policy.uri", uri.toString()), Optional.of(instrumentedHttpClient), Optional.of(SYSTEM_ACCESS_CONTROL_CONTEXT));
    }

    public static OpaAccessControl createOpaAuthorizer(URI uri, URI uri2, HttpClientUtils.InstrumentedHttpClient instrumentedHttpClient) {
        return OpaAccessControlFactory.create(ImmutableMap.builder().put("opa.policy.uri", uri.toString()).put("opa.policy.batched-uri", uri2.toString()).buildOrThrow(), Optional.of(instrumentedHttpClient), Optional.of(SYSTEM_ACCESS_CONTROL_CONTEXT));
    }
}
