/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.analysis.CharArrayMap;
import org.apache.lucene.analysis.CharArraySet;
import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.util.AttributeSource;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BasicTokenFilter;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.MultiCharSequence;

public final class WordPieceTokenFilter
extends TokenFilter {
    private final LinkedList<WordPieceToken> tokens;
    private final CharTermAttribute termAtt = (CharTermAttribute)this.addAttribute(CharTermAttribute.class);
    private final OffsetAttribute offsetAtt = (OffsetAttribute)this.addAttribute(OffsetAttribute.class);
    private final PositionIncrementAttribute posIncAtt = (PositionIncrementAttribute)this.addAttribute(PositionIncrementAttribute.class);
    private static final CharSequence CONTINUATION = "##";
    private AttributeSource.State current;
    private final CharArraySet neverSplit;
    private final CharArrayMap<Integer> vocabulary;
    private final List<WordPieceToken> tokenizedValues;
    private final int maxInputCharsPerWord;
    private final int tokenizedUnknown;
    private final CharSequence unknownToken;

    public static WordPieceTokenFilter build(boolean isLowerCase, boolean isTokenizeCjkChars, boolean isStripAccents, List<String> neverSplit, List<String> dictionary, String unknownToken, int maxInputCharsPerWord, TokenStream input) throws IOException {
        CharArrayMap vocabMap = new CharArrayMap(dictionary.size(), isLowerCase);
        int i = 0;
        for (String word : dictionary) {
            vocabMap.put(word, (Object)i++);
        }
        input = BasicTokenFilter.build(isTokenizeCjkChars, isStripAccents, neverSplit, input);
        return new WordPieceTokenFilter((TokenStream)input, new CharArraySet(neverSplit, isLowerCase), (CharArrayMap<Integer>)vocabMap, unknownToken, maxInputCharsPerWord);
    }

    public WordPieceTokenFilter(TokenStream input, CharArraySet neverSplit, CharArrayMap<Integer> vocabulary, CharSequence unknownToken, int maxInputCharsPerWord) {
        super(input);
        this.tokens = new LinkedList();
        this.neverSplit = neverSplit;
        this.vocabulary = vocabulary;
        this.tokenizedValues = new ArrayList<WordPieceToken>();
        if (!vocabulary.containsKey(unknownToken)) {
            throw new IllegalArgumentException("provided vocabulary does not contain the unknown token of [" + unknownToken.toString() + "]");
        }
        this.unknownToken = unknownToken;
        this.tokenizedUnknown = (Integer)vocabulary.get(unknownToken);
        this.maxInputCharsPerWord = maxInputCharsPerWord;
    }

    public List<WordPieceToken> getTokenizedValues() {
        return this.tokenizedValues;
    }

    public void reset() throws IOException {
        super.reset();
        this.tokens.clear();
        this.tokenizedValues.clear();
        this.current = null;
    }

    public boolean incrementToken() throws IOException {
        if (!this.tokens.isEmpty()) {
            assert (this.current != null);
            WordPieceToken token = this.tokens.removeFirst();
            this.restoreState(this.current);
            this.termAtt.setEmpty().append(token.charSequence());
            this.offsetAtt.setOffset(token.startOffset(), token.endOffset());
            this.posIncAtt.setPositionIncrement(0);
            return true;
        }
        this.current = null;
        if (this.input.incrementToken()) {
            if (this.neverSplit.contains((CharSequence)this.termAtt)) {
                Integer maybeTokenized = (Integer)this.vocabulary.get((CharSequence)this.termAtt);
                this.tokenizedValues.add(new WordPieceToken(this.termAtt.toString(), Objects.requireNonNullElse(maybeTokenized, this.tokenizedUnknown), this.offsetAtt.startOffset(), this.offsetAtt.endOffset()));
                return true;
            }
            if (this.termAtt.length() > this.maxInputCharsPerWord) {
                this.tokenizedValues.add(new WordPieceToken(this.unknownToken, this.tokenizedUnknown, this.offsetAtt.startOffset(), this.offsetAtt.endOffset()));
                this.termAtt.setEmpty().append(this.unknownToken);
                return true;
            }
            boolean isBad = false;
            int start = 0;
            int length = this.termAtt.length();
            while (start < length) {
                int end;
                CharSequence currentValidSubStr = null;
                for (end = length; start < end; --end) {
                    CharSequence subStr = start > 0 ? new MultiCharSequence(List.of(CONTINUATION, this.termAtt.subSequence(start, end))) : this.termAtt.subSequence(start, end);
                    if (!this.vocabulary.containsKey(subStr)) continue;
                    currentValidSubStr = subStr;
                    break;
                }
                if (currentValidSubStr == null) {
                    isBad = true;
                    break;
                }
                int encoding = (Integer)this.vocabulary.get(currentValidSubStr);
                WordPieceToken t = new WordPieceToken(currentValidSubStr, encoding, this.offsetAtt.startOffset(), this.offsetAtt.endOffset());
                this.tokens.add(t);
                start = end;
            }
            if (isBad) {
                this.tokens.clear();
                WordPieceToken t = new WordPieceToken(this.unknownToken, this.tokenizedUnknown, this.offsetAtt.startOffset(), this.offsetAtt.endOffset());
                this.tokenizedValues.add(t);
                this.termAtt.setEmpty().append(this.unknownToken);
            } else {
                this.tokenizedValues.addAll(this.tokens);
                this.current = this.captureState();
                WordPieceToken token = this.tokens.removeFirst();
                this.termAtt.setEmpty().append(token.charSequence());
                this.offsetAtt.setOffset(token.startOffset(), token.endOffset());
            }
            return true;
        }
        return false;
    }

    public static class WordPieceToken
    extends DelimitedToken.Encoded
    implements CharSequence {
        WordPieceToken(CharSequence sequence, int encoding, int startOffset, int endOffset) {
            super(sequence, encoding, startOffset, endOffset);
        }

        @Override
        public int length() {
            return this.charSequence().length();
        }

        @Override
        public char charAt(int index) {
            return this.charSequence().charAt(index);
        }

        @Override
        public CharSequence subSequence(int start, int end) {
            return this.charSequence().subSequence(start, end);
        }

        @Override
        public String toString() {
            return this.charSequence().toString();
        }
    }
}

