/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.planner;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.function.Consumer;
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregator;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.operator.AggregationOperator;
import org.elasticsearch.compute.operator.HashAggregationOperator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.planner.AggregateMapper;
import org.elasticsearch.xpack.esql.planner.Layout;
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner;
import org.elasticsearch.xpack.esql.planner.PhysicalOperationProviders;
import org.elasticsearch.xpack.esql.planner.PlannerUtils;
import org.elasticsearch.xpack.esql.planner.ToAggregator;
import org.elasticsearch.xpack.ql.InvalidArgumentException;
import org.elasticsearch.xpack.ql.expression.Alias;
import org.elasticsearch.xpack.ql.expression.Attribute;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.Expressions;
import org.elasticsearch.xpack.ql.expression.NameId;
import org.elasticsearch.xpack.ql.expression.NamedExpression;
import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;

public abstract class AbstractPhysicalOperationProviders
implements PhysicalOperationProviders {
    private final AggregateMapper aggregateMapper = new AggregateMapper();

    @Override
    public final LocalExecutionPlanner.PhysicalOperation groupingPhysicalOperation(AggregateExec aggregateExec, LocalExecutionPlanner.PhysicalOperation source, LocalExecutionPlanner.LocalExecutionPlannerContext context) {
        Layout.Builder layout = new Layout.Builder();
        Object operatorFactory = null;
        AggregateExec.Mode mode = aggregateExec.getMode();
        List<? extends NamedExpression> aggregates = aggregateExec.aggregates();
        Layout sourceLayout = source.layout;
        if (aggregateExec.groupings().isEmpty()) {
            ArrayList aggregatorFactories = new ArrayList();
            if (mode == AggregateExec.Mode.FINAL) {
                layout.append(aggregates);
            } else {
                layout.append(this.aggregateMapper.mapNonGrouping(aggregates));
            }
            this.aggregatesToFactory(aggregates, mode, sourceLayout, false, s -> aggregatorFactories.add(s.supplier.aggregatorFactory(s.mode)));
            if (!aggregatorFactories.isEmpty()) {
                operatorFactory = new AggregationOperator.AggregationOperatorFactory(aggregatorFactories, mode == AggregateExec.Mode.FINAL ? AggregatorMode.FINAL : AggregatorMode.INITIAL);
            }
        } else {
            ArrayList<GroupingAggregator.Factory> aggregatorFactories = new ArrayList<GroupingAggregator.Factory>();
            ArrayList<GroupSpec> groupSpecs = new ArrayList<GroupSpec>(aggregateExec.groupings().size());
            for (Expression expression : aggregateExec.groupings()) {
                Attribute groupAttribute = Expressions.attribute((Expression)expression);
                if (groupAttribute == null) {
                    throw new EsqlIllegalArgumentException("Unexpected non-named expression[{}] as grouping in [{}]", expression, aggregateExec);
                }
                Layout.ChannelSet groupAttributeLayout = new Layout.ChannelSet(new HashSet<NameId>(), groupAttribute.dataType());
                groupAttributeLayout.nameIds().add(groupAttribute.id());
                for (NamedExpression namedExpression : aggregates) {
                    Alias a;
                    Expression expression2;
                    if (!(namedExpression instanceof Alias) || !((expression2 = (a = (Alias)namedExpression).child()) instanceof Attribute)) continue;
                    Attribute attr = (Attribute)expression2;
                    if (groupAttribute.id().equals((Object)attr.id())) {
                        groupAttributeLayout.nameIds().add(a.id());
                        continue;
                    }
                    if (mode != AggregateExec.Mode.PARTIAL || !groupAttribute.semanticEquals((Expression)a.toAttribute())) continue;
                    groupAttribute = attr;
                    break;
                }
                layout.append(groupAttributeLayout);
                Layout.ChannelAndType groupInput = source.layout.get(groupAttribute.id());
                groupSpecs.add(new GroupSpec(groupInput == null ? null : Integer.valueOf(groupInput.channel()), groupAttribute));
            }
            if (mode == AggregateExec.Mode.FINAL) {
                for (NamedExpression namedExpression : aggregates) {
                    if (!(Alias.unwrap((Expression)namedExpression) instanceof AggregateFunction)) continue;
                    layout.append(namedExpression);
                }
            } else {
                layout.append(this.aggregateMapper.mapGrouping(aggregates));
            }
            this.aggregatesToFactory(aggregates, mode, sourceLayout, true, s -> aggregatorFactories.add(s.supplier.groupingAggregatorFactory(s.mode)));
            operatorFactory = groupSpecs.size() == 1 && ((GroupSpec)groupSpecs.get((int)0)).channel == null ? this.ordinalGroupingOperatorFactory(source, aggregateExec, aggregatorFactories, ((GroupSpec)groupSpecs.get((int)0)).attribute, ((GroupSpec)groupSpecs.get(0)).elementType(), context) : new HashAggregationOperator.HashAggregationOperatorFactory(groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(), aggregatorFactories, context.pageSize(aggregateExec.estimatedRowSize()));
        }
        if (operatorFactory != null) {
            return source.with((Operator.OperatorFactory)operatorFactory, layout.build());
        }
        throw new EsqlIllegalArgumentException("no operator factory");
    }

    public static List<Attribute> intermediateAttributes(List<? extends NamedExpression> aggregates, List<? extends Expression> groupings) {
        AggregateMapper aggregateMapper = new AggregateMapper();
        ArrayList<Attribute> attrs = new ArrayList();
        if (groupings.isEmpty()) {
            attrs = Expressions.asAttributes(aggregateMapper.mapNonGrouping(aggregates));
        } else {
            for (Expression expression : groupings) {
                Attribute groupAttribute = Expressions.attribute((Expression)expression);
                if (groupAttribute == null) {
                    throw new EsqlIllegalArgumentException("Unexpected non-named expression[{}] as grouping", expression);
                }
                HashSet<NameId> grpAttribIds = new HashSet<NameId>();
                grpAttribIds.add(groupAttribute.id());
                for (NamedExpression namedExpression : aggregates) {
                    Alias a;
                    Expression expression2;
                    if (!(namedExpression instanceof Alias) || !((expression2 = (a = (Alias)namedExpression).child()) instanceof Attribute)) continue;
                    Attribute attr = (Attribute)expression2;
                    if (!groupAttribute.id().equals((Object)attr.id())) continue;
                    grpAttribIds.add(a.id());
                }
                attrs.add(groupAttribute);
            }
            attrs.addAll(Expressions.asAttributes(aggregateMapper.mapGrouping(aggregates)));
        }
        return attrs;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private void aggregatesToFactory(List<? extends NamedExpression> aggregates, AggregateExec.Mode mode, Layout layout, boolean grouping, Consumer<AggFunctionSupplierContext> consumer) {
        for (NamedExpression namedExpression : aggregates) {
            List<Object> sourceAttr;
            Alias alias;
            Expression child;
            if (!(namedExpression instanceof Alias) || !((child = (alias = (Alias)namedExpression).child()) instanceof AggregateFunction)) continue;
            AggregateFunction aggregateFunction = (AggregateFunction)child;
            AggregatorMode aggMode = null;
            if (mode == AggregateExec.Mode.PARTIAL) {
                aggMode = AggregatorMode.INITIAL;
                Expression field = aggregateFunction.field();
                if (field.foldable()) {
                    if (!(aggregateFunction instanceof Count)) throw new InvalidArgumentException("Does not support yet aggregations over constants - [{}]", new Object[]{aggregateFunction.sourceText()});
                    sourceAttr = Collections.emptyList();
                } else {
                    Attribute attr2 = Expressions.attribute((Expression)field);
                    if (attr2 == null) {
                        throw new EsqlIllegalArgumentException("Cannot work with target field [{}] for agg [{}]", field.sourceText(), aggregateFunction.sourceText());
                    }
                    sourceAttr = List.of(attr2);
                }
            } else {
                if (mode != AggregateExec.Mode.FINAL) throw new EsqlIllegalArgumentException("illegal aggregation mode");
                aggMode = AggregatorMode.FINAL;
                sourceAttr = grouping ? this.aggregateMapper.mapGrouping((Expression)aggregateFunction) : this.aggregateMapper.mapNonGrouping((Expression)aggregateFunction);
            }
            List aggParams = aggregateFunction.parameters();
            Object[] params = new Object[aggParams.size()];
            for (int i2 = 0; i2 < params.length; ++i2) {
                params[i2] = ((Expression)aggParams.get(i2)).fold();
            }
            List<Integer> inputChannels = sourceAttr.stream().map(attr -> layout.get(attr.id()).channel()).toList();
            if (inputChannels.size() > 0) assert (inputChannels.size() > 0 && inputChannels.stream().allMatch(i -> i >= 0));
            if (!(aggregateFunction instanceof ToAggregator)) throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
            ToAggregator agg = (ToAggregator)aggregateFunction;
            consumer.accept(new AggFunctionSupplierContext(agg.supplier(inputChannels), aggMode));
        }
    }

    public abstract Operator.OperatorFactory ordinalGroupingOperatorFactory(LocalExecutionPlanner.PhysicalOperation var1, AggregateExec var2, List<GroupingAggregator.Factory> var3, Attribute var4, ElementType var5, LocalExecutionPlanner.LocalExecutionPlannerContext var6);

    private record GroupSpec(Integer channel, Attribute attribute) {
        HashAggregationOperator.GroupSpec toHashGroupSpec() {
            if (this.channel == null) {
                throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead");
            }
            return new HashAggregationOperator.GroupSpec(this.channel.intValue(), this.elementType());
        }

        ElementType elementType() {
            return PlannerUtils.toElementType(this.attribute.dataType());
        }
    }

    private record AggFunctionSupplierContext(AggregatorFunctionSupplier supplier, AggregatorMode mode) {
    }
}

