/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;

public class TextEmbeddingProcessor
extends NlpTask.Processor {
    private final NlpTask.RequestBuilder requestBuilder;

    TextEmbeddingProcessor(NlpTokenizer tokenizer) {
        super(tokenizer);
        this.requestBuilder = tokenizer.requestBuilder();
    }

    @Override
    public void validateInputs(List<String> inputs) {
    }

    @Override
    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
        return this.requestBuilder;
    }

    @Override
    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
        return (tokenization, pyTorchResult, chunkResults) -> TextEmbeddingProcessor.processResult(tokenization, pyTorchResult, config.getResultsField(), chunkResults);
    }

    static InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult, String resultsField, boolean chunkResults) {
        if (chunkResults) {
            ArrayList<ChunkedTextEmbeddingResults.EmbeddingChunk> embeddings = new ArrayList<ChunkedTextEmbeddingResults.EmbeddingChunk>();
            for (int i = 0; i < pyTorchResult.getInferenceResult()[0].length; ++i) {
                int startOffset = tokenization.getTokenization(i).tokens().get(0).get(0).startOffset();
                int lastIndex = tokenization.getTokenization(i).tokens().get(0).size() - 1;
                int endOffset = tokenization.getTokenization(i).tokens().get(0).get(lastIndex).endOffset();
                String matchedText = tokenization.getTokenization(i).input().get(0).substring(startOffset, endOffset);
                embeddings.add(new ChunkedTextEmbeddingResults.EmbeddingChunk(matchedText, pyTorchResult.getInferenceResult()[0][i]));
            }
            return new ChunkedTextEmbeddingResults(Optional.ofNullable(resultsField).orElse("predicted_value"), embeddings, tokenization.anyTruncated());
        }
        return new TextEmbeddingResults(Optional.ofNullable(resultsField).orElse("predicted_value"), pyTorchResult.getInferenceResult()[0][0], tokenization.anyTruncated());
    }
}

