/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
import org.elasticsearch.xpack.ml.inference.nlp.NlpHelpers;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;

public class NerProcessor
implements NlpTask.Processor {
    private final NlpTask.RequestBuilder requestBuilder;
    private final IobTag[] iobMap;
    private final String resultsField;
    private final boolean ignoreCase;

    NerProcessor(NlpTokenizer tokenizer, NerConfig config) {
        this.validate(config.getClassificationLabels());
        this.iobMap = NerProcessor.buildIobMap(config.getClassificationLabels());
        this.requestBuilder = tokenizer.requestBuilder();
        this.resultsField = config.getResultsField();
        this.ignoreCase = config.getTokenization().doLowerCase();
    }

    private void validate(List<String> classificationLabels) {
        if (classificationLabels == null || classificationLabels.isEmpty()) {
            return;
        }
        ValidationException ve = new ValidationException();
        EnumSet<IobTag> tags = EnumSet.noneOf(IobTag.class);
        for (String label : classificationLabels) {
            try {
                IobTag iobTag = IobTag.valueOf(label);
                if (tags.contains((Object)iobTag)) {
                    ve.addValidationError("the classification label [" + label + "] is duplicated in the list " + classificationLabels);
                }
                tags.add(iobTag);
            }
            catch (IllegalArgumentException iae) {
                ve.addValidationError("classification label [" + label + "] is not an entity I-O-B tag.");
            }
        }
        if (!ve.validationErrors().isEmpty()) {
            ve.addValidationError("Valid entity I-O-B tags are " + Arrays.toString((Object[])IobTag.values()));
            throw ve;
        }
    }

    static IobTag[] buildIobMap(List<String> classificationLabels) {
        if (classificationLabels == null || classificationLabels.isEmpty()) {
            return IobTag.values();
        }
        IobTag[] map = new IobTag[classificationLabels.size()];
        for (int i = 0; i < classificationLabels.size(); ++i) {
            map[i] = IobTag.valueOf(classificationLabels.get(i));
        }
        return map;
    }

    @Override
    public void validateInputs(List<String> inputs) {
    }

    @Override
    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
        return this.requestBuilder;
    }

    @Override
    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
        if (config instanceof NerConfig) {
            NerConfig nerConfig = (NerConfig)config;
            return new NerResultProcessor(this.iobMap, nerConfig.getResultsField(), this.ignoreCase);
        }
        return new NerResultProcessor(this.iobMap, this.resultsField, this.ignoreCase);
    }

    static String buildAnnotatedText(String seq, List<NerResults.EntityGroup> entities) {
        if (entities.isEmpty()) {
            return seq;
        }
        StringBuilder annotatedResultBuilder = new StringBuilder();
        int curPos = 0;
        for (NerResults.EntityGroup entity : entities) {
            if (entity.getStartPos() == -1) continue;
            if (entity.getStartPos() != curPos) {
                annotatedResultBuilder.append(seq, curPos, entity.getStartPos());
            }
            String entitySeq = seq.substring(entity.getStartPos(), entity.getEndPos());
            annotatedResultBuilder.append("[").append(entitySeq).append("]").append("(").append(entity.getClassName()).append("&").append(entitySeq.replace(" ", "+")).append(")");
            curPos = entity.getEndPos();
        }
        if (curPos < seq.length()) {
            annotatedResultBuilder.append(seq, curPos, seq.length());
        }
        return annotatedResultBuilder.toString();
    }

    static enum IobTag {
        O(Entity.NONE),
        B_MISC(Entity.MISC),
        I_MISC(Entity.MISC),
        B_PER(Entity.PER),
        I_PER(Entity.PER),
        B_ORG(Entity.ORG),
        I_ORG(Entity.ORG),
        B_LOC(Entity.LOC),
        I_LOC(Entity.LOC);

        private final Entity entity;

        private IobTag(Entity entity) {
            this.entity = entity;
        }

        Entity getEntity() {
            return this.entity;
        }

        boolean isBeginning() {
            return this.name().toLowerCase(Locale.ROOT).startsWith("b");
        }
    }

    static class NerResultProcessor
    implements NlpTask.ResultProcessor {
        private final IobTag[] iobMap;
        private final String resultsField;
        private final boolean ignoreCase;

        NerResultProcessor(IobTag[] iobMap, String resultsField, boolean ignoreCase) {
            this.iobMap = iobMap;
            this.resultsField = Optional.ofNullable(resultsField).orElse("predicted_value");
            this.ignoreCase = ignoreCase;
        }

        @Override
        public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
            if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokens().length == 0) {
                throw new ElasticsearchStatusException("no valid tokenization to build result", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
            }
            double[][] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0]);
            List<TaggedToken> taggedTokens = NerResultProcessor.tagTokens(tokenization.getTokenizations().get(0), normalizedScores, this.iobMap);
            List<NerResults.EntityGroup> entities = NerResultProcessor.groupTaggedTokens(taggedTokens, this.ignoreCase ? tokenization.getTokenizations().get(0).getInput().toLowerCase(Locale.ROOT) : tokenization.getTokenizations().get(0).getInput());
            return new NerResults(this.resultsField, NerProcessor.buildAnnotatedText(tokenization.getTokenizations().get(0).getInput(), entities), entities, tokenization.anyTruncated());
        }

        static List<TaggedToken> tagTokens(TokenizationResult.Tokenization tokenization, double[][] scores, IobTag[] iobMap) {
            ArrayList<TaggedToken> taggedTokens = new ArrayList<TaggedToken>();
            int startTokenIndex = 0;
            while (startTokenIndex < tokenization.getTokens().length) {
                int inputMapping = tokenization.getTokenMap()[startTokenIndex];
                if (inputMapping < 0) {
                    ++startTokenIndex;
                    continue;
                }
                int endTokenIndex = startTokenIndex;
                StringBuilder word = new StringBuilder(tokenization.getTokens()[startTokenIndex]);
                while (endTokenIndex < tokenization.getTokens().length - 1 && tokenization.getTokenMap()[endTokenIndex + 1] == inputMapping) {
                    String endTokenWord = tokenization.getTokens()[++endTokenIndex].substring(2);
                    word.append(endTokenWord);
                }
                double[] avgScores = Arrays.copyOf(scores[startTokenIndex], iobMap.length);
                for (int i = startTokenIndex + 1; i <= endTokenIndex; ++i) {
                    for (int j = 0; j < scores[i].length; ++j) {
                        int n = j;
                        avgScores[n] = avgScores[n] + scores[i][j];
                    }
                }
                int numTokensInBlock = endTokenIndex - startTokenIndex + 1;
                if (numTokensInBlock > 1) {
                    int i = 0;
                    while (i < avgScores.length) {
                        int n = i++;
                        avgScores[n] = avgScores[n] / (double)numTokensInBlock;
                    }
                }
                int maxScoreIndex = NlpHelpers.argmax(avgScores);
                double score = avgScores[maxScoreIndex];
                taggedTokens.add(new TaggedToken(word.toString(), iobMap[maxScoreIndex], score));
                startTokenIndex = endTokenIndex + 1;
            }
            return taggedTokens;
        }

        static List<NerResults.EntityGroup> groupTaggedTokens(List<TaggedToken> tokens, String inputSeq) {
            if (tokens.isEmpty()) {
                return Collections.emptyList();
            }
            ArrayList<NerResults.EntityGroup> entities = new ArrayList<NerResults.EntityGroup>();
            int startTokenIndex = 0;
            int startFindInSeq = 0;
            while (startTokenIndex < tokens.size()) {
                int endTokenIndex;
                TaggedToken token = tokens.get(startTokenIndex);
                if (token.tag.getEntity() == Entity.NONE) {
                    ++startTokenIndex;
                    continue;
                }
                StringBuilder entityWord = new StringBuilder(token.word);
                double scoreSum = token.score;
                for (endTokenIndex = startTokenIndex + 1; endTokenIndex < tokens.size(); ++endTokenIndex) {
                    TaggedToken endToken = tokens.get(endTokenIndex);
                    if (endToken.tag.isBeginning() || endToken.tag.getEntity() != token.tag.getEntity()) break;
                    entityWord.append(" ").append(endToken.word);
                    scoreSum += endToken.score;
                }
                String entity = entityWord.toString();
                int i = inputSeq.indexOf(entity, startFindInSeq);
                entities.add(new NerResults.EntityGroup(entity, token.tag.getEntity().toString(), scoreSum / (double)(endTokenIndex - startTokenIndex), i, i == -1 ? -1 : i + entity.length()));
                startTokenIndex = endTokenIndex;
                if (i == -1) continue;
                startFindInSeq = i + entity.length();
            }
            return entities;
        }

        static class TaggedToken {
            private final String word;
            private final IobTag tag;
            private final double score;

            TaggedToken(String word, IobTag tag, double score) {
                this.word = word;
                this.tag = tag;
                this.score = score;
            }
        }
    }

    public static enum Entity implements Writeable
    {
        NONE,
        MISC,
        PER,
        ORG,
        LOC;


        public void writeTo(StreamOutput out) throws IOException {
            out.writeEnum((Enum)this);
        }

        public String toString() {
            return this.name().toUpperCase(Locale.ROOT);
        }
    }
}

