/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.NlpHelpers;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;

public class FillMaskProcessor
implements NlpTask.Processor {
    private final NlpTokenizer tokenizer;

    FillMaskProcessor(NlpTokenizer tokenizer, FillMaskConfig config) {
        this.tokenizer = tokenizer;
    }

    @Override
    public void validateInputs(List<String> inputs) {
        ValidationException ve = new ValidationException();
        if (inputs.isEmpty()) {
            ve.addValidationError("input request is empty");
        }
        String mask = this.tokenizer.getMaskToken();
        for (String input : inputs) {
            int maskIndex = input.indexOf(mask);
            if (maskIndex < 0) {
                ve.addValidationError("no " + mask + " token could be found in the input");
            }
            if ((maskIndex = input.indexOf(mask, maskIndex + mask.length())) <= 0) continue;
            throw ExceptionsHelper.badRequestException((String)"only one {} token should exist in the input", (Object[])new Object[]{mask});
        }
        if (!ve.validationErrors().isEmpty()) {
            throw ve;
        }
    }

    @Override
    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
        return this.tokenizer.requestBuilder();
    }

    @Override
    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
        if (config instanceof FillMaskConfig) {
            FillMaskConfig fillMaskConfig = (FillMaskConfig)config;
            return (tokenization, result) -> FillMaskProcessor.processResult(tokenization, result, this.tokenizer, fillMaskConfig.getNumTopClasses(), fillMaskConfig.getResultsField());
        }
        return (tokenization, result) -> FillMaskProcessor.processResult(tokenization, result, this.tokenizer, 5, "predicted_value");
    }

    static InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult, NlpTokenizer tokenizer, int numResults, String resultsField) {
        if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) {
            throw new ElasticsearchStatusException("tokenization is empty", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }
        if (tokenizer.getMaskTokenId().isEmpty()) {
            throw ExceptionsHelper.conflictStatusException((String)"The token id for the mask token {} is not known in the tokenizer. Check the vocabulary contains the mask token", (Object[])new Object[]{tokenizer.getMaskToken()});
        }
        int maskTokenIndex = -1;
        int maskTokenId = tokenizer.getMaskTokenId().getAsInt();
        for (int i = 0; i < tokenization.getTokenizations().get(0).getTokenIds().length; ++i) {
            if (tokenization.getTokenizations().get(0).getTokenIds()[i] != maskTokenId) continue;
            maskTokenIndex = i;
            break;
        }
        if (maskTokenIndex == -1) {
            throw new ElasticsearchStatusException("mask token id [{}] not found in the tokenization {}", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{maskTokenId, List.of(tokenization.getTokenizations().get(0).getTokenIds())});
        }
        double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][maskTokenIndex]);
        NlpHelpers.ScoreAndIndex[] scoreAndIndices = NlpHelpers.topK(numResults == -1 ? Integer.MAX_VALUE : Math.max(numResults, 1), normalizedScores);
        ArrayList<TopClassEntry> results = new ArrayList<TopClassEntry>(scoreAndIndices.length);
        if (numResults != 0) {
            for (NlpHelpers.ScoreAndIndex scoreAndIndex : scoreAndIndices) {
                String predictedToken = tokenization.getFromVocab(scoreAndIndex.index);
                results.add(new TopClassEntry((Object)predictedToken, scoreAndIndex.score, scoreAndIndex.score));
            }
        }
        return new FillMaskResults(tokenization.getFromVocab(scoreAndIndices[0].index), tokenization.getTokenizations().get(0).getInput().replace(tokenizer.getMaskToken(), tokenization.getFromVocab(scoreAndIndices[0].index)), results, Optional.ofNullable(resultsField).orElse("predicted_value"), Double.valueOf(scoreAndIndices[0].score), tokenization.anyTruncated());
    }
}

