/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.core.optimizer;

import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.function.BiFunction;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.SurrogateFunction;
import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryPredicate;
import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable;
import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or;
import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNull;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.NotEquals;
import org.elasticsearch.xpack.esql.core.plan.logical.Filter;
import org.elasticsearch.xpack.esql.core.plan.logical.Limit;
import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.core.rule.Rule;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.core.util.ReflectionUtils;

public final class OptimizerRules {

    public static enum TransformDirection {
        UP,
        DOWN;

    }

    public static abstract class OptimizerExpressionRule<E extends Expression>
    extends Rule<LogicalPlan, LogicalPlan> {
        private final TransformDirection direction;
        private final Class<E> expressionTypeToken = ReflectionUtils.detectSuperTypeForRuleLike(this.getClass());

        public OptimizerExpressionRule(TransformDirection direction) {
            this.direction = direction;
        }

        @Override
        public final LogicalPlan apply(LogicalPlan plan) {
            return this.direction == TransformDirection.DOWN ? (LogicalPlan)plan.transformExpressionsDown(this.expressionTypeToken, this::rule) : (LogicalPlan)plan.transformExpressionsUp(this.expressionTypeToken, this::rule);
        }

        protected LogicalPlan rule(LogicalPlan plan) {
            return plan;
        }

        protected abstract Expression rule(E var1);

        public Class<E> expressionToken() {
            return this.expressionTypeToken;
        }
    }

    public static abstract class OptimizerRule<SubPlan extends LogicalPlan>
    extends Rule<SubPlan, LogicalPlan> {
        private final TransformDirection direction;

        public OptimizerRule() {
            this(TransformDirection.DOWN);
        }

        protected OptimizerRule(TransformDirection direction) {
            this.direction = direction;
        }

        @Override
        public final LogicalPlan apply(LogicalPlan plan) {
            return this.direction == TransformDirection.DOWN ? plan.transformDown(this.typeToken(), this::rule) : plan.transformUp(this.typeToken(), this::rule);
        }

        protected abstract LogicalPlan rule(SubPlan var1);
    }

    public static class PropagateNullable
    extends OptimizerExpressionRule<And> {
        public PropagateNullable() {
            super(TransformDirection.DOWN);
        }

        @Override
        public Expression rule(And and) {
            List<Expression> splits = Predicates.splitAnd(and);
            LinkedHashSet<Expression> nullExpressions = new LinkedHashSet<Expression>();
            LinkedHashSet<Expression> notNullExpressions = new LinkedHashSet<Expression>();
            LinkedList<Expression> others = new LinkedList<Expression>();
            for (Expression ex : splits) {
                if (ex instanceof IsNull) {
                    IsNull isn = (IsNull)ex;
                    nullExpressions.add(isn.field());
                    continue;
                }
                if (ex instanceof IsNotNull) {
                    IsNotNull isnn = (IsNotNull)ex;
                    notNullExpressions.add(isnn.field());
                    continue;
                }
                others.add(ex);
            }
            if (Sets.haveNonEmptyIntersection(nullExpressions, notNullExpressions)) {
                return Literal.of(and, Boolean.FALSE);
            }
            boolean modified = PropagateNullable.replace(nullExpressions, others, splits, this::nullify);
            if (modified |= PropagateNullable.replace(notNullExpressions, others, splits, this::nonNullify)) {
                return Predicates.combineAnd(splits);
            }
            return and;
        }

        private static boolean replace(Iterable<Expression> pattern, List<Expression> target, List<Expression> originalExpressions, BiFunction<Expression, Expression, Expression> replacer) {
            boolean modified = false;
            for (Expression s : pattern) {
                for (int i = 0; i < target.size(); ++i) {
                    Expression replacement;
                    Expression t = target.get(i);
                    if (!t.anyMatch(s::semanticEquals) || (replacement = replacer.apply(t, s)) == t) continue;
                    modified = true;
                    target.set(i, replacement);
                    originalExpressions.replaceAll(e -> t.semanticEquals((Expression)e) ? replacement : e);
                }
            }
            return modified;
        }

        protected Expression nullify(Expression exp, Expression nullExp) {
            return exp.nullable() == Nullability.TRUE ? Literal.of(exp, null) : exp;
        }

        protected Expression nonNullify(Expression exp, Expression nonNullExp) {
            return exp;
        }
    }

    public static class FoldNull
    extends OptimizerExpressionRule<Expression> {
        public FoldNull() {
            super(TransformDirection.UP);
        }

        @Override
        public Expression rule(Expression e) {
            Expression result = this.tryReplaceIsNullIsNotNull(e);
            if (result != e) {
                return result;
            }
            if (e instanceof In) {
                In in = (In)e;
                if (Expressions.isNull(in.value())) {
                    return Literal.of(in, null);
                }
            } else if (!(e instanceof Alias) && e.nullable() == Nullability.TRUE && Expressions.anyMatch(e.children(), Expressions::isNull)) {
                return Literal.of(e, null);
            }
            return e;
        }

        protected Expression tryReplaceIsNullIsNotNull(Expression e) {
            IsNull isn;
            if (e instanceof IsNotNull) {
                IsNotNull isnn = (IsNotNull)e;
                if (isnn.field().nullable() == Nullability.FALSE) {
                    return new Literal(e.source(), Boolean.TRUE, DataType.BOOLEAN);
                }
            } else if (e instanceof IsNull && (isn = (IsNull)e).field().nullable() == Nullability.FALSE) {
                return new Literal(e.source(), Boolean.FALSE, DataType.BOOLEAN);
            }
            return e;
        }
    }

    public static abstract class SkipQueryOnLimitZero
    extends OptimizerRule<Limit> {
        @Override
        protected LogicalPlan rule(Limit limit) {
            if (limit.limit().foldable() && Integer.valueOf(0).equals(limit.limit().fold())) {
                return this.skipPlan(limit);
            }
            return limit;
        }

        protected abstract LogicalPlan skipPlan(Limit var1);
    }

    public static abstract class PruneCast<C extends Expression>
    extends Rule<LogicalPlan, LogicalPlan> {
        private final Class<C> castType;

        public PruneCast(Class<C> castType) {
            this.castType = castType;
        }

        @Override
        public final LogicalPlan apply(LogicalPlan plan) {
            return this.rule(plan);
        }

        protected final LogicalPlan rule(LogicalPlan plan) {
            return (LogicalPlan)plan.transformExpressionsUp(this.castType, this::maybePruneCast);
        }

        protected abstract Expression maybePruneCast(C var1);
    }

    public static abstract class PruneFilters
    extends OptimizerRule<Filter> {
        @Override
        protected LogicalPlan rule(Filter filter) {
            Expression condition = filter.condition().transformUp(BinaryLogic.class, PruneFilters::foldBinaryLogic);
            if (condition instanceof Literal) {
                if (Literal.TRUE.equals(condition)) {
                    return filter.child();
                }
                if (Literal.FALSE.equals(condition) || Expressions.isNull(condition)) {
                    return this.skipPlan(filter);
                }
            }
            if (!condition.equals(filter.condition())) {
                return new Filter(filter.source(), filter.child(), condition);
            }
            return filter;
        }

        protected abstract LogicalPlan skipPlan(Filter var1);

        private static Expression foldBinaryLogic(BinaryLogic binaryLogic) {
            And and;
            if (binaryLogic instanceof Or) {
                Or or = (Or)binaryLogic;
                boolean nullLeft = Expressions.isNull(or.left());
                boolean nullRight = Expressions.isNull(or.right());
                if (nullLeft && nullRight) {
                    return new Literal(binaryLogic.source(), null, DataType.NULL);
                }
                if (nullLeft) {
                    return or.right();
                }
                if (nullRight) {
                    return or.left();
                }
            }
            if (binaryLogic instanceof And && (Expressions.isNull((and = (And)binaryLogic).left()) || Expressions.isNull(and.right()))) {
                return new Literal(binaryLogic.source(), null, DataType.NULL);
            }
            return binaryLogic;
        }
    }

    public static class ReplaceSurrogateFunction
    extends OptimizerExpressionRule<Expression> {
        public ReplaceSurrogateFunction() {
            super(TransformDirection.DOWN);
        }

        @Override
        protected Expression rule(Expression e) {
            if (e instanceof SurrogateFunction) {
                e = ((SurrogateFunction)((Object)e)).substitute();
            }
            return e;
        }
    }

    public static class CombineDisjunctionsToIn
    extends OptimizerExpressionRule<Or> {
        public CombineDisjunctionsToIn() {
            super(TransformDirection.UP);
        }

        @Override
        protected Expression rule(Or or) {
            Expression e = or;
            List<Expression> exps = Predicates.splitOr(e);
            LinkedHashMap<Expression, Set> found = new LinkedHashMap<Expression, Set>();
            ZoneId zoneId = null;
            LinkedList<Expression> ors = new LinkedList<Expression>();
            for (Expression exp : exps) {
                if (exp instanceof Equals) {
                    Equals eq = (Equals)exp;
                    if (eq.right().foldable()) {
                        found.computeIfAbsent(eq.left(), k -> new LinkedHashSet()).add(eq.right());
                    } else {
                        ors.add(exp);
                    }
                    if (zoneId != null) continue;
                    zoneId = eq.zoneId();
                    continue;
                }
                if (exp instanceof In) {
                    In in = (In)exp;
                    found.computeIfAbsent(in.value(), k -> new LinkedHashSet()).addAll(in.list());
                    if (zoneId != null) continue;
                    zoneId = in.zoneId();
                    continue;
                }
                ors.add(exp);
            }
            if (!found.isEmpty()) {
                ZoneId finalZoneId = zoneId;
                found.forEach((k, v) -> ors.add(v.size() == 1 ? this.createEquals((Expression)k, (Set<Expression>)v, finalZoneId) : this.createIn((Expression)k, (List<Expression>)new ArrayList<Expression>((Collection<Expression>)v), finalZoneId)));
                Expression combineOr = Predicates.combineOr(ors);
                if (!e.semanticEquals(combineOr)) {
                    e = combineOr;
                }
            }
            return e;
        }

        protected Equals createEquals(Expression k, Set<Expression> v, ZoneId finalZoneId) {
            return new Equals(k.source(), k, v.iterator().next(), finalZoneId);
        }

        protected In createIn(Expression key, List<Expression> values, ZoneId zoneId) {
            return new In(key.source(), key, values, zoneId);
        }
    }

    public static class BooleanSimplification
    extends OptimizerExpressionRule<ScalarFunction> {
        public BooleanSimplification() {
            super(TransformDirection.UP);
        }

        @Override
        public Expression rule(ScalarFunction e) {
            if (e instanceof And || e instanceof Or) {
                return BooleanSimplification.simplifyAndOr((BinaryPredicate)e);
            }
            if (e instanceof Not) {
                return this.simplifyNot((Not)e);
            }
            return e;
        }

        private static Expression simplifyAndOr(BinaryPredicate<?, ?, ?, ?> bc) {
            Expression l = bc.left();
            Expression r = bc.right();
            if (bc instanceof And) {
                List<Expression> rightSplit;
                if (Literal.TRUE.equals(l)) {
                    return r;
                }
                if (Literal.TRUE.equals(r)) {
                    return l;
                }
                if (Literal.FALSE.equals(l) || Literal.FALSE.equals(r)) {
                    return new Literal(bc.source(), Boolean.FALSE, DataType.BOOLEAN);
                }
                if (l.semanticEquals(r)) {
                    return l;
                }
                List<Expression> leftSplit = Predicates.splitOr(l);
                List<Expression> common = Predicates.inCommon(leftSplit, rightSplit = Predicates.splitOr(r));
                if (common.isEmpty()) {
                    return bc;
                }
                List<Expression> lDiff = Predicates.subtract(leftSplit, common);
                List<Expression> rDiff = Predicates.subtract(rightSplit, common);
                if (lDiff.isEmpty() || rDiff.isEmpty()) {
                    return Predicates.combineOr(common);
                }
                Expression combineLeft = Predicates.combineOr(lDiff);
                Expression combineRight = Predicates.combineOr(rDiff);
                return Predicates.combineOr(CollectionUtils.combine(common, new And(combineLeft.source(), combineLeft, combineRight)));
            }
            if (bc instanceof Or) {
                List<Expression> rightSplit;
                if (Literal.TRUE.equals(l) || Literal.TRUE.equals(r)) {
                    return new Literal(bc.source(), Boolean.TRUE, DataType.BOOLEAN);
                }
                if (Literal.FALSE.equals(l)) {
                    return r;
                }
                if (Literal.FALSE.equals(r)) {
                    return l;
                }
                if (l.semanticEquals(r)) {
                    return l;
                }
                List<Expression> leftSplit = Predicates.splitAnd(l);
                List<Expression> common = Predicates.inCommon(leftSplit, rightSplit = Predicates.splitAnd(r));
                if (common.isEmpty()) {
                    return bc;
                }
                List<Expression> lDiff = Predicates.subtract(leftSplit, common);
                List<Expression> rDiff = Predicates.subtract(rightSplit, common);
                if (lDiff.isEmpty() || rDiff.isEmpty()) {
                    return Predicates.combineAnd(common);
                }
                Expression combineLeft = Predicates.combineAnd(lDiff);
                Expression combineRight = Predicates.combineAnd(rDiff);
                return Predicates.combineAnd(CollectionUtils.combine(common, new Or(combineLeft.source(), combineLeft, combineRight)));
            }
            return bc;
        }

        private Expression simplifyNot(Not n) {
            Expression c = n.field();
            if (Literal.TRUE.semanticEquals(c)) {
                return new Literal(n.source(), Boolean.FALSE, DataType.BOOLEAN);
            }
            if (Literal.FALSE.semanticEquals(c)) {
                return new Literal(n.source(), Boolean.TRUE, DataType.BOOLEAN);
            }
            Expression negated = this.maybeSimplifyNegatable(c);
            if (negated != null) {
                return negated;
            }
            if (c instanceof Not) {
                return ((Not)c).field();
            }
            return n;
        }

        protected Expression maybeSimplifyNegatable(Expression e) {
            if (e instanceof Negatable) {
                return ((Negatable)((Object)e)).negate();
            }
            return null;
        }
    }

    public static final class BooleanFunctionEqualsElimination
    extends OptimizerExpressionRule<BinaryComparison> {
        public BooleanFunctionEqualsElimination() {
            super(TransformDirection.UP);
        }

        @Override
        protected Expression rule(BinaryComparison bc) {
            if ((bc instanceof Equals || bc instanceof NotEquals) && bc.left() instanceof Function) {
                if (Literal.TRUE.equals(bc.right())) {
                    return bc instanceof Equals ? bc.left() : new Not(bc.left().source(), bc.left());
                }
                if (Literal.FALSE.equals(bc.right())) {
                    return bc instanceof Equals ? new Not(bc.left().source(), bc.left()) : bc.left();
                }
            }
            return bc;
        }
    }
}

