/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
import org.elasticsearch.xpack.ml.inference.nlp.NlpHelpers;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;

public class TextClassificationProcessor
implements NlpTask.Processor {
    private final NlpTask.RequestBuilder requestBuilder;
    private final String[] classLabels;
    private final int numTopClasses;

    TextClassificationProcessor(NlpTokenizer tokenizer, TextClassificationConfig config) {
        this.requestBuilder = tokenizer.requestBuilder();
        List classLabels = config.getClassificationLabels();
        this.classLabels = (String[])classLabels.toArray(String[]::new);
        this.numTopClasses = config.getNumTopClasses() < 0 ? this.classLabels.length : config.getNumTopClasses();
    }

    @Override
    public void validateInputs(List<String> inputs) {
    }

    @Override
    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
        return this.requestBuilder;
    }

    @Override
    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
        if (config instanceof TextClassificationConfig) {
            TextClassificationConfig textClassificationConfig = (TextClassificationConfig)config;
            return (tokenization, pytorchResult) -> TextClassificationProcessor.processResult(tokenization, pytorchResult, textClassificationConfig.getNumTopClasses() < 0 ? textClassificationConfig.getClassificationLabels().size() : textClassificationConfig.getNumTopClasses(), textClassificationConfig.getClassificationLabels(), textClassificationConfig.getResultsField());
        }
        return (tokenization, pytorchResult) -> TextClassificationProcessor.processResult(tokenization, pytorchResult, this.numTopClasses, Arrays.asList(this.classLabels), "predicted_value");
    }

    static InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult, int numTopClasses, List<String> labels, String resultsField) {
        if (pyTorchResult.getInferenceResult().length < 1) {
            throw new ElasticsearchStatusException("Text classification result has no data", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }
        if (pyTorchResult.getInferenceResult()[0][0].length != labels.size()) {
            throw new ElasticsearchStatusException("Expected exactly [{}] values in text classification result; got [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{labels.size(), pyTorchResult.getInferenceResult()[0][0].length});
        }
        double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][0]);
        int[] sortedIndices = IntStream.range(0, normalizedScores.length).boxed().sorted(Comparator.comparing(i -> normalizedScores[(Integer)i]).reversed()).mapToInt(i -> i).toArray();
        return new NlpClassificationInferenceResults(labels.get(sortedIndices[0]), Arrays.stream(sortedIndices).mapToObj(i -> new TopClassEntry(labels.get(i), normalizedScores[i])).limit(numTopClasses).collect(Collectors.toList()), Optional.ofNullable(resultsField).orElse("predicted_value"), Double.valueOf(normalizedScores[sortedIndices[0]]), tokenization.anyTruncated());
    }
}

