/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.OptionalInt;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BasicTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.WordPieceTokenizer;

public class BertTokenizer
implements NlpTokenizer {
    public static final String UNKNOWN_TOKEN = "[UNK]";
    public static final String SEPARATOR_TOKEN = "[SEP]";
    public static final String PAD_TOKEN = "[PAD]";
    public static final String CLASS_TOKEN = "[CLS]";
    public static final String MASK_TOKEN = "[MASK]";
    public static final int SPECIAL_TOKEN_POSITION = -1;
    public static final int DEFAULT_MAX_INPUT_CHARS_PER_WORD = 100;
    private static final Set<String> NEVER_SPLIT = Set.of("[MASK]");
    private final WordPieceTokenizer wordPieceTokenizer;
    private final List<String> originalVocab;
    private final SortedMap<String, Integer> vocab;
    private final boolean doLowerCase;
    private final boolean doTokenizeCjKChars;
    private final boolean doStripAccents;
    protected final boolean withSpecialTokens;
    private final Set<String> neverSplit;
    private final int maxSequenceLength;
    private final NlpTask.RequestBuilder requestBuilder;
    private final String sepToken;
    protected final int sepTokenId;
    private final String clsToken;
    private final int clsTokenId;
    private final String padToken;
    private final String maskToken;
    private final String unknownToken;

    protected BertTokenizer(List<String> originalVocab, SortedMap<String, Integer> vocab, boolean doLowerCase, boolean doTokenizeCjKChars, boolean doStripAccents, boolean withSpecialTokens, int maxSequenceLength, Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory, Set<String> neverSplit) {
        this(originalVocab, vocab, doLowerCase, doTokenizeCjKChars, doStripAccents, withSpecialTokens, maxSequenceLength, requestBuilderFactory, Sets.union(neverSplit, NEVER_SPLIT), SEPARATOR_TOKEN, CLASS_TOKEN, PAD_TOKEN, MASK_TOKEN, UNKNOWN_TOKEN);
    }

    protected BertTokenizer(List<String> originalVocab, SortedMap<String, Integer> vocab, boolean doLowerCase, boolean doTokenizeCjKChars, boolean doStripAccents, boolean withSpecialTokens, int maxSequenceLength, Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory, Set<String> neverSplit, String sepToken, String clsToken, String padToken, String maskToken, String unknownToken) {
        this.wordPieceTokenizer = new WordPieceTokenizer(vocab, unknownToken, 100);
        this.originalVocab = originalVocab;
        this.vocab = vocab;
        this.doLowerCase = doLowerCase;
        this.doTokenizeCjKChars = doTokenizeCjKChars;
        this.doStripAccents = doStripAccents;
        this.withSpecialTokens = withSpecialTokens;
        this.neverSplit = neverSplit;
        this.maxSequenceLength = maxSequenceLength;
        this.requestBuilder = requestBuilderFactory.apply(this);
        if (!vocab.containsKey(unknownToken)) {
            throw ExceptionsHelper.conflictStatusException((String)"stored vocabulary is missing required [{}] token", (Object[])new Object[]{unknownToken});
        }
        if (!vocab.containsKey(padToken)) {
            throw ExceptionsHelper.conflictStatusException((String)"stored vocabulary is missing required [{}] token", (Object[])new Object[]{padToken});
        }
        if (withSpecialTokens) {
            Set missingSpecialTokens = Sets.difference(Set.of(sepToken, clsToken), vocab.keySet());
            if (!missingSpecialTokens.isEmpty()) {
                throw ExceptionsHelper.conflictStatusException((String)"stored vocabulary is missing required {} token(s)", (Object[])new Object[]{missingSpecialTokens});
            }
            this.sepTokenId = (Integer)vocab.get(sepToken);
            this.clsTokenId = (Integer)vocab.get(clsToken);
        } else {
            this.sepTokenId = -1;
            this.clsTokenId = -1;
        }
        this.sepToken = sepToken;
        this.clsToken = clsToken;
        this.padToken = padToken;
        this.maskToken = maskToken;
        this.unknownToken = unknownToken;
    }

    public String getSepToken() {
        return this.sepToken;
    }

    public String getClsToken() {
        return this.clsToken;
    }

    @Override
    public String getPadToken() {
        return this.padToken;
    }

    public String getUnknownToken() {
        return this.unknownToken;
    }

    @Override
    public OptionalInt getPadTokenId() {
        Integer pad = (Integer)this.vocab.get(this.padToken);
        if (pad != null) {
            return OptionalInt.of(pad);
        }
        return OptionalInt.empty();
    }

    @Override
    public OptionalInt getMaskTokenId() {
        Integer pad = (Integer)this.vocab.get(this.maskToken);
        if (pad != null) {
            return OptionalInt.of(pad);
        }
        return OptionalInt.empty();
    }

    @Override
    public String getMaskToken() {
        return this.maskToken;
    }

    @Override
    public TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokenization> tokenizations) {
        TokenizationResult tokenizationResult = new TokenizationResult(this.originalVocab);
        for (TokenizationResult.Tokenization tokenization : tokenizations) {
            tokenizationResult.addTokenization(tokenization);
        }
        return tokenizationResult;
    }

    @Override
    public TokenizationResult.Tokenization tokenize(String seq, Tokenization.Truncate truncate) {
        InnerTokenization innerResult = this.innerTokenize(seq);
        List<Integer> wordPieceTokenIds = innerResult.wordPieceTokenIds;
        List<Integer> tokenPositionMap = innerResult.tokenPositionMap;
        int numTokens = this.withSpecialTokens ? wordPieceTokenIds.size() + 2 : wordPieceTokenIds.size();
        boolean isTruncated = false;
        if (numTokens > this.maxSequenceLength) {
            switch (truncate) {
                case FIRST: 
                case SECOND: {
                    isTruncated = true;
                    wordPieceTokenIds = wordPieceTokenIds.subList(0, this.withSpecialTokens ? this.maxSequenceLength - 2 : this.maxSequenceLength);
                    tokenPositionMap = tokenPositionMap.subList(0, this.withSpecialTokens ? this.maxSequenceLength - 2 : this.maxSequenceLength);
                    break;
                }
                case NONE: {
                    throw ExceptionsHelper.badRequestException((String)"Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]", (Object[])new Object[]{numTokens, this.maxSequenceLength});
                }
            }
        }
        BertTokenizationBuilder bertTokenizationBuilder = this.bertTokenizationBuilder().addTokens(wordPieceTokenIds, tokenPositionMap).addEndTokensIfNecessary();
        return new TokenizationResult.Tokenization(seq, innerResult.tokens, isTruncated, bertTokenizationBuilder.buildIds(), bertTokenizationBuilder.buildMap());
    }

    @Override
    public TokenizationResult.Tokenization tokenize(String seq1, String seq2, Tokenization.Truncate truncate) {
        InnerTokenization innerResultSeq1 = this.innerTokenize(seq1);
        List<Integer> wordPieceTokenIdsSeq1 = innerResultSeq1.wordPieceTokenIds;
        List<Integer> tokenPositionMapSeq1 = innerResultSeq1.tokenPositionMap;
        InnerTokenization innerResultSeq2 = this.innerTokenize(seq2);
        List<Integer> wordPieceTokenIdsSeq2 = innerResultSeq2.wordPieceTokenIds;
        List<Integer> tokenPositionMapSeq2 = innerResultSeq2.tokenPositionMap;
        if (!this.withSpecialTokens) {
            throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
        }
        int extraTokens = this.getNumExtraTokensForSeqPair();
        int numTokens = wordPieceTokenIdsSeq1.size() + wordPieceTokenIdsSeq2.size() + extraTokens;
        boolean isTruncated = false;
        if (numTokens > this.maxSequenceLength) {
            switch (truncate) {
                case FIRST: {
                    isTruncated = true;
                    if (wordPieceTokenIdsSeq2.size() > this.maxSequenceLength - extraTokens) {
                        throw ExceptionsHelper.badRequestException((String)"Attempting truncation [{}] but input is too large for the second sequence. The tokenized input length [{}] exceeds the maximum sequence length [{}], when taking special tokens into account", (Object[])new Object[]{truncate.toString(), wordPieceTokenIdsSeq2.size(), this.maxSequenceLength - extraTokens});
                    }
                    wordPieceTokenIdsSeq1 = wordPieceTokenIdsSeq1.subList(0, this.maxSequenceLength - extraTokens - wordPieceTokenIdsSeq2.size());
                    tokenPositionMapSeq1 = tokenPositionMapSeq1.subList(0, this.maxSequenceLength - extraTokens - wordPieceTokenIdsSeq2.size());
                    break;
                }
                case SECOND: {
                    isTruncated = true;
                    if (wordPieceTokenIdsSeq1.size() > this.maxSequenceLength - extraTokens) {
                        throw ExceptionsHelper.badRequestException((String)"Attempting truncation [{}] but input is too large for the first sequence. The tokenized input length [{}] exceeds the maximum sequence length [{}], when taking special tokens into account", (Object[])new Object[]{truncate.toString(), wordPieceTokenIdsSeq1.size(), this.maxSequenceLength - extraTokens});
                    }
                    wordPieceTokenIdsSeq2 = wordPieceTokenIdsSeq2.subList(0, this.maxSequenceLength - extraTokens - wordPieceTokenIdsSeq1.size());
                    tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, this.maxSequenceLength - extraTokens - wordPieceTokenIdsSeq1.size());
                    break;
                }
                case NONE: {
                    throw ExceptionsHelper.badRequestException((String)"Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]", (Object[])new Object[]{numTokens, this.maxSequenceLength});
                }
            }
        }
        BertTokenizationBuilder bertTokenizationBuilder = this.bertTokenizationBuilder().addTokens(wordPieceTokenIdsSeq1, tokenPositionMapSeq1).addTokens(wordPieceTokenIdsSeq2, tokenPositionMapSeq2).addEndTokensIfNecessary();
        ArrayList<DelimitedToken> tokens = new ArrayList<DelimitedToken>(innerResultSeq1.tokens);
        tokens.addAll(innerResultSeq2.tokens);
        return new TokenizationResult.Tokenization(seq1 + seq2, tokens, isTruncated, bertTokenizationBuilder.buildIds(), bertTokenizationBuilder.buildMap());
    }

    protected BertTokenizationBuilder bertTokenizationBuilder() {
        return new BertTokenizationBuilder();
    }

    protected int getNumExtraTokensForSeqPair() {
        return 3;
    }

    private InnerTokenization innerTokenize(String seq) {
        BasicTokenizer basicTokenizer = new BasicTokenizer(this.doLowerCase, this.doTokenizeCjKChars, this.doStripAccents, this.neverSplit);
        List<DelimitedToken> tokenSequences = basicTokenizer.tokenize(seq);
        ArrayList<Integer> wordPieceTokens = new ArrayList<Integer>();
        ArrayList<Integer> tokenPositionMap = new ArrayList<Integer>();
        for (int sourceIndex = 0; sourceIndex < tokenSequences.size(); ++sourceIndex) {
            String token = tokenSequences.get(sourceIndex).getToken();
            if (this.neverSplit.contains(token)) {
                wordPieceTokens.add(this.vocab.getOrDefault(token, (Integer)this.vocab.get(this.unknownToken)));
                tokenPositionMap.add(sourceIndex);
                continue;
            }
            List<Integer> tokens = this.wordPieceTokenizer.tokenize(tokenSequences.get(sourceIndex));
            for (int tokenCount = 0; tokenCount < tokens.size(); ++tokenCount) {
                tokenPositionMap.add(sourceIndex);
            }
            wordPieceTokens.addAll(tokens);
        }
        return new InnerTokenization(tokenSequences, wordPieceTokens, tokenPositionMap);
    }

    @Override
    public NlpTask.RequestBuilder requestBuilder() {
        return this.requestBuilder;
    }

    public int getMaxSequenceLength() {
        return this.maxSequenceLength;
    }

    public static Builder builder(List<String> vocab, Tokenization tokenization) {
        return new Builder(vocab, tokenization);
    }

    private static class InnerTokenization {
        List<DelimitedToken> tokens;
        List<Integer> wordPieceTokenIds;
        List<Integer> tokenPositionMap;

        InnerTokenization(List<DelimitedToken> tokens, List<Integer> wordPieceTokenIds, List<Integer> tokenPositionMap) {
            this.tokens = tokens;
            this.wordPieceTokenIds = wordPieceTokenIds;
            this.tokenPositionMap = tokenPositionMap;
        }
    }

    protected class BertTokenizationBuilder {
        Stream.Builder<IntStream> tokenIds = Stream.builder();
        Stream.Builder<IntStream> tokenMap = Stream.builder();
        int numSeq;

        BertTokenizationBuilder() {
            if (BertTokenizer.this.withSpecialTokens) {
                this.tokenIds.add(IntStream.of(BertTokenizer.this.clsTokenId));
                this.tokenMap.add(IntStream.of(-1));
            }
        }

        BertTokenizationBuilder addTokens(List<Integer> wordPieceTokenIds, List<Integer> tokenPositionMap) {
            if (this.numSeq > 0 && BertTokenizer.this.withSpecialTokens) {
                this.tokenIds.add(IntStream.of(BertTokenizer.this.sepTokenId));
                this.tokenMap.add(IntStream.of(-1));
            }
            this.tokenIds.add(wordPieceTokenIds.stream().mapToInt(Integer::valueOf));
            this.tokenMap.add(tokenPositionMap.stream().mapToInt(Integer::valueOf));
            ++this.numSeq;
            return this;
        }

        BertTokenizationBuilder addEndTokensIfNecessary() {
            if (BertTokenizer.this.withSpecialTokens) {
                this.tokenIds.add(IntStream.of(BertTokenizer.this.sepTokenId));
                this.tokenMap.add(IntStream.of(-1));
            }
            return this;
        }

        int[] buildIds() {
            return this.tokenIds.build().flatMapToInt(Function.identity()).toArray();
        }

        int[] buildMap() {
            return this.tokenMap.build().flatMapToInt(Function.identity()).toArray();
        }
    }

    public static class Builder {
        protected final List<String> originalVocab;
        protected final SortedMap<String, Integer> vocab;
        protected boolean doLowerCase = false;
        protected boolean doTokenizeCjKChars = true;
        protected boolean withSpecialTokens = true;
        protected int maxSequenceLength;
        protected Boolean doStripAccents = null;
        protected Set<String> neverSplit;
        protected Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = BertRequestBuilder::new;

        protected Builder(List<String> vocab, Tokenization tokenization) {
            this.originalVocab = vocab;
            this.vocab = Builder.buildSortedVocab(vocab);
            this.doLowerCase = tokenization.doLowerCase();
            this.withSpecialTokens = tokenization.withSpecialTokens();
            this.maxSequenceLength = tokenization.maxSequenceLength();
        }

        private static SortedMap<String, Integer> buildSortedVocab(List<String> vocab) {
            TreeMap<String, Integer> sortedVocab = new TreeMap<String, Integer>();
            for (int i = 0; i < vocab.size(); ++i) {
                sortedVocab.put(vocab.get(i), i);
            }
            return sortedVocab;
        }

        public Builder setDoLowerCase(boolean doLowerCase) {
            this.doLowerCase = doLowerCase;
            return this;
        }

        public Builder setDoTokenizeCjKChars(boolean doTokenizeCjKChars) {
            this.doTokenizeCjKChars = doTokenizeCjKChars;
            return this;
        }

        public Builder setDoStripAccents(Boolean doStripAccents) {
            this.doStripAccents = doStripAccents;
            return this;
        }

        public Builder setNeverSplit(Set<String> neverSplit) {
            this.neverSplit = neverSplit;
            return this;
        }

        public Builder setMaxSequenceLength(int maxSequenceLength) {
            this.maxSequenceLength = maxSequenceLength;
            return this;
        }

        public Builder setWithSpecialTokens(boolean withSpecialTokens) {
            this.withSpecialTokens = withSpecialTokens;
            return this;
        }

        public Builder setRequestBuilderFactory(Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
            this.requestBuilderFactory = requestBuilderFactory;
            return this;
        }

        public BertTokenizer build() {
            if (this.doStripAccents == null) {
                this.doStripAccents = this.doLowerCase;
            }
            if (this.neverSplit == null) {
                this.neverSplit = Collections.emptySet();
            }
            return new BertTokenizer(this.originalVocab, this.vocab, this.doLowerCase, this.doTokenizeCjKChars, this.doStripAccents, this.withSpecialTokens, this.maxSequenceLength, this.requestBuilderFactory, this.neverSplit);
        }
    }
}

