package com.datastax.oss.driver.internal.core.channel;

import com.datastax.oss.driver.Assertions;
import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.DefaultProtocolVersion;
import com.datastax.oss.driver.api.core.InvalidKeyspaceException;
import com.datastax.oss.driver.api.core.auth.AuthProvider;
import com.datastax.oss.driver.api.core.auth.AuthenticationException;
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
import com.datastax.oss.driver.api.core.config.DriverConfig;
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
import com.datastax.oss.driver.api.core.metadata.EndPoint;
import com.datastax.oss.driver.internal.core.CassandraProtocolVersionRegistry;
import com.datastax.oss.driver.internal.core.ProtocolVersionRegistry;
import com.datastax.oss.driver.internal.core.TestResponses;
import com.datastax.oss.driver.internal.core.context.InternalDriverContext;
import com.datastax.oss.driver.internal.core.metadata.TestNodeFactory;
import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList;
import com.datastax.oss.protocol.internal.Frame;
import com.datastax.oss.protocol.internal.request.AuthResponse;
import com.datastax.oss.protocol.internal.request.Query;
import com.datastax.oss.protocol.internal.request.Register;
import com.datastax.oss.protocol.internal.request.Startup;
import com.datastax.oss.protocol.internal.response.AuthChallenge;
import com.datastax.oss.protocol.internal.response.AuthSuccess;
import com.datastax.oss.protocol.internal.response.Authenticate;
import com.datastax.oss.protocol.internal.response.Error;
import com.datastax.oss.protocol.internal.response.Ready;
import com.datastax.oss.protocol.internal.response.result.SetKeyspace;
import com.datastax.oss.protocol.internal.util.Bytes;
import io.netty.channel.ChannelFuture;
import io.netty.util.concurrent.Future;
import java.net.InetSocketAddress;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;

/* loaded from: input_file:com/datastax/oss/driver/internal/core/channel/ProtocolInitHandlerTest.class */
public class ProtocolInitHandlerTest extends ChannelHandlerTestBase {
    private static final long QUERY_TIMEOUT_MILLIS = 100;
    private static final EndPoint END_POINT = TestNodeFactory.newEndPoint(1);

    @Mock
    private InternalDriverContext internalDriverContext;

    @Mock
    private DriverConfig driverConfig;

    @Mock
    private DriverExecutionProfile defaultProfile;
    private ProtocolVersionRegistry protocolVersionRegistry = new CassandraProtocolVersionRegistry("test");
    private HeartbeatHandler heartbeatHandler;

    @Override // com.datastax.oss.driver.internal.core.channel.ChannelHandlerTestBase
    @Before
    public void setup() {
        super.setup();
        MockitoAnnotations.initMocks(this);
        Mockito.when(this.internalDriverContext.getConfig()).thenReturn(this.driverConfig);
        Mockito.when(this.driverConfig.getDefaultProfile()).thenReturn(this.defaultProfile);
        Mockito.when(this.defaultProfile.getDuration(DefaultDriverOption.CONNECTION_INIT_QUERY_TIMEOUT)).thenReturn(Duration.ofMillis(QUERY_TIMEOUT_MILLIS));
        Mockito.when(this.defaultProfile.getDuration(DefaultDriverOption.HEARTBEAT_INTERVAL)).thenReturn(Duration.ofSeconds(30L));
        Mockito.when(this.internalDriverContext.getProtocolVersionRegistry()).thenReturn(this.protocolVersionRegistry);
        this.channel.pipeline().addLast("inflight", new InFlightHandler(DefaultProtocolVersion.V4, new StreamIdGenerator(100), Integer.MAX_VALUE, QUERY_TIMEOUT_MILLIS, this.channel.newPromise(), (EventCallback) null, "test"));
        this.heartbeatHandler = new HeartbeatHandler(this.defaultProfile);
    }

    @Test
    public void should_initialize() {
        this.channel.pipeline().addLast("init", new ProtocolInitHandler(this.internalDriverContext, DefaultProtocolVersion.V4, (String) null, END_POINT, DriverChannelOptions.DEFAULT, this.heartbeatHandler));
        ChannelFuture connect = this.channel.connect(new InetSocketAddress("localhost", 9042));
        Frame readOutboundFrame = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame.message).isInstanceOf(Startup.class);
        Assertions.assertThat((Future) connect).isNotDone();
        writeInboundFrame(buildInboundFrame(readOutboundFrame, new Ready()));
        Frame readOutboundFrame2 = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame2.message).isInstanceOf(Query.class);
        writeInboundFrame(readOutboundFrame2, TestResponses.clusterNameResponse("someClusterName"));
        Assertions.assertThat((Future) connect).isSuccess();
    }

    @Test
    public void should_add_heartbeat_handler_to_pipeline_on_success() {
        ProtocolInitHandler protocolInitHandler = new ProtocolInitHandler(this.internalDriverContext, DefaultProtocolVersion.V4, (String) null, END_POINT, DriverChannelOptions.DEFAULT, this.heartbeatHandler);
        this.channel.pipeline().addLast("init", protocolInitHandler);
        ChannelFuture connect = this.channel.connect(new InetSocketAddress("localhost", 9042));
        Assertions.assertThat(this.channel.pipeline().get("heartbeat")).isNull();
        Frame readOutboundFrame = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame.message).isInstanceOf(Startup.class);
        Assertions.assertThat((Future) connect).isNotDone();
        writeInboundFrame(buildInboundFrame(readOutboundFrame, new Ready()));
        Frame readOutboundFrame2 = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame2.message).isInstanceOf(Query.class);
        writeInboundFrame(readOutboundFrame2, TestResponses.clusterNameResponse("someClusterName"));
        Assertions.assertThat((Future) connect).isSuccess();
        Assertions.assertThat(this.channel.pipeline().get("heartbeat")).isEqualTo(this.heartbeatHandler);
        Assertions.assertThat(this.channel.pipeline().last()).isNotEqualTo(protocolInitHandler);
    }

    @Test
    public void should_fail_to_initialize_if_init_query_times_out() throws InterruptedException {
        this.channel.pipeline().addLast("init", new ProtocolInitHandler(this.internalDriverContext, DefaultProtocolVersion.V4, (String) null, END_POINT, DriverChannelOptions.DEFAULT, this.heartbeatHandler));
        ChannelFuture connect = this.channel.connect(new InetSocketAddress("localhost", 9042));
        readOutboundFrame();
        TimeUnit.MILLISECONDS.sleep(200L);
        this.channel.runPendingTasks();
        Assertions.assertThat((Future) connect).isFailed();
    }

    @Test
    public void should_initialize_with_authentication() {
        this.channel.pipeline().addLast("init", new ProtocolInitHandler(this.internalDriverContext, DefaultProtocolVersion.V4, (String) null, END_POINT, DriverChannelOptions.DEFAULT, this.heartbeatHandler));
        AuthProvider authProvider = (AuthProvider) Mockito.mock(AuthProvider.class);
        MockAuthenticator mockAuthenticator = new MockAuthenticator();
        Mockito.when(authProvider.newAuthenticator(END_POINT, "mockServerAuthenticator")).thenReturn(mockAuthenticator);
        Mockito.when(this.internalDriverContext.getAuthProvider()).thenReturn(Optional.of(authProvider));
        ChannelFuture connect = this.channel.connect(new InetSocketAddress("localhost", 9042));
        Frame readOutboundFrame = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame.message).isInstanceOf(Startup.class);
        Assertions.assertThat((Future) connect).isNotDone();
        writeInboundFrame(readOutboundFrame, new Authenticate("mockServerAuthenticator"));
        ((AuthProvider) Mockito.verify(authProvider)).newAuthenticator(END_POINT, "mockServerAuthenticator");
        Frame readOutboundFrame2 = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame2.message).isInstanceOf(AuthResponse.class);
        Assertions.assertThat(Bytes.toHexString(readOutboundFrame2.message.token)).isEqualTo("0xcafebabe");
        Assertions.assertThat((Future) connect).isNotDone();
        for (int i = 0; i < 5; i++) {
            writeInboundFrame(readOutboundFrame2, new AuthChallenge(Bytes.fromHexString("0xabcd")));
            readOutboundFrame2 = readOutboundFrame();
            Assertions.assertThat(readOutboundFrame2.message).isInstanceOf(AuthResponse.class);
            Assertions.assertThat(Bytes.toHexString(readOutboundFrame2.message.token)).isEqualTo("0xabcd");
            Assertions.assertThat((Future) connect).isNotDone();
        }
        writeInboundFrame(readOutboundFrame2, new AuthSuccess(Bytes.fromHexString("0xabcd")));
        Assertions.assertThat(mockAuthenticator.successToken).isEqualTo("0xabcd");
        writeInboundFrame(readOutboundFrame(), TestResponses.clusterNameResponse("someClusterName"));
        Assertions.assertThat((Future) connect).isSuccess();
    }

    @Test
    public void should_invoke_auth_provider_when_server_does_not_send_challenge() {
        this.channel.pipeline().addLast("init", new ProtocolInitHandler(this.internalDriverContext, DefaultProtocolVersion.V4, (String) null, END_POINT, DriverChannelOptions.DEFAULT, this.heartbeatHandler));
        AuthProvider authProvider = (AuthProvider) Mockito.mock(AuthProvider.class);
        Mockito.when(this.internalDriverContext.getAuthProvider()).thenReturn(Optional.of(authProvider));
        ChannelFuture connect = this.channel.connect(new InetSocketAddress("localhost", 9042));
        Frame readOutboundFrame = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame.message).isInstanceOf(Startup.class);
        writeInboundFrame(buildInboundFrame(readOutboundFrame, new Ready()));
        ((AuthProvider) Mockito.verify(authProvider)).onMissingChallenge(END_POINT);
        Frame readOutboundFrame2 = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame2.message).isInstanceOf(Query.class);
        writeInboundFrame(readOutboundFrame2, TestResponses.clusterNameResponse("someClusterName"));
        Assertions.assertThat((Future) connect).isSuccess();
    }

    @Test
    public void should_fail_to_initialize_if_server_sends_auth_error() throws Throwable {
        this.channel.pipeline().addLast("init", new ProtocolInitHandler(this.internalDriverContext, DefaultProtocolVersion.V4, (String) null, END_POINT, DriverChannelOptions.DEFAULT, this.heartbeatHandler));
        AuthProvider authProvider = (AuthProvider) Mockito.mock(AuthProvider.class);
        Mockito.when(authProvider.newAuthenticator(END_POINT, "mockServerAuthenticator")).thenReturn(new MockAuthenticator());
        Mockito.when(this.internalDriverContext.getAuthProvider()).thenReturn(Optional.of(authProvider));
        ChannelFuture connect = this.channel.connect(new InetSocketAddress("localhost", 9042));
        Frame readOutboundFrame = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame.message).isInstanceOf(Startup.class);
        Assertions.assertThat((Future) connect).isNotDone();
        writeInboundFrame(readOutboundFrame, new Authenticate("mockServerAuthenticator"));
        Frame readOutboundFrame2 = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame2.message).isInstanceOf(AuthResponse.class);
        Assertions.assertThat((Future) connect).isNotDone();
        writeInboundFrame(readOutboundFrame2, new Error(256, "mock error"));
        Assertions.assertThat((Future) connect).isFailed(th -> {
            Assertions.assertThat(th).isInstanceOf(AuthenticationException.class).hasMessage(String.format("Authentication error on node %s: server replied 'mock error'", END_POINT));
        });
    }

    @Test
    public void should_check_cluster_name_if_provided() {
        this.channel.pipeline().addLast("init", new ProtocolInitHandler(this.internalDriverContext, DefaultProtocolVersion.V4, "expectedClusterName", END_POINT, DriverChannelOptions.DEFAULT, this.heartbeatHandler));
        ChannelFuture connect = this.channel.connect(new InetSocketAddress("localhost", 9042));
        writeInboundFrame(readOutboundFrame(), new Ready());
        Frame readOutboundFrame = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame.message).isInstanceOf(Query.class);
        Assertions.assertThat(readOutboundFrame.message.query).isEqualTo("SELECT cluster_name FROM system.local");
        Assertions.assertThat((Future) connect).isNotDone();
        writeInboundFrame(readOutboundFrame, TestResponses.clusterNameResponse("expectedClusterName"));
        Assertions.assertThat((Future) connect).isSuccess();
    }

    @Test
    public void should_fail_to_initialize_if_cluster_name_does_not_match() throws Throwable {
        this.channel.pipeline().addLast("init", new ProtocolInitHandler(this.internalDriverContext, DefaultProtocolVersion.V4, "expectedClusterName", END_POINT, DriverChannelOptions.DEFAULT, this.heartbeatHandler));
        ChannelFuture connect = this.channel.connect(new InetSocketAddress("localhost", 9042));
        writeInboundFrame(readOutboundFrame(), new Ready());
        writeInboundFrame(readOutboundFrame(), TestResponses.clusterNameResponse("differentClusterName"));
        Assertions.assertThat((Future) connect).isFailed(th -> {
            Assertions.assertThat(th).isInstanceOf(ClusterNameMismatchException.class).hasMessageContaining(String.format("Node %s reports cluster name 'differentClusterName' that doesn't match our cluster name 'expectedClusterName'.", END_POINT));
        });
    }

    @Test
    public void should_initialize_with_keyspace() {
        this.channel.pipeline().addLast("init", new ProtocolInitHandler(this.internalDriverContext, DefaultProtocolVersion.V4, (String) null, END_POINT, DriverChannelOptions.builder().withKeyspace(CqlIdentifier.fromCql("ks")).build(), this.heartbeatHandler));
        ChannelFuture connect = this.channel.connect(new InetSocketAddress("localhost", 9042));
        writeInboundFrame(readOutboundFrame(), new Ready());
        writeInboundFrame(readOutboundFrame(), TestResponses.clusterNameResponse("someClusterName"));
        Frame readOutboundFrame = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame.message).isInstanceOf(Query.class);
        Assertions.assertThat(readOutboundFrame.message.query).isEqualTo("USE \"ks\"");
        writeInboundFrame(readOutboundFrame, new SetKeyspace("ks"));
        Assertions.assertThat((Future) connect).isSuccess();
    }

    @Test
    public void should_initialize_with_events() {
        this.channel.pipeline().addLast("init", new ProtocolInitHandler(this.internalDriverContext, DefaultProtocolVersion.V4, (String) null, END_POINT, DriverChannelOptions.builder().withEvents(ImmutableList.of("foo", "bar"), (EventCallback) Mockito.mock(EventCallback.class)).build(), this.heartbeatHandler));
        ChannelFuture connect = this.channel.connect(new InetSocketAddress("localhost", 9042));
        writeInboundFrame(readOutboundFrame(), new Ready());
        writeInboundFrame(readOutboundFrame(), TestResponses.clusterNameResponse("someClusterName"));
        Frame readOutboundFrame = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame.message).isInstanceOf(Register.class);
        Assertions.assertThat(readOutboundFrame.message.eventTypes).containsExactly(new String[]{"foo", "bar"});
        writeInboundFrame(readOutboundFrame, new Ready());
        Assertions.assertThat((Future) connect).isSuccess();
    }

    @Test
    public void should_initialize_with_keyspace_and_events() {
        this.channel.pipeline().addLast("init", new ProtocolInitHandler(this.internalDriverContext, DefaultProtocolVersion.V4, (String) null, END_POINT, DriverChannelOptions.builder().withKeyspace(CqlIdentifier.fromCql("ks")).withEvents(ImmutableList.of("foo", "bar"), (EventCallback) Mockito.mock(EventCallback.class)).build(), this.heartbeatHandler));
        ChannelFuture connect = this.channel.connect(new InetSocketAddress("localhost", 9042));
        writeInboundFrame(readOutboundFrame(), new Ready());
        writeInboundFrame(readOutboundFrame(), TestResponses.clusterNameResponse("someClusterName"));
        Frame readOutboundFrame = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame.message).isInstanceOf(Query.class);
        Assertions.assertThat(readOutboundFrame.message.query).isEqualTo("USE \"ks\"");
        writeInboundFrame(readOutboundFrame, new SetKeyspace("ks"));
        Frame readOutboundFrame2 = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame2.message).isInstanceOf(Register.class);
        Assertions.assertThat(readOutboundFrame2.message.eventTypes).containsExactly(new String[]{"foo", "bar"});
        writeInboundFrame(readOutboundFrame2, new Ready());
        Assertions.assertThat((Future) connect).isSuccess();
    }

    @Test
    public void should_fail_to_initialize_if_keyspace_is_invalid() {
        this.channel.pipeline().addLast("init", new ProtocolInitHandler(this.internalDriverContext, DefaultProtocolVersion.V4, (String) null, END_POINT, DriverChannelOptions.builder().withKeyspace(CqlIdentifier.fromCql("ks")).build(), this.heartbeatHandler));
        ChannelFuture connect = this.channel.connect(new InetSocketAddress("localhost", 9042));
        writeInboundFrame(readOutboundFrame(), new Ready());
        writeInboundFrame(readOutboundFrame(), TestResponses.clusterNameResponse("someClusterName"));
        Frame readOutboundFrame = readOutboundFrame();
        Assertions.assertThat(readOutboundFrame.message).isInstanceOf(Query.class);
        Assertions.assertThat(readOutboundFrame.message.query).isEqualTo("USE \"ks\"");
        writeInboundFrame(readOutboundFrame, new Error(8704, "invalid keyspace"));
        Assertions.assertThat((Future) connect).isFailed(th -> {
            Assertions.assertThat(th).isInstanceOf(InvalidKeyspaceException.class).hasMessage("invalid keyspace");
        });
    }
}
