package io.prestosql.server.security;

import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.io.ByteStreams;
import com.google.common.net.MediaType;
import io.prestosql.server.HttpRequestSessionContext;
import io.prestosql.server.InternalAuthenticationManager;
import io.prestosql.server.ui.WebUiAuthenticationManager;
import io.prestosql.spi.security.Identity;
import java.io.IOException;
import java.io.PrintWriter;
import java.security.Principal;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import javax.inject.Inject;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;

/* loaded from: input_file:io/prestosql/server/security/AuthenticationFilter.class */
public class AuthenticationFilter implements Filter {
    private static final String HTTPS_PROTOCOL = "https";
    private final List<Authenticator> authenticators;
    private final boolean httpsForwardingEnabled;
    private final InternalAuthenticationManager internalAuthenticationManager;
    private final WebUiAuthenticationManager uiAuthenticationManager;

    @Inject
    public AuthenticationFilter(List<Authenticator> list, SecurityConfig securityConfig, InternalAuthenticationManager internalAuthenticationManager, WebUiAuthenticationManager webUiAuthenticationManager) {
        this.authenticators = ImmutableList.copyOf((Collection) Objects.requireNonNull(list, "authenticators is null"));
        this.httpsForwardingEnabled = ((SecurityConfig) Objects.requireNonNull(securityConfig, "securityConfig is null")).getEnableForwardingHttps();
        this.internalAuthenticationManager = (InternalAuthenticationManager) Objects.requireNonNull(internalAuthenticationManager, "internalAuthenticationManager is null");
        this.uiAuthenticationManager = (WebUiAuthenticationManager) Objects.requireNonNull(webUiAuthenticationManager, "uiAuthenticationManager is null");
    }

    public void init(FilterConfig filterConfig) {
    }

    public void destroy() {
    }

    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
        HttpServletResponse httpServletResponse = (HttpServletResponse) servletResponse;
        if (this.internalAuthenticationManager.isInternalRequest(httpServletRequest)) {
            Principal authenticateInternalRequest = this.internalAuthenticationManager.authenticateInternalRequest(httpServletRequest);
            if (authenticateInternalRequest != null) {
                withAuthenticatedIdentity(filterChain, httpServletRequest, httpServletResponse, Identity.forUser("<internal>").withPrincipal(authenticateInternalRequest).build());
                return;
            } else {
                httpServletResponse.setStatus(401);
                httpServletResponse.setContentType(MediaType.PLAIN_TEXT_UTF_8.toString());
                return;
            }
        }
        if (WebUiAuthenticationManager.isUiRequest(httpServletRequest)) {
            this.uiAuthenticationManager.handleUiRequest(httpServletRequest, httpServletResponse, filterChain);
            return;
        }
        if (!doesRequestSupportAuthentication(httpServletRequest)) {
            handleInsecureRequest(filterChain, httpServletRequest, httpServletResponse);
            return;
        }
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        LinkedHashSet linkedHashSet2 = new LinkedHashSet();
        Iterator<Authenticator> it = this.authenticators.iterator();
        while (it.hasNext()) {
            try {
                withAuthenticatedIdentity(filterChain, httpServletRequest, httpServletResponse, it.next().authenticate(httpServletRequest));
                return;
            } catch (AuthenticationException e) {
                if (e.getMessage() != null) {
                    linkedHashSet.add(e.getMessage());
                }
                Optional<String> authenticateHeader = e.getAuthenticateHeader();
                linkedHashSet2.getClass();
                authenticateHeader.ifPresent((v1) -> {
                    r1.add(v1);
                });
            }
        }
        skipRequestBody(httpServletRequest);
        Iterator it2 = linkedHashSet2.iterator();
        while (it2.hasNext()) {
            httpServletResponse.addHeader("WWW-Authenticate", (String) it2.next());
        }
        if (linkedHashSet.isEmpty()) {
            linkedHashSet.add("Unauthorized");
        }
        sendErrorMessage(httpServletResponse, 401, Joiner.on(" | ").join(linkedHashSet));
    }

    private static void sendErrorMessage(HttpServletResponse httpServletResponse, int i, String str) throws IOException {
        httpServletResponse.setStatus(i, str);
        httpServletResponse.setContentType(MediaType.PLAIN_TEXT_UTF_8.toString());
        PrintWriter writer = httpServletResponse.getWriter();
        Throwable th = null;
        try {
            try {
                writer.write(str);
                if (writer != null) {
                    if (0 == 0) {
                        writer.close();
                        return;
                    }
                    try {
                        writer.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (writer != null) {
                if (th != null) {
                    try {
                        writer.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    writer.close();
                }
            }
            throw th4;
        }
    }

    private static void handleInsecureRequest(FilterChain filterChain, HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws IOException, ServletException {
        try {
            Optional<BasicAuthCredentials> extractBasicAuthCredentials = BasicAuthCredentials.extractBasicAuthCredentials(httpServletRequest);
            if (!extractBasicAuthCredentials.isPresent()) {
                filterChain.doFilter(httpServletRequest, httpServletResponse);
            } else if (extractBasicAuthCredentials.get().getPassword().isPresent()) {
                sendErrorMessage(httpServletResponse, 403, "Password not allowed for insecure request");
            } else {
                withAuthenticatedIdentity(filterChain, httpServletRequest, httpServletResponse, Identity.ofUser(extractBasicAuthCredentials.get().getUser()));
            }
        } catch (AuthenticationException e) {
            sendErrorMessage(httpServletResponse, 403, e.getMessage());
        }
    }

    private boolean doesRequestSupportAuthentication(HttpServletRequest httpServletRequest) {
        if (this.authenticators.isEmpty()) {
            return false;
        }
        if (httpServletRequest.isSecure()) {
            return true;
        }
        return this.httpsForwardingEnabled && Strings.nullToEmpty(httpServletRequest.getHeader("X-Forwarded-Proto")).equalsIgnoreCase(HTTPS_PROTOCOL);
    }

    private static void withAuthenticatedIdentity(FilterChain filterChain, HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Identity identity) throws IOException, ServletException {
        httpServletRequest.setAttribute(HttpRequestSessionContext.AUTHENTICATED_IDENTITY, identity);
        try {
            filterChain.doFilter(withPrincipal(httpServletRequest, identity.getPrincipal()), httpServletResponse);
            Optional ofNullable = Optional.ofNullable(httpServletRequest.getAttribute(HttpRequestSessionContext.AUTHENTICATED_IDENTITY));
            Class<Identity> cls = Identity.class;
            Identity.class.getClass();
            ofNullable.map(cls::cast).ifPresent((v0) -> {
                v0.destroy();
            });
        } catch (Throwable th) {
            Optional ofNullable2 = Optional.ofNullable(httpServletRequest.getAttribute(HttpRequestSessionContext.AUTHENTICATED_IDENTITY));
            Class<Identity> cls2 = Identity.class;
            Identity.class.getClass();
            ofNullable2.map(cls2::cast).ifPresent((v0) -> {
                v0.destroy();
            });
            throw th;
        }
    }

    private static ServletRequest withPrincipal(HttpServletRequest httpServletRequest, final Optional<Principal> optional) {
        Objects.requireNonNull(optional, "principal is null");
        return !optional.isPresent() ? httpServletRequest : new HttpServletRequestWrapper(httpServletRequest) { // from class: io.prestosql.server.security.AuthenticationFilter.1
            public Principal getUserPrincipal() {
                return (Principal) optional.get();
            }
        };
    }

    private static void skipRequestBody(HttpServletRequest httpServletRequest) throws IOException {
        ServletInputStream inputStream = httpServletRequest.getInputStream();
        Throwable th = null;
        try {
            ByteStreams.copy(inputStream, ByteStreams.nullOutputStream());
            if (inputStream != null) {
                if (0 == 0) {
                    inputStream.close();
                    return;
                }
                try {
                    inputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (inputStream != null) {
                if (0 != 0) {
                    try {
                        inputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    inputStream.close();
                }
            }
            throw th3;
        }
    }
}
