/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules;
import org.elasticsearch.xpack.esql.core.plan.logical.Filter;
import org.elasticsearch.xpack.esql.core.plan.logical.Limit;
import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.core.rule.ParameterizedRule;
import org.elasticsearch.xpack.esql.core.rule.ParameterizedRuleExecutor;
import org.elasticsearch.xpack.esql.core.rule.Rule;
import org.elasticsearch.xpack.esql.core.rule.RuleExecutor;
import org.elasticsearch.xpack.esql.core.tree.Node;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
import org.elasticsearch.xpack.esql.optimizer.rules.PropagateEmptyRelation;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.RegexExtract;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders;
import org.elasticsearch.xpack.esql.planner.PlannerUtils;
import org.elasticsearch.xpack.esql.stats.SearchStats;

public class LocalLogicalPlanOptimizer
extends ParameterizedRuleExecutor<LogicalPlan, LocalLogicalOptimizerContext> {
    public LocalLogicalPlanOptimizer(LocalLogicalOptimizerContext localLogicalOptimizerContext) {
        super((Object)localLogicalOptimizerContext);
    }

    protected List<RuleExecutor.Batch<LogicalPlan>> batches() {
        RuleExecutor.Batch local = new RuleExecutor.Batch("Local rewrite", RuleExecutor.Limiter.ONCE, new Rule[]{new ReplaceTopNWithLimitAndSort(), new ReplaceMissingFieldWithNull(), new InferIsNotNull(), new InferNonNullAggConstraint()});
        ArrayList<RuleExecutor.Batch<LogicalPlan>> rules = new ArrayList<RuleExecutor.Batch<LogicalPlan>>();
        rules.add(local);
        rules.addAll(Arrays.asList(LogicalPlanOptimizer.operators(), LogicalPlanOptimizer.cleanup()));
        this.replaceRules(rules);
        return rules;
    }

    private List<RuleExecutor.Batch<LogicalPlan>> replaceRules(List<RuleExecutor.Batch<LogicalPlan>> listOfRules) {
        for (RuleExecutor.Batch<LogicalPlan> batch : listOfRules) {
            Rule[] rules = batch.rules();
            for (int i = 0; i < rules.length; ++i) {
                if (!(rules[i] instanceof PropagateEmptyRelation)) continue;
                rules[i] = new LocalPropagateEmptyRelation();
            }
        }
        return listOfRules;
    }

    public LogicalPlan localOptimize(LogicalPlan plan) {
        return (LogicalPlan)this.execute((Node)plan);
    }

    public static class ReplaceTopNWithLimitAndSort
    extends OptimizerRules.OptimizerRule<TopN> {
        public ReplaceTopNWithLimitAndSort() {
            super(OptimizerRules.TransformDirection.UP);
        }

        protected LogicalPlan rule(TopN plan) {
            return new Limit(plan.source(), plan.limit(), (LogicalPlan)new OrderBy(plan.source(), plan.child(), plan.order()));
        }
    }

    private static class ReplaceMissingFieldWithNull
    extends ParameterizedRule<LogicalPlan, LogicalPlan, LocalLogicalOptimizerContext> {
        private ReplaceMissingFieldWithNull() {
        }

        public LogicalPlan apply(LogicalPlan plan, LocalLogicalOptimizerContext localLogicalOptimizerContext) {
            return (LogicalPlan)plan.transformUp(p -> this.missingToNull((LogicalPlan)p, localLogicalOptimizerContext.searchStats()));
        }

        /*
         * WARNING - void declaration
         */
        private LogicalPlan missingToNull(LogicalPlan plan, SearchStats stats) {
            if (plan instanceof EsRelation || plan instanceof LocalRelation) {
                return plan;
            }
            if (plan instanceof Aggregate) {
                Aggregate a = (Aggregate)((Object)plan);
                return a;
            }
            if (plan instanceof Project) {
                Project project = (Project)((Object)plan);
                List<? extends NamedExpression> projections = project.projections();
                ArrayList<void> newProjections = new ArrayList<void>(projections.size());
                LinkedHashMap nullLiteral = Maps.newLinkedHashMapWithExpectedSize((int)DataType.types().size());
                for (NamedExpression namedExpression : projections) {
                    void var9_9;
                    FieldAttribute f2;
                    if (namedExpression instanceof FieldAttribute && !stats.exists((f2 = (FieldAttribute)namedExpression).fieldName())) {
                        DataType dt = f2.dataType();
                        Alias nullAlias = (Alias)nullLiteral.get(f2.dataType());
                        if (nullAlias == null) {
                            Alias alias = new Alias(f2.source(), f2.name(), null, (Expression)Literal.of((Expression)f2, null), f2.id());
                            nullLiteral.put(dt, alias);
                            Attribute attribute = alias.toAttribute();
                        } else {
                            Alias alias = new Alias(f2.source(), f2.name(), f2.qualifier(), (Expression)nullAlias.toAttribute(), f2.id());
                        }
                    }
                    newProjections.add(var9_9);
                }
                if (nullLiteral.size() > 0) {
                    plan = new Eval(project.source(), project.child(), new ArrayList<Alias>(nullLiteral.values()));
                    plan = new Project(project.source(), (LogicalPlan)plan, (List<? extends NamedExpression>)newProjections);
                }
            } else if (plan instanceof Eval || plan instanceof Filter || plan instanceof OrderBy || plan instanceof RegexExtract || plan instanceof TopN) {
                plan = (LogicalPlan)plan.transformExpressionsOnlyUp(FieldAttribute.class, f -> stats.exists(f.fieldName()) ? f : Literal.of((Expression)f, null));
            }
            return plan;
        }
    }

    static class InferIsNotNull
    extends Rule<LogicalPlan, LogicalPlan> {
        InferIsNotNull() {
        }

        public LogicalPlan apply(LogicalPlan plan) {
            AttributeMap aliases = new AttributeMap();
            plan = (LogicalPlan)plan.transformUp(p -> this.inspectPlan((LogicalPlan)p, (AttributeMap<Expression>)aliases));
            return plan;
        }

        private LogicalPlan inspectPlan(LogicalPlan plan, AttributeMap<Expression> aliases) {
            plan.forEachExpression(Alias.class, a -> aliases.put(a.toAttribute(), (Object)a.child()));
            LogicalPlan newPlan = (LogicalPlan)plan.transformExpressionsOnlyUp(IsNotNull.class, inn -> this.inferNotNullable((IsNotNull)inn, aliases));
            return newPlan;
        }

        private Expression inferNotNullable(IsNotNull inn, AttributeMap<Expression> aliases) {
            IsNotNull result = inn;
            Set<Expression> refs = this.resolveExpressionAsRootAttributes(inn.field(), aliases);
            if (refs.size() > 0) {
                List innList = CollectionUtils.combine(refs.stream().map(r -> new IsNotNull(inn.source(), r)).toList(), (Object[])new Expression[]{inn});
                result = Predicates.combineAnd((List)innList);
            }
            return result;
        }

        protected Set<Expression> resolveExpressionAsRootAttributes(Expression exp, AttributeMap<Expression> aliases) {
            LinkedHashSet<Expression> resolvedExpressions = new LinkedHashSet<Expression>();
            boolean changed = this.doResolve(exp, aliases, resolvedExpressions);
            return changed ? resolvedExpressions : Collections.emptySet();
        }

        private boolean doResolve(Expression exp, AttributeMap<Expression> aliases, Set<Expression> resolvedExpressions) {
            boolean changed = false;
            if (InferIsNotNull.skipExpression(exp)) {
                resolvedExpressions.add(exp);
            } else {
                for (Expression e : exp.references()) {
                    Expression resolved = (Expression)aliases.resolve((Object)e, (Object)e);
                    if (resolved instanceof Attribute) {
                        Attribute a = (Attribute)resolved;
                        if (resolved == e) {
                            resolvedExpressions.add((Expression)a);
                            changed |= resolved != exp;
                            continue;
                        }
                    }
                    changed |= this.doResolve(resolved, aliases, resolvedExpressions);
                }
            }
            return changed;
        }

        private static boolean skipExpression(Expression e) {
            return e instanceof Coalesce;
        }
    }

    static class InferNonNullAggConstraint
    extends ParameterizedOptimizerRule<Aggregate, LocalLogicalOptimizerContext> {
        InferNonNullAggConstraint() {
        }

        @Override
        protected LogicalPlan rule(Aggregate aggregate, LocalLogicalOptimizerContext context) {
            if (aggregate.groupings().size() > 0) {
                return aggregate;
            }
            SearchStats stats = context.searchStats();
            Aggregate plan = aggregate;
            List<? extends NamedExpression> aggs = aggregate.aggregates();
            LinkedHashSet nonNullAggFields = Sets.newLinkedHashSetWithExpectedSize((int)aggs.size());
            for (NamedExpression namedExpression : aggs) {
                FieldAttribute fa;
                Expression expression = Alias.unwrap((Expression)namedExpression);
                if (!(expression instanceof AggregateFunction)) continue;
                AggregateFunction af = (AggregateFunction)expression;
                Expression field = af.field();
                if (!field.foldable() && field instanceof FieldAttribute && stats.isIndexed((fa = (FieldAttribute)field).name())) {
                    nonNullAggFields.add(field);
                    continue;
                }
                return plan;
            }
            if (nonNullAggFields.size() > 0) {
                Expression condition = Predicates.combineOr(nonNullAggFields.stream().map(f -> new IsNotNull(aggregate.source(), f)).toList());
                plan = aggregate.replaceChild((LogicalPlan)new Filter(aggregate.source(), aggregate.child(), condition));
            }
            return plan;
        }
    }

    private static class LocalPropagateEmptyRelation
    extends PropagateEmptyRelation {
        private LocalPropagateEmptyRelation() {
        }

        @Override
        protected void aggOutput(NamedExpression agg, AggregateFunction aggFunc, BlockFactory blockFactory, List<Block> blocks) {
            List<Attribute> output = AbstractPhysicalOperationProviders.intermediateAttributes(List.of(agg), List.of());
            for (Attribute o : output) {
                Count count;
                DataType dataType = o.dataType();
                Boolean value = dataType == DataType.BOOLEAN ? (Comparable<Boolean>)Boolean.valueOf(true) : (Comparable<Boolean>)(aggFunc instanceof Count && (!(count = (Count)aggFunc).foldable() || count.fold() != null) ? Long.valueOf(0L) : null);
                BlockUtils.BuilderWrapper wrapper = BlockUtils.wrapperFor((BlockFactory)blockFactory, (ElementType)PlannerUtils.toElementType(dataType), (int)1);
                wrapper.accept((Object)value);
                blocks.add(wrapper.builder().build());
            }
        }
    }

    static abstract class ParameterizedOptimizerRule<SubPlan extends LogicalPlan, P>
    extends ParameterizedRule<SubPlan, LogicalPlan, P> {
        ParameterizedOptimizerRule() {
        }

        public final LogicalPlan apply(LogicalPlan plan, P context) {
            return (LogicalPlan)plan.transformUp(this.typeToken(), t -> this.rule(t, context));
        }

        protected abstract LogicalPlan rule(SubPlan var1, P var2);
    }
}

