/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.TemporaryNameUtils;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Stats;

public final class ReplaceStatsNestedExpressionWithEval
extends OptimizerRules.OptimizerRule<LogicalPlan> {
    @Override
    protected LogicalPlan rule(LogicalPlan p) {
        if (p instanceof Stats) {
            Stats stats = (Stats)((Object)p);
            return this.rule(stats);
        }
        return p;
    }

    @Override
    private LogicalPlan rule(Stats aggregate) {
        ArrayList<Alias> evals = new ArrayList<Alias>();
        HashMap<String, Object> evalNames = new HashMap<String, Object>();
        HashMap<GroupingFunction, Object> groupingAttributes = new HashMap<GroupingFunction, Object>();
        ArrayList<Expression> newGroupings = new ArrayList<Expression>(aggregate.groupings());
        boolean groupingChanged = false;
        int s = newGroupings.size();
        for (int i = 0; i < s; ++i) {
            Expression g = (Expression)newGroupings.get(i);
            if (!(g instanceof Alias)) continue;
            Alias as2 = (Alias)g;
            groupingChanged = true;
            Attribute attr = as2.toAttribute();
            evals.add(as2);
            evalNames.put(as2.name(), attr);
            newGroupings.set(i, (Expression)attr);
            Expression expression = as2.child();
            if (!(expression instanceof GroupingFunction)) continue;
            GroupingFunction groupingFunction = (GroupingFunction)expression;
            groupingAttributes.put(groupingFunction, attr);
        }
        Holder aggsChanged = new Holder((Object)false);
        List<? extends NamedExpression> aggs = aggregate.aggregates();
        ArrayList<NamedExpression> newAggs = new ArrayList<NamedExpression>(aggs.size());
        HashMap<Expression, Attribute> expToAttribute = new HashMap<Expression, Attribute>();
        for (Alias alias : evals) {
            expToAttribute.put(alias.child().canonical(), alias.toAttribute());
        }
        int[] counter = new int[]{0};
        for (NamedExpression namedExpression : aggs) {
            NamedExpression a = (NamedExpression)namedExpression.transformDown(Alias.class, as -> {
                AggregateFunction af2;
                Expression child = as.child();
                if (child instanceof AggregateFunction) {
                    af2 = (AggregateFunction)child;
                    Holder foundNestedAggs = new Holder((Object)Boolean.FALSE);
                    af2.children().forEach(e -> e.forEachDown(AggregateFunction.class, unused -> foundNestedAggs.set((Object)Boolean.TRUE)));
                    if (((Boolean)foundNestedAggs.get()).booleanValue()) {
                        return as;
                    }
                }
                if (child instanceof AggregateFunction && (af2 = (AggregateFunction)child).field() instanceof Attribute) {
                    return as;
                }
                Attribute ref = (Attribute)evalNames.get(as.name());
                if (ref != null) {
                    aggsChanged.set((Object)true);
                    return ref;
                }
                Expression replaced = (Expression)child.transformUp(AggregateFunction.class, af -> {
                    AggregateFunction result = af;
                    Expression field = af.field();
                    if (!(field instanceof Attribute) && !field.foldable()) {
                        Attribute attr = expToAttribute.computeIfAbsent(field.canonical(), k -> {
                            int n = counter[0];
                            counter[0] = n + 1;
                            Alias newAlias = new Alias(k.source(), ReplaceStatsNestedExpressionWithEval.syntheticName(k, af, n), k, null, true);
                            evals.add(newAlias);
                            return newAlias.toAttribute();
                        });
                        aggsChanged.set((Object)true);
                        ArrayList<Attribute> newChildren = new ArrayList<Attribute>(af.children());
                        newChildren.set(0, attr);
                        result = (Expression)af.replaceChildren(newChildren);
                    }
                    return result;
                });
                replaced = (Expression)replaced.transformDown(GroupingFunction.class, gf -> {
                    aggsChanged.set((Object)true);
                    return (Expression)groupingAttributes.get(gf);
                });
                return as.replaceChild(replaced);
            });
            newAggs.add(a);
        }
        if (evals.size() > 0) {
            List<Expression> list = groupingChanged ? newGroupings : aggregate.groupings();
            List<Object> list2 = (Boolean)aggsChanged.get() != false ? newAggs : aggregate.aggregates();
            Eval newEval = new Eval(aggregate.source(), aggregate.child(), evals);
            aggregate = aggregate.with(newEval, list, list2);
        }
        return (LogicalPlan)((Object)aggregate);
    }

    static String syntheticName(Expression expression, AggregateFunction af, int counter) {
        return TemporaryNameUtils.temporaryName(expression, (Expression)af, counter);
    }
}

