package ai.djl.huggingface.translator;

import ai.djl.Model;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.nlp.translator.NamedEntity;
import ai.djl.modality.nlp.translator.TokenClassificationServingTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import java.io.IOException;
import java.io.Serializable;
import java.lang.reflect.Type;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:META-INF/bundled-dependencies/tokenizers-0.22.1.jar:ai/djl/huggingface/translator/TokenClassificationTranslatorFactory.class */
public class TokenClassificationTranslatorFactory implements TranslatorFactory, Serializable {
    private static final long serialVersionUID = 1;
    private static final Set<Pair<Type, Type>> SUPPORTED_TYPES = new HashSet();

    @Override // ai.djl.translate.TranslatorFactory
    public Set<Pair<Type, Type>> getSupportedTypes() {
        return SUPPORTED_TYPES;
    }

    @Override // ai.djl.translate.TranslatorFactory
    public <I, O> Translator<I, O> newInstance(Class<I> cls, Class<O> cls2, Model model, Map<String, ?> map) throws TranslateException {
        try {
            TokenClassificationTranslator build = TokenClassificationTranslator.builder(HuggingFaceTokenizer.builder(map).optTokenizerPath(model.getModelPath()).optManager(model.getNDManager()).build(), map).build();
            if (cls == String.class && cls2 == NamedEntity[].class) {
                return build;
            }
            if (cls == String[].class && cls2 == NamedEntity[][].class) {
                return (Translator<I, O>) build.toBatchTranslator();
            }
            if (cls == Input.class && cls2 == Output.class) {
                return new TokenClassificationServingTranslator(build);
            }
            throw new IllegalArgumentException("Unsupported input/output types.");
        } catch (IOException e) {
            throw new TranslateException("Failed to load tokenizer.", e);
        }
    }

    static {
        SUPPORTED_TYPES.add(new Pair<>(String.class, NamedEntity[].class));
        SUPPORTED_TYPES.add(new Pair<>(String[].class, NamedEntity[][].class));
        SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class));
    }
}
