/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;

public class TextExpansionProcessor
extends NlpTask.Processor {
    private final NlpTask.RequestBuilder requestBuilder;
    private Map<Integer, String> replacementVocab;

    public TextExpansionProcessor(NlpTokenizer tokenizer) {
        super(tokenizer);
        this.requestBuilder = tokenizer.requestBuilder();
        this.replacementVocab = TextExpansionProcessor.buildSanitizedVocabMap(tokenizer.getVocabulary());
    }

    static Map<Integer, String> buildSanitizedVocabMap(List<String> inputVocab) {
        HashMap<Integer, String> sanitized = new HashMap<Integer, String>();
        for (int i = 0; i < inputVocab.size(); ++i) {
            if (!inputVocab.get(i).contains(".")) continue;
            sanitized.put(i, inputVocab.get(i).replace(".", "__"));
        }
        return sanitized;
    }

    @Override
    public void validateInputs(List<String> inputs) {
    }

    @Override
    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
        return this.requestBuilder;
    }

    @Override
    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
        return (tokenization, pyTorchResult, chunkResults) -> TextExpansionProcessor.processResult(tokenization, pyTorchResult, this.replacementVocab, config.getResultsField(), chunkResults);
    }

    static InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult, Map<Integer, String> replacementVocab, String resultsField, boolean chunkResults) {
        if (chunkResults) {
            ArrayList<ChunkedTextExpansionResults.ChunkedResult> chunkedResults = new ArrayList<ChunkedTextExpansionResults.ChunkedResult>();
            for (int i = 0; i < pyTorchResult.getInferenceResult()[0].length; ++i) {
                int startOffset = tokenization.getTokenization(i).tokens().get(0).get(0).startOffset();
                int lastIndex = tokenization.getTokenization(i).tokens().get(0).size() - 1;
                int endOffset = tokenization.getTokenization(i).tokens().get(0).get(lastIndex).endOffset();
                String matchedText = tokenization.getTokenization(i).input().get(0).substring(startOffset, endOffset);
                List<TextExpansionResults.WeightedToken> weightedTokens = TextExpansionProcessor.sparseVectorToTokenWeights(pyTorchResult.getInferenceResult()[0][i], tokenization, replacementVocab);
                weightedTokens.sort((t1, t2) -> Float.compare(t2.weight(), t1.weight()));
                chunkedResults.add(new ChunkedTextExpansionResults.ChunkedResult(matchedText, weightedTokens));
            }
            return new ChunkedTextExpansionResults(Optional.ofNullable(resultsField).orElse("predicted_value"), chunkedResults, tokenization.anyTruncated());
        }
        List<TextExpansionResults.WeightedToken> weightedTokens = TextExpansionProcessor.sparseVectorToTokenWeights(pyTorchResult.getInferenceResult()[0][0], tokenization, replacementVocab);
        weightedTokens.sort((t1, t2) -> Float.compare(t2.weight(), t1.weight()));
        return new TextExpansionResults(Optional.ofNullable(resultsField).orElse("predicted_value"), weightedTokens, tokenization.anyTruncated());
    }

    static List<TextExpansionResults.WeightedToken> sparseVectorToTokenWeights(double[] vector, TokenizationResult tokenization, Map<Integer, String> replacementVocab) {
        ArrayList<TextExpansionResults.WeightedToken> weightedTokens = new ArrayList<TextExpansionResults.WeightedToken>();
        for (int i = 0; i < vector.length; ++i) {
            if (!(vector[i] > 0.0)) continue;
            weightedTokens.add(new TextExpansionResults.WeightedToken(TextExpansionProcessor.tokenForId(i, tokenization, replacementVocab), (float)vector[i]));
        }
        return weightedTokens;
    }

    static String tokenForId(int id, TokenizationResult tokenization, Map<Integer, String> replacementVocab) {
        String token = replacementVocab.get(id);
        if (token == null) {
            token = tokenization.getFromVocab(id);
        }
        return token;
    }
}

