/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.IntPredicate;
import org.apache.lucene.util.PriorityQueue;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.TaskType;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;

public class QuestionAnsweringProcessor
extends NlpTask.Processor {
    QuestionAnsweringProcessor(NlpTokenizer tokenizer) {
        super(tokenizer);
    }

    @Override
    public void validateInputs(List<String> inputs) {
    }

    @Override
    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig nlpConfig) {
        if (nlpConfig instanceof QuestionAnsweringConfig) {
            QuestionAnsweringConfig questionAnsweringConfig = (QuestionAnsweringConfig)nlpConfig;
            return new RequestBuilder(this.tokenizer, questionAnsweringConfig.getQuestion());
        }
        throw ExceptionsHelper.badRequestException((String)"please provide configuration update for question_answering task including the desired [question]", (Object[])new Object[0]);
    }

    @Override
    public NlpTask.ResultProcessor getResultProcessor(NlpConfig nlpConfig) {
        if (nlpConfig instanceof QuestionAnsweringConfig) {
            QuestionAnsweringConfig questionAnsweringConfig = (QuestionAnsweringConfig)nlpConfig;
            int maxAnswerLength = questionAnsweringConfig.getMaxAnswerLength();
            int numTopClasses = questionAnsweringConfig.getNumTopClasses();
            String resultsFieldValue = questionAnsweringConfig.getResultsField();
            return new ResultProcessor(questionAnsweringConfig.getQuestion(), maxAnswerLength, numTopClasses, resultsFieldValue);
        }
        throw ExceptionsHelper.badRequestException((String)"please provide configuration update for question_answering task including the desired [question]", (Object[])new Object[0]);
    }

    static void topScores(double[] start, double[] end, int numAnswersToGather, Consumer<ScoreAndIndices> topScoresCollector, int seq2Start, int tokenSize, int maxAnswerLength, int spanIndex) {
        if (start.length != end.length) {
            throw new ElasticsearchStatusException("question answering result has invalid dimensions; possible start tokens [{}] must equal possible end tokens [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{start.length, end.length});
        }
        double[] startNormalized = QuestionAnsweringProcessor.normalizeWith(start, i -> {
            if (i == 0) {
                return true;
            }
            return i >= seq2Start && i < tokenSize - 1;
        }, -10000.0);
        double[] endNormalized = QuestionAnsweringProcessor.normalizeWith(end, i -> {
            if (i == 0) {
                return true;
            }
            return i >= seq2Start && i < tokenSize - 1;
        }, -10000.0);
        startNormalized[0] = 0.0;
        endNormalized[0] = 0.0;
        if (numAnswersToGather == 1) {
            ScoreAndIndices toReturn = new ScoreAndIndices(0, 0, 0.0, spanIndex);
            double maxScore = 0.0;
            for (int i2 = seq2Start; i2 < tokenSize; ++i2) {
                if (startNormalized[i2] == 0.0) continue;
                for (int j = i2; j < maxAnswerLength + i2 && j < tokenSize; ++j) {
                    double score = startNormalized[i2] * endNormalized[j];
                    if (!(score > maxScore)) continue;
                    maxScore = score;
                    toReturn = new ScoreAndIndices(i2 - seq2Start, j - seq2Start, score, spanIndex);
                }
            }
            topScoresCollector.accept(toReturn);
            return;
        }
        for (int i3 = seq2Start; i3 < tokenSize; ++i3) {
            for (int j = i3; j < maxAnswerLength + i3 && j < tokenSize; ++j) {
                topScoresCollector.accept(new ScoreAndIndices(i3 - seq2Start, j - seq2Start, startNormalized[i3] * endNormalized[j], spanIndex));
            }
        }
    }

    static double[] normalizeWith(double[] values, IntPredicate mutateIndex, double predicateValue) {
        double[] toReturn = new double[values.length];
        for (int i = 0; i < values.length; ++i) {
            toReturn[i] = values[i];
            if (mutateIndex.test(i)) continue;
            toReturn[i] = predicateValue;
        }
        double expSum = 0.0;
        for (double v : toReturn) {
            expSum += Math.exp(v);
        }
        double diff = Math.log(expSum);
        for (int i = 0; i < toReturn.length; ++i) {
            toReturn[i] = Math.exp(toReturn[i] - diff);
        }
        return toReturn;
    }

    record RequestBuilder(NlpTokenizer tokenizer, String question) implements NlpTask.RequestBuilder
    {
        @Override
        public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate, int span, Integer windowSize) throws IOException {
            if (inputs.size() > 1) {
                throw ExceptionsHelper.badRequestException((String)"Unable to do question answering on more than one text input at a time", (Object[])new Object[0]);
            }
            String context = inputs.get(0);
            List<TokenizationResult.Tokens> tokenizations = this.tokenizer.tokenize(this.question, context, truncate, span, 0);
            TokenizationResult result = this.tokenizer.buildTokenizationResult(tokenizations);
            return result.buildRequest(requestId, truncate);
        }
    }

    record ResultProcessor(String question, int maxAnswerLength, int numTopClasses, String resultsField) implements NlpTask.ResultProcessor
    {
        @Override
        public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult, boolean chunkResult) {
            if (chunkResult) {
                throw NlpTask.Processor.chunkingNotSupportedException(TaskType.NER);
            }
            if (pyTorchResult.getInferenceResult().length < 1) {
                throw new ElasticsearchStatusException("question answering result has no data", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
            }
            if (pyTorchResult.getInferenceResult().length % 2 != 0) {
                throw new ElasticsearchStatusException("question answering result has invalid dimension, number of dimensions must be a multiple of 2 found [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{pyTorchResult.getInferenceResult().length});
            }
            int numAnswersToGather = Math.max(this.numTopClasses, 1);
            ScoreAndIndicesPriorityQueue finalEntries = new ScoreAndIndicesPriorityQueue(numAnswersToGather);
            List<TokenizationResult.Tokens> tokensList = tokenization.getTokensBySequenceId().get(0);
            int numberOfSpans = pyTorchResult.getInferenceResult().length / 2;
            if (numberOfSpans != tokensList.size()) {
                throw new ElasticsearchStatusException("question answering result has invalid dimensions; the number of spans [{}] does not match batched token size [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{numberOfSpans, tokensList.size()});
            }
            for (int spanIndex = 0; spanIndex < numberOfSpans; ++spanIndex) {
                double[][] starts = pyTorchResult.getInferenceResult()[spanIndex * 2];
                double[][] ends = pyTorchResult.getInferenceResult()[spanIndex * 2 + 1];
                assert (starts.length == 1);
                assert (ends.length == 1);
                if (starts.length != ends.length) {
                    throw new ElasticsearchStatusException("question answering result has invalid dimensions; start positions [{}] must equal potential end [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{starts.length, ends.length});
                }
                QuestionAnsweringProcessor.topScores(starts[0], ends[0], numAnswersToGather, arg_0 -> ((ScoreAndIndicesPriorityQueue)finalEntries).insertWithOverflow(arg_0), tokensList.get(spanIndex).seqPairOffset(), tokensList.get(spanIndex).tokenIds().length, this.maxAnswerLength, spanIndex);
            }
            QuestionAnsweringInferenceResults.TopAnswerEntry[] topAnswerList = new QuestionAnsweringInferenceResults.TopAnswerEntry[numAnswersToGather];
            for (int i = numAnswersToGather - 1; i >= 0; --i) {
                ScoreAndIndices scoreAndIndices = (ScoreAndIndices)finalEntries.pop();
                TokenizationResult.Tokens tokens = tokensList.get(scoreAndIndices.spanIndex());
                int startOffset = tokens.tokens().get(1).get(scoreAndIndices.startToken).startOffset();
                int endOffset = tokens.tokens().get(1).get(scoreAndIndices.endToken).endOffset();
                topAnswerList[i] = new QuestionAnsweringInferenceResults.TopAnswerEntry(tokens.input().get(1).substring(startOffset, endOffset), scoreAndIndices.score(), startOffset, endOffset);
            }
            QuestionAnsweringInferenceResults.TopAnswerEntry finalAnswer = topAnswerList[0];
            return new QuestionAnsweringInferenceResults(finalAnswer.answer(), finalAnswer.startOffset(), finalAnswer.endOffset(), this.numTopClasses > 0 ? Arrays.asList(topAnswerList) : List.of(), Optional.ofNullable(this.resultsField).orElse("predicted_value"), finalAnswer.score(), tokenization.anyTruncated());
        }
    }

    record ScoreAndIndices(int startToken, int endToken, double score, int spanIndex) implements Comparable<ScoreAndIndices>
    {
        @Override
        public int compareTo(ScoreAndIndices o) {
            return Double.compare(this.score, o.score);
        }
    }

    static class ScoreAndIndicesPriorityQueue
    extends PriorityQueue<ScoreAndIndices> {
        ScoreAndIndicesPriorityQueue(int maxSize) {
            super(maxSize);
        }

        protected boolean lessThan(ScoreAndIndices a, ScoreAndIndices b) {
            return a.compareTo(b) < 0;
        }
    }
}

