/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.io.IOException;
import java.io.Reader;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.OptionalInt;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.XLMRobertaTokenization;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.PrecompiledCharMapNormalizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.UnigramTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.XLMRobertaTokenizationResult;

public class XLMRobertaTokenizer
extends NlpTokenizer {
    public static final String UNKNOWN_TOKEN = "<unk>";
    public static final String SEPARATOR_TOKEN = "</s>";
    public static final String PAD_TOKEN = "<pad>";
    public static final String CLASS_TOKEN = "<s>";
    public static final String MASK_TOKEN = "<mask>";
    private static final Set<String> NEVER_SPLIT = Set.of("<mask>");
    private final XLMAnalyzer xlmAnalyzer;
    protected final List<String> originalVocab;
    private final SortedMap<String, Integer> vocab;
    protected final boolean withSpecialTokens;
    protected final int sepTokenId;
    private final int clsTokenId;
    protected final int padTokenId;
    private final int maxSequenceLength;

    protected XLMRobertaTokenizer(List<String> originalVocab, SortedMap<String, Integer> vocab, List<Double> scores, boolean withSpecialTokens, int maxSequenceLength, Set<String> neverSplit) throws IOException {
        this.originalVocab = originalVocab;
        this.xlmAnalyzer = new XLMAnalyzer(originalVocab, scores, new ArrayList<String>(Sets.union(NEVER_SPLIT, neverSplit)), UNKNOWN_TOKEN);
        this.vocab = vocab;
        this.withSpecialTokens = withSpecialTokens;
        this.maxSequenceLength = maxSequenceLength;
        if (!vocab.containsKey(UNKNOWN_TOKEN)) {
            throw ExceptionsHelper.conflictStatusException((String)"stored vocabulary is missing required [{}] token", (Object[])new Object[]{UNKNOWN_TOKEN});
        }
        if (!vocab.containsKey(PAD_TOKEN)) {
            throw ExceptionsHelper.conflictStatusException((String)"stored vocabulary is missing required [{}] token", (Object[])new Object[]{PAD_TOKEN});
        }
        this.padTokenId = (Integer)vocab.get(PAD_TOKEN);
        if (withSpecialTokens) {
            Set missingSpecialTokens = Sets.difference(Set.of(SEPARATOR_TOKEN, CLASS_TOKEN), vocab.keySet());
            if (!missingSpecialTokens.isEmpty()) {
                throw ExceptionsHelper.conflictStatusException((String)"stored vocabulary is missing required {} token(s)", (Object[])new Object[]{missingSpecialTokens});
            }
            this.sepTokenId = (Integer)vocab.get(SEPARATOR_TOKEN);
            this.clsTokenId = (Integer)vocab.get(CLASS_TOKEN);
        } else {
            this.sepTokenId = -1;
            this.clsTokenId = -1;
        }
    }

    @Override
    int sepTokenId() {
        return this.sepTokenId;
    }

    @Override
    int maxSequenceLength() {
        return this.maxSequenceLength;
    }

    @Override
    boolean isWithSpecialTokens() {
        return this.withSpecialTokens;
    }

    @Override
    int getNumExtraTokensForSeqPair() {
        return 4;
    }

    @Override
    int clsTokenId() {
        return this.clsTokenId;
    }

    @Override
    public String getPadToken() {
        return PAD_TOKEN;
    }

    public String getUnknownToken() {
        return UNKNOWN_TOKEN;
    }

    public void close() {
        this.xlmAnalyzer.close();
    }

    @Override
    public TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokens> tokenizations) {
        return new XLMRobertaTokenizationResult(this.originalVocab, tokenizations, this.padTokenId);
    }

    @Override
    public NlpTask.RequestBuilder requestBuilder() {
        return (inputs, requestId, truncate, span) -> this.buildTokenizationResult(IntStream.range(0, inputs.size()).boxed().flatMap(seqId -> this.tokenize((String)inputs.get((int)seqId), truncate, span, (int)seqId).stream()).collect(Collectors.toList())).buildRequest(requestId, truncate);
    }

    @Override
    public OptionalInt getPadTokenId() {
        return OptionalInt.of(this.padTokenId);
    }

    @Override
    public OptionalInt getMaskTokenId() {
        Integer maskId = (Integer)this.vocab.get(MASK_TOKEN);
        if (maskId == null) {
            return OptionalInt.empty();
        }
        return OptionalInt.of(maskId);
    }

    @Override
    public String getMaskToken() {
        return MASK_TOKEN;
    }

    @Override
    public List<String> getVocabulary() {
        return this.originalVocab;
    }

    @Override
    TokenizationResult.TokensBuilder createTokensBuilder(int clsTokenId, int sepTokenId, boolean withSpecialTokens) {
        return new XLMRobertaTokenizationResult.XLMRobertaTokensBuilder(withSpecialTokens, clsTokenId, sepTokenId);
    }

    @Override
    public NlpTokenizer.InnerTokenization innerTokenize(String seq) {
        ArrayList<Integer> tokenPositionMap = new ArrayList<Integer>();
        try (TokenStream ts = this.xlmAnalyzer.tokenStream("input", seq);){
            ts.reset();
            PositionIncrementAttribute tokenPos = (PositionIncrementAttribute)ts.addAttribute(PositionIncrementAttribute.class);
            int currPos = -1;
            while (ts.incrementToken()) {
                tokenPositionMap.add(currPos += tokenPos.getPositionIncrement());
            }
        }
        catch (IOException ex) {
            throw new UncheckedIOException(ex);
        }
        return new NlpTokenizer.InnerTokenization(new ArrayList<DelimitedToken.Encoded>(this.xlmAnalyzer.getTokens()), tokenPositionMap);
    }

    public static Builder builder(List<String> vocab, List<Double> scores, XLMRobertaTokenization tokenization) {
        return new Builder(vocab, scores, tokenization);
    }

    static class XLMAnalyzer
    extends Analyzer {
        private final List<String> vocabulary;
        private final List<String> neverSplit;
        private final double[] scores;
        private UnigramTokenizer innerTokenizer;
        private final String unknownToken;
        private final PrecompiledCharMapNormalizer.Config normalizer;

        XLMAnalyzer(List<String> vocabulary, List<Double> scores, List<String> neverSplit, String unknownToken) throws IOException {
            this.vocabulary = vocabulary;
            this.neverSplit = neverSplit;
            this.unknownToken = unknownToken;
            this.scores = new double[scores.size()];
            int i = 0;
            for (Double s : scores) {
                this.scores[i++] = s;
            }
            this.normalizer = PrecompiledCharMapNormalizer.fromBase64EncodedResource("/org/elasticsearch/xpack/ml/inference.nlp.tokenizers/spm_precompiled_normalizer.txt");
        }

        protected Reader initReader(String fieldName, Reader reader) {
            if (this.normalizer.offsets().length > 0) {
                return new PrecompiledCharMapNormalizer(this.normalizer.offsets(), this.normalizer.utf8str(), reader);
            }
            return reader;
        }

        protected Analyzer.TokenStreamComponents createComponents(String fieldName) {
            this.innerTokenizer = UnigramTokenizer.build(this.neverSplit, this.vocabulary, this.scores, this.unknownToken);
            return new Analyzer.TokenStreamComponents((Tokenizer)this.innerTokenizer);
        }

        public List<DelimitedToken.Encoded> getTokens() {
            if (this.innerTokenizer != null) {
                return this.innerTokenizer.getTokenizedValues();
            }
            return List.of();
        }
    }

    public static class Builder {
        protected final List<String> originalVocab;
        protected final List<Double> scores;
        protected final SortedMap<String, Integer> vocab;
        protected boolean withSpecialTokens;
        protected int maxSequenceLength;
        protected Set<String> neverSplit;

        protected Builder(List<String> vocab, List<Double> scores, XLMRobertaTokenization tokenization) {
            this.originalVocab = vocab;
            this.vocab = Builder.buildSortedVocab(vocab);
            this.scores = scores;
            this.withSpecialTokens = tokenization.withSpecialTokens();
            this.maxSequenceLength = tokenization.maxSequenceLength();
        }

        private static SortedMap<String, Integer> buildSortedVocab(List<String> vocab) {
            TreeMap<String, Integer> sortedVocab = new TreeMap<String, Integer>();
            for (int i = 0; i < vocab.size(); ++i) {
                sortedVocab.put(vocab.get(i), i);
            }
            return sortedVocab;
        }

        public Builder setNeverSplit(Set<String> neverSplit) {
            this.neverSplit = neverSplit;
            return this;
        }

        public Builder setMaxSequenceLength(int maxSequenceLength) {
            this.maxSequenceLength = maxSequenceLength;
            return this;
        }

        public Builder setWithSpecialTokens(boolean withSpecialTokens) {
            this.withSpecialTokens = withSpecialTokens;
            return this;
        }

        public XLMRobertaTokenizer build() throws IOException {
            if (this.neverSplit == null) {
                this.neverSplit = Collections.emptySet();
            }
            return new XLMRobertaTokenizer(this.originalVocab, this.vocab, this.scores, this.withSpecialTokens, this.maxSequenceLength, this.neverSplit);
        }
    }
}

