/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.OptionalInt;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizationResult;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.WordPieceAnalyzer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.WordPieceTokenFilter;

public class BertTokenizer
extends NlpTokenizer {
    public static final String UNKNOWN_TOKEN = "[UNK]";
    public static final String SEPARATOR_TOKEN = "[SEP]";
    public static final String PAD_TOKEN = "[PAD]";
    public static final String CLASS_TOKEN = "[CLS]";
    public static final String MASK_TOKEN = "[MASK]";
    private static final Set<String> NEVER_SPLIT = Set.of("[MASK]");
    private final WordPieceAnalyzer wordPieceAnalyzer;
    protected final List<String> originalVocab;
    protected final boolean withSpecialTokens;
    private final int maxSequenceLength;
    private final int sepTokenId;
    private final int clsTokenId;
    private final String padToken;
    private final int padTokenId;
    private final String maskToken;
    private final OptionalInt maskTokenId;
    private final String unknownToken;

    protected BertTokenizer(List<String> originalVocab, SortedMap<String, Integer> vocab, boolean doLowerCase, boolean doTokenizeCjKChars, boolean doStripAccents, boolean withSpecialTokens, int maxSequenceLength, Set<String> neverSplit) {
        this(originalVocab, vocab, doLowerCase, doTokenizeCjKChars, doStripAccents, withSpecialTokens, maxSequenceLength, Sets.union(neverSplit, NEVER_SPLIT), SEPARATOR_TOKEN, CLASS_TOKEN, PAD_TOKEN, MASK_TOKEN, UNKNOWN_TOKEN);
    }

    protected BertTokenizer(List<String> originalVocab, SortedMap<String, Integer> vocab, boolean doLowerCase, boolean doTokenizeCjKChars, boolean doStripAccents, boolean withSpecialTokens, int maxSequenceLength, Set<String> neverSplit, String sepToken, String clsToken, String padToken, String maskToken, String unknownToken) {
        this.wordPieceAnalyzer = this.createWordPieceAnalyzer(originalVocab, new ArrayList<String>(neverSplit), doLowerCase, doTokenizeCjKChars, doStripAccents, unknownToken);
        this.originalVocab = originalVocab;
        this.withSpecialTokens = withSpecialTokens;
        this.maxSequenceLength = maxSequenceLength;
        if (!vocab.containsKey(unknownToken)) {
            throw ExceptionsHelper.conflictStatusException((String)"stored vocabulary is missing required [{}] token", (Object[])new Object[]{unknownToken});
        }
        if (!vocab.containsKey(padToken)) {
            throw ExceptionsHelper.conflictStatusException((String)"stored vocabulary is missing required [{}] token", (Object[])new Object[]{padToken});
        }
        this.padTokenId = (Integer)vocab.get(padToken);
        if (withSpecialTokens) {
            Set missingSpecialTokens = Sets.difference(Set.of(sepToken, clsToken), vocab.keySet());
            if (!missingSpecialTokens.isEmpty()) {
                throw ExceptionsHelper.conflictStatusException((String)"stored vocabulary is missing required {} token(s)", (Object[])new Object[]{missingSpecialTokens});
            }
            this.sepTokenId = (Integer)vocab.get(sepToken);
            this.clsTokenId = (Integer)vocab.get(clsToken);
        } else {
            this.sepTokenId = -1;
            this.clsTokenId = -1;
        }
        this.padToken = padToken;
        this.maskToken = maskToken;
        this.maskTokenId = vocab.containsKey(maskToken) ? OptionalInt.of((Integer)vocab.get(maskToken)) : OptionalInt.empty();
        this.unknownToken = unknownToken;
    }

    protected WordPieceAnalyzer createWordPieceAnalyzer(List<String> vocabulary, List<String> neverSplit, boolean doLowerCase, boolean doTokenizeCjKChars, boolean doStripAccents, String unknownToken) {
        return new WordPieceAnalyzer(vocabulary, new ArrayList<String>(neverSplit), doLowerCase, doTokenizeCjKChars, doStripAccents, unknownToken);
    }

    @Override
    int sepTokenId() {
        return this.sepTokenId;
    }

    @Override
    int maxSequenceLength() {
        return this.maxSequenceLength;
    }

    @Override
    boolean isWithSpecialTokens() {
        return this.withSpecialTokens;
    }

    @Override
    int defaultSpanForChunking(int maxWindowSize) {
        return (maxWindowSize - this.numExtraTokensForSingleSequence()) / 2;
    }

    @Override
    int getNumExtraTokensForSeqPair() {
        return 3;
    }

    @Override
    int numExtraTokensForSingleSequence() {
        return 2;
    }

    @Override
    int clsTokenId() {
        return this.clsTokenId;
    }

    @Override
    public String getPadToken() {
        return this.padToken;
    }

    public String getUnknownToken() {
        return this.unknownToken;
    }

    @Override
    public OptionalInt getPadTokenId() {
        return OptionalInt.of(this.padTokenId);
    }

    @Override
    public OptionalInt getMaskTokenId() {
        return this.maskTokenId;
    }

    @Override
    public String getMaskToken() {
        return this.maskToken;
    }

    @Override
    public List<String> getVocabulary() {
        return this.originalVocab;
    }

    @Override
    public TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokens> tokenizations) {
        return new BertTokenizationResult(this.originalVocab, tokenizations, this.padTokenId);
    }

    @Override
    TokenizationResult.TokensBuilder createTokensBuilder(int clsTokenId, int sepTokenId, boolean withSpecialTokens) {
        return new BertTokenizationResult.BertTokensBuilder(withSpecialTokens, clsTokenId, sepTokenId);
    }

    @Override
    public NlpTask.RequestBuilder requestBuilder() {
        return (inputs, requestId, truncate, span, windowSize) -> this.buildTokenizationResult(IntStream.range(0, inputs.size()).boxed().flatMap(seqId -> this.tokenize((String)inputs.get((int)seqId), truncate, span, (int)seqId, windowSize).stream()).collect(Collectors.toList())).buildRequest(requestId, truncate);
    }

    @Override
    public NlpTokenizer.InnerTokenization innerTokenize(String seq) {
        ArrayList<Integer> tokenPositionMap = new ArrayList<Integer>();
        try (TokenStream ts = this.wordPieceAnalyzer.tokenStream("input", seq);){
            ts.reset();
            PositionIncrementAttribute tokenPos = (PositionIncrementAttribute)ts.addAttribute(PositionIncrementAttribute.class);
            int currPos = -1;
            while (ts.incrementToken()) {
                tokenPositionMap.add(currPos += tokenPos.getPositionIncrement());
            }
        }
        catch (IOException ex) {
            throw new UncheckedIOException(ex);
        }
        return new NlpTokenizer.InnerTokenization(new ArrayList<WordPieceTokenFilter.WordPieceToken>(this.wordPieceAnalyzer.getTokens()), tokenPositionMap);
    }

    public void close() {
        this.wordPieceAnalyzer.close();
    }

    public static Builder builder(List<String> vocab, Tokenization tokenization) {
        return new Builder(vocab, tokenization);
    }

    public static class Builder {
        protected final List<String> originalVocab;
        protected final SortedMap<String, Integer> vocab;
        protected boolean doLowerCase;
        protected boolean doTokenizeCjKChars = true;
        protected boolean withSpecialTokens;
        protected int span = -1;
        protected int maxSequenceLength;
        protected Boolean doStripAccents = null;
        protected Set<String> neverSplit;

        protected Builder(List<String> vocab, Tokenization tokenization) {
            this.originalVocab = vocab;
            this.vocab = Builder.buildSortedVocab(vocab);
            this.doLowerCase = tokenization.doLowerCase();
            this.withSpecialTokens = tokenization.withSpecialTokens();
            this.maxSequenceLength = tokenization.maxSequenceLength();
            this.span = tokenization.getSpan();
        }

        private static SortedMap<String, Integer> buildSortedVocab(List<String> vocab) {
            TreeMap<String, Integer> sortedVocab = new TreeMap<String, Integer>();
            for (int i = 0; i < vocab.size(); ++i) {
                sortedVocab.put(vocab.get(i), i);
            }
            return sortedVocab;
        }

        public Builder setDoLowerCase(boolean doLowerCase) {
            this.doLowerCase = doLowerCase;
            return this;
        }

        public Builder setDoTokenizeCjKChars(boolean doTokenizeCjKChars) {
            this.doTokenizeCjKChars = doTokenizeCjKChars;
            return this;
        }

        public Builder setDoStripAccents(Boolean doStripAccents) {
            this.doStripAccents = doStripAccents;
            return this;
        }

        public Builder setNeverSplit(Set<String> neverSplit) {
            this.neverSplit = neverSplit;
            return this;
        }

        public Builder setMaxSequenceLength(int maxSequenceLength) {
            this.maxSequenceLength = maxSequenceLength;
            return this;
        }

        public Builder setWithSpecialTokens(boolean withSpecialTokens) {
            this.withSpecialTokens = withSpecialTokens;
            return this;
        }

        public BertTokenizer build() {
            if (this.doStripAccents == null) {
                this.doStripAccents = this.doLowerCase;
            }
            if (this.neverSplit == null) {
                this.neverSplit = Collections.emptySet();
            }
            return new BertTokenizer(this.originalVocab, this.vocab, this.doLowerCase, this.doTokenizeCjKChars, this.doStripAccents, this.withSpecialTokens, this.maxSequenceLength, this.neverSplit);
        }
    }
}

