/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.NlpHelpers;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;

public class ZeroShotClassificationProcessor
extends NlpTask.Processor {
    private final int entailmentPos;
    private final int contraPos;
    private final String[] labels;
    private final String hypothesisTemplate;
    private final boolean isMultiLabel;
    private final String resultsField;

    ZeroShotClassificationProcessor(NlpTokenizer tokenizer, ZeroShotClassificationConfig config) {
        super(tokenizer);
        List<String> lowerCased = config.getClassificationLabels().stream().map(s -> s.toLowerCase(Locale.ROOT)).toList();
        this.entailmentPos = lowerCased.indexOf("entailment");
        this.contraPos = lowerCased.indexOf("contradiction");
        if (this.entailmentPos == -1 || this.contraPos == -1) {
            throw ExceptionsHelper.badRequestException((String)"zero_shot_classification requires [entailment] and [contradiction] in classification_labels", (Object[])new Object[0]);
        }
        this.labels = (String[])config.getLabels().orElse(List.of()).toArray(String[]::new);
        this.hypothesisTemplate = config.getHypothesisTemplate();
        this.isMultiLabel = config.isMultiLabel();
        this.resultsField = config.getResultsField();
    }

    @Override
    public void validateInputs(List<String> inputs) {
    }

    @Override
    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig nlpConfig) {
        String[] labelsValue;
        if (nlpConfig instanceof ZeroShotClassificationConfig) {
            ZeroShotClassificationConfig zeroShotConfig = (ZeroShotClassificationConfig)nlpConfig;
            labelsValue = zeroShotConfig.getLabels().orElse(List.of()).toArray(new String[0]);
        } else {
            labelsValue = this.labels;
        }
        if (labelsValue == null || labelsValue.length == 0) {
            throw ExceptionsHelper.badRequestException((String)"zero_shot_classification requires non-empty [labels]", (Object[])new Object[0]);
        }
        return new RequestBuilder(this.tokenizer, labelsValue, this.hypothesisTemplate);
    }

    @Override
    public NlpTask.ResultProcessor getResultProcessor(NlpConfig nlpConfig) {
        String resultsFieldValue;
        boolean isMultiLabelValue;
        String[] labelsValue;
        if (nlpConfig instanceof ZeroShotClassificationConfig) {
            ZeroShotClassificationConfig zeroShotConfig = (ZeroShotClassificationConfig)nlpConfig;
            labelsValue = zeroShotConfig.getLabels().orElse(List.of()).toArray(new String[0]);
            isMultiLabelValue = zeroShotConfig.isMultiLabel();
            resultsFieldValue = zeroShotConfig.getResultsField();
        } else {
            labelsValue = this.labels;
            isMultiLabelValue = this.isMultiLabel;
            resultsFieldValue = this.resultsField;
        }
        return new ResultProcessor(this.entailmentPos, this.contraPos, labelsValue, isMultiLabelValue, resultsFieldValue);
    }

    record RequestBuilder(NlpTokenizer tokenizer, String[] labels, String hypothesisTemplate) implements NlpTask.RequestBuilder
    {
        @Override
        public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate, int span) throws IOException {
            if (inputs.size() > 1) {
                throw ExceptionsHelper.badRequestException((String)"Unable to do zero-shot classification on more than one text input at a time", (Object[])new Object[0]);
            }
            if (span > -1) {
                throw ExceptionsHelper.badRequestException((String)"Unable to span zero-shot classification on long text input", (Object[])new Object[0]);
            }
            ArrayList<TokenizationResult.Tokens> tokenizations = new ArrayList<TokenizationResult.Tokens>(this.labels.length);
            int seqId = 0;
            NlpTokenizer.InnerTokenization firstSequenceTokenization = this.tokenizer.innerTokenize(inputs.get(0));
            for (String label : this.labels) {
                tokenizations.add(this.tokenizer.tokenize(inputs.get(0), firstSequenceTokenization, LoggerMessageFormat.format(null, (String)this.hypothesisTemplate, (Object[])new Object[]{label}), truncate, seqId++));
            }
            TokenizationResult result = this.tokenizer.buildTokenizationResult(tokenizations);
            return result.buildRequest(requestId, truncate);
        }
    }

    record ResultProcessor(int entailmentPos, int contraPos, String[] labels, boolean isMultiLabel, String resultsField) implements NlpTask.ResultProcessor
    {
        @Override
        public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult) {
            double[] normalizedScores;
            if (pyTorchResult.getInferenceResult().length < 1) {
                throw new ElasticsearchStatusException("Zero shot classification result has no data", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
            }
            if (pyTorchResult.getInferenceResult()[0].length != this.labels.length) {
                throw new ElasticsearchStatusException("Expected exactly [{}] values in zero shot classification result; got [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{this.labels.length, pyTorchResult.getInferenceResult().length});
            }
            if (this.isMultiLabel) {
                normalizedScores = new double[pyTorchResult.getInferenceResult()[0].length];
                int v = 0;
                for (double[] vals : pyTorchResult.getInferenceResult()[0]) {
                    if (vals.length != 3) {
                        throw new ElasticsearchStatusException("Expected exactly [{}] values in inner zero shot classification result; got [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{3, vals.length});
                    }
                    normalizedScores[v++] = NlpHelpers.convertToProbabilitiesBySoftMax(new double[]{vals[this.entailmentPos], vals[this.contraPos]})[0];
                }
            } else {
                double[] entailmentScores = new double[pyTorchResult.getInferenceResult()[0].length];
                int v = 0;
                for (double[] vals : pyTorchResult.getInferenceResult()[0]) {
                    if (vals.length != 3) {
                        throw new ElasticsearchStatusException("Expected exactly [{}] values in inner zero shot classification result; got [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{3, vals.length});
                    }
                    entailmentScores[v++] = vals[this.entailmentPos];
                }
                normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(entailmentScores);
            }
            int[] sortedIndices = IntStream.range(0, normalizedScores.length).boxed().sorted(Comparator.comparing(i -> normalizedScores[(Integer)i]).reversed()).mapToInt(i -> i).toArray();
            return new NlpClassificationInferenceResults(this.labels[sortedIndices[0]], Arrays.stream(sortedIndices).mapToObj(i -> new TopClassEntry((Object)this.labels[i], normalizedScores[i])).collect(Collectors.toList()), Optional.ofNullable(this.resultsField).orElse("predicted_value"), Double.valueOf(normalizedScores[sortedIndices[0]]), tokenization.anyTruncated());
        }
    }
}

