/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.planner;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.elasticsearch.compute.aggregation.IntermediateStateDesc;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Max;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Median;
import org.elasticsearch.xpack.esql.expression.function.aggregate.MedianAbsoluteDeviation;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.esql.expression.function.aggregate.NumericAggregate;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.ql.expression.Alias;
import org.elasticsearch.xpack.ql.expression.AttributeMap;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.FieldAttribute;
import org.elasticsearch.xpack.ql.expression.MetadataAttribute;
import org.elasticsearch.xpack.ql.expression.NamedExpression;
import org.elasticsearch.xpack.ql.expression.ReferenceAttribute;
import org.elasticsearch.xpack.ql.expression.function.Function;
import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypes;

public class AggregateMapper {
    static final List<String> NUMERIC = List.of("Int", "Long", "Double");
    static final List<? extends Class<? extends Function>> AGG_FUNCTIONS = List.of(Count.class, CountDistinct.class, Max.class, Median.class, MedianAbsoluteDeviation.class, Min.class, Percentile.class, Sum.class);
    private final Map<AggDef, List<IntermediateStateDesc>> mapper;
    private final HashMap<Expression, List<? extends NamedExpression>> cache = new HashMap();

    AggregateMapper() {
        this(AGG_FUNCTIONS.stream().filter(Predicate.not(SurrogateExpression.class::isAssignableFrom)).toList());
    }

    AggregateMapper(List<? extends Class<? extends Function>> aggregateFunctionClasses) {
        this.mapper = aggregateFunctionClasses.stream().flatMap(AggregateMapper::typeAndNames).flatMap(AggregateMapper::groupingAndNonGrouping).collect(Collectors.toUnmodifiableMap(aggDef -> aggDef, AggregateMapper::lookupIntermediateState));
    }

    public List<? extends NamedExpression> mapNonGrouping(List<? extends Expression> aggregates) {
        return this.doMapping(aggregates, false);
    }

    public List<? extends NamedExpression> mapNonGrouping(Expression aggregate) {
        return this.map(aggregate, false).toList();
    }

    public List<? extends NamedExpression> mapGrouping(List<? extends Expression> aggregates) {
        return this.doMapping(aggregates, true);
    }

    private List<? extends NamedExpression> doMapping(List<? extends Expression> aggregates, boolean grouping) {
        AttributeMap attrToExpressions = new AttributeMap();
        aggregates.stream().flatMap(agg -> this.map((Expression)agg, grouping)).forEach(ne -> attrToExpressions.put(ne.toAttribute(), ne));
        return attrToExpressions.values().stream().toList();
    }

    public List<? extends NamedExpression> mapGrouping(Expression aggregate) {
        return this.map(aggregate, true).toList();
    }

    private Stream<? extends NamedExpression> map(Expression aggregate, boolean grouping) {
        aggregate = AggregateMapper.unwrapAlias(aggregate);
        return this.cache.computeIfAbsent(aggregate, aggKey -> this.computeEntryForAgg((Expression)aggKey, grouping)).stream();
    }

    private List<? extends NamedExpression> computeEntryForAgg(Expression aggregate, boolean grouping) {
        AggDef aggDef = AggregateMapper.aggDefOrNull(aggregate, grouping);
        if (aggDef != null) {
            List<IntermediateStateDesc> is = this.getNonNull(aggDef);
            List<NamedExpression> exp = AggregateMapper.isToNE(is).toList();
            return exp;
        }
        if (aggregate instanceof FieldAttribute || aggregate instanceof MetadataAttribute || aggregate instanceof ReferenceAttribute) {
            return List.of();
        }
        throw new EsqlIllegalArgumentException("unknown agg: " + aggregate.getClass() + ": " + aggregate);
    }

    private List<IntermediateStateDesc> getNonNull(AggDef aggDef) {
        List<IntermediateStateDesc> l = this.mapper.get(aggDef);
        if (l == null) {
            throw new EsqlIllegalArgumentException("Cannot find intermediate state for: " + aggDef);
        }
        return l;
    }

    static Stream<Tuple<Class<?>, String>> typeAndNames(Class<?> clazz) {
        List<String> types;
        if (NumericAggregate.class.isAssignableFrom(clazz)) {
            types = NUMERIC;
        } else if (clazz == Count.class) {
            types = List.of("");
        } else {
            assert (clazz == CountDistinct.class) : "Expected CountDistinct, got: " + clazz;
            types = Stream.concat(NUMERIC.stream(), Stream.of("Boolean", "BytesRef")).toList();
        }
        return types.stream().map(type -> new Tuple((Object)clazz, type));
    }

    static Stream<AggDef> groupingAndNonGrouping(Tuple<Class<?>, String> tuple) {
        return Stream.of(new AggDef((Class)tuple.v1(), (String)tuple.v2(), true), new AggDef((Class)tuple.v1(), (String)tuple.v2(), false));
    }

    static AggDef aggDefOrNull(Expression aggregate, boolean grouping) {
        if (aggregate instanceof AggregateFunction) {
            AggregateFunction aggregateFunction = (AggregateFunction)aggregate;
            return new AggDef(aggregateFunction.getClass(), AggregateMapper.dataTypeToString(aggregateFunction.field().dataType(), aggregateFunction.getClass()), grouping);
        }
        return null;
    }

    static List<IntermediateStateDesc> lookupIntermediateState(AggDef aggDef) {
        try {
            return AggregateMapper.lookup(aggDef.aggClazz(), aggDef.type(), aggDef.grouping()).invokeExact();
        }
        catch (Throwable t) {
            throw new EsqlIllegalArgumentException(t);
        }
    }

    static MethodHandle lookup(Class<?> clazz, String type, boolean grouping) {
        try {
            return MethodHandles.lookup().findStatic(Class.forName(AggregateMapper.determineAggName(clazz, type, grouping)), "intermediateStateDesc", MethodType.methodType(List.class));
        }
        catch (ClassNotFoundException | IllegalAccessException | NoSuchMethodException e) {
            throw new EsqlIllegalArgumentException(e);
        }
    }

    static String determineAggName(Class<?> clazz, String type, boolean grouping) {
        StringBuilder sb = new StringBuilder();
        sb.append("org.elasticsearch.compute.aggregation.");
        sb.append(clazz.getSimpleName());
        sb.append(type);
        sb.append(grouping ? "Grouping" : "");
        sb.append("AggregatorFunction");
        return sb.toString();
    }

    static Stream<NamedExpression> isToNE(List<IntermediateStateDesc> intermediateStateDescs) {
        return intermediateStateDescs.stream().map(is -> new ReferenceAttribute(Source.EMPTY, is.name(), AggregateMapper.toDataType(is.type())));
    }

    static DataType toDataType(ElementType elementType) {
        return switch (elementType) {
            case ElementType.BOOLEAN -> DataTypes.BOOLEAN;
            case ElementType.BYTES_REF -> DataTypes.KEYWORD;
            case ElementType.INT -> DataTypes.INTEGER;
            case ElementType.LONG -> DataTypes.LONG;
            case ElementType.DOUBLE -> DataTypes.DOUBLE;
            default -> throw new EsqlIllegalArgumentException("unsupported agg type: " + elementType);
        };
    }

    static String dataTypeToString(DataType type, Class<?> aggClass) {
        if (aggClass == Count.class) {
            return "";
        }
        if (type.equals((Object)DataTypes.BOOLEAN)) {
            return "Boolean";
        }
        if (type.equals((Object)DataTypes.INTEGER)) {
            return "Int";
        }
        if (type.equals((Object)DataTypes.LONG) || type.equals((Object)DataTypes.DATETIME)) {
            return "Long";
        }
        if (type.equals((Object)DataTypes.DOUBLE)) {
            return "Double";
        }
        if (type.equals((Object)DataTypes.KEYWORD) || type.equals((Object)DataTypes.IP) || type.equals((Object)DataTypes.TEXT)) {
            return "BytesRef";
        }
        throw new EsqlIllegalArgumentException("illegal agg type: " + type.typeName());
    }

    static Expression unwrapAlias(Expression expression) {
        if (expression instanceof Alias) {
            Alias alias = (Alias)expression;
            return alias.child();
        }
        return expression;
    }

    record AggDef(Class<?> aggClazz, String type, boolean grouping) {
    }
}

