/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.ml.inference.nlp.NlpHelpers;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.TaskType;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;

public class NerProcessor
extends NlpTask.Processor {
    static final IobTag[] DEFAULT_IOB_TAGS = new IobTag[]{IobTag.fromTag("O"), IobTag.fromTag("B_MISC"), IobTag.fromTag("I_MISC"), IobTag.fromTag("B_PER"), IobTag.fromTag("I_PER"), IobTag.fromTag("B_ORG"), IobTag.fromTag("I_ORG"), IobTag.fromTag("B_LOC"), IobTag.fromTag("I_LOC")};
    private final NlpTask.RequestBuilder requestBuilder;
    private final IobTag[] iobMap;
    private final String resultsField;
    private final boolean ignoreCase;

    NerProcessor(NlpTokenizer tokenizer, NerConfig config) {
        super(tokenizer);
        NerProcessor.validate(config.getClassificationLabels());
        this.iobMap = NerProcessor.buildIobMap(config.getClassificationLabels());
        this.requestBuilder = tokenizer.requestBuilder();
        this.resultsField = config.getResultsField();
        this.ignoreCase = config.getTokenization().doLowerCase();
    }

    private static void validate(List<String> classificationLabels) {
        if (classificationLabels == null || classificationLabels.isEmpty()) {
            return;
        }
        ValidationException ve = new ValidationException();
        HashSet<IobTag> tags = new HashSet<IobTag>();
        for (String label : classificationLabels) {
            try {
                IobTag iobTag = IobTag.fromTag(label);
                if (tags.contains(iobTag)) {
                    ve.addValidationError("the classification label [" + label + "] is duplicated in the list " + classificationLabels);
                }
                tags.add(iobTag);
            }
            catch (IllegalArgumentException iae) {
                ve.addValidationError("classification label [" + label + "] is not an entity I-O-B tag.");
            }
        }
        if (!ve.validationErrors().isEmpty()) {
            throw ve;
        }
    }

    static IobTag[] buildIobMap(List<String> classificationLabels) {
        if (classificationLabels == null || classificationLabels.isEmpty()) {
            return DEFAULT_IOB_TAGS;
        }
        IobTag[] map = new IobTag[classificationLabels.size()];
        for (int i = 0; i < classificationLabels.size(); ++i) {
            map[i] = IobTag.fromTag(classificationLabels.get(i));
        }
        return map;
    }

    @Override
    public void validateInputs(List<String> inputs) {
    }

    @Override
    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
        return this.requestBuilder;
    }

    @Override
    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
        if (config instanceof NerConfig) {
            NerConfig nerConfig = (NerConfig)config;
            return new NerResultProcessor(this.iobMap, nerConfig.getResultsField(), this.ignoreCase);
        }
        return new NerResultProcessor(this.iobMap, this.resultsField, this.ignoreCase);
    }

    static String buildAnnotatedText(String seq, List<NerResults.EntityGroup> entities) {
        if (entities.isEmpty()) {
            return seq;
        }
        StringBuilder annotatedResultBuilder = new StringBuilder();
        int curPos = 0;
        for (NerResults.EntityGroup entity : entities) {
            if (entity.getStartPos() == -1) continue;
            if (entity.getStartPos() != curPos) {
                annotatedResultBuilder.append(seq, curPos, entity.getStartPos());
            }
            String entitySeq = seq.substring(entity.getStartPos(), entity.getEndPos());
            annotatedResultBuilder.append("[").append(entitySeq).append("]").append("(").append(entity.getClassName()).append("&").append(entitySeq.replace(" ", "+")).append(")");
            curPos = entity.getEndPos();
        }
        if (curPos < seq.length()) {
            annotatedResultBuilder.append(seq, curPos, seq.length());
        }
        return annotatedResultBuilder.toString();
    }

    record IobTag(String tag, String entity) {
        static IobTag fromTag(String tag) {
            String entity = tag.toUpperCase(Locale.ROOT);
            if (entity.startsWith("B-") || entity.startsWith("I-") || entity.startsWith("B_") || entity.startsWith("I_")) {
                entity = entity.substring(2);
                return new IobTag(tag, entity);
            }
            if (entity.equals("O")) {
                return new IobTag(tag, entity);
            }
            throw new IllegalArgumentException("classification label [" + tag + "] is not an entity I-O-B tag.");
        }

        boolean isBeginning() {
            return this.tag.startsWith("b") || this.tag.startsWith("B");
        }

        boolean isNone() {
            return this.tag.equals("o") || this.tag.equals("O");
        }

        @Override
        public String toString() {
            return this.tag;
        }
    }

    record NerResultProcessor(IobTag[] iobMap, String resultsField, boolean ignoreCase) implements NlpTask.ResultProcessor
    {
        NerResultProcessor(IobTag[] iobMap, String resultsField, boolean ignoreCase) {
            this.iobMap = iobMap;
            this.resultsField = Optional.ofNullable(resultsField).orElse("predicted_value");
            this.ignoreCase = ignoreCase;
        }

        @Override
        public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult, boolean chunkResult) {
            if (tokenization.isEmpty()) {
                throw new ElasticsearchStatusException("no valid tokenization to build result", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
            }
            if (chunkResult) {
                throw NlpTask.Processor.chunkingNotSupportedException(TaskType.NER);
            }
            double[][] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0]);
            List<TaggedToken> taggedTokens = NerResultProcessor.tagTokens(tokenization.getTokenization(0), normalizedScores, this.iobMap);
            List<NerResults.EntityGroup> entities = NerResultProcessor.groupTaggedTokens(taggedTokens, this.ignoreCase ? tokenization.getTokenization(0).input().get(0).toLowerCase(Locale.ROOT) : tokenization.getTokenization(0).input().get(0));
            return new NerResults(this.resultsField, NerProcessor.buildAnnotatedText(tokenization.getTokenization(0).input().get(0), entities), entities, tokenization.anyTruncated());
        }

        static List<TaggedToken> tagTokens(TokenizationResult.Tokens tokenization, double[][] scores, IobTag[] iobMap) {
            ArrayList<TaggedToken> taggedTokens = new ArrayList<TaggedToken>();
            int startTokenIndex = 0;
            int numSpecialTokens = 0;
            while (startTokenIndex < tokenization.tokenIds().length) {
                int endTokenIndex;
                int inputMapping = tokenization.tokenMap()[startTokenIndex];
                if (inputMapping < 0) {
                    ++startTokenIndex;
                    ++numSpecialTokens;
                    continue;
                }
                for (endTokenIndex = startTokenIndex; endTokenIndex < tokenization.tokenMap().length - 1 && tokenization.tokenMap()[endTokenIndex + 1] == inputMapping; ++endTokenIndex) {
                }
                double[] avgScores = Arrays.copyOf(scores[startTokenIndex], iobMap.length);
                for (int i = startTokenIndex + 1; i <= endTokenIndex; ++i) {
                    for (int j = 0; j < scores[i].length; ++j) {
                        int n = j;
                        avgScores[n] = avgScores[n] + scores[i][j];
                    }
                }
                int numTokensInBlock = endTokenIndex - startTokenIndex + 1;
                if (numTokensInBlock > 1) {
                    int i = 0;
                    while (i < avgScores.length) {
                        int n = i++;
                        avgScores[n] = avgScores[n] / (double)numTokensInBlock;
                    }
                }
                int maxScoreIndex = NlpHelpers.argmax(avgScores);
                double score = avgScores[maxScoreIndex];
                taggedTokens.add(new TaggedToken(tokenization.tokens().get(0).get(startTokenIndex - numSpecialTokens), iobMap[maxScoreIndex], score));
                startTokenIndex = endTokenIndex + 1;
            }
            return taggedTokens;
        }

        static List<NerResults.EntityGroup> groupTaggedTokens(List<TaggedToken> tokens, String inputSeq) {
            if (tokens.isEmpty()) {
                return Collections.emptyList();
            }
            ArrayList<NerResults.EntityGroup> entities = new ArrayList<NerResults.EntityGroup>();
            int startTokenIndex = 0;
            while (startTokenIndex < tokens.size()) {
                int endTokenIndex;
                TaggedToken token = tokens.get(startTokenIndex);
                if (token.tag.isNone()) {
                    ++startTokenIndex;
                    continue;
                }
                double scoreSum = token.score;
                for (endTokenIndex = startTokenIndex + 1; endTokenIndex < tokens.size(); ++endTokenIndex) {
                    TaggedToken endToken = tokens.get(endTokenIndex);
                    if (endToken.tag.isBeginning() || !endToken.tag.entity().equals(token.tag.entity())) break;
                    scoreSum += endToken.score;
                }
                int startPos = token.token.startOffset();
                int endPos = tokens.get((int)(endTokenIndex - 1)).token.endOffset();
                String entity = inputSeq.substring(startPos, endPos);
                entities.add(new NerResults.EntityGroup(entity, token.tag.entity(), scoreSum / (double)(endTokenIndex - startTokenIndex), startPos, endPos));
                startTokenIndex = endTokenIndex;
            }
            return entities;
        }

        record TaggedToken(DelimitedToken token, IobTag tag, double score) {
            @Override
            public String toString() {
                return "{token:" + this.token + ", " + this.tag + ", " + this.score + "}";
            }
        }
    }
}

