package com.linkedin.venice.storagenode;

import com.github.luben.zstd.ZstdDictTrainer;
import com.linkedin.venice.client.exceptions.VeniceClientException;
import com.linkedin.venice.client.store.AvroGenericStoreClient;
import com.linkedin.venice.client.store.ClientConfig;
import com.linkedin.venice.client.store.ClientFactory;
import com.linkedin.venice.compression.CompressionStrategy;
import com.linkedin.venice.compression.CompressorFactory;
import com.linkedin.venice.compression.VeniceCompressor;
import com.linkedin.venice.controllerapi.ControllerClient;
import com.linkedin.venice.controllerapi.ControllerResponse;
import com.linkedin.venice.controllerapi.UpdateStoreQueryParams;
import com.linkedin.venice.controllerapi.VersionCreationResponse;
import com.linkedin.venice.exceptions.VeniceException;
import com.linkedin.venice.integration.utils.ServiceFactory;
import com.linkedin.venice.integration.utils.VeniceClusterWrapper;
import com.linkedin.venice.meta.Version;
import com.linkedin.venice.pushmonitor.ExecutionStatus;
import com.linkedin.venice.serialization.DefaultSerializer;
import com.linkedin.venice.serialization.VeniceKafkaSerializer;
import com.linkedin.venice.serialization.avro.VeniceAvroKafkaSerializer;
import com.linkedin.venice.tehuti.MetricsUtils;
import com.linkedin.venice.utils.TestUtils;
import com.linkedin.venice.utils.Utils;
import com.linkedin.venice.writer.VeniceWriter;
import com.linkedin.venice.writer.VeniceWriterOptions;
import java.io.Closeable;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Properties;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData;
import org.apache.avro.util.Utf8;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:com/linkedin/venice/storagenode/StorageNodeComputeTest.class */
public class StorageNodeComputeTest {
    private static final Logger LOGGER = LogManager.getLogger(StorageNodeComputeTest.class);
    private VeniceClusterWrapper veniceCluster;
    private Map<SerializerReuse, AvroGenericStoreClient<String, Object>> clientsMap = new HashMap();
    private int valueSchemaId;
    private String storeName;
    private String routerAddr;
    private VeniceKafkaSerializer keySerializer;
    private VeniceKafkaSerializer valueSerializer;
    private CompressorFactory compressorFactory;
    private static final String VALUE_SCHEMA_FOR_COMPUTE = "{  \"namespace\": \"example.compute\",      \"type\": \"record\",          \"name\": \"MemberFeature\",         \"fields\": [                 { \"name\": \"id\", \"type\": \"string\" },                      { \"name\": \"name\", \"type\": \"string\" },                    { \"name\": \"namemap\", \"type\":  {\"type\" : \"map\", \"values\" : \"int\" }},                    { \"name\": \"member_feature\", \"type\": { \"type\": \"array\", \"items\": \"float\" } }          ]        }       ";

    /* loaded from: input_file:com/linkedin/venice/storagenode/StorageNodeComputeTest$AvroImpl.class */
    enum AvroImpl {
        VANILLA_AVRO(false),
        FAST_AVRO(true);

        final boolean config;

        AvroImpl(boolean z) {
            this.config = z;
        }
    }

    /* loaded from: input_file:com/linkedin/venice/storagenode/StorageNodeComputeTest$SerializerReuse.class */
    enum SerializerReuse {
        NO_REUSE(false),
        REUSE(true);

        final boolean config;

        SerializerReuse(boolean z) {
            this.config = z;
        }
    }

    /* loaded from: input_file:com/linkedin/venice/storagenode/StorageNodeComputeTest$ValueSize.class */
    enum ValueSize {
        SMALL_VALUE(false),
        LARGE_VALUE(true);

        final boolean config;

        ValueSize(boolean z) {
            this.config = z;
        }
    }

    @BeforeClass(alwaysRun = true)
    public void setUp() throws InterruptedException, ExecutionException, VeniceClientException {
        this.veniceCluster = ServiceFactory.getVeniceCluster(1, 1, 0, 2, 100, false, false);
        Properties properties = new Properties();
        properties.put("server.compute.fast.avro.enabled", true);
        this.veniceCluster.addVeniceServer(new Properties(), properties);
        Properties properties2 = new Properties();
        properties2.put("router.long.tail.retry.for.single.get.threshold.ms", 1);
        properties2.put("router.long.tail.retry.for.batch.get.threshold.ms", "1-:1");
        properties2.put("router.smart.long.tail.retry.enabled", false);
        this.veniceCluster.addVeniceRouter(properties2);
        this.routerAddr = "http://" + this.veniceCluster.getVeniceRouters().get(0).getAddress();
        this.storeName = Version.parseStoreFromKafkaTopicName(this.veniceCluster.getNewStoreVersion("\"string\"", VALUE_SCHEMA_FOR_COMPUTE).getKafkaTopic());
        this.valueSchemaId = 1;
        this.keySerializer = new VeniceAvroKafkaSerializer("\"string\"");
        this.valueSerializer = new VeniceAvroKafkaSerializer(VALUE_SCHEMA_FOR_COMPUTE);
        for (SerializerReuse serializerReuse : SerializerReuse.values()) {
            this.clientsMap.put(serializerReuse, ClientFactory.getAndStartGenericAvroClient(ClientConfig.defaultGenericClientConfig(this.storeName).setVeniceURL(this.routerAddr).setUseFastAvro(true).setReuseObjectsForSerialization(serializerReuse.config)));
        }
        this.compressorFactory = new CompressorFactory();
    }

    @AfterClass(alwaysRun = true)
    public void cleanUp() {
        if (this.veniceCluster != null) {
            this.veniceCluster.close();
        }
        this.clientsMap.forEach((serializerReuse, avroGenericStoreClient) -> {
            Utils.closeQuietlyWithErrorLogged(new Closeable[]{avroGenericStoreClient});
        });
        Utils.closeQuietlyWithErrorLogged(new Closeable[]{this.compressorFactory});
    }

    @DataProvider(name = "testPermutations")
    public static Object[][] testPermutations() {
        CompressionStrategy[] compressionStrategyArr = {CompressionStrategy.NO_OP, CompressionStrategy.GZIP, CompressionStrategy.ZSTD_WITH_DICT};
        ArrayList arrayList = new ArrayList();
        for (CompressionStrategy compressionStrategy : compressionStrategyArr) {
            for (ValueSize valueSize : ValueSize.values()) {
                for (SerializerReuse serializerReuse : SerializerReuse.values()) {
                    arrayList.add(new Object[]{compressionStrategy, serializerReuse, valueSize});
                }
            }
        }
        return (Object[][]) arrayList.toArray(new Object[arrayList.size()][5]);
    }

    @Test(timeOut = 30000, dataProvider = "testPermutations")
    public void testCompute(CompressionStrategy compressionStrategy, SerializerReuse serializerReuse, ValueSize valueSize) throws Exception {
        UpdateStoreQueryParams updateStoreQueryParams = new UpdateStoreQueryParams();
        updateStoreQueryParams.setCompressionStrategy(compressionStrategy);
        updateStoreQueryParams.setReadComputationEnabled(true);
        updateStoreQueryParams.setChunkingEnabled(valueSize.config);
        ControllerResponse updateStore = this.veniceCluster.updateStore(this.storeName, updateStoreQueryParams);
        Assert.assertFalse(updateStore.isError(), "Error updating store settings: " + updateStore.getError());
        VersionCreationResponse newVersion = this.veniceCluster.getNewVersion(this.storeName, false);
        Assert.assertFalse(newVersion.isError(), "Error creation new version: " + newVersion.getError());
        int version = newVersion.getVersion();
        String kafkaTopic = newVersion.getKafkaTopic();
        String str = "key_";
        String str2 = "value_";
        VeniceWriter<Object, byte[], byte[]> createVeniceWriter = TestUtils.getVeniceWriterFactory(this.veniceCluster.getKafka().getAddress()).createVeniceWriter(new VeniceWriterOptions.Builder(kafkaTopic).setKeySerializer(this.keySerializer).setValueSerializer(new DefaultSerializer()).setChunkingEnabled(valueSize.config).build());
        try {
            pushSyntheticDataForCompute(kafkaTopic, "key_", "value_", 10, this.veniceCluster, createVeniceWriter, version, compressionStrategy, valueSize.config);
            if (createVeniceWriter != null) {
                createVeniceWriter.close();
            }
            AvroGenericStoreClient<String, Object> avroGenericStoreClient = this.clientsMap.get(serializerReuse);
            int i = 0;
            while (true) {
                int i2 = i;
                i++;
                if (i2 >= 100) {
                    break;
                }
                HashSet hashSet = new HashSet();
                for (int i3 = 0; i3 < 10; i3++) {
                    hashSet.add("key_" + i3);
                }
                hashSet.add("unknown_key");
                List asList = Arrays.asList(Float.valueOf(100.0f), Float.valueOf(0.1f));
                List asList2 = Arrays.asList(Float.valueOf(123.4f), Float.valueOf(5.6f));
                List asList3 = Arrays.asList(Float.valueOf(135.7f), Float.valueOf(246.8f));
                TestUtils.waitForNonDeterministicAssertion(30L, TimeUnit.SECONDS, false, true, () -> {
                    Map map = (Map) avroGenericStoreClient.compute().project(new String[]{"id"}).dotProduct("member_feature", asList, "member_score").cosineSimilarity("member_feature", asList2, "cosine_similarity_result").hadamardProduct("member_feature", asList3, "hadamard_product_result").count("namemap", "namemap_count").count("member_feature", "member_feature_count").execute(hashSet).get(2L, TimeUnit.SECONDS);
                    Assert.assertEquals(map.size(), 10);
                    map.forEach((str3, computeGenericRecord) -> {
                        Assert.assertEquals(computeGenericRecord.get("id"), new Utf8(str2 + getKeyIndex(str3, str)));
                        Assert.assertEquals(computeGenericRecord.get("member_score"), Float.valueOf((((Float) asList.get(0)).floatValue() * (r0 + 1)) + (((Float) asList.get(1)).floatValue() * (r0 + 1) * 10)));
                        Assert.assertEquals(((Float) computeGenericRecord.get("cosine_similarity_result")).floatValue(), ((((Float) asList2.get(0)).floatValue() * (r0 + 1)) + (((Float) asList2.get(1)).floatValue() * ((r0 + 1) * 10))) / (((float) Math.sqrt((((Float) asList2.get(0)).floatValue() * ((Float) asList2.get(0)).floatValue()) + (((Float) asList2.get(1)).floatValue() * ((Float) asList2.get(1)).floatValue()))) * ((float) Math.sqrt(((r0 + 1) * (r0 + 1)) + (((r0 + 1) * 10.0f) * ((r0 + 1) * 10.0f))))), 1.0E-6f);
                        Assert.assertEquals(((Integer) computeGenericRecord.get("member_feature_count")).intValue(), 2);
                        Assert.assertEquals(((Integer) computeGenericRecord.get("namemap_count")).intValue(), 0);
                        ArrayList arrayList = new ArrayList(2);
                        arrayList.add(Float.valueOf(((Float) asList3.get(0)).floatValue() * (r0 + 1)));
                        arrayList.add(Float.valueOf(((Float) asList3.get(1)).floatValue() * (r0 + 1) * 10));
                        Assert.assertEquals(computeGenericRecord.get("hadamard_product_result"), arrayList);
                    });
                });
            }
            Assert.assertTrue(MetricsUtils.getSum(new StringBuilder().append(".").append(this.storeName).append("--compute_dot_product_count.Total").toString(), this.veniceCluster.getVeniceServers()) >= ((double) (100 * 10)));
            Assert.assertTrue(MetricsUtils.getSum(new StringBuilder().append(".").append(this.storeName).append("--compute_cosine_similarity_count.Total").toString(), this.veniceCluster.getVeniceServers()) >= ((double) (100 * 10)));
            Assert.assertTrue(MetricsUtils.getSum(new StringBuilder().append(".").append(this.storeName).append("--compute_hadamard_product_count.Total").toString(), this.veniceCluster.getVeniceServers()) >= ((double) (100 * 10)));
            Assert.assertTrue(MetricsUtils.getSum(".total--compute_streaming_retry_count.LambdaStat", this.veniceCluster.getVeniceRouters()) > 0.0d, "After 100 reads, there should be some compute retry requests");
        } catch (Throwable th) {
            if (createVeniceWriter != null) {
                try {
                    createVeniceWriter.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test(timeOut = 30000, groups = {"flaky"})
    public void testComputeRequestSize() throws Exception {
        UpdateStoreQueryParams updateStoreQueryParams = new UpdateStoreQueryParams();
        updateStoreQueryParams.setReadComputationEnabled(true);
        this.veniceCluster.updateStore(this.storeName, updateStoreQueryParams);
        VersionCreationResponse newVersion = this.veniceCluster.getNewVersion(this.storeName);
        int version = newVersion.getVersion();
        VeniceWriter<Object, byte[], byte[]> createVeniceWriter = TestUtils.getVeniceWriterFactory(this.veniceCluster.getKafka().getAddress()).createVeniceWriter(new VeniceWriterOptions.Builder(newVersion.getKafkaTopic()).setKeySerializer(this.keySerializer).build());
        try {
            AvroGenericStoreClient andStartGenericAvroClient = ClientFactory.getAndStartGenericAvroClient(ClientConfig.defaultGenericClientConfig(this.storeName).setVeniceURL(this.routerAddr));
            try {
                pushSyntheticDataForCompute(newVersion.getKafkaTopic(), "key_", "value_", 10, this.veniceCluster, createVeniceWriter, version, CompressionStrategy.NO_OP, false);
                int i = 0;
                while (true) {
                    int i2 = i;
                    i++;
                    if (i2 >= 100) {
                        break;
                    }
                    HashSet hashSet = new HashSet();
                    for (int i3 = 0; i3 < 10; i3++) {
                        hashSet.add("key_" + i3);
                    }
                    hashSet.add("unknown_key");
                    ArrayList arrayList = new ArrayList();
                    for (int i4 = 0; i4 < i; i4++) {
                        arrayList.add(Float.valueOf((float) (i4 * 0.1d)));
                    }
                    Assert.assertEquals(((Map) andStartGenericAvroClient.compute().dotProduct("member_feature", arrayList, "member_score").execute(hashSet).get(2L, TimeUnit.SECONDS)).size(), 10);
                    LOGGER.info("Current round: {}/{}\n - max part_count: {}\n - min part_count: {}\n - avg part_count: {}\n - max request size: {}\n - min request size: {}\n - avg request size: {}", Integer.valueOf(i), 100, Double.valueOf(MetricsUtils.getMax(".total--compute_request_part_count.Max", this.veniceCluster.getVeniceServers())), Double.valueOf(MetricsUtils.getMin(".total--compute_request_part_count.Min", this.veniceCluster.getVeniceServers())), Double.valueOf(MetricsUtils.getAvg(".total--compute_request_part_count.Avg", this.veniceCluster.getVeniceServers())), Double.valueOf(MetricsUtils.getMax(".total--compute_request_size_in_bytes.Max", this.veniceCluster.getVeniceServers())), Double.valueOf(MetricsUtils.getMin(".total--compute_request_size_in_bytes.Min", this.veniceCluster.getVeniceServers())), Double.valueOf(MetricsUtils.getAvg(".total--compute_request_size_in_bytes.Avg", this.veniceCluster.getVeniceServers())));
                    Assert.assertEquals(Double.valueOf(MetricsUtils.getMax(".total--compute_request_part_count.Max", this.veniceCluster.getVeniceServers())), Double.valueOf(1.0d), "Expected a max of one part in compute request");
                }
                if (andStartGenericAvroClient != null) {
                    andStartGenericAvroClient.close();
                }
                if (createVeniceWriter != null) {
                    createVeniceWriter.close();
                }
            } catch (Throwable th) {
                if (andStartGenericAvroClient != null) {
                    try {
                        andStartGenericAvroClient.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        } catch (Throwable th3) {
            if (createVeniceWriter != null) {
                try {
                    createVeniceWriter.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    private int getKeyIndex(String str, String str2) {
        if (str.startsWith(str2)) {
            return Integer.parseInt(str.substring(str2.length()));
        }
        return -1;
    }

    private void pushSyntheticDataForCompute(String str, String str2, String str3, int i, VeniceClusterWrapper veniceClusterWrapper, VeniceWriter<Object, byte[], byte[]> veniceWriter, int i2, CompressionStrategy compressionStrategy, boolean z) throws Exception {
        VeniceCompressor compressor;
        Schema parse = Schema.parse(VALUE_SCHEMA_FOR_COMPUTE);
        ArrayList arrayList = new ArrayList(i);
        for (int i3 = 0; i3 < i; i3++) {
            GenericData.Record record = new GenericData.Record(parse);
            record.put("id", str3 + i3);
            String str4 = str3 + i3 + "_name";
            if (z) {
                char[] cArr = new char[1048577];
                Arrays.fill(cArr, Integer.toString(i3).charAt(0));
                str4 = new String(cArr);
            }
            record.put("name", str4);
            record.put("namemap", Collections.emptyMap());
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add(Float.valueOf(i3 + 1));
            arrayList2.add(Float.valueOf((i3 + 1) * 10));
            record.put("member_feature", arrayList2);
            arrayList.add(i3, this.valueSerializer.serialize(str, record));
        }
        Optional empty = Optional.empty();
        if (compressionStrategy.equals(CompressionStrategy.ZSTD_WITH_DICT)) {
            ZstdDictTrainer zstdDictTrainer = new ZstdDictTrainer(209715200, 102400);
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                zstdDictTrainer.addSample((byte[]) it.next());
            }
            byte[] trainSamples = zstdDictTrainer.trainSamples();
            empty = Optional.of(ByteBuffer.wrap(trainSamples));
            compressor = this.compressorFactory.createVersionSpecificCompressorIfNotExist(compressionStrategy, str, trainSamples);
        } else {
            compressor = this.compressorFactory.getCompressor(compressionStrategy);
        }
        veniceWriter.broadcastStartOfPush(false, z, compressionStrategy, empty, new HashMap());
        Future[] futureArr = new Future[i];
        for (int i4 = 0; i4 < i; i4++) {
            futureArr[i4] = veniceWriter.put(str2 + i4, compressor.compress((byte[]) arrayList.get(i4)), this.valueSchemaId);
        }
        for (int i5 = 0; i5 < i; i5++) {
            futureArr[i5].get();
        }
        veniceWriter.broadcastEndOfPush(new HashMap());
        ControllerClient controllerClient = new ControllerClient(veniceClusterWrapper.getClusterName(), veniceClusterWrapper.getAllControllersURLs());
        try {
            TestUtils.waitForNonDeterministicAssertion(30L, TimeUnit.SECONDS, true, () -> {
                if (controllerClient.queryJobStatus(str).getStatus().equals(ExecutionStatus.ERROR.name())) {
                    throw new VeniceException("Push failed.");
                }
                int currentVersion = controllerClient.getStore(this.storeName).getStore().getCurrentVersion();
                if (currentVersion == i2) {
                    veniceClusterWrapper.refreshAllRouterMetaData();
                }
                Assert.assertEquals(currentVersion, i2, "New version not online yet.");
            });
            controllerClient.close();
        } catch (Throwable th) {
            try {
                controllerClient.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }
}
