package io.trino.server.security.jwt;

import com.fasterxml.jackson.annotation.JsonAnySetter;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import java.math.BigInteger;
import java.security.PublicKey;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.ECParameterSpec;
import java.security.spec.ECPoint;
import java.util.Base64;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;

/* loaded from: input_file:io/trino/server/security/jwt/JwkDecoder.class */
public final class JwkDecoder {
    private static final Logger log = Logger.get(JwkDecoder.class);
    private static final JsonCodec<Keys> KEYS_CODEC = JsonCodec.jsonCodec(Keys.class);

    /* loaded from: input_file:io/trino/server/security/jwt/JwkDecoder$JwkEcPublicKey.class */
    public static class JwkEcPublicKey implements JwkPublicKey, ECPublicKey {
        private final String keyId;
        private final ECParameterSpec parameterSpec;
        private final ECPoint w;

        public JwkEcPublicKey(String str, ECParameterSpec eCParameterSpec, ECPoint eCPoint) {
            this.keyId = (String) Objects.requireNonNull(str, "keyId is null");
            this.parameterSpec = (ECParameterSpec) Objects.requireNonNull(eCParameterSpec, "parameterSpec is null");
            this.w = (ECPoint) Objects.requireNonNull(eCPoint, "w is null");
        }

        @Override // io.trino.server.security.jwt.JwkDecoder.JwkPublicKey
        public String getKeyId() {
            return this.keyId;
        }

        @Override // java.security.interfaces.ECKey
        public ECParameterSpec getParams() {
            return this.parameterSpec;
        }

        @Override // java.security.interfaces.ECPublicKey
        public ECPoint getW() {
            return this.w;
        }

        @Override // java.security.Key
        public String getAlgorithm() {
            return "EC";
        }

        @Override // java.security.Key
        public String getFormat() {
            return "JWK";
        }

        @Override // java.security.Key
        public byte[] getEncoded() {
            throw new UnsupportedOperationException();
        }
    }

    /* loaded from: input_file:io/trino/server/security/jwt/JwkDecoder$JwkPublicKey.class */
    public interface JwkPublicKey extends PublicKey {
        String getKeyId();
    }

    /* loaded from: input_file:io/trino/server/security/jwt/JwkDecoder$JwkRsaPublicKey.class */
    public static class JwkRsaPublicKey implements JwkPublicKey, RSAPublicKey {
        private final String keyId;
        private final BigInteger modulus;
        private final BigInteger exponent;

        public JwkRsaPublicKey(String str, BigInteger bigInteger, BigInteger bigInteger2) {
            this.keyId = (String) Objects.requireNonNull(str, "keyId is null");
            this.exponent = (BigInteger) Objects.requireNonNull(bigInteger, "exponent is null");
            this.modulus = (BigInteger) Objects.requireNonNull(bigInteger2, "modulus is null");
        }

        @Override // io.trino.server.security.jwt.JwkDecoder.JwkPublicKey
        public String getKeyId() {
            return this.keyId;
        }

        @Override // java.security.interfaces.RSAKey
        public BigInteger getModulus() {
            return this.modulus;
        }

        @Override // java.security.interfaces.RSAPublicKey
        public BigInteger getPublicExponent() {
            return this.exponent;
        }

        @Override // java.security.Key
        public String getAlgorithm() {
            return "RSA";
        }

        @Override // java.security.Key
        public String getFormat() {
            return "JWK";
        }

        @Override // java.security.Key
        public byte[] getEncoded() {
            throw new UnsupportedOperationException();
        }
    }

    /* loaded from: input_file:io/trino/server/security/jwt/JwkDecoder$Key.class */
    public static class Key {
        private final String kty;
        private final Optional<String> kid;
        private final Map<String, Object> other = new HashMap();

        @JsonCreator
        public Key(@JsonProperty("kty") String str, @JsonProperty("kid") Optional<String> optional) {
            this.kty = (String) Objects.requireNonNull(str, "kty is null");
            this.kid = (Optional) Objects.requireNonNull(optional, "kid is null");
        }

        public String getKty() {
            return this.kty;
        }

        public Optional<String> getKid() {
            return this.kid;
        }

        public Optional<String> getStringProperty(String str) {
            Object obj = this.other.get(str);
            return (!(obj instanceof String) || Strings.isNullOrEmpty((String) obj)) ? Optional.empty() : Optional.of((String) obj);
        }

        @JsonAnySetter
        public void set(String str, Object obj) {
            this.other.put(str, obj);
        }
    }

    /* loaded from: input_file:io/trino/server/security/jwt/JwkDecoder$Keys.class */
    public static class Keys {
        private final List<Key> keys;

        @JsonCreator
        public Keys(@JsonProperty("keys") List<Key> list) {
            this.keys = ImmutableList.copyOf((Collection) Objects.requireNonNull(list, "keys is null"));
        }

        public List<Key> getKeys() {
            return this.keys;
        }
    }

    private JwkDecoder() {
    }

    public static Map<String, PublicKey> decodeKeys(String str) {
        return (Map) ((Keys) KEYS_CODEC.fromJson(str)).getKeys().stream().map(JwkDecoder::tryDecodeJwkKey).filter((v0) -> {
            return v0.isPresent();
        }).map((v0) -> {
            return v0.get();
        }).collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getKeyId();
        }, Function.identity()));
    }

    public static Optional<? extends JwkPublicKey> tryDecodeJwkKey(Key key) {
        if (key.getKid().isEmpty() || key.getKid().get().isEmpty()) {
            return Optional.empty();
        }
        String str = key.getKid().get();
        String kty = key.getKty();
        boolean z = -1;
        switch (kty.hashCode()) {
            case 2206:
                if (kty.equals("EC")) {
                    z = true;
                    break;
                }
                break;
            case 81440:
                if (kty.equals("RSA")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return tryDecodeRsaKey(str, key);
            case true:
                return tryDecodeEcKey(str, key);
            default:
                return Optional.empty();
        }
    }

    public static Optional<JwkRsaPublicKey> tryDecodeRsaKey(String str, Key key) {
        Optional<U> flatMap = key.getStringProperty("n").flatMap(str2 -> {
            return decodeBigint(str, "modulus", str2);
        });
        if (flatMap.isEmpty()) {
            return Optional.empty();
        }
        Optional<U> flatMap2 = key.getStringProperty("e").flatMap(str3 -> {
            return decodeBigint(str, "exponent", str3);
        });
        return flatMap2.isEmpty() ? Optional.empty() : Optional.of(new JwkRsaPublicKey(str, (BigInteger) flatMap2.get(), (BigInteger) flatMap.get()));
    }

    public static Optional<JwkEcPublicKey> tryDecodeEcKey(String str, Key key) {
        Optional<String> stringProperty = key.getStringProperty("crv");
        Optional<U> flatMap = stringProperty.flatMap(EcCurve::tryGet);
        if (flatMap.isEmpty()) {
            log.error("JWK EC %s curve '%s' is not supported", new Object[]{str, stringProperty});
            return Optional.empty();
        }
        Optional<U> flatMap2 = key.getStringProperty("x").flatMap(str2 -> {
            return decodeBigint(str, "x", str2);
        });
        if (flatMap2.isEmpty()) {
            return Optional.empty();
        }
        Optional<U> flatMap3 = key.getStringProperty("y").flatMap(str3 -> {
            return decodeBigint(str, "y", str3);
        });
        return flatMap3.isEmpty() ? Optional.empty() : Optional.of(new JwkEcPublicKey(str, (ECParameterSpec) flatMap.get(), new ECPoint((BigInteger) flatMap2.get(), (BigInteger) flatMap3.get())));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Optional<BigInteger> decodeBigint(String str, String str2, String str3) {
        try {
            return Optional.of(new BigInteger(1, Base64.getUrlDecoder().decode(str3)));
        } catch (IllegalArgumentException e) {
            log.error(e, "JWK %s %s is not a valid number", new Object[]{str, str2});
            return Optional.empty();
        }
    }
}
