package com.linkedin.venice.router.api.path;

import com.linkedin.alpini.netty4.misc.BasicFullHttpRequest;
import com.linkedin.alpini.router.api.RouterException;
import com.linkedin.venice.compute.ComputeRequestWrapper;
import com.linkedin.venice.compute.protocol.request.ComputeOperation;
import com.linkedin.venice.compute.protocol.request.ComputeRequestV1;
import com.linkedin.venice.compute.protocol.request.ComputeRequestV2;
import com.linkedin.venice.compute.protocol.request.CosineSimilarity;
import com.linkedin.venice.compute.protocol.request.DotProduct;
import com.linkedin.venice.compute.protocol.request.enums.ComputeOperationType;
import com.linkedin.venice.partitioner.VenicePartitioner;
import com.linkedin.venice.router.api.VenicePartitionFinder;
import com.linkedin.venice.schema.avro.ReadAvroProtocolDefinition;
import com.linkedin.venice.serializer.SerializerDeserializerFactory;
import com.linkedin.venice.utils.Utils;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpVersion;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.LinkedList;
import org.apache.avro.Schema;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/linkedin/venice/router/api/path/TestVeniceComputePath.class */
public class TestVeniceComputePath {
    private static String resultSchemaStr = "{  \"namespace\": \"example.compute\",      \"type\": \"record\",          \"name\": \"MemberFeature\",         \"fields\": [                 { \"name\": \"id\", \"type\": \"string\" },                { \"name\": \"member_score\", \"type\": \"double\" }          ]        }       ";

    private ComputeRequestV1 getComputeRequest() {
        DotProduct dotProduct = new DotProduct();
        dotProduct.field = "member_feature";
        ArrayList arrayList = new ArrayList(3);
        arrayList.add(Float.valueOf(0.4f));
        arrayList.add(Float.valueOf(66.6f));
        arrayList.add(Float.valueOf(5.2f));
        dotProduct.dotProductParam = arrayList;
        dotProduct.resultFieldName = "member_score";
        ComputeOperation computeOperation = new ComputeOperation();
        computeOperation.operationType = 0;
        computeOperation.operation = dotProduct;
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(computeOperation);
        ComputeRequestV1 computeRequestV1 = new ComputeRequestV1();
        computeRequestV1.operations = arrayList2;
        computeRequestV1.resultSchemaStr = resultSchemaStr;
        return computeRequestV1;
    }

    private BasicFullHttpRequest getComputeHttpRequest(String str, byte[] bArr, int i) {
        BasicFullHttpRequest basicFullHttpRequest = new BasicFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/compute/" + str, Unpooled.wrappedBuffer(bArr), 0L, 0L);
        basicFullHttpRequest.headers().add("X-VENICE-API-VERSION", Integer.valueOf(i));
        return basicFullHttpRequest;
    }

    private VenicePartitionFinder getVenicePartitionFinder(int i) {
        VenicePartitionFinder venicePartitionFinder = (VenicePartitionFinder) Mockito.mock(VenicePartitionFinder.class);
        VenicePartitioner venicePartitioner = (VenicePartitioner) Mockito.mock(VenicePartitioner.class);
        Mockito.when(Integer.valueOf(venicePartitioner.getPartitionId((ByteBuffer) Mockito.any(ByteBuffer.class), Mockito.anyInt()))).thenReturn(Integer.valueOf(i));
        Mockito.when(venicePartitionFinder.findPartitioner(ArgumentMatchers.anyString(), Mockito.anyInt())).thenReturn(venicePartitioner);
        return venicePartitionFinder;
    }

    @Test
    public void testDeserializationCorrectness() throws RouterException {
        String uniqueString = Utils.getUniqueString("test_store");
        String str = uniqueString + "_v1";
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 5; i++) {
            arrayList.add(ByteBuffer.wrap(("key_" + i).getBytes()));
        }
        ComputeRequestV1 computeRequest = getComputeRequest();
        byte[] serialize = SerializerDeserializerFactory.getAvroGenericSerializer(ComputeRequestV1.getClassSchema()).serialize(computeRequest);
        int length = serialize.length;
        byte[] serializeObjects = SerializerDeserializerFactory.getAvroGenericSerializer(ReadAvroProtocolDefinition.COMPUTE_REQUEST_CLIENT_KEY_V1.getSchema()).serializeObjects(arrayList);
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        try {
            byteArrayOutputStream.write(serialize);
            byteArrayOutputStream.write(serializeObjects);
        } catch (Exception e) {
            e.printStackTrace();
            Assert.fail("Failed to write bytes to output stream", e);
        }
        for (int i2 = 1; i2 <= 3; i2++) {
            VeniceComputePath veniceComputePath = new VeniceComputePath(uniqueString, 1, str, getComputeHttpRequest(str, byteArrayOutputStream.toByteArray(), i2), getVenicePartitionFinder(-1), 10, false, -1, false, 1);
            Assert.assertEquals(veniceComputePath.getComputeRequestLengthInBytes(), length);
            ComputeRequestWrapper computeRequest2 = veniceComputePath.getComputeRequest();
            Assert.assertTrue(Schema.parse(computeRequest2.getResultSchemaStr().toString()).equals(Schema.parse(computeRequest.resultSchemaStr.toString())));
            Assert.assertEquals(computeRequest2.getOperations(), computeRequest.operations);
        }
    }

    @Test
    public void testComputeRequestVersionBackwardCompatible() {
        ComputeRequestV1 computeRequestV1 = new ComputeRequestV1();
        computeRequestV1.resultSchemaStr = "test_result_schema";
        computeRequestV1.operations = new LinkedList();
        LinkedList linkedList = new LinkedList();
        linkedList.add(Float.valueOf(0.1f));
        ComputeOperation computeOperation = new ComputeOperation();
        computeOperation.operationType = ComputeOperationType.DOT_PRODUCT.getValue();
        DotProduct dotProduct = (DotProduct) ComputeOperationType.DOT_PRODUCT.getNewInstance();
        dotProduct.field = "dotProductField";
        dotProduct.resultFieldName = "dotProductResultField";
        dotProduct.dotProductParam = linkedList;
        computeOperation.operation = dotProduct;
        computeRequestV1.operations.add(computeOperation);
        LinkedList linkedList2 = new LinkedList();
        linkedList2.add(Float.valueOf(0.2f));
        ComputeOperation computeOperation2 = new ComputeOperation();
        computeOperation2.operationType = ComputeOperationType.COSINE_SIMILARITY.getValue();
        CosineSimilarity cosineSimilarity = (CosineSimilarity) ComputeOperationType.COSINE_SIMILARITY.getNewInstance();
        cosineSimilarity.field = "cosineSimilarityField";
        cosineSimilarity.resultFieldName = "cosineSimilarityResultField";
        cosineSimilarity.cosSimilarityParam = linkedList2;
        computeOperation2.operation = cosineSimilarity;
        computeRequestV1.operations.add(computeOperation2);
        ComputeRequestV2 computeRequestV2 = (ComputeRequestV2) SerializerDeserializerFactory.getAvroSpecificDeserializer(ReadAvroProtocolDefinition.COMPUTE_REQUEST_V1.getSchema(), ComputeRequestV2.class).deserialize(SerializerDeserializerFactory.getAvroGenericSerializer(ReadAvroProtocolDefinition.COMPUTE_REQUEST_V1.getSchema()).serialize(computeRequestV1));
        Assert.assertEquals(computeRequestV2.resultSchemaStr.toString(), "test_result_schema");
        ComputeOperation computeOperation3 = (ComputeOperation) computeRequestV2.operations.get(0);
        Assert.assertEquals(computeOperation3.operationType, ComputeOperationType.DOT_PRODUCT.getValue());
        Assert.assertEquals(computeOperation3.operation, dotProduct);
        ComputeOperation computeOperation4 = (ComputeOperation) computeRequestV2.operations.get(1);
        Assert.assertEquals(computeOperation4.operationType, ComputeOperationType.COSINE_SIMILARITY.getValue());
        Assert.assertEquals(computeOperation4.operation, cosineSimilarity);
    }
}
