/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.elasticsearch.index.IndexMode;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.FromPartial;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
import org.elasticsearch.xpack.esql.expression.function.aggregate.TimeSeriesAggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;

public final class TranslateTimeSeriesAggregate
extends OptimizerRules.OptimizerRule<Aggregate> {
    public TranslateTimeSeriesAggregate() {
        super(OptimizerRules.TransformDirection.UP);
    }

    @Override
    protected LogicalPlan rule(Aggregate aggregate) {
        TimeSeriesAggregate ts;
        if (aggregate instanceof TimeSeriesAggregate && (ts = (TimeSeriesAggregate)aggregate).timeBucket() == null) {
            return this.translate(ts);
        }
        return aggregate;
    }

    LogicalPlan translate(TimeSeriesAggregate aggregate) {
        HashMap timeSeriesAggs = new HashMap();
        ArrayList<Object> firstPassAggs = new ArrayList<Object>();
        ArrayList<Alias> secondPassAggs = new ArrayList<Alias>();
        Holder hasRateAggregates = new Holder((Object)Boolean.FALSE);
        InternalNames internalNames = new InternalNames();
        for (NamedExpression namedExpression : aggregate.aggregates()) {
            Alias alias;
            Expression expression;
            if (!(namedExpression instanceof Alias) || !((expression = (alias = (Alias)namedExpression).child()) instanceof AggregateFunction)) continue;
            AggregateFunction af = (AggregateFunction)expression;
            Holder changed = new Holder((Object)Boolean.FALSE);
            Expression outerAgg = (Expression)af.transformDown(TimeSeriesAggregateFunction.class, tsAgg -> {
                changed.set((Object)Boolean.TRUE);
                if (tsAgg instanceof Rate) {
                    hasRateAggregates.set((Object)Boolean.TRUE);
                }
                AggregateFunction firstStageFn = tsAgg.perTimeSeriesAggregation();
                Alias newAgg = timeSeriesAggs.computeIfAbsent(firstStageFn, k -> {
                    Alias firstStageAlias = new Alias(tsAgg.source(), internalNames.next(tsAgg.functionName()), (Expression)firstStageFn);
                    firstPassAggs.add(firstStageAlias);
                    return firstStageAlias;
                });
                return newAgg.toAttribute();
            });
            if (((Boolean)changed.get()).booleanValue()) {
                secondPassAggs.add(new Alias(alias.source(), alias.name(), outerAgg, namedExpression.id()));
                continue;
            }
            Alias toPartial = new Alias(namedExpression.source(), alias.name(), (Expression)new ToPartial(namedExpression.source(), af.field(), (Expression)af));
            FromPartial fromPartial = new FromPartial(namedExpression.source(), (Expression)toPartial.toAttribute(), (Expression)af);
            firstPassAggs.add(toPartial);
            secondPassAggs.add(new Alias(alias.source(), alias.name(), (Expression)fromPartial, alias.id()));
        }
        if (timeSeriesAggs.isEmpty()) {
            return new Aggregate(aggregate.source(), aggregate.child(), aggregate.groupings(), aggregate.aggregates());
        }
        Holder tsid = new Holder();
        Holder holder = new Holder();
        aggregate.forEachDown(EsRelation.class, r -> {
            for (Attribute attr : r.output()) {
                if (attr.name().equals("_tsid")) {
                    tsid.set((Object)attr);
                }
                if (!attr.name().equals("@timestamp")) continue;
                timestamp.set((Object)attr);
            }
        });
        if (tsid.get() == null) {
            tsid.set((Object)new MetadataAttribute(aggregate.source(), "_tsid", DataType.KEYWORD, false));
        }
        if (holder.get() == null) {
            throw new IllegalArgumentException("_tsid or @timestamp field are missing from the time-series source");
        }
        ArrayList<Expression> firstPassGroupings = new ArrayList<Expression>();
        firstPassGroupings.add((Expression)tsid.get());
        ArrayList<Expression> secondPassGroupings = new ArrayList<Expression>();
        Holder timeBucketRef = new Holder();
        aggregate.child().forEachExpressionUp(NamedExpression.class, e -> {
            for (Expression child : e.children()) {
                Bucket bucket;
                if (!(child instanceof Bucket) || !(bucket = (Bucket)child).field().equals(timestamp.get())) continue;
                if (timeBucketRef.get() != null) {
                    throw new IllegalArgumentException("expected at most one time bucket");
                }
                timeBucketRef.set(e);
            }
        });
        NamedExpression timeBucket = (NamedExpression)timeBucketRef.get();
        for (Expression group : aggregate.groupings()) {
            Attribute newFinalGroup;
            if (!(group instanceof Attribute)) {
                throw new EsqlIllegalArgumentException("expected named expression for grouping; got " + String.valueOf(group));
            }
            Attribute g = (Attribute)group;
            if (timeBucket != null && g.id().equals((Object)timeBucket.id())) {
                newFinalGroup = timeBucket.toAttribute();
                firstPassGroupings.add((Expression)newFinalGroup);
            } else {
                newFinalGroup = new Alias(g.source(), g.name(), (Expression)new Values(g.source(), (Expression)g), g.id());
                firstPassAggs.add(newFinalGroup);
            }
            secondPassGroupings.add((Expression)new Alias(g.source(), g.name(), (Expression)newFinalGroup.toAttribute(), g.id()));
        }
        LogicalPlan newChild = (LogicalPlan)aggregate.child().transformUp(EsRelation.class, r -> {
            IndexMode indexMode;
            IndexMode indexMode2 = indexMode = (Boolean)hasRateAggregates.get() != false ? r.indexMode() : IndexMode.STANDARD;
            if (!r.output().contains(tsid.get())) {
                return new EsRelation(r.source(), r.indexPattern(), indexMode, r.indexNameWithModes(), CollectionUtils.combine(r.output(), (Object[])new Attribute[]{(Attribute)tsid.get()}));
            }
            return new EsRelation(r.source(), r.indexPattern(), indexMode, r.indexNameWithModes(), r.output());
        });
        TimeSeriesAggregate firstPhase = new TimeSeriesAggregate(newChild.source(), newChild, firstPassGroupings, TranslateTimeSeriesAggregate.mergeExpressions(firstPassAggs, firstPassGroupings), (Bucket)Alias.unwrap((Expression)timeBucket));
        return new Aggregate(firstPhase.source(), firstPhase, secondPassGroupings, TranslateTimeSeriesAggregate.mergeExpressions(secondPassAggs, secondPassGroupings));
    }

    private static List<? extends NamedExpression> mergeExpressions(List<? extends NamedExpression> aggregates, List<Expression> groupings) {
        ArrayList<? extends NamedExpression> merged = new ArrayList<NamedExpression>(aggregates.size() + groupings.size());
        merged.addAll(aggregates);
        groupings.forEach(g -> merged.add((NamedExpression)Expressions.attribute((Expression)g)));
        return merged;
    }

    private static class InternalNames {
        final Map<String, Integer> next = new HashMap<String, Integer>();

        private InternalNames() {
        }

        String next(String prefix) {
            int id = this.next.merge(prefix, 1, Integer::sum);
            return prefix + "_$" + id;
        }
    }
}

