/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.compute.aggregation.blockhash;

import java.io.IOException;
import java.util.HashMap;
import java.util.Objects;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.analysis.AnalysisRegistry;
import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategory;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;

public class CategorizeBlockHash
extends BlockHash {
    private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig.buildStandardEsqlCategorizationAnalyzer();
    private static final int NULL_ORD = 0;
    private final int channel;
    private final AggregatorMode aggregatorMode;
    private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
    private final CategorizeEvaluator evaluator;
    private boolean seenNull = false;

    CategorizeBlockHash(BlockFactory blockFactory, int channel, AggregatorMode aggregatorMode, AnalysisRegistry analysisRegistry) {
        super(blockFactory);
        this.channel = channel;
        this.aggregatorMode = aggregatorMode;
        this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer(new CategorizationBytesRefHash(new BytesRefHash(2048L, blockFactory.bigArrays())), CategorizationPartOfSpeechDictionary.getInstance(), 0.7f);
        if (!aggregatorMode.isInputPartial()) {
            CategorizationAnalyzer analyzer;
            try {
                Objects.requireNonNull(analysisRegistry);
                analyzer = new CategorizationAnalyzer(analysisRegistry, ANALYZER_CONFIG);
            }
            catch (Exception e) {
                this.categorizer.close();
                throw new RuntimeException(e);
            }
            this.evaluator = new CategorizeEvaluator(analyzer);
        } else {
            this.evaluator = null;
        }
    }

    boolean seenNull() {
        return this.seenNull;
    }

    @Override
    public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
        try (IntBlock block = this.add(page);){
            addInput.add(0, block);
        }
    }

    @Override
    public Block[] getKeys() {
        return new Block[]{this.aggregatorMode.isOutputPartial() ? this.buildIntermediateBlock() : this.buildFinalBlock()};
    }

    @Override
    public IntVector nonEmpty() {
        return IntVector.range(this.seenNull ? 0 : 1, this.categorizer.getCategoryCount() + 1, this.blockFactory);
    }

    @Override
    public BitArray seenGroupIds(BigArrays bigArrays) {
        return new SeenGroupIds.Range(this.seenNull ? 0 : 1, Math.toIntExact(this.categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays);
    }

    @Override
    public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
        throw new UnsupportedOperationException();
    }

    public void close() {
        Releasables.close((Releasable[])new Releasable[]{this.evaluator, this.categorizer});
    }

    private IntBlock add(Page page) {
        return !this.aggregatorMode.isInputPartial() ? this.addInitial(page) : this.addIntermediate(page);
    }

    IntBlock addInitial(Page page) {
        return (IntBlock)this.evaluator.eval((BytesRefBlock)page.getBlock(this.channel));
    }

    private IntBlock addIntermediate(Page page) {
        if (page.getPositionCount() == 0) {
            return null;
        }
        BytesRefBlock categorizerState = (BytesRefBlock)page.getBlock(this.channel);
        if (categorizerState.areAllValuesNull()) {
            this.seenNull = true;
            return this.blockFactory.newConstantIntBlockWith(0, 1);
        }
        return this.recategorize(categorizerState.getBytesRef(0, new BytesRef()), null).asBlock();
    }

    IntVector recategorize(BytesRef bytes, IntVector ids) {
        HashMap<Integer, Integer> idMap = new HashMap<Integer, Integer>();
        try (StreamInput in = new BytesArray(bytes).streamInput();){
            if (in.readBoolean()) {
                this.seenNull = true;
                idMap.put(0, 0);
            }
            int count = in.readVInt();
            for (int oldCategoryId = 0; oldCategoryId < count; ++oldCategoryId) {
                int newCategoryId = this.categorizer.mergeWireCategory(new SerializableTokenListCategory(in)).getId();
                idMap.put(oldCategoryId + 1, newCategoryId + 1);
            }
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        try (IntVector.Builder newIdsBuilder = this.blockFactory.newIntVectorBuilder(idMap.size());){
            if (ids == null) {
                int idOffset = idMap.containsKey(0) ? 0 : 1;
                for (int i = 0; i < idMap.size(); ++i) {
                    newIdsBuilder.appendInt((Integer)idMap.get(i + idOffset));
                }
            } else {
                for (int i = 0; i < ids.getPositionCount(); ++i) {
                    newIdsBuilder.appendInt((Integer)idMap.get(ids.getInt(i)));
                }
            }
            IntVector intVector = newIdsBuilder.build();
            return intVector;
        }
    }

    private Block buildIntermediateBlock() {
        if (this.categorizer.getCategoryCount() == 0) {
            return this.blockFactory.newConstantNullBlock(this.seenNull ? 1 : 0);
        }
        int positionCount = this.categorizer.getCategoryCount() + (this.seenNull ? 1 : 0);
        return this.blockFactory.newConstantBytesRefBlockWith(this.serializeCategorizer(), positionCount);
    }

    BytesRef serializeCategorizer() {
        BytesRef bytesRef;
        BytesStreamOutput out = new BytesStreamOutput();
        try {
            out.writeBoolean(this.seenNull);
            out.writeVInt(this.categorizer.getCategoryCount());
            for (SerializableTokenListCategory category : this.categorizer.toCategoriesById()) {
                category.writeTo((StreamOutput)out);
            }
            bytesRef = out.bytes().toBytesRef();
        }
        catch (Throwable throwable) {
            try {
                try {
                    out.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        out.close();
        return bytesRef;
    }

    private Block buildFinalBlock() {
        BytesRefBuilder scratch = new BytesRefBuilder();
        if (this.seenNull) {
            try (BytesRefBlock.Builder result = this.blockFactory.newBytesRefBlockBuilder(this.categorizer.getCategoryCount());){
                result.appendNull();
                for (SerializableTokenListCategory category : this.categorizer.toCategoriesById()) {
                    scratch.copyChars((CharSequence)category.getRegex());
                    result.appendBytesRef(scratch.get());
                    scratch.clear();
                }
                BytesRefBlock bytesRefBlock = result.build();
                return bytesRefBlock;
            }
        }
        try (BytesRefVector.Builder result = this.blockFactory.newBytesRefVectorBuilder(this.categorizer.getCategoryCount());){
            for (SerializableTokenListCategory category : this.categorizer.toCategoriesById()) {
                scratch.copyChars((CharSequence)category.getRegex());
                result.appendBytesRef(scratch.get());
                scratch.clear();
            }
            BytesRefBlock bytesRefBlock = result.build().asBlock();
            return bytesRefBlock;
        }
    }

    private final class CategorizeEvaluator
    implements Releasable {
        private final CategorizationAnalyzer analyzer;

        CategorizeEvaluator(CategorizationAnalyzer analyzer) {
            this.analyzer = analyzer;
        }

        Block eval(BytesRefBlock vBlock) {
            BytesRefVector vVector = vBlock.asVector();
            if (vVector == null) {
                return this.eval(vBlock.getPositionCount(), vBlock);
            }
            IntVector vector = this.eval(vBlock.getPositionCount(), vVector);
            return vector.asBlock();
        }

        IntBlock eval(int positionCount, BytesRefBlock vBlock) {
            try (IntBlock.Builder result = CategorizeBlockHash.this.blockFactory.newIntBlockBuilder(positionCount);){
                BytesRef vScratch = new BytesRef();
                for (int p = 0; p < positionCount; ++p) {
                    if (vBlock.isNull(p)) {
                        CategorizeBlockHash.this.seenNull = true;
                        result.appendInt(0);
                        continue;
                    }
                    int first = vBlock.getFirstValueIndex(p);
                    int count = vBlock.getValueCount(p);
                    if (count == 1) {
                        result.appendInt(this.process(vBlock.getBytesRef(first, vScratch)));
                        continue;
                    }
                    int end = first + count;
                    result.beginPositionEntry();
                    for (int i = first; i < end; ++i) {
                        result.appendInt(this.process(vBlock.getBytesRef(i, vScratch)));
                    }
                    result.endPositionEntry();
                }
                IntBlock intBlock = result.build();
                return intBlock;
            }
        }

        IntVector eval(int positionCount, BytesRefVector vVector) {
            try (IntVector.FixedBuilder result = CategorizeBlockHash.this.blockFactory.newIntVectorFixedBuilder(positionCount);){
                BytesRef vScratch = new BytesRef();
                for (int p = 0; p < positionCount; ++p) {
                    result.appendInt(p, this.process(vVector.getBytesRef(p, vScratch)));
                }
                IntVector intVector = result.build();
                return intVector;
            }
        }

        int process(BytesRef v) {
            TokenListCategory category = CategorizeBlockHash.this.categorizer.computeCategory(v.utf8ToString(), this.analyzer);
            if (category == null) {
                CategorizeBlockHash.this.seenNull = true;
                return 0;
            }
            return category.getId() + 1;
        }

        public void close() {
            this.analyzer.close();
        }
    }
}

