package com.datastax.oss.streaming.ai.embeddings;

import ai.djl.MalformedModelException;
import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory;
import ai.djl.inference.Predictor;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.lang.reflect.ParameterizedType;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:pulsar-transformations.nar:META-INF/bundled-dependencies/streaming-ai-3.1.4.jar:com/datastax/oss/streaming/ai/embeddings/AbstractHuggingFaceEmbeddingService.class */
public abstract class AbstractHuggingFaceEmbeddingService<IN, OUT> implements EmbeddingsService, AutoCloseable {
    public static final String URL_PREFIXES_SYSTEM_PROP = "ALLOWED_HF_URLS";
    public static final String DLJ_BASE_URL = "djl://ai.djl.huggingface.pytorch";
    ZooModel<IN, OUT> model;
    private static final Logger log = LoggerFactory.getLogger((Class<?>) AbstractHuggingFaceEmbeddingService.class);
    public static final Set<String> allowedUrlPrefixes = getHuggingFaceAllowedUrlPrefixes();
    private static final ThreadLocal<Predictor<?, ?>> predictorThreadLocal = new ThreadLocal<>();
    private static final ConcurrentLinkedQueue<Predictor<?, ?>> predictorList = new ConcurrentLinkedQueue<>();

    /* loaded from: input_file:pulsar-transformations.nar:META-INF/bundled-dependencies/streaming-ai-3.1.4.jar:com/datastax/oss/streaming/ai/embeddings/AbstractHuggingFaceEmbeddingService$HuggingFaceConfig.class */
    public static class HuggingFaceConfig {
        String engine;
        Map<String, String> options;
        Map<String, String> arguments;
        String modelUrl;
        String modelName;

        /* loaded from: input_file:pulsar-transformations.nar:META-INF/bundled-dependencies/streaming-ai-3.1.4.jar:com/datastax/oss/streaming/ai/embeddings/AbstractHuggingFaceEmbeddingService$HuggingFaceConfig$HuggingFaceConfigBuilder.class */
        public static class HuggingFaceConfigBuilder {
            private boolean engine$set;
            private String engine$value;
            private boolean options$set;
            private Map<String, String> options$value;
            private boolean arguments$set;
            private Map<String, String> arguments$value;
            private String modelUrl;
            private String modelName;

            HuggingFaceConfigBuilder() {
            }

            public HuggingFaceConfigBuilder engine(String str) {
                this.engine$value = str;
                this.engine$set = true;
                return this;
            }

            public HuggingFaceConfigBuilder options(Map<String, String> map) {
                this.options$value = map;
                this.options$set = true;
                return this;
            }

            public HuggingFaceConfigBuilder arguments(Map<String, String> map) {
                this.arguments$value = map;
                this.arguments$set = true;
                return this;
            }

            public HuggingFaceConfigBuilder modelUrl(String str) {
                this.modelUrl = str;
                return this;
            }

            public HuggingFaceConfigBuilder modelName(String str) {
                this.modelName = str;
                return this;
            }

            public HuggingFaceConfig build() {
                String str = this.engine$value;
                if (!this.engine$set) {
                    str = HuggingFaceConfig.$default$engine();
                }
                Map<String, String> map = this.options$value;
                if (!this.options$set) {
                    map = HuggingFaceConfig.$default$options();
                }
                Map<String, String> map2 = this.arguments$value;
                if (!this.arguments$set) {
                    map2 = HuggingFaceConfig.$default$arguments();
                }
                return new HuggingFaceConfig(str, map, map2, this.modelUrl, this.modelName);
            }

            public String toString() {
                return "AbstractHuggingFaceEmbeddingService.HuggingFaceConfig.HuggingFaceConfigBuilder(engine$value=" + this.engine$value + ", options$value=" + this.options$value + ", arguments$value=" + this.arguments$value + ", modelUrl=" + this.modelUrl + ", modelName=" + this.modelName + ")";
            }
        }

        private static String $default$engine() {
            return "PyTorch";
        }

        private static Map<String, String> $default$options() {
            return Map.of();
        }

        private static Map<String, String> $default$arguments() {
            return Map.of();
        }

        HuggingFaceConfig(String str, Map<String, String> map, Map<String, String> map2, String str2, String str3) {
            this.engine = str;
            this.options = map;
            this.arguments = map2;
            this.modelUrl = str2;
            this.modelName = str3;
        }

        public static HuggingFaceConfigBuilder builder() {
            return new HuggingFaceConfigBuilder();
        }

        public String getEngine() {
            return this.engine;
        }

        public Map<String, String> getOptions() {
            return this.options;
        }

        public Map<String, String> getArguments() {
            return this.arguments;
        }

        public String getModelUrl() {
            return this.modelUrl;
        }

        public String getModelName() {
            return this.modelName;
        }

        public void setEngine(String str) {
            this.engine = str;
        }

        public void setOptions(Map<String, String> map) {
            this.options = map;
        }

        public void setArguments(Map<String, String> map) {
            this.arguments = map;
        }

        public void setModelUrl(String str) {
            this.modelUrl = str;
        }

        public void setModelName(String str) {
            this.modelName = str;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof HuggingFaceConfig)) {
                return false;
            }
            HuggingFaceConfig huggingFaceConfig = (HuggingFaceConfig) obj;
            if (!huggingFaceConfig.canEqual(this)) {
                return false;
            }
            String engine = getEngine();
            String engine2 = huggingFaceConfig.getEngine();
            if (engine == null) {
                if (engine2 != null) {
                    return false;
                }
            } else if (!engine.equals(engine2)) {
                return false;
            }
            Map<String, String> options = getOptions();
            Map<String, String> options2 = huggingFaceConfig.getOptions();
            if (options == null) {
                if (options2 != null) {
                    return false;
                }
            } else if (!options.equals(options2)) {
                return false;
            }
            Map<String, String> arguments = getArguments();
            Map<String, String> arguments2 = huggingFaceConfig.getArguments();
            if (arguments == null) {
                if (arguments2 != null) {
                    return false;
                }
            } else if (!arguments.equals(arguments2)) {
                return false;
            }
            String modelUrl = getModelUrl();
            String modelUrl2 = huggingFaceConfig.getModelUrl();
            if (modelUrl == null) {
                if (modelUrl2 != null) {
                    return false;
                }
            } else if (!modelUrl.equals(modelUrl2)) {
                return false;
            }
            String modelName = getModelName();
            String modelName2 = huggingFaceConfig.getModelName();
            return modelName == null ? modelName2 == null : modelName.equals(modelName2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof HuggingFaceConfig;
        }

        public int hashCode() {
            String engine = getEngine();
            int hashCode = (1 * 59) + (engine == null ? 43 : engine.hashCode());
            Map<String, String> options = getOptions();
            int hashCode2 = (hashCode * 59) + (options == null ? 43 : options.hashCode());
            Map<String, String> arguments = getArguments();
            int hashCode3 = (hashCode2 * 59) + (arguments == null ? 43 : arguments.hashCode());
            String modelUrl = getModelUrl();
            int hashCode4 = (hashCode3 * 59) + (modelUrl == null ? 43 : modelUrl.hashCode());
            String modelName = getModelName();
            return (hashCode4 * 59) + (modelName == null ? 43 : modelName.hashCode());
        }

        public String toString() {
            return "AbstractHuggingFaceEmbeddingService.HuggingFaceConfig(engine=" + getEngine() + ", options=" + getOptions() + ", arguments=" + getArguments() + ", modelUrl=" + getModelUrl() + ", modelName=" + getModelName() + ")";
        }
    }

    private static Set<String> getHuggingFaceAllowedUrlPrefixes() {
        String str = System.getenv(URL_PREFIXES_SYSTEM_PROP);
        if (str == null || str.isEmpty()) {
            str = System.getProperty(URL_PREFIXES_SYSTEM_PROP);
        }
        if (str == null || str.isEmpty()) {
            str = "file://,djl://ai.djl.huggingface.pytorch";
        }
        return Set.of((Object[]) str.split(","));
    }

    @Override // com.datastax.oss.streaming.ai.embeddings.EmbeddingsService, java.lang.AutoCloseable
    public void close() throws Exception {
        while (!predictorList.isEmpty()) {
            Predictor<?, ?> poll = predictorList.poll();
            if (poll != null) {
                poll.close();
            }
        }
        if (this.model != null) {
            this.model.close();
        }
    }

    public AbstractHuggingFaceEmbeddingService(HuggingFaceConfig huggingFaceConfig) throws IOException, ModelNotFoundException, MalformedModelException, IllegalAccessException {
        Objects.requireNonNull(huggingFaceConfig);
        Objects.requireNonNull(huggingFaceConfig.modelName);
        checkIfUrlIsAllowed(huggingFaceConfig.modelUrl);
        Criteria.Builder types = Criteria.builder().setTypes((Class) ((ParameterizedType) getClass().getGenericSuperclass()).getActualTypeArguments()[0], (Class) ((ParameterizedType) getClass().getGenericSuperclass()).getActualTypeArguments()[1]);
        types.optModelUrls(huggingFaceConfig.modelUrl);
        log.info("Loading model from {}", huggingFaceConfig.modelUrl);
        if (huggingFaceConfig.modelName != null) {
            types.optModelName(huggingFaceConfig.modelName);
        }
        if (huggingFaceConfig.engine != null) {
            types.optEngine(huggingFaceConfig.engine);
        } else {
            types.optEngine("PyTorch");
        }
        if (huggingFaceConfig.options != null && !huggingFaceConfig.options.isEmpty()) {
            Map<String, String> map = huggingFaceConfig.options;
            Objects.requireNonNull(types);
            map.forEach(types::optOption);
        }
        if (huggingFaceConfig.arguments != null && !huggingFaceConfig.arguments.isEmpty()) {
            Map<String, String> map2 = huggingFaceConfig.arguments;
            Objects.requireNonNull(types);
            map2.forEach((v1, v2) -> {
                r1.optArgument(v1, v2);
            });
        }
        types.optTranslatorFactory(new TextEmbeddingTranslatorFactory());
        this.model = types.build().loadModel();
    }

    private void checkIfUrlIsAllowed(String str) throws IllegalAccessException {
        Iterator<String> it = allowedUrlPrefixes.iterator();
        while (it.hasNext()) {
            if (str.startsWith(it.next())) {
                return;
            }
        }
        throw new IllegalAccessException("modelUrl is not allowed: " + str);
    }

    public List<OUT> compute(List<IN> list) throws TranslateException {
        Predictor<?, ?> predictor = predictorThreadLocal.get();
        if (predictor == null) {
            predictor = this.model.newPredictor();
            predictorThreadLocal.set(predictor);
            predictorList.add(predictor);
        }
        return predictor.batchPredict(list);
    }

    abstract List<IN> convertInput(List<String> list);

    abstract List<List<Double>> convertOutput(List<OUT> list);

    @Override // com.datastax.oss.streaming.ai.embeddings.EmbeddingsService
    public List<List<Double>> computeEmbeddings(List<String> list) {
        try {
            return convertOutput(compute(convertInput(list)));
        } catch (TranslateException e) {
            log.error("failed to run compute", e);
            throw new RuntimeException("failed to run compute", e);
        }
    }
}
