package com.linkedin.alpini.netty4.handlers;

import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Function;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

@ChannelHandler.Sharable
/* loaded from: input_file:com/linkedin/alpini/netty4/handlers/ClientConnectionTracker.class */
public class ClientConnectionTracker extends ChannelDuplexHandler {
    private final int _connectionCountLimit;
    private final int _connectionWarningResetThreshold;
    private boolean _shouldLogMetricForSiRootCausingTheirOwnClientConnectionSurge;
    private final Map<InetAddress, ConnectionStats> _statsMap = new ConcurrentHashMap();
    private static final Logger LOG = LogManager.getLogger((Class<?>) ClientConnectionTracker.class);

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/linkedin/alpini/netty4/handlers/ClientConnectionTracker$ConnectionStats.class */
    public static class ConnectionStats {
        private final String _clientHostName;
        private final LongAdder _connectionCount = new LongAdder();
        private boolean _reported = false;
        private final LongAdder _activeRequestCount = new LongAdder();

        public ConnectionStats(InetAddress inetAddress) {
            this._clientHostName = inetAddress.toString();
        }

        public ConnectionStats increment() {
            this._connectionCount.increment();
            return this;
        }

        public int activeRequestCount() {
            return this._activeRequestCount.intValue();
        }

        public int increaseActiveRequestAndGet() {
            this._activeRequestCount.increment();
            return this._activeRequestCount.intValue();
        }

        public int decreaseActiveRequestAndGet() {
            this._activeRequestCount.decrement();
            return this._activeRequestCount.intValue();
        }

        public int decrementAndGet() {
            this._connectionCount.decrement();
            return this._connectionCount.intValue();
        }

        public int connectionCount() {
            return this._connectionCount.intValue();
        }

        public ConnectionStats reported() {
            this._reported = true;
            return this;
        }

        public ConnectionStats resetReported() {
            this._reported = false;
            return this;
        }

        public boolean isReported() {
            return this._reported;
        }

        public String toString() {
            int connectionCount = connectionCount();
            int activeRequestCount = activeRequestCount();
            return "[host: " + this._clientHostName + ", connections: " + connectionCount + ", inflight-requests: " + activeRequestCount + ", available connections: " + (connectionCount - activeRequestCount) + "]";
        }
    }

    public ClientConnectionTracker(int i, int i2) {
        if (i2 > i) {
            throw new IllegalArgumentException(String.format("connectionWarningResetThreshold(%d) must not be higher than connectionCountLimit(%d).", Integer.valueOf(i2), Integer.valueOf(i)));
        }
        this._connectionCountLimit = i;
        this._connectionWarningResetThreshold = i2;
    }

    public ClientConnectionTracker setShouldLogMetricForSiRootCausingTheirOwnClientConnectionSurge(boolean z) {
        this._shouldLogMetricForSiRootCausingTheirOwnClientConnectionSurge = z;
        return this;
    }

    Map<InetAddress, ConnectionStats> statsMap() {
        return this._statsMap;
    }

    ConnectionStats getStatsByContext(ChannelHandlerContext channelHandlerContext) {
        Optional<InetAddress> hostInetAddressFromContext = getHostInetAddressFromContext(channelHandlerContext);
        Map<InetAddress, ConnectionStats> map = this._statsMap;
        Objects.requireNonNull(map);
        return (ConnectionStats) hostInetAddressFromContext.map((v1) -> {
            return r1.get(v1);
        }).orElse(null);
    }

    @Override // io.netty.channel.ChannelInboundHandlerAdapter, io.netty.channel.ChannelInboundHandler
    public void channelActive(ChannelHandlerContext channelHandlerContext) throws Exception {
        checkConnectionLimit(channelHandlerContext).ifPresent(connectionStats -> {
            if (this._shouldLogMetricForSiRootCausingTheirOwnClientConnectionSurge) {
                LOG.info("channelActive for client: {}", connectionStats);
            }
        });
        super.channelActive(channelHandlerContext);
    }

    @Override // io.netty.channel.ChannelInboundHandlerAdapter, io.netty.channel.ChannelInboundHandler
    public void channelRead(ChannelHandlerContext channelHandlerContext, Object obj) throws Exception {
        if (obj instanceof HttpRequest) {
            onRequest(channelHandlerContext);
        }
        super.channelRead(channelHandlerContext, obj);
    }

    @Override // io.netty.channel.ChannelDuplexHandler, io.netty.channel.ChannelOutboundHandler
    public void write(ChannelHandlerContext channelHandlerContext, Object obj, ChannelPromise channelPromise) throws Exception {
        if (obj instanceof HttpResponse) {
            onWritingLastHttpContent(channelHandlerContext);
        }
        super.write(channelHandlerContext, obj, channelPromise);
    }

    @Override // io.netty.channel.ChannelInboundHandlerAdapter, io.netty.channel.ChannelInboundHandler
    public void channelInactive(ChannelHandlerContext channelHandlerContext) throws Exception {
        onInActive(channelHandlerContext);
        super.channelInactive(channelHandlerContext);
    }

    private static Optional<InetAddress> getHostInetAddressFromContext(ChannelHandlerContext channelHandlerContext) {
        return getHostInetSocketAddressFromContext(channelHandlerContext).map((v0) -> {
            return v0.getAddress();
        });
    }

    private static Optional<InetSocketAddress> getHostInetSocketAddressFromContext(ChannelHandlerContext channelHandlerContext) {
        return Optional.ofNullable(channelHandlerContext.channel()).map((v0) -> {
            return v0.remoteAddress();
        }).filter(socketAddress -> {
            return socketAddress instanceof InetSocketAddress;
        }).map(socketAddress2 -> {
            return (InetSocketAddress) socketAddress2;
        });
    }

    protected Optional<ConnectionStats> checkConnectionLimit(ChannelHandlerContext channelHandlerContext) {
        return getHostInetSocketAddressFromContext(channelHandlerContext).map(inetSocketAddress -> {
            ConnectionStats increment = this._statsMap.computeIfAbsent(inetSocketAddress.getAddress(), inetAddress -> {
                return new ConnectionStats(inetSocketAddress.getAddress());
            }).increment();
            if (increment.connectionCount() > this._connectionCountLimit) {
                whenOverLimit(increment, channelHandlerContext);
            }
            return increment;
        });
    }

    protected void whenOverLimit(ConnectionStats connectionStats, ChannelHandlerContext channelHandlerContext) {
        if (connectionStats.isReported()) {
            return;
        }
        LOG.warn("Connection cap for {} is {}, exceeding the limit {}.", connectionStats._clientHostName, Integer.valueOf(connectionStats.connectionCount()), Integer.valueOf(this._connectionCountLimit));
        connectionStats.reported();
    }

    private int onRequest(ChannelHandlerContext channelHandlerContext) {
        return modifyRequestCount(channelHandlerContext, (v0) -> {
            return v0.increaseActiveRequestAndGet();
        });
    }

    private int onWritingLastHttpContent(ChannelHandlerContext channelHandlerContext) {
        return modifyRequestCount(channelHandlerContext, (v0) -> {
            return v0.decreaseActiveRequestAndGet();
        });
    }

    private int modifyRequestCount(ChannelHandlerContext channelHandlerContext, Function<ConnectionStats, Integer> function) {
        Optional<InetAddress> hostInetAddressFromContext = getHostInetAddressFromContext(channelHandlerContext);
        Map<InetAddress, ConnectionStats> map = this._statsMap;
        Objects.requireNonNull(map);
        Optional<U> map2 = hostInetAddressFromContext.map((v1) -> {
            return r1.get(v1);
        });
        Objects.requireNonNull(function);
        return ((Integer) map2.map((v1) -> {
            return r1.apply(v1);
        }).orElse(0)).intValue();
    }

    private void onInActive(ChannelHandlerContext channelHandlerContext) {
        getHostInetAddressFromContext(channelHandlerContext).ifPresent(inetAddress -> {
            this._statsMap.computeIfPresent(inetAddress, (inetAddress, connectionStats) -> {
                int decrementAndGet = connectionStats.decrementAndGet();
                if (decrementAndGet < this._connectionWarningResetThreshold && connectionStats.isReported()) {
                    connectionStats.resetReported();
                }
                if (decrementAndGet <= 0) {
                    connectionStats = null;
                }
                return connectionStats;
            });
        });
    }
}
