package com.datastax.oss.streaming.ai;

import com.datastax.oss.streaming.ai.embeddings.EmbeddingsService;
import com.samskivert.mustache.Mustache;
import com.samskivert.mustache.Template;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.avro.Schema;

/* loaded from: input_file:META-INF/bundled-dependencies/streaming-ai-3.1.1.jar:com/datastax/oss/streaming/ai/ComputeAIEmbeddingsStep.class */
public class ComputeAIEmbeddingsStep implements TransformStep {
    private final Template template;
    private final String embeddingsFieldName;
    private final EmbeddingsService embeddingsService;
    private final Map<Schema, Schema> avroValueSchemaCache = new ConcurrentHashMap();
    private final Map<Schema, Schema> avroKeySchemaCache = new ConcurrentHashMap();

    public ComputeAIEmbeddingsStep(String str, String str2, EmbeddingsService embeddingsService) {
        this.template = Mustache.compiler().compile(str);
        this.embeddingsFieldName = str2;
        this.embeddingsService = embeddingsService;
    }

    @Override // com.datastax.oss.streaming.ai.TransformStep, java.lang.AutoCloseable
    public void close() throws Exception {
        if (this.embeddingsService != null) {
            this.embeddingsService.close();
        }
    }

    @Override // com.datastax.oss.streaming.ai.TransformStep
    public void process(TransformContext transformContext) {
        transformContext.setResultField(this.embeddingsService.computeEmbeddings(List.of(this.template.execute(transformContext.toJsonRecord()))).get(0), this.embeddingsFieldName, Schema.createArray(Schema.create(Schema.Type.DOUBLE)), this.avroKeySchemaCache, this.avroValueSchemaCache);
    }
}
