package ai.djl.modality.nlp.translator;

import ai.djl.Model;
import ai.djl.modality.nlp.Decoder;
import ai.djl.modality.nlp.Encoder;
import ai.djl.modality.nlp.EncoderDecoder;
import ai.djl.modality.nlp.embedding.TrainableTextEmbedding;
import ai.djl.modality.nlp.preprocess.LowerCaseConvertor;
import ai.djl.modality.nlp.preprocess.PunctuationSeparator;
import ai.djl.modality.nlp.preprocess.SimpleTokenizer;
import ai.djl.modality.nlp.preprocess.TextProcessor;
import ai.djl.modality.nlp.preprocess.TextTruncator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.SequentialBlock;
import ai.djl.translate.Batchifier;
import ai.djl.translate.PaddingStackBatchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;

/* loaded from: input_file:META-INF/bundled-dependencies/api-0.22.1.jar:ai/djl/modality/nlp/translator/SimpleText2TextTranslator.class */
public class SimpleText2TextTranslator implements Translator<String, String> {
    private TrainableTextEmbedding sourceEmbedding;
    private TrainableTextEmbedding targetEmbedding;
    private SimpleTokenizer tokenizer = new SimpleTokenizer();
    private List<TextProcessor> textProcessors = Arrays.asList(new SimpleTokenizer(), new LowerCaseConvertor(Locale.ENGLISH), new PunctuationSeparator(), new TextTruncator(10));

    @Override // ai.djl.translate.PostProcessor
    public String processOutput(TranslatorContext translatorContext, NDList nDList) {
        if (nDList.singletonOrThrow().getShape().dimension() > 2) {
            throw new IllegalArgumentException("Input must correspond to one sentence. Shape must be of 2 or less dimensions");
        }
        if (this.targetEmbedding == null) {
            this.targetEmbedding = (TrainableTextEmbedding) ((SequentialBlock) ((Decoder) ((EncoderDecoder) translatorContext.getModel().getBlock()).getChildren().get(1).getValue()).getChildren().get(0).getValue()).getChildren().get(0).getValue();
        }
        ArrayList arrayList = new ArrayList();
        for (String str : this.targetEmbedding.unembedText(nDList.singletonOrThrow().toType(DataType.INT32, false).flatten())) {
            if ("<eos>".equals(str)) {
                break;
            }
            arrayList.add(str);
        }
        return this.tokenizer.buildSentence(arrayList);
    }

    @Override // ai.djl.translate.PreProcessor
    public NDList processInput(TranslatorContext translatorContext, String str) {
        Model model = translatorContext.getModel();
        if (this.sourceEmbedding == null) {
            this.sourceEmbedding = (TrainableTextEmbedding) ((SequentialBlock) ((Encoder) ((EncoderDecoder) model.getBlock()).getChildren().get(0).getValue()).getChildren().get(0).getValue()).getChildren().get(0).getValue();
        }
        List<String> singletonList = Collections.singletonList(str);
        Iterator<TextProcessor> it = this.textProcessors.iterator();
        while (it.hasNext()) {
            singletonList = it.next().preprocess(singletonList);
        }
        return new NDList(model.getNDManager().create(this.sourceEmbedding.preprocessTextToEmbed(singletonList)), model.getNDManager().create(this.sourceEmbedding.preprocessTextToEmbed(Arrays.asList("<bos>"))));
    }

    @Override // ai.djl.translate.Translator
    public Batchifier getBatchifier() {
        return PaddingStackBatchifier.builder().optIncludeValidLengths(false).addPad(0, 0, this::get, 10).build();
    }

    private NDArray get(NDManager nDManager) {
        return nDManager.ones(new Shape(1)).mul(Long.valueOf(this.sourceEmbedding.preprocessTextToEmbed(Collections.singletonList("<pad>"))[0]));
    }
}
