/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.compute.operator;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.IntFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.PriorityQueue;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.compute.Describable;
import org.elasticsearch.compute.aggregation.GroupingAggregator;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.HashAggregationOperator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.mapper.BlockLoader;

public class OrdinalsGroupingOperator
implements Operator {
    private final IntFunction<BlockLoader> blockLoaders;
    private final List<ValuesSourceReaderOperator.ShardContext> shardContexts;
    private final int docChannel;
    private final String groupingField;
    private final List<GroupingAggregator.Factory> aggregatorFactories;
    private final ElementType groupingElementType;
    private final Map<SegmentID, OrdinalSegmentAggregator> ordinalAggregators;
    private final DriverContext driverContext;
    private boolean finished = false;
    private final int maxPageSize;
    private ValuesAggregator valuesAggregator;

    public OrdinalsGroupingOperator(IntFunction<BlockLoader> blockLoaders, List<ValuesSourceReaderOperator.ShardContext> shardContexts, ElementType groupingElementType, int docChannel, String groupingField, List<GroupingAggregator.Factory> aggregatorFactories, int maxPageSize, DriverContext driverContext) {
        Objects.requireNonNull(aggregatorFactories);
        this.blockLoaders = blockLoaders;
        this.shardContexts = shardContexts;
        this.groupingElementType = groupingElementType;
        this.docChannel = docChannel;
        this.groupingField = groupingField;
        this.aggregatorFactories = aggregatorFactories;
        this.ordinalAggregators = new HashMap<SegmentID, OrdinalSegmentAggregator>();
        this.maxPageSize = maxPageSize;
        this.driverContext = driverContext;
    }

    @Override
    public boolean needsInput() {
        return !this.finished;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void addInput(Page page) {
        OrdinalsGroupingOperator.checkState(this.needsInput(), "Operator is already finishing");
        Objects.requireNonNull(page, "page is null");
        DocVector docVector = ((DocBlock)page.getBlock(this.docChannel)).asVector();
        int shardIndex = docVector.shards().getInt(0);
        BlockLoader blockLoader = this.blockLoaders.apply(shardIndex);
        boolean pagePassed = false;
        try {
            if (docVector.singleSegmentNonDecreasing() && blockLoader.supportsOrdinals()) {
                IntVector segmentIndexVector = docVector.segments();
                assert (segmentIndexVector.isConstant());
                OrdinalSegmentAggregator ordinalAggregator = this.ordinalAggregators.computeIfAbsent(new SegmentID(shardIndex, segmentIndexVector.getInt(0)), k -> {
                    try {
                        return new OrdinalSegmentAggregator(this.driverContext.blockFactory(), this::createGroupingAggregators, (CheckedSupplier<SortedSetDocValues, IOException>)((CheckedSupplier)() -> blockLoader.ordinals((LeafReaderContext)this.shardContexts.get(k.shardIndex).reader().leaves().get(k.segmentIndex))), this.driverContext.bigArrays());
                    }
                    catch (IOException e) {
                        throw new UncheckedIOException(e);
                    }
                });
                pagePassed = true;
                ordinalAggregator.addInput(docVector.docs(), page);
            } else {
                if (this.valuesAggregator == null) {
                    int channelIndex = page.getBlockCount();
                    this.valuesAggregator = new ValuesAggregator(this.blockLoaders, this.shardContexts, this.groupingElementType, this.docChannel, this.groupingField, channelIndex, this.aggregatorFactories, this.maxPageSize, this.driverContext);
                }
                pagePassed = true;
                this.valuesAggregator.addInput(page);
            }
        }
        finally {
            if (!pagePassed) {
                Releasables.closeExpectNoException(page::releaseBlocks);
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private List<GroupingAggregator> createGroupingAggregators() {
        boolean success = false;
        ArrayList<GroupingAggregator> aggregators = new ArrayList<GroupingAggregator>(this.aggregatorFactories.size());
        try {
            for (GroupingAggregator.Factory aggregatorFactory : this.aggregatorFactories) {
                aggregators.add((GroupingAggregator)aggregatorFactory.apply(this.driverContext));
            }
            success = true;
            ArrayList<GroupingAggregator> arrayList = aggregators;
            return arrayList;
        }
        finally {
            if (!success) {
                Releasables.close(aggregators);
            }
        }
    }

    @Override
    public Page getOutput() {
        if (!this.finished) {
            return null;
        }
        if (this.valuesAggregator != null) {
            try {
                Page page = this.valuesAggregator.getOutput();
                return page;
            }
            finally {
                ValuesAggregator aggregator = this.valuesAggregator;
                this.valuesAggregator = null;
                Releasables.close((Releasable)aggregator);
            }
        }
        if (!this.ordinalAggregators.isEmpty()) {
            Page page;
            try {
                page = this.mergeOrdinalsSegmentResults();
            }
            catch (IOException e) {
                try {
                    throw new UncheckedIOException(e);
                }
                catch (Throwable throwable) {
                    Releasable[] releasableArray = new Releasable[2];
                    releasableArray[0] = () -> Releasables.close(this.ordinalAggregators.values());
                    releasableArray[1] = this.ordinalAggregators::clear;
                    Releasables.close((Releasable[])releasableArray);
                    throw throwable;
                }
            }
            Releasable[] releasableArray = new Releasable[2];
            releasableArray[0] = () -> Releasables.close(this.ordinalAggregators.values());
            releasableArray[1] = this.ordinalAggregators::clear;
            Releasables.close((Releasable[])releasableArray);
            return page;
        }
        return null;
    }

    @Override
    public void finish() {
        this.finished = true;
        if (this.valuesAggregator != null) {
            this.valuesAggregator.finish();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private Page mergeOrdinalsSegmentResults() throws IOException {
        PriorityQueue<AggregatedResultIterator> pq = new PriorityQueue<AggregatedResultIterator>(this.ordinalAggregators.size()){

            protected boolean lessThan(AggregatedResultIterator a, AggregatedResultIterator b) {
                return a.currentTerm.compareTo(b.currentTerm) < 0;
            }
        };
        List<GroupingAggregator> aggregators = this.createGroupingAggregators();
        try {
            Page page;
            block30: {
                Releasable[] blocks;
                int[] aggBlockCounts;
                int startPosition;
                boolean seenNulls = false;
                for (OrdinalSegmentAggregator agg : this.ordinalAggregators.values()) {
                    if (!agg.seenNulls()) continue;
                    seenNulls = true;
                    for (int i = 0; i < aggregators.size(); ++i) {
                        aggregators.get(i).addIntermediateRow(0, agg.aggregators.get(i), 0);
                    }
                }
                for (OrdinalSegmentAggregator agg : this.ordinalAggregators.values()) {
                    AggregatedResultIterator it = agg.getResultIterator();
                    if (!it.next()) continue;
                    pq.add((Object)it);
                }
                int position = startPosition = seenNulls ? 0 : -1;
                BytesRefBuilder lastTerm = new BytesRefBuilder();
                try (BytesRefBlock.Builder keysBuilder = this.driverContext.blockFactory().newBytesRefBlockBuilder(1);){
                    if (seenNulls) {
                        keysBuilder.appendNull();
                    }
                    while (pq.size() > 0) {
                        AggregatedResultIterator top = (AggregatedResultIterator)pq.top();
                        if (position == startPosition || !lastTerm.get().equals((Object)top.currentTerm)) {
                            ++position;
                            lastTerm.copyBytes(top.currentTerm);
                            keysBuilder.appendBytesRef(top.currentTerm);
                        }
                        for (int i = 0; i < top.aggregators.size(); ++i) {
                            aggregators.get(i).addIntermediateRow(position, top.aggregators.get(i), top.currentPosition());
                        }
                        if (top.next()) {
                            pq.updateTop();
                            continue;
                        }
                        pq.pop();
                    }
                    aggBlockCounts = aggregators.stream().mapToInt(GroupingAggregator::evaluateBlockCount).toArray();
                    blocks = new Block[1 + Arrays.stream(aggBlockCounts).sum()];
                    blocks[0] = keysBuilder.build();
                }
                boolean success = false;
                try {
                    try (IntVector selected = IntVector.range(0, blocks[0].getPositionCount(), this.driverContext.blockFactory());){
                        int offset = 1;
                        for (int i = 0; i < aggregators.size(); ++i) {
                            aggregators.get(i).evaluate((Block[])blocks, offset, selected, this.driverContext);
                            offset += aggBlockCounts[i];
                        }
                    }
                    success = true;
                    page = new Page((Block[])blocks);
                    if (success) break block30;
                }
                catch (Throwable throwable) {
                    if (!success) {
                        Releasables.closeExpectNoException((Releasable[])blocks);
                    }
                    throw throwable;
                }
                Releasables.closeExpectNoException((Releasable[])blocks);
            }
            return page;
        }
        finally {
            Releasables.close(() -> Releasables.close((Iterable)aggregators));
        }
    }

    @Override
    public boolean isFinished() {
        return this.finished && this.valuesAggregator == null && this.ordinalAggregators.isEmpty();
    }

    @Override
    public void close() {
        Releasables.close((Releasable[])new Releasable[]{() -> Releasables.close(this.ordinalAggregators.values()), this.valuesAggregator});
    }

    private static void checkState(boolean condition, String msg) {
        if (!condition) {
            throw new IllegalArgumentException(msg);
        }
    }

    public String toString() {
        return this.getClass().getSimpleName() + "[aggregators=" + this.aggregatorFactories + "]";
    }

    record SegmentID(int shardIndex, int segmentIndex) {
    }

    static final class OrdinalSegmentAggregator
    implements Releasable,
    SeenGroupIds {
        private final BlockFactory blockFactory;
        private final List<GroupingAggregator> aggregators;
        private final CheckedSupplier<SortedSetDocValues, IOException> docValuesSupplier;
        private final BitArray visitedOrds;
        private BlockOrdinalsReader currentReader;

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        OrdinalSegmentAggregator(BlockFactory blockFactory, Supplier<List<GroupingAggregator>> aggregatorsSupplier, CheckedSupplier<SortedSetDocValues, IOException> docValuesSupplier, BigArrays bigArrays) throws IOException {
            block6: {
                List<GroupingAggregator> groupingAggregators;
                block7: {
                    boolean success = false;
                    groupingAggregators = null;
                    BitArray bitArray = null;
                    try {
                        SortedSetDocValues sortedSetDocValues = (SortedSetDocValues)docValuesSupplier.get();
                        bitArray = new BitArray(sortedSetDocValues.getValueCount(), bigArrays);
                        groupingAggregators = aggregatorsSupplier.get();
                        this.currentReader = new BlockOrdinalsReader(sortedSetDocValues, blockFactory);
                        this.blockFactory = blockFactory;
                        this.docValuesSupplier = docValuesSupplier;
                        this.aggregators = groupingAggregators;
                        this.visitedOrds = bitArray;
                        success = true;
                        if (success) break block6;
                        if (bitArray == null) break block7;
                    }
                    catch (Throwable throwable) {
                        if (!success) {
                            if (bitArray != null) {
                                Releasables.close(bitArray);
                            }
                            if (groupingAggregators != null) {
                                Releasables.close(groupingAggregators);
                            }
                        }
                        throw throwable;
                    }
                    Releasables.close((Releasable)bitArray);
                }
                if (groupingAggregators != null) {
                    Releasables.close(groupingAggregators);
                }
            }
        }

        void addInput(IntVector docs, Page page) {
            try {
                GroupingAggregatorFunction.AddInput[] prepared = new GroupingAggregatorFunction.AddInput[this.aggregators.size()];
                for (int i = 0; i < prepared.length; ++i) {
                    prepared[i] = this.aggregators.get(i).prepareProcessPage(this, page);
                }
                if (!BlockOrdinalsReader.canReuse(this.currentReader, docs.getInt(0))) {
                    this.currentReader = new BlockOrdinalsReader((SortedSetDocValues)this.docValuesSupplier.get(), this.blockFactory);
                }
                try (IntBlock ordinals = this.currentReader.readOrdinalsAdded1(docs);){
                    for (int p = 0; p < ordinals.getPositionCount(); ++p) {
                        int start = ordinals.getFirstValueIndex(p);
                        int end = start + ordinals.getValueCount(p);
                        for (int i = start; i < end; ++i) {
                            long ord = ordinals.getInt(i);
                            this.visitedOrds.set(ord);
                        }
                    }
                    for (GroupingAggregatorFunction.AddInput addInput : prepared) {
                        addInput.add(0, ordinals);
                    }
                }
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
            finally {
                page.releaseBlocks();
            }
        }

        AggregatedResultIterator getResultIterator() throws IOException {
            return new AggregatedResultIterator(this.aggregators, this.visitedOrds, (SortedSetDocValues)this.docValuesSupplier.get());
        }

        boolean seenNulls() {
            return this.visitedOrds.get(0L);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public BitArray seenGroupIds(BigArrays bigArrays) {
            BitArray seen = new BitArray(0L, bigArrays);
            boolean success = false;
            try {
                seen.or(this.visitedOrds);
                success = true;
                BitArray bitArray = seen;
                return bitArray;
            }
            finally {
                if (!success) {
                    Releasables.close((Releasable)seen);
                }
            }
        }

        public void close() {
            Releasables.close((Releasable[])new Releasable[]{this.visitedOrds, () -> Releasables.close(this.aggregators)});
        }
    }

    private static class ValuesAggregator
    implements Releasable {
        private final ValuesSourceReaderOperator extractor;
        private final HashAggregationOperator aggregator;

        ValuesAggregator(IntFunction<BlockLoader> blockLoaders, List<ValuesSourceReaderOperator.ShardContext> shardContexts, ElementType groupingElementType, int docChannel, String groupingField, int channelIndex, List<GroupingAggregator.Factory> aggregatorFactories, int maxPageSize, DriverContext driverContext) {
            this.extractor = new ValuesSourceReaderOperator(driverContext.blockFactory(), List.of(new ValuesSourceReaderOperator.FieldInfo(groupingField, groupingElementType, blockLoaders)), shardContexts, docChannel);
            this.aggregator = new HashAggregationOperator(aggregatorFactories, () -> BlockHash.build(List.of(new HashAggregationOperator.GroupSpec(channelIndex, groupingElementType)), driverContext.blockFactory(), maxPageSize, false), driverContext);
        }

        void addInput(Page page) {
            this.extractor.addInput(page);
            Page out = this.extractor.getOutput();
            if (out != null) {
                this.aggregator.addInput(out);
            }
        }

        void finish() {
            this.aggregator.finish();
        }

        Page getOutput() {
            return this.aggregator.getOutput();
        }

        public void close() {
            Releasables.close((Releasable[])new Releasable[]{this.extractor, this.aggregator});
        }
    }

    private static class AggregatedResultIterator {
        private BytesRef currentTerm;
        private long currentOrd = 0L;
        private final List<GroupingAggregator> aggregators;
        private final BitArray ords;
        private final SortedSetDocValues dv;

        AggregatedResultIterator(List<GroupingAggregator> aggregators, BitArray ords, SortedSetDocValues dv) {
            this.aggregators = aggregators;
            this.ords = ords;
            this.dv = dv;
        }

        int currentPosition() {
            assert (this.currentOrd != Long.MAX_VALUE) : "Must not read position when iterator is exhausted";
            return Math.toIntExact(this.currentOrd);
        }

        boolean next() throws IOException {
            this.currentOrd = this.ords.nextSetBit(this.currentOrd + 1L);
            assert (this.currentOrd > 0L) : this.currentOrd;
            if (this.currentOrd < Long.MAX_VALUE) {
                this.currentTerm = this.dv.lookupOrd(this.currentOrd - 1L);
                return true;
            }
            this.currentTerm = null;
            return false;
        }
    }

    static final class BlockOrdinalsReader {
        private final SortedSetDocValues sortedSetDocValues;
        private final Thread creationThread;
        private final BlockFactory blockFactory;

        BlockOrdinalsReader(SortedSetDocValues sortedSetDocValues, BlockFactory blockFactory) {
            this.sortedSetDocValues = sortedSetDocValues;
            this.blockFactory = blockFactory;
            this.creationThread = Thread.currentThread();
        }

        IntBlock readOrdinalsAdded1(IntVector docs) throws IOException {
            int positionCount = docs.getPositionCount();
            try (IntBlock.Builder builder = this.blockFactory.newIntBlockBuilder(positionCount);){
                for (int p = 0; p < positionCount; ++p) {
                    int doc = docs.getInt(p);
                    if (!this.sortedSetDocValues.advanceExact(doc)) {
                        builder.appendInt(0);
                        continue;
                    }
                    int count = this.sortedSetDocValues.docValueCount();
                    if (count == 1) {
                        builder.appendInt(Math.toIntExact(this.sortedSetDocValues.nextOrd() + 1L));
                        continue;
                    }
                    builder.beginPositionEntry();
                    for (int i = 0; i < count; ++i) {
                        builder.appendInt(Math.toIntExact(this.sortedSetDocValues.nextOrd() + 1L));
                    }
                    builder.endPositionEntry();
                }
                IntBlock intBlock = builder.build();
                return intBlock;
            }
        }

        int docID() {
            return this.sortedSetDocValues.docID();
        }

        static boolean canReuse(BlockOrdinalsReader reader, int startingDocID) {
            return reader != null && reader.creationThread == Thread.currentThread() && reader.docID() <= startingDocID;
        }
    }

    public record OrdinalsGroupingOperatorFactory(IntFunction<BlockLoader> blockLoaders, List<ValuesSourceReaderOperator.ShardContext> shardContexts, ElementType groupingElementType, int docChannel, String groupingField, List<GroupingAggregator.Factory> aggregators, int maxPageSize) implements Operator.OperatorFactory
    {
        @Override
        public Operator get(DriverContext driverContext) {
            return new OrdinalsGroupingOperator(this.blockLoaders, this.shardContexts, this.groupingElementType, this.docChannel, this.groupingField, this.aggregators, this.maxPageSize, driverContext);
        }

        @Override
        public String describe() {
            return "OrdinalsGroupingOperator(aggs = " + this.aggregators.stream().map(Describable::describe).collect(Collectors.joining(", ")) + ")";
        }
    }
}

