/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.expression.function.aggregate;

import java.io.IOException;
import java.util.List;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.VersionId;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.tree.Node;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;

public class WeightedAvg
extends AggregateFunction
implements SurrogateExpression {
    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "WeightedAvg", WeightedAvg::new);
    private final Expression weight;
    private static final String invalidWeightError = "{} argument of [{}] cannot be null or 0, received [{}]";

    @FunctionInfo(returnType={"double"}, description="The weighted average of a numeric expression.", isAggregation=true, examples={@Example(file="stats", tag="weighted-avg")})
    public WeightedAvg(Source source, @Param(name="number", type={"double", "integer", "long"}, description="A numeric value.") Expression field, @Param(name="weight", type={"double", "integer", "long"}, description="A numeric weight.") Expression weight) {
        this(source, field, (Expression)Literal.TRUE, weight);
    }

    public WeightedAvg(Source source, Expression field, Expression filter, Expression weight) {
        super(source, field, filter, List.of(weight));
        this.weight = weight;
    }

    private WeightedAvg(StreamInput in) throws IOException {
        this(Source.readFrom((StreamInput)((PlanStreamInput)in)), (Expression)in.readNamedWriteable(Expression.class), (Expression)(in.getTransportVersion().onOrAfter((VersionId)TransportVersions.V_8_16_0) ? (Expression)in.readNamedWriteable(Expression.class) : Literal.TRUE), in.getTransportVersion().onOrAfter((VersionId)TransportVersions.V_8_16_0) ? (Expression)in.readNamedWriteableCollectionAsList(Expression.class).get(0) : (Expression)in.readNamedWriteable(Expression.class));
    }

    @Override
    protected void deprecatedWriteParams(StreamOutput out) throws IOException {
        out.writeNamedWriteable((NamedWriteable)this.weight);
    }

    public String getWriteableName() {
        return WeightedAvg.ENTRY.name;
    }

    @Override
    protected Expression.TypeResolution resolveType() {
        if (!this.childrenResolved()) {
            return new Expression.TypeResolution("Unresolved children");
        }
        Expression.TypeResolution resolution = TypeResolutions.isType((Expression)this.field(), dt -> dt.isNumeric() && dt != DataType.UNSIGNED_LONG, (String)this.sourceText(), (TypeResolutions.ParamOrdinal)TypeResolutions.ParamOrdinal.FIRST, (String[])new String[]{"numeric except unsigned_long or counter types"});
        if (resolution.unresolved()) {
            return resolution;
        }
        resolution = TypeResolutions.isType((Expression)this.weight(), dt -> dt.isNumeric() && dt != DataType.UNSIGNED_LONG, (String)this.sourceText(), (TypeResolutions.ParamOrdinal)TypeResolutions.ParamOrdinal.SECOND, (String[])new String[]{"numeric except unsigned_long or counter types"});
        if (resolution.unresolved()) {
            return resolution;
        }
        if (this.weight.dataType() == DataType.NULL) {
            return new Expression.TypeResolution(LoggerMessageFormat.format(null, (String)invalidWeightError, (Object[])new Object[]{TypeResolutions.ParamOrdinal.SECOND, this.sourceText(), null}));
        }
        if (!this.weight.foldable()) {
            return Expression.TypeResolution.TYPE_RESOLVED;
        }
        Object weightVal = this.weight.fold(FoldContext.small());
        if (weightVal == null || weightVal.equals(0) || weightVal.equals(0.0)) {
            return new Expression.TypeResolution(LoggerMessageFormat.format(null, (String)invalidWeightError, (Object[])new Object[]{TypeResolutions.ParamOrdinal.SECOND, this.sourceText(), weightVal}));
        }
        return Expression.TypeResolution.TYPE_RESOLVED;
    }

    public DataType dataType() {
        return DataType.DOUBLE;
    }

    protected NodeInfo<WeightedAvg> info() {
        return NodeInfo.create((Node)this, WeightedAvg::new, (Object)this.field(), (Object)this.filter(), (Object)this.weight);
    }

    public WeightedAvg replaceChildren(List<Expression> newChildren) {
        return new WeightedAvg(this.source(), newChildren.get(0), newChildren.get(1), newChildren.get(2));
    }

    @Override
    public WeightedAvg withFilter(Expression filter) {
        return new WeightedAvg(this.source(), this.field(), filter, this.weight());
    }

    @Override
    public Expression surrogate() {
        Source s = this.source();
        Expression field = this.field();
        Expression weight = this.weight();
        if (field.foldable()) {
            return new MvAvg(s, field);
        }
        if (weight.foldable()) {
            return new Div(s, (Expression)new Sum(s, field, this.filter()), (Expression)new Count(s, field, this.filter()), this.dataType());
        }
        return new Div(s, (Expression)new Sum(s, (Expression)new Mul(s, field, weight), this.filter()), (Expression)new Sum(s, weight, this.filter()), this.dataType());
    }

    public Expression weight() {
        return this.weight;
    }
}

