package org.apache.spark.network.client;

import com.facebook.presto.spark.$internal.io.netty.channel.Channel;
import com.facebook.presto.spark.$internal.org.slf4j.Logger;
import com.facebook.presto.spark.$internal.org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Iterator;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.ResponseMessage;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.protocol.StreamFailure;
import org.apache.spark.network.protocol.StreamResponse;
import org.apache.spark.network.server.MessageHandler;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportFrameDecoder;
import org.spark_project.guava.annotations.VisibleForTesting;

/* loaded from: input_file:org/apache/spark/network/client/TransportResponseHandler.class */
public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
    private final Channel channel;
    private volatile boolean streamActive;
    private final Logger logger = LoggerFactory.getLogger((Class<?>) TransportResponseHandler.class);
    private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches = new ConcurrentHashMap();
    private final Map<Long, RpcResponseCallback> outstandingRpcs = new ConcurrentHashMap();
    private final Queue<StreamCallback> streamCallbacks = new ConcurrentLinkedQueue();
    private final AtomicLong timeOfLastRequestNs = new AtomicLong(0);

    public TransportResponseHandler(Channel channel) {
        this.channel = channel;
    }

    public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback chunkReceivedCallback) {
        updateTimeOfLastRequest();
        this.outstandingFetches.put(streamChunkId, chunkReceivedCallback);
    }

    public void removeFetchRequest(StreamChunkId streamChunkId) {
        this.outstandingFetches.remove(streamChunkId);
    }

    public void addRpcRequest(long j, RpcResponseCallback rpcResponseCallback) {
        updateTimeOfLastRequest();
        this.outstandingRpcs.put(Long.valueOf(j), rpcResponseCallback);
    }

    public void removeRpcRequest(long j) {
        this.outstandingRpcs.remove(Long.valueOf(j));
    }

    public void addStreamCallback(StreamCallback streamCallback) {
        this.timeOfLastRequestNs.set(System.nanoTime());
        this.streamCallbacks.offer(streamCallback);
    }

    @VisibleForTesting
    public void deactivateStream() {
        this.streamActive = false;
    }

    private void failOutstandingRequests(Throwable th) {
        for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : this.outstandingFetches.entrySet()) {
            entry.getValue().onFailure(entry.getKey().chunkIndex, th);
        }
        Iterator<Map.Entry<Long, RpcResponseCallback>> it = this.outstandingRpcs.entrySet().iterator();
        while (it.hasNext()) {
            it.next().getValue().onFailure(th);
        }
        this.outstandingFetches.clear();
        this.outstandingRpcs.clear();
    }

    @Override // org.apache.spark.network.server.MessageHandler
    public void channelActive() {
    }

    @Override // org.apache.spark.network.server.MessageHandler
    public void channelInactive() {
        if (numOutstandingRequests() > 0) {
            String remoteAddress = NettyUtils.getRemoteAddress(this.channel);
            this.logger.error("Still have {} requests outstanding when connection from {} is closed", Integer.valueOf(numOutstandingRequests()), remoteAddress);
            failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
        }
    }

    @Override // org.apache.spark.network.server.MessageHandler
    public void exceptionCaught(Throwable th) {
        if (numOutstandingRequests() > 0) {
            this.logger.error("Still have {} requests outstanding when connection from {} is closed", Integer.valueOf(numOutstandingRequests()), NettyUtils.getRemoteAddress(this.channel));
            failOutstandingRequests(th);
        }
    }

    @Override // org.apache.spark.network.server.MessageHandler
    public void handle(ResponseMessage responseMessage) throws Exception {
        if (responseMessage instanceof ChunkFetchSuccess) {
            ChunkFetchSuccess chunkFetchSuccess = (ChunkFetchSuccess) responseMessage;
            ChunkReceivedCallback chunkReceivedCallback = this.outstandingFetches.get(chunkFetchSuccess.streamChunkId);
            if (chunkReceivedCallback == null) {
                this.logger.warn("Ignoring response for block {} from {} since it is not outstanding", chunkFetchSuccess.streamChunkId, NettyUtils.getRemoteAddress(this.channel));
                chunkFetchSuccess.body().release();
                return;
            } else {
                this.outstandingFetches.remove(chunkFetchSuccess.streamChunkId);
                chunkReceivedCallback.onSuccess(chunkFetchSuccess.streamChunkId.chunkIndex, chunkFetchSuccess.body());
                chunkFetchSuccess.body().release();
                return;
            }
        }
        if (responseMessage instanceof ChunkFetchFailure) {
            ChunkFetchFailure chunkFetchFailure = (ChunkFetchFailure) responseMessage;
            ChunkReceivedCallback chunkReceivedCallback2 = this.outstandingFetches.get(chunkFetchFailure.streamChunkId);
            if (chunkReceivedCallback2 == null) {
                this.logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding", chunkFetchFailure.streamChunkId, NettyUtils.getRemoteAddress(this.channel), chunkFetchFailure.errorString);
                return;
            } else {
                this.outstandingFetches.remove(chunkFetchFailure.streamChunkId);
                chunkReceivedCallback2.onFailure(chunkFetchFailure.streamChunkId.chunkIndex, new ChunkFetchFailureException("Failure while fetching " + chunkFetchFailure.streamChunkId + ": " + chunkFetchFailure.errorString));
                return;
            }
        }
        if (responseMessage instanceof RpcResponse) {
            RpcResponse rpcResponse = (RpcResponse) responseMessage;
            RpcResponseCallback rpcResponseCallback = this.outstandingRpcs.get(Long.valueOf(rpcResponse.requestId));
            if (rpcResponseCallback == null) {
                this.logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", Long.valueOf(rpcResponse.requestId), NettyUtils.getRemoteAddress(this.channel), Long.valueOf(rpcResponse.body().size()));
                return;
            }
            this.outstandingRpcs.remove(Long.valueOf(rpcResponse.requestId));
            try {
                rpcResponseCallback.onSuccess(rpcResponse.body().nioByteBuffer());
                rpcResponse.body().release();
                return;
            } catch (Throwable th) {
                rpcResponse.body().release();
                throw th;
            }
        }
        if (responseMessage instanceof RpcFailure) {
            RpcFailure rpcFailure = (RpcFailure) responseMessage;
            RpcResponseCallback rpcResponseCallback2 = this.outstandingRpcs.get(Long.valueOf(rpcFailure.requestId));
            if (rpcResponseCallback2 == null) {
                this.logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", Long.valueOf(rpcFailure.requestId), NettyUtils.getRemoteAddress(this.channel), rpcFailure.errorString);
                return;
            } else {
                this.outstandingRpcs.remove(Long.valueOf(rpcFailure.requestId));
                rpcResponseCallback2.onFailure(new RuntimeException(rpcFailure.errorString));
                return;
            }
        }
        if (!(responseMessage instanceof StreamResponse)) {
            if (!(responseMessage instanceof StreamFailure)) {
                throw new IllegalStateException("Unknown response type: " + responseMessage.type());
            }
            StreamFailure streamFailure = (StreamFailure) responseMessage;
            StreamCallback poll = this.streamCallbacks.poll();
            if (poll == null) {
                this.logger.warn("Stream failure with unknown callback: {}", streamFailure.error);
                return;
            }
            try {
                poll.onFailure(streamFailure.streamId, new RuntimeException(streamFailure.error));
                return;
            } catch (IOException e) {
                this.logger.warn("Error in stream failure handler.", (Throwable) e);
                return;
            }
        }
        StreamResponse streamResponse = (StreamResponse) responseMessage;
        StreamCallback poll2 = this.streamCallbacks.poll();
        if (poll2 == null) {
            this.logger.error("Could not find callback for StreamResponse.");
            return;
        }
        if (streamResponse.byteCount <= 0) {
            try {
                poll2.onComplete(streamResponse.streamId);
                return;
            } catch (Exception e2) {
                this.logger.warn("Error in stream handler onComplete().", (Throwable) e2);
                return;
            }
        }
        try {
            ((TransportFrameDecoder) this.channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME)).setInterceptor(new StreamInterceptor(this, streamResponse.streamId, streamResponse.byteCount, poll2));
            this.streamActive = true;
        } catch (Exception e3) {
            this.logger.error("Error installing stream handler.", (Throwable) e3);
            deactivateStream();
        }
    }

    public int numOutstandingRequests() {
        return this.outstandingFetches.size() + this.outstandingRpcs.size() + this.streamCallbacks.size() + (this.streamActive ? 1 : 0);
    }

    public long getTimeOfLastRequestNs() {
        return this.timeOfLastRequestNs.get();
    }

    public void updateTimeOfLastRequest() {
        this.timeOfLastRequestNs.set(System.nanoTime());
    }
}
