/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.physical.local;

import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.TypedAttribute;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
import org.elasticsearch.xpack.esql.rule.Rule;

public class InsertFieldExtraction
extends Rule<PhysicalPlan, PhysicalPlan> {
    @Override
    public PhysicalPlan apply(PhysicalPlan plan) {
        plan = (PhysicalPlan)plan.transformUp(UnaryExec.class, p -> {
            AggregateExec agg;
            Set<Attribute> missing = InsertFieldExtraction.missingAttributes(p);
            if (p instanceof AggregateExec && (agg = (AggregateExec)p).groupings().size() == 1) {
                LinkedList leaves = new LinkedList();
                agg.aggregates().stream().filter(a -> !agg.groupings().contains(a)).forEach(a -> leaves.addAll(a.collectLeaves()));
                List<Expression> remove = agg.groupings().stream().filter(g -> !leaves.contains(g)).toList();
                missing.removeAll((Collection<?>)Expressions.references(remove));
            }
            if (!missing.isEmpty()) {
                FieldExtractExec extractor = new FieldExtractExec(p.source(), p.child(), List.copyOf(missing));
                p = p.replaceChild(extractor);
            }
            return p;
        });
        return plan;
    }

    private static Set<Attribute> missingAttributes(PhysicalPlan p) {
        LinkedHashSet<Attribute> missing = new LinkedHashSet<Attribute>();
        AttributeSet input = p.inputSet();
        p.forEachExpression(TypedAttribute.class, f -> {
            if ((f instanceof FieldAttribute || f instanceof MetadataAttribute) && !input.contains(f)) {
                missing.add((Attribute)f);
            }
        });
        return missing;
    }
}

