/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.physical.local;

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialExtent;
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.EvalExec;
import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec;
import org.elasticsearch.xpack.esql.plan.physical.FilterExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;

public class SpatialShapeBoundsExtraction
extends PhysicalOptimizerRules.ParameterizedOptimizerRule<AggregateExec, LocalPhysicalOptimizerContext> {
    @Override
    protected PhysicalPlan rule(AggregateExec aggregate, LocalPhysicalOptimizerContext ctx) {
        Set<Attribute> foundAttributes = SpatialShapeBoundsExtraction.findSpatialShapeBoundsAttributes(aggregate, ctx);
        if (foundAttributes.isEmpty()) {
            return aggregate;
        }
        return (PhysicalPlan)aggregate.transformDown(PhysicalPlan.class, exec -> switch (((Object)exec).getClass().getSimpleName()) {
            case "AggregateExec" -> SpatialShapeBoundsExtraction.transformAggregateExec((AggregateExec)exec, foundAttributes);
            case "FieldExtractExec" -> SpatialShapeBoundsExtraction.transformFieldExtractExec((FieldExtractExec)exec, foundAttributes);
            default -> exec;
        });
    }

    private static Set<Attribute> findSpatialShapeBoundsAttributes(AggregateExec aggregate, LocalPhysicalOptimizerContext ctx) {
        HashSet<Attribute> foundAttributes = new HashSet<Attribute>();
        aggregate.transformDown(UnaryExec.class, exec -> {
            if (exec instanceof AggregateExec) {
                AggregateExec agg = (AggregateExec)exec;
                List aggregateFunctions = agg.aggregates().stream().flatMap(e -> SpatialShapeBoundsExtraction.extractAggregateFunction(e).stream()).toList();
                List<SpatialExtent> spatialExtents = aggregateFunctions.stream().filter(SpatialExtent.class::isInstance).map(SpatialExtent.class::cast).toList();
                List<AggregateFunction> nonSpatialExtents = aggregateFunctions.stream().filter(a -> !(a instanceof SpatialExtent)).toList();
                Set fieldsAppearingInNonSpatialExtents = nonSpatialExtents.stream().flatMap(af -> af.references().stream()).filter(FieldAttribute.class::isInstance).map(f -> ((FieldAttribute)f).field()).collect(Collectors.toSet());
                spatialExtents.stream().map(AggregateFunction::field).filter(FieldAttribute.class::isInstance).map(FieldAttribute.class::cast).filter(f -> SpatialShapeBoundsExtraction.isShape(f.field().getDataType()) && !fieldsAppearingInNonSpatialExtents.contains(f.field()) && ctx.searchStats().hasDocValues(f.fieldName())).forEach(foundAttributes::add);
            } else if (exec instanceof EvalExec) {
                EvalExec evalExec = (EvalExec)exec;
                foundAttributes.removeAll((Collection<?>)evalExec.references());
            } else if (exec instanceof FilterExec) {
                FilterExec filterExec = (FilterExec)((Object)exec);
                foundAttributes.removeAll((Collection<?>)filterExec.condition().references());
            }
            return exec;
        });
        return foundAttributes;
    }

    private static PhysicalPlan transformFieldExtractExec(FieldExtractExec fieldExtractExec, Set<Attribute> foundAttributes) {
        HashSet<Attribute> boundsAttributes = new HashSet<Attribute>(foundAttributes);
        boundsAttributes.retainAll(fieldExtractExec.attributesToExtract());
        return fieldExtractExec.withBoundsAttributes(boundsAttributes);
    }

    private static PhysicalPlan transformAggregateExec(AggregateExec agg, Set<Attribute> foundAttributes) {
        return (PhysicalPlan)((Object)agg.transformExpressionsDown(SpatialExtent.class, spatialExtent -> foundAttributes.contains(spatialExtent.field()) ? spatialExtent.withFieldExtractPreference(MappedFieldType.FieldExtractPreference.EXTRACT_SPATIAL_BOUNDS) : spatialExtent));
    }

    private static boolean isShape(DataType dataType) {
        return dataType == DataType.GEO_SHAPE || dataType == DataType.CARTESIAN_SHAPE;
    }

    private static Optional<AggregateFunction> extractAggregateFunction(NamedExpression expr) {
        Optional<AggregateFunction> optional;
        Alias as;
        Expression expression;
        if (expr instanceof Alias && (expression = (as = (Alias)expr).child()) instanceof AggregateFunction) {
            AggregateFunction af = (AggregateFunction)expression;
            optional = Optional.of(af);
        } else {
            optional = Optional.empty();
        }
        return optional;
    }
}

