/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import com.ibm.icu.text.BreakIterator;
import java.io.CharArrayReader;
import java.io.IOException;
import java.io.Reader;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.CharBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Locale;
import java.util.Optional;
import java.util.OptionalInt;
import org.apache.lucene.analysis.charfilter.BaseCharFilter;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.CharsRefBuilder;
import org.apache.lucene.util.UnicodeUtil;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.MultiCharSequence;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizerUtils;

public class PrecompiledCharMapNormalizer
extends BaseCharFilter {
    private final int[] offsets;
    private final byte[] normalizedStrUtf8Bytes;
    private final byte[] reusableCharByteBuffer = new byte[4];
    private final char[] reusableCharDecodeBuffer = new char[8];
    private Reader transformedInput;

    static Config fromBase64EncodedResource(String resourcePath) throws IOException {
        byte[] bytes = Base64.getDecoder().wrap(PrecompiledCharMapNormalizer.class.getResourceAsStream(resourcePath)).readAllBytes();
        int offset = 0;
        int trieSize = ByteBuffer.wrap(bytes, offset, 4).order(ByteOrder.LITTLE_ENDIAN).getInt();
        offset += 4;
        int size = trieSize / 4;
        int[] offsets = new int[size];
        for (int i = 0; i < size; ++i) {
            offsets[i] = ByteBuffer.wrap(bytes, offset, 4).order(ByteOrder.LITTLE_ENDIAN).getInt();
            offset += 4;
        }
        String utf8Str = new String(bytes, offset, bytes.length - offset, StandardCharsets.UTF_8);
        return new Config(offsets, utf8Str);
    }

    public PrecompiledCharMapNormalizer(int[] offsets, String normalizedStr, Reader in) {
        super(in);
        this.offsets = offsets;
        this.normalizedStrUtf8Bytes = normalizedStr.getBytes(StandardCharsets.UTF_8);
    }

    private boolean hasLeaf(int v) {
        return (v >>> 8 & 1) == 1;
    }

    private int label(int v) {
        return v & 0x800000FF;
    }

    private int value(int v) {
        return v & Integer.MAX_VALUE;
    }

    private int offset(int v) {
        return v >>> 10 << ((v & 0x200) >>> 6);
    }

    OptionalInt commonPrefix(byte[] inputBytes) {
        return this.commonPrefix(inputBytes, 0, inputBytes.length);
    }

    private OptionalInt commonPrefix(byte[] inputBytes, int offset, int len) {
        int pos = 0;
        OptionalInt vs = OptionalInt.empty();
        int v = this.offsets[pos];
        pos ^= this.offset(v);
        for (int i = offset; i < offset + len; ++i) {
            int k = inputBytes[i];
            if (k < 0) {
                k += 256;
            }
            if (k == 0) break;
            v = this.offsets[pos ^= k];
            if (this.label(v) != k) {
                return vs;
            }
            pos ^= this.offset(v);
            if (!this.hasLeaf(v)) continue;
            vs = OptionalInt.of(this.value(this.offsets[pos]));
            return vs;
        }
        return vs;
    }

    private Optional<BytesRef> normalizePart(byte[] strBytes, int offset, int len) {
        int firstIndex;
        int secondIndex;
        OptionalInt index = this.commonPrefix(strBytes, offset, len);
        if (index.isEmpty()) {
            return Optional.empty();
        }
        for (secondIndex = firstIndex = index.getAsInt(); secondIndex < this.normalizedStrUtf8Bytes.length && this.normalizedStrUtf8Bytes[secondIndex] != 0; ++secondIndex) {
        }
        if (secondIndex == firstIndex) {
            return Optional.of(new BytesRef(BytesRef.EMPTY_BYTES));
        }
        return Optional.of(new BytesRef(this.normalizedStrUtf8Bytes, firstIndex, secondIndex - firstIndex));
    }

    Reader normalize(CharSequence str) {
        ByteBuffer byteBuffer = StandardCharsets.UTF_8.encode(CharBuffer.wrap(str));
        byte[] strBytes = new byte[byteBuffer.limit()];
        byteBuffer.get(strBytes);
        int[] strCp = str.codePoints().toArray();
        BreakIterator b = BreakIterator.getCharacterInstance((Locale)Locale.ROOT);
        b.setText(str);
        int startIter = b.first();
        int codePointPos = 0;
        CharsRefBuilder strBuilder = new CharsRefBuilder();
        strBuilder.grow(strBytes.length);
        int bytePos = 0;
        int normalizedCharPos = 0;
        int end = b.next();
        while (end != -1) {
            Optional<BytesRef> maybeSubStr;
            int byteLen = 0;
            int numCp = Character.codePointCount(str, startIter, end);
            for (int i = codePointPos; i < numCp + codePointPos; ++i) {
                byteLen += TokenizerUtils.numUtf8Bytes(strCp[i]);
            }
            codePointPos += numCp;
            if (byteLen < 6 && (maybeSubStr = this.normalizePart(strBytes, bytePos, byteLen)).isPresent()) {
                BytesRef subStr = maybeSubStr.get();
                int numChars = UnicodeUtil.UTF8toUTF16((byte[])subStr.bytes, (int)subStr.offset, (int)subStr.length, (char[])this.reusableCharDecodeBuffer);
                normalizedCharPos += numChars;
                if (numChars != end - startIter) {
                    this.addOffCorrectMap(normalizedCharPos, this.getLastCumulativeDiff() + end - startIter - numChars);
                }
                strBuilder.append(this.reusableCharDecodeBuffer, 0, numChars);
                bytePos += byteLen;
            } else {
                int charByteIndex = 0;
                for (int i = startIter; i < end; ++i) {
                    int utf8CharBytes = TokenizerUtils.numUtf8Bytes(str.charAt(i));
                    Optional<BytesRef> maybeSubStr2 = this.normalizePart(strBytes, charByteIndex + bytePos, utf8CharBytes);
                    if (maybeSubStr2.isPresent()) {
                        BytesRef subStr = maybeSubStr2.get();
                        int numChars = UnicodeUtil.UTF8toUTF16((byte[])subStr.bytes, (int)subStr.offset, (int)subStr.length, (char[])this.reusableCharDecodeBuffer);
                        normalizedCharPos += numChars;
                        if (numChars < 1) {
                            this.addOffCorrectMap(normalizedCharPos, this.getLastCumulativeDiff() + 1);
                        } else if (numChars > 1) {
                            this.addOffCorrectMap(normalizedCharPos, this.getLastCumulativeDiff() - 1);
                        }
                        strBuilder.append(this.reusableCharDecodeBuffer, 0, numChars);
                    } else {
                        ++normalizedCharPos;
                        strBuilder.append(str.charAt(i));
                    }
                    charByteIndex += utf8CharBytes;
                }
                bytePos += byteLen;
            }
            startIter = end;
            end = b.next();
        }
        return new CharArrayReader(strBuilder.chars(), 0, strBuilder.length());
    }

    public int read(char[] cbuf, int off, int len) throws IOException {
        if (this.transformedInput == null) {
            this.fill();
        }
        return this.transformedInput.read(cbuf, off, len);
    }

    public int read() throws IOException {
        if (this.transformedInput == null) {
            this.fill();
        }
        return this.transformedInput.read();
    }

    private void fill() throws IOException {
        ArrayList<CharSequence> charArrays = new ArrayList<CharSequence>();
        char[] temp = new char[1024];
        int cnt = this.input.read(temp);
        while (cnt > 0) {
            charArrays.add((CharSequence)new CharsRef(Arrays.copyOfRange(temp, 0, cnt), 0, cnt));
            cnt = this.input.read(temp);
        }
        this.transformedInput = this.normalize(new MultiCharSequence(charArrays));
    }

    record Config(int[] offsets, String utf8str) {
    }
}

