/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.List;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;

public final class ExtractAggregateCommonFilter
extends OptimizerRules.OptimizerRule<Aggregate> {
    public ExtractAggregateCommonFilter() {
        super(OptimizerRules.TransformDirection.UP);
    }

    @Override
    protected LogicalPlan rule(Aggregate aggregate) {
        if (!aggregate.groupings().isEmpty()) {
            return aggregate;
        }
        ArrayList<Expression> filters = new ArrayList<Expression>(aggregate.aggregates().size());
        for (NamedExpression namedExpression : aggregate.aggregates()) {
            AggregateFunction aggFunction;
            Alias alias;
            Expression expression;
            if (namedExpression instanceof Alias && (expression = (alias = (Alias)namedExpression).child()) instanceof AggregateFunction && (aggFunction = (AggregateFunction)expression).hasFilter()) {
                filters.add(aggFunction.filter());
                continue;
            }
            return aggregate;
        }
        Tuple<Expression, List<Expression>> common = Predicates.extractCommon(filters);
        if (common.v1() == null) {
            return aggregate;
        }
        List list = (List)common.v2();
        ArrayList<Alias> newAggs = new ArrayList<Alias>(aggregate.aggregates().size());
        for (int i = 0; i < aggregate.aggregates().size(); ++i) {
            Alias alias = (Alias)aggregate.aggregates().get(i);
            AggregateFunction newChild = ((AggregateFunction)alias.child()).withFilter((Expression)list.get(i));
            newAggs.add(alias.replaceChild((Expression)newChild));
        }
        return aggregate.with(new Filter(aggregate.source(), aggregate.child(), (Expression)common.v1()), aggregate.groupings(), newAggs);
    }
}

