/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.analysis;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.elasticsearch.xpack.esql.core.analyzer.VerifierChecks;
import org.elasticsearch.xpack.esql.core.capabilities.Unresolvable;
import org.elasticsearch.xpack.esql.core.common.Failure;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryOperator;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.esql.core.plan.logical.Limit;
import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.core.tree.Node;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.Lookup;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.RegexExtract;
import org.elasticsearch.xpack.esql.plan.logical.Row;
import org.elasticsearch.xpack.esql.stats.FeatureMetric;
import org.elasticsearch.xpack.esql.stats.Metrics;
import org.elasticsearch.xpack.esql.type.EsqlDataTypes;

public class Verifier {
    private final Metrics metrics;

    public Verifier(Metrics metrics) {
        this.metrics = metrics;
    }

    Collection<Failure> verify(LogicalPlan plan, BitSet partialMetrics) {
        assert (partialMetrics != null);
        LinkedHashSet<Failure> failures = new LinkedHashSet<Failure>();
        AttributeMap aliases = new AttributeMap();
        plan.forEachUp(p -> {
            if (!p.childrenResolved()) {
                return;
            }
            if (p instanceof Unresolvable) {
                Unresolvable u = (Unresolvable)p;
                failures.add(Failure.fail((Node)p, (String)u.unresolvedMessage(), (Object[])new Object[0]));
            } else if (p.resolved()) {
                p.forEachExpressionUp(Alias.class, a -> aliases.put(a.toAttribute(), (Object)a.child()));
                return;
            }
            Consumer<Expression> unresolvedExpressions = e -> {
                if (e.resolved()) {
                    return;
                }
                e.forEachUp(ae -> {
                    if (!ae.childrenResolved()) {
                        return;
                    }
                    if (ae instanceof Unresolvable) {
                        Unresolvable u = (Unresolvable)ae;
                        if (!(p instanceof Project) || !(u instanceof UnsupportedAttribute)) {
                            failures.add(Failure.fail((Node)ae, (String)u.unresolvedMessage(), (Object[])new Object[0]));
                        }
                    }
                    if (ae.typeResolved().unresolved()) {
                        failures.add(Failure.fail((Node)ae, (String)ae.typeResolved().message(), (Object[])new Object[0]));
                    }
                });
            };
            if (p instanceof Aggregate) {
                Aggregate agg = (Aggregate)((Object)p);
                List<Expression> groupings = agg.groupings();
                groupings.forEach(unresolvedExpressions);
                List<? extends NamedExpression> aggs = agg.aggregates();
                int size = aggs.size() - groupings.size();
                aggs.subList(0, size).forEach(unresolvedExpressions);
            } else if (p instanceof Lookup) {
                Lookup lookup = (Lookup)((Object)p);
                Expression tableName = lookup.tableName();
                if (tableName instanceof Unresolvable) {
                    Unresolvable u = (Unresolvable)tableName;
                    failures.add(Failure.fail((Node)tableName, (String)u.unresolvedMessage(), (Object[])new Object[0]));
                } else {
                    lookup.matchFields().forEach(unresolvedExpressions);
                }
            } else {
                p.forEachExpression(unresolvedExpressions);
            }
        });
        if (!failures.isEmpty()) {
            return failures;
        }
        plan.forEachDown(p -> {
            if (!p.childrenResolved()) {
                return;
            }
            VerifierChecks.checkFilterConditionType((LogicalPlan)p, (Set)failures);
            Verifier.checkAggregate(p, failures);
            Verifier.checkRegexExtractOnlyOnStrings(p, failures);
            Verifier.checkRow(p, failures);
            Verifier.checkEvalFields(p, failures);
            Verifier.checkOperationsOnUnsignedLong(p, failures);
            Verifier.checkBinaryComparison(p, failures);
            Verifier.checkForSortOnSpatialTypes(p, failures);
        });
        Verifier.checkRemoteEnrich(plan, failures);
        if (failures.isEmpty()) {
            this.gatherMetrics(plan, partialMetrics);
        }
        return failures;
    }

    private static void checkAggregate(LogicalPlan p, Set<Failure> failures) {
        if (p instanceof Aggregate) {
            Aggregate agg = (Aggregate)p;
            List<Expression> groupings = agg.groupings();
            AttributeSet groupRefs = new AttributeSet();
            groupings.forEach(e -> {
                FieldAttribute f;
                e.forEachUp(g -> {
                    if (g instanceof AggregateFunction) {
                        AggregateFunction af = (AggregateFunction)((Object)((Object)g));
                        failures.add(Failure.fail((Node)g, (String)"cannot use an aggregate [{}] for grouping", (Object[])new Object[]{af}));
                    } else if (g instanceof GroupingFunction) {
                        GroupingFunction gf = (GroupingFunction)g;
                        gf.children().forEach(c -> c.forEachDown(GroupingFunction.class, inner -> failures.add(Failure.fail((Node)inner, (String)"cannot nest grouping functions; found [{}] inside [{}]", (Object[])new Object[]{inner.sourceText(), gf.sourceText()}))));
                    }
                });
                Attribute attr = Expressions.attribute((Expression)e);
                if (attr != null) {
                    groupRefs.add(attr);
                }
                if (e instanceof FieldAttribute && (f = (FieldAttribute)e).dataType().isCounter()) {
                    failures.add(Failure.fail((Node)e, (String)"cannot group by on [{}] type for grouping [{}]", (Object[])new Object[]{f.dataType().typeName(), e.sourceText()}));
                }
            });
            List<? extends NamedExpression> aggs = agg.aggregates();
            aggs.subList(0, aggs.size() - groupings.size()).forEach(e -> {
                Expression exp = Alias.unwrap((Expression)e);
                if (exp.foldable()) {
                    failures.add(Failure.fail((Node)exp, (String)"expected an aggregate function but found [{}]", (Object[])new Object[]{exp.sourceText()}));
                }
                Verifier.checkInvalidNamedExpressionUsage(exp, groupings, groupRefs, failures, 0);
            });
            if (agg.aggregateType() == Aggregate.AggregateType.METRICS) {
                aggs.forEach(a -> Verifier.checkRateAggregates((Expression)a, 0, failures));
            } else {
                agg.forEachExpression(Rate.class, r -> failures.add(Failure.fail((Node)r, (String)"the rate aggregate[{}] can only be used within the metrics command", (Object[])new Object[]{r.sourceText()})));
            }
        } else {
            p.forEachExpression(GroupingFunction.class, gf -> failures.add(Failure.fail((Node)gf, (String)"cannot use grouping function [{}] outside of a STATS command", (Object[])new Object[]{gf.sourceText()})));
        }
    }

    private static void checkRateAggregates(Expression expr, int nestedLevel, Set<Failure> failures) {
        if (expr instanceof AggregateFunction) {
            ++nestedLevel;
        }
        if (expr instanceof Rate) {
            Rate r = (Rate)expr;
            if (nestedLevel != 2) {
                failures.add(Failure.fail((Node)expr, (String)"the rate aggregate [{}] can only be used within the metrics command and inside another aggregate", (Object[])new Object[]{r.sourceText()}));
            }
        }
        for (Expression child : expr.children()) {
            Verifier.checkRateAggregates(child, nestedLevel, failures);
        }
    }

    private static void checkInvalidNamedExpressionUsage(Expression e, List<Expression> groups, AttributeSet groupRefs, Set<Failure> failures, int level) {
        block10: {
            block11: {
                GroupingFunction gf;
                block12: {
                    block9: {
                        if (!(e instanceof AggregateFunction)) break block9;
                        AggregateFunction af = (AggregateFunction)e;
                        af.field().forEachDown(AggregateFunction.class, f -> {
                            if (!(f instanceof Rate)) {
                                failures.add(Failure.fail((Node)f, (String)"nested aggregations [{}] not allowed inside other aggregations [{}]", (Object[])new Object[]{f, af}));
                            }
                        });
                        break block10;
                    }
                    if (!(e instanceof GroupingFunction)) break block11;
                    gf = (GroupingFunction)e;
                    if (groups.stream().anyMatch(ex -> {
                        Alias a;
                        return ex instanceof Alias && (a = (Alias)ex).child().semanticEquals((Expression)gf);
                    })) break block12;
                    failures.add(Failure.fail((Node)gf, (String)"can only use grouping function [{}] part of the BY clause", (Object[])new Object[]{gf.sourceText()}));
                    break block10;
                }
                if (level != 0) break block10;
                Verifier.addFailureOnGroupingUsedNakedInAggs(failures, (Expression)gf, "function");
                break block10;
            }
            if (!e.foldable()) {
                if (groups.contains(e) || groupRefs.contains((Object)e)) {
                    if (level == 0) {
                        Verifier.addFailureOnGroupingUsedNakedInAggs(failures, e, "key");
                    }
                } else if (e instanceof NamedExpression) {
                    NamedExpression ne = (NamedExpression)e;
                    boolean foundInGrouping = false;
                    for (Expression g : groups) {
                        if (!g.anyMatch(se -> se.semanticEquals((Expression)ne))) continue;
                        foundInGrouping = true;
                        failures.add(Failure.fail((Node)e, (String)"column [{}] cannot be used as an aggregate once declared in the STATS BY grouping key [{}]", (Object[])new Object[]{ne.name(), g.sourceText()}));
                        break;
                    }
                    if (!foundInGrouping) {
                        failures.add(Failure.fail((Node)e, (String)"column [{}] must appear in the STATS BY clause or be used in an aggregate function", (Object[])new Object[]{ne.name()}));
                    }
                } else {
                    for (Expression child : e.children()) {
                        Verifier.checkInvalidNamedExpressionUsage(child, groups, groupRefs, failures, level + 1);
                    }
                }
            }
        }
    }

    private static void addFailureOnGroupingUsedNakedInAggs(Set<Failure> failures, Expression e, String element) {
        failures.add(Failure.fail((Node)e, (String)"grouping {} [{}] cannot be used as an aggregate once declared in the STATS BY clause", (Object[])new Object[]{element, e.sourceText()}));
    }

    private static void checkRegexExtractOnlyOnStrings(LogicalPlan p, Set<Failure> failures) {
        RegexExtract re;
        Expression expr;
        DataType type;
        if (p instanceof RegexExtract && !EsqlDataTypes.isString(type = (expr = (re = (RegexExtract)p).input()).dataType())) {
            failures.add(Failure.fail((Node)expr, (String)"{} only supports KEYWORD or TEXT values, found expression [{}] type [{}]", (Object[])new Object[]{re.getClass().getSimpleName(), expr.sourceText(), type}));
        }
    }

    private static void checkRow(LogicalPlan p, Set<Failure> failures) {
        if (p instanceof Row) {
            Row row = (Row)p;
            row.fields().forEach(a -> {
                if (!EsqlDataTypes.isRepresentable(a.dataType())) {
                    failures.add(Failure.fail((Node)a, (String)"cannot use [{}] directly in a row assignment", (Object[])new Object[]{a.child().sourceText()}));
                }
            });
        }
    }

    private static void checkEvalFields(LogicalPlan p, Set<Failure> failures) {
        if (p instanceof Eval) {
            Eval eval = (Eval)p;
            eval.fields().forEach(field -> {
                DataType dataType = field.dataType();
                if (!EsqlDataTypes.isRepresentable(dataType)) {
                    failures.add(Failure.fail((Node)field, (String)"EVAL does not support type [{}] in expression [{}]", (Object[])new Object[]{dataType.typeName(), field.child().sourceText()}));
                }
                field.forEachDown(AggregateFunction.class, af -> {
                    if (af instanceof Rate) {
                        failures.add(Failure.fail((Node)af, (String)"aggregate function [{}] not allowed outside METRICS command", (Object[])new Object[]{af.sourceText()}));
                    } else {
                        failures.add(Failure.fail((Node)af, (String)"aggregate function [{}] not allowed outside STATS command", (Object[])new Object[]{af.sourceText()}));
                    }
                });
            });
        }
    }

    private static void checkOperationsOnUnsignedLong(LogicalPlan p, Set<Failure> failures) {
        p.forEachExpression(e -> {
            Failure f = null;
            if (e instanceof BinaryOperator) {
                BinaryOperator bo = (BinaryOperator)e;
                f = Verifier.validateUnsignedLongOperator(bo);
            } else if (e instanceof Neg) {
                Neg neg = (Neg)e;
                f = Verifier.validateUnsignedLongNegation(neg);
            }
            if (f != null) {
                failures.add(f);
            }
        });
    }

    private static void checkBinaryComparison(LogicalPlan p, Set<Failure> failures) {
        p.forEachExpression(BinaryComparison.class, bc -> {
            Failure f = Verifier.validateBinaryComparison(bc);
            if (f != null) {
                failures.add(f);
            }
        });
    }

    private void gatherMetrics(LogicalPlan plan, BitSet b) {
        plan.forEachDown(p -> FeatureMetric.set(p, b));
        int i = b.nextSetBit(0);
        while (i >= 0) {
            this.metrics.inc(FeatureMetric.values()[i]);
            i = b.nextSetBit(i + 1);
        }
    }

    public static Failure validateBinaryComparison(BinaryComparison bc) {
        if (bc.left().dataType().isNumeric()) {
            if (!bc.right().dataType().isNumeric()) {
                return Failure.fail((Node)bc, (String)"first argument of [{}] is [numeric] so second argument must also be [numeric] but was [{}]", (Object[])new Object[]{bc.sourceText(), bc.right().dataType().typeName()});
            }
            return null;
        }
        ArrayList<DataType> allowed = new ArrayList<DataType>();
        allowed.add(DataType.KEYWORD);
        allowed.add(DataType.TEXT);
        allowed.add(DataType.IP);
        allowed.add(DataType.DATETIME);
        allowed.add(DataType.VERSION);
        allowed.add(DataType.GEO_POINT);
        allowed.add(DataType.GEO_SHAPE);
        allowed.add(DataType.CARTESIAN_POINT);
        allowed.add(DataType.CARTESIAN_SHAPE);
        if (bc instanceof Equals || bc instanceof NotEquals) {
            allowed.add(DataType.BOOLEAN);
        }
        Expression.TypeResolution r = TypeResolutions.isType((Expression)bc.left(), allowed::contains, (String)bc.sourceText(), (TypeResolutions.ParamOrdinal)TypeResolutions.ParamOrdinal.FIRST, (String[])((String[])Stream.concat(Stream.of("numeric"), allowed.stream().map(DataType::typeName)).toArray(String[]::new)));
        if (!r.resolved()) {
            return Failure.fail((Node)bc, (String)r.message(), (Object[])new Object[0]);
        }
        if (DataType.isString((DataType)bc.left().dataType()) && DataType.isString((DataType)bc.right().dataType())) {
            return null;
        }
        if (bc.left().dataType() != bc.right().dataType()) {
            return Failure.fail((Node)bc, (String)"first argument of [{}] is [{}] so second argument must also be [{}] but was [{}]", (Object[])new Object[]{bc.sourceText(), bc.left().dataType().typeName(), bc.left().dataType().typeName(), bc.right().dataType().typeName()});
        }
        return null;
    }

    public static Failure validateUnsignedLongOperator(BinaryOperator<?, ?, ?, ?> bo) {
        DataType leftType = bo.left().dataType();
        DataType rightType = bo.right().dataType();
        if ((leftType == DataType.UNSIGNED_LONG || rightType == DataType.UNSIGNED_LONG) && leftType != rightType) {
            return Failure.fail(bo, (String)"first argument of [{}] is [{}] and second is [{}]. [{}] can only be operated on together with another [{}]", (Object[])new Object[]{bo.sourceText(), leftType.typeName(), rightType.typeName(), DataType.UNSIGNED_LONG.typeName(), DataType.UNSIGNED_LONG.typeName()});
        }
        return null;
    }

    private static Failure validateUnsignedLongNegation(Neg neg) {
        DataType childExpressionType = neg.field().dataType();
        if (childExpressionType.equals((Object)DataType.UNSIGNED_LONG)) {
            return Failure.fail((Node)neg, (String)"negation unsupported for arguments of type [{}] in expression [{}]", (Object[])new Object[]{childExpressionType.typeName(), neg.sourceText()});
        }
        return null;
    }

    private static void checkForSortOnSpatialTypes(LogicalPlan p, Set<Failure> localFailures) {
        if (p instanceof OrderBy) {
            OrderBy ob = (OrderBy)p;
            ob.forEachExpression(Attribute.class, attr -> {
                DataType dataType = attr.dataType();
                if (EsqlDataTypes.isSpatial(dataType)) {
                    localFailures.add(Failure.fail((Node)attr, (String)("cannot sort on " + dataType.typeName()), (Object[])new Object[0]));
                }
            });
        }
    }

    private static void checkRemoteEnrich(LogicalPlan plan, Set<Failure> failures) {
        boolean[] agg = new boolean[]{false};
        boolean[] limit = new boolean[]{false};
        boolean[] enrichCoord = new boolean[]{false};
        plan.forEachUp(UnaryPlan.class, u -> {
            Enrich enrich;
            if (u instanceof Limit) {
                limit[0] = true;
            }
            if (u instanceof Aggregate) {
                agg[0] = true;
            } else if (u instanceof Enrich && (enrich = (Enrich)u).mode() == Enrich.Mode.COORDINATOR) {
                enrichCoord[0] = true;
            }
            if (u instanceof Enrich && (enrich = (Enrich)u).mode() == Enrich.Mode.REMOTE) {
                if (limit[0]) {
                    failures.add(Failure.fail((Node)enrich, (String)"ENRICH with remote policy can't be executed after LIMIT", (Object[])new Object[0]));
                }
                if (agg[0]) {
                    failures.add(Failure.fail((Node)enrich, (String)"ENRICH with remote policy can't be executed after STATS", (Object[])new Object[0]));
                }
                if (enrichCoord[0]) {
                    failures.add(Failure.fail((Node)enrich, (String)"ENRICH with remote policy can't be executed after another ENRICH with coordinator policy", (Object[])new Object[0]));
                }
            }
        });
    }
}

