/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.io.IOException;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.lucene.analysis.CharArrayMap;
import org.apache.lucene.analysis.CharArraySet;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.util.CharsRef;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BpeTokenReader;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.CharTrie;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.MultiCharSequence;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizerUtils;

public class BpeTokenizer
extends Tokenizer {
    private final CharTermAttribute termAtt = (CharTermAttribute)this.addAttribute(CharTermAttribute.class);
    private final OffsetAttribute offsetAtt = (OffsetAttribute)this.addAttribute(OffsetAttribute.class);
    private final PositionIncrementAttribute posIncAtt = (PositionIncrementAttribute)this.addAttribute(PositionIncrementAttribute.class);
    private static final char[] BYTES_CHAR = BpeTokenizer.byteEncoder();
    private static final char ENCODED_SPACE_CHAR = BYTES_CHAR[32];
    private final StringBuilder inputStr = new StringBuilder();
    private final LinkedList<BpeToken> tokens = new LinkedList();
    private final List<BpeToken> tokenizedValues = new ArrayList<BpeToken>();
    private final CharArrayMap<Integer> mergeRanks;
    private final CharArrayMap<Integer> vocabulary;
    private final CharSequence unknownToken;
    private final CharArraySet neverSplitSet;
    private final CharTrie neverSplit;
    private final int tokenizedUnknown;
    private final boolean prefixSpace;
    private boolean filled;

    static char[] byteEncoder() {
        List bytes = IntStream.concat(IntStream.range(Character.codePointAt("!", 0), Character.codePointAt("~", 0) + 1), IntStream.concat(IntStream.range(Character.codePointAt("\u00a1", 0), Character.codePointAt("\u00ac", 0) + 1), IntStream.range(Character.codePointAt("\u00ae", 0), Character.codePointAt("\u00ff", 0) + 1))).boxed().collect(Collectors.toList());
        ArrayList chars = new ArrayList(bytes);
        int n = 0;
        for (int i = 0; i < 256; ++i) {
            if (bytes.contains(i)) continue;
            bytes.add(i);
            chars.add(256 + n);
            ++n;
        }
        char[] charArray = new char[chars.size()];
        for (int j = 0; j < bytes.size(); ++j) {
            charArray[((Integer)bytes.get((int)j)).intValue()] = Character.toChars((Integer)chars.get(j))[0];
        }
        return charArray;
    }

    public static BpeTokenizer build(List<String> neverSplit, List<String> dictionary, List<String> merges, String unknownToken, boolean isPrefixSpace) {
        CharArraySet neverSplitSet = new CharArraySet(neverSplit, false);
        CharTrie neverSplitTree = CharTrie.build(neverSplit);
        CharArrayMap mergeRanks = new CharArrayMap(merges.size(), false);
        int mergePos = 0;
        for (String merge : merges) {
            mergeRanks.put(Strings.replace((String)merge, (String)" ", (String)""), (Object)mergePos++);
        }
        CharArrayMap vocabHash = new CharArrayMap(dictionary.size(), false);
        int vocabPos = 0;
        for (String v : dictionary) {
            vocabHash.put(v, (Object)vocabPos++);
        }
        return new BpeTokenizer(isPrefixSpace, (CharArrayMap<Integer>)mergeRanks, neverSplitSet, neverSplitTree, (CharArrayMap<Integer>)vocabHash, unknownToken);
    }

    public BpeTokenizer(boolean prefixSpace, CharArrayMap<Integer> mergeRanks, CharArraySet neverSplitSet, CharTrie neverSplit, CharArrayMap<Integer> vocabulary, CharSequence unknownToken) {
        this.mergeRanks = mergeRanks;
        this.neverSplitSet = neverSplitSet;
        this.neverSplit = neverSplit;
        this.vocabulary = vocabulary;
        if (!vocabulary.containsKey(unknownToken)) {
            throw new IllegalArgumentException("provided vocabulary does not contain the unknown token of [" + unknownToken.toString() + "]");
        }
        this.unknownToken = unknownToken;
        this.tokenizedUnknown = (Integer)vocabulary.get(unknownToken);
        this.prefixSpace = prefixSpace;
    }

    List<BpeToken> getTokenizedValues() {
        return this.tokenizedValues;
    }

    public void reset() throws IOException {
        super.reset();
        this.fillBuffer(this.input);
        this.tokens.clear();
        this.tokenizedValues.clear();
        this.filled = false;
    }

    public final void end() throws IOException {
        super.end();
        this.offsetAtt.setOffset(this.inputStr.length(), this.inputStr.length());
    }

    public final boolean incrementToken() throws IOException {
        if (this.filled && this.tokens.isEmpty()) {
            return false;
        }
        if (this.tokens.isEmpty()) {
            this.fillTokens();
        }
        if (this.tokens.isEmpty()) {
            return false;
        }
        this.clearAttributes();
        BpeToken token = this.tokens.removeFirst();
        this.tokenizedValues.add(token);
        this.termAtt.setEmpty().append(token.charSequence());
        this.offsetAtt.setOffset(token.startOffset(), token.endOffset());
        if (token.subWordToken) {
            this.posIncAtt.setPositionIncrement(0);
        }
        return true;
    }

    private void fillTokens() {
        boolean firstFind = true;
        LinkedList<DelimitedToken> largeTokensWithNeverSplits = TokenizerUtils.splitOutNeverSplit(this.inputStr.toString(), this.neverSplit, this.neverSplitSet);
        int split = 0;
        for (DelimitedToken token : largeTokensWithNeverSplits) {
            Optional<TokenizerUtils.CharSequenceRef> tokenSequence;
            if (this.neverSplitSet.contains(token.charSequence())) {
                Integer tokenId = (Integer)this.vocabulary.get(token.charSequence());
                BpeToken toAdd = tokenId == null ? new BpeToken(this.unknownToken, false, this.tokenizedUnknown, token.startOffset(), token.endOffset()) : new BpeToken(token.charSequence().toString(), false, tokenId, token.startOffset(), token.endOffset());
                this.tokens.add(toAdd);
                firstFind = false;
                ++split;
                continue;
            }
            int offsetOffset = token.startOffset();
            CharSequence delimitedTokenSequence = token.charSequence();
            if (split < largeTokensWithNeverSplits.size() - 1 && delimitedTokenSequence.charAt(delimitedTokenSequence.length() - 1) == ' ') {
                delimitedTokenSequence = new TokenizerUtils.CharSequenceRef(delimitedTokenSequence, 0, delimitedTokenSequence.length() - 1);
            }
            BpeTokenReader tokenReader = new BpeTokenReader(delimitedTokenSequence);
            while ((tokenSequence = tokenReader.next()).isPresent()) {
                boolean addedSpace = false;
                int offsetStart = tokenSequence.get().getOffset();
                int offsetEnd = tokenSequence.get().getOffset() + tokenSequence.get().length();
                Object subStr = tokenSequence.get().toString();
                if (firstFind && this.prefixSpace && !((String)subStr).startsWith(" ")) {
                    subStr = " " + (String)subStr;
                    addedSpace = true;
                }
                firstFind = false;
                byte[] bytes = ((String)subStr).getBytes(StandardCharsets.UTF_8);
                char[] cs = new char[bytes.length];
                for (int i = 0; i < bytes.length; ++i) {
                    int b = bytes[i];
                    if (b < 0) {
                        b += 256;
                    }
                    cs[i] = BYTES_CHAR[b];
                }
                ArrayList<CharSequence> bpeTokens = new ArrayList<CharSequence>(cs.length);
                for (int i = 0; i < cs.length; ++i) {
                    bpeTokens.add((CharSequence)new CharsRef(cs, i, 1));
                }
                while (bpeTokens.size() > 1) {
                    int i;
                    int minRank = Integer.MAX_VALUE;
                    CharSequencePair minSeq = null;
                    List<CharSequencePair> pairs = BpeTokenizer.pairs(bpeTokens);
                    for (CharSequencePair sequence : pairs) {
                        int rank = (Integer)this.mergeRanks.getOrDefault((Object)sequence, (Object)Integer.MAX_VALUE);
                        if (rank >= minRank) continue;
                        minSeq = sequence;
                        minRank = rank;
                    }
                    if (minSeq == null) break;
                    ArrayList<CharSequence> mergedBpeTokens = new ArrayList<CharSequence>(bpeTokens.size() - 1);
                    for (i = 0; i < minSeq.firstPos; ++i) {
                        mergedBpeTokens.add((CharSequence)bpeTokens.get(i));
                    }
                    mergedBpeTokens.add(minSeq);
                    for (i = minSeq.secondPos + 1; i < bpeTokens.size(); ++i) {
                        mergedBpeTokens.add((CharSequence)bpeTokens.get(i));
                    }
                    bpeTokens = mergedBpeTokens;
                }
                boolean subWordToken = false;
                for (CharSequence charSequence : bpeTokens) {
                    Integer tokenId = (Integer)this.vocabulary.get(charSequence);
                    int startOffsetAdj = !subWordToken && charSequence.charAt(0) == ENCODED_SPACE_CHAR && !addedSpace && charSequence.length() > 1 ? 1 : 0;
                    BpeToken toAdd = tokenId == null ? new BpeToken(this.unknownToken, subWordToken, this.tokenizedUnknown, offsetStart + offsetOffset + startOffsetAdj, offsetEnd + offsetOffset) : new BpeToken(charSequence.toString(), subWordToken, tokenId, offsetStart + offsetOffset + startOffsetAdj, offsetEnd + offsetOffset);
                    this.tokens.add(toAdd);
                    subWordToken = true;
                }
            }
        }
        this.filled = true;
    }

    private static List<CharSequencePair> pairs(List<CharSequence> tokens) {
        ArrayList<CharSequencePair> pairs = new ArrayList<CharSequencePair>(tokens.size() - 1);
        for (int i = 0; i < tokens.size() - 1; ++i) {
            pairs.add(new CharSequencePair(MultiCharSequence.from(tokens.get(i), tokens.get(i + 1)), i, i + 1));
        }
        return pairs;
    }

    private void fillBuffer(Reader input) throws IOException {
        int len;
        char[] buffer = new char[1024];
        this.inputStr.setLength(0);
        while ((len = input.read(buffer)) > 0) {
            this.inputStr.append(buffer, 0, len);
        }
    }

    public static class BpeToken
    extends DelimitedToken.Encoded {
        private final boolean subWordToken;

        public BpeToken(CharSequence charSequence, boolean subWordToken, int tokenId, int startOffset, int endOffset) {
            super(charSequence, tokenId, startOffset, endOffset);
            this.subWordToken = subWordToken;
        }
    }

    private record CharSequencePair(CharSequence pair, int firstPos, int secondPos) implements CharSequence
    {
        @Override
        public int length() {
            return this.pair.length();
        }

        @Override
        public char charAt(int index) {
            return this.pair.charAt(index);
        }

        @Override
        public CharSequence subSequence(int start, int end) {
            return this.pair.subSequence(start, end);
        }

        @Override
        public String toString() {
            return this.pair.toString();
        }
    }
}

