/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Stream;
import org.elasticsearch.index.IndexMode;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.FromPartial;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
import org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;

public final class TranslateMetricsAggregate
extends OptimizerRules.OptimizerRule<Aggregate> {
    public TranslateMetricsAggregate() {
        super(OptimizerRules.TransformDirection.UP);
    }

    @Override
    protected LogicalPlan rule(Aggregate aggregate) {
        if (aggregate.aggregateType() == Aggregate.AggregateType.METRICS) {
            return this.translate(aggregate);
        }
        return aggregate;
    }

    LogicalPlan translate(Aggregate metrics) {
        HashMap rateAggs = new HashMap();
        ArrayList<Object> firstPassAggs = new ArrayList<Object>();
        ArrayList<Alias> secondPassAggs = new ArrayList<Alias>();
        for (NamedExpression namedExpression : metrics.aggregates()) {
            Alias alias;
            Expression expression;
            if (!(namedExpression instanceof Alias) || !((expression = (alias = (Alias)namedExpression).child()) instanceof AggregateFunction)) continue;
            AggregateFunction af = (AggregateFunction)expression;
            Holder changed = new Holder((Object)Boolean.FALSE);
            Expression outerAgg = (Expression)af.transformDown(Rate.class, rate -> {
                changed.set((Object)Boolean.TRUE);
                Alias rateAgg = rateAggs.computeIfAbsent(rate, k -> {
                    Alias newRateAgg = new Alias(rate.source(), agg.name(), (Expression)rate);
                    firstPassAggs.add(newRateAgg);
                    return newRateAgg;
                });
                return rateAgg.toAttribute();
            });
            if (((Boolean)changed.get()).booleanValue()) {
                secondPassAggs.add(new Alias(alias.source(), alias.name(), outerAgg, namedExpression.id()));
                continue;
            }
            Alias toPartial = new Alias(namedExpression.source(), alias.name(), (Expression)new ToPartial(namedExpression.source(), af.field(), (Expression)af));
            FromPartial fromPartial = new FromPartial(namedExpression.source(), (Expression)toPartial.toAttribute(), (Expression)af);
            firstPassAggs.add(toPartial);
            secondPassAggs.add(new Alias(alias.source(), alias.name(), (Expression)fromPartial, alias.id()));
        }
        if (rateAggs.isEmpty()) {
            return TranslateMetricsAggregate.toStandardAggregate(metrics);
        }
        Holder tsid = new Holder();
        Holder holder = new Holder();
        metrics.forEachDown(EsRelation.class, r -> {
            for (Attribute attr : r.output()) {
                if (attr.name().equals("_tsid")) {
                    tsid.set((Object)attr);
                }
                if (!attr.name().equals("@timestamp")) continue;
                timestamp.set((Object)attr);
            }
        });
        if (tsid.get() == null || holder.get() == null) {
            throw new IllegalArgumentException("_tsid or @timestamp field are missing from the metrics source");
        }
        ArrayList<Expression> firstPassGroupings = new ArrayList<Expression>();
        firstPassGroupings.add((Expression)tsid.get());
        ArrayList<Expression> secondPassGroupings = new ArrayList<Expression>();
        Holder timeBucketRef = new Holder();
        metrics.child().forEachExpressionUp(NamedExpression.class, e -> {
            for (Expression child : e.children()) {
                Bucket bucket;
                if (!(child instanceof Bucket) || !(bucket = (Bucket)child).field().equals(timestamp.get())) continue;
                if (timeBucketRef.get() != null) {
                    throw new IllegalArgumentException("expected at most one time bucket");
                }
                timeBucketRef.set(e);
            }
        });
        NamedExpression timeBucket = (NamedExpression)timeBucketRef.get();
        for (Expression group : metrics.groupings()) {
            Attribute newFinalGroup;
            if (!(group instanceof Attribute)) {
                throw new EsqlIllegalArgumentException("expected named expression for grouping; got " + String.valueOf(group));
            }
            Attribute g = (Attribute)group;
            if (timeBucket != null && g.id().equals((Object)timeBucket.id())) {
                newFinalGroup = timeBucket.toAttribute();
                firstPassGroupings.add((Expression)newFinalGroup);
            } else {
                newFinalGroup = new Alias(g.source(), g.name(), (Expression)new Values(g.source(), (Expression)g), g.id());
                firstPassAggs.add(newFinalGroup);
            }
            secondPassGroupings.add((Expression)new Alias(g.source(), g.name(), (Expression)newFinalGroup.toAttribute(), g.id()));
        }
        return TranslateMetricsAggregate.newAggregate(TranslateMetricsAggregate.newAggregate(metrics.child(), Aggregate.AggregateType.METRICS, firstPassAggs, firstPassGroupings), Aggregate.AggregateType.STANDARD, secondPassAggs, secondPassGroupings);
    }

    private static Aggregate toStandardAggregate(Aggregate metrics) {
        LogicalPlan child = (LogicalPlan)metrics.child().transformDown(EsRelation.class, r -> {
            ArrayList<Attribute> attributes = new ArrayList<Attribute>((Collection<Attribute>)AttributeSet.of((Collection)metrics.inputSet()));
            attributes.removeIf(a -> a.name().equals("_tsid"));
            if (attributes.stream().noneMatch(a -> a.name().equals("@timestamp"))) {
                attributes.removeIf(a -> a.name().equals("@timestamp"));
            }
            return new EsRelation(r.source(), r.indexPattern(), IndexMode.STANDARD, r.indexNameWithModes(), new ArrayList<Attribute>(attributes));
        });
        return new Aggregate(metrics.source(), child, Aggregate.AggregateType.STANDARD, metrics.groupings(), metrics.aggregates());
    }

    private static Aggregate newAggregate(LogicalPlan child, Aggregate.AggregateType type, List<? extends NamedExpression> aggregates, List<Expression> groupings) {
        return new Aggregate(child.source(), child, type, groupings, Stream.concat(aggregates.stream(), groupings.stream().map(Expressions::attribute)).toList());
    }
}

