/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.expression.function.aggregate;

import java.io.IOException;
import java.time.Duration;
import java.util.List;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.RateDoubleAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.RateIntAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.RateLongAggregatorFunctionSupplier;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.tree.Node;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.planner.ToAggregator;
import org.elasticsearch.xpack.esql.type.EsqlDataTypes;

public class Rate
extends AggregateFunction
implements OptionalArgument,
ToAggregator {
    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Rate", Rate::new);
    private static final TimeValue DEFAULT_UNIT = TimeValue.timeValueSeconds((long)1L);
    private final Expression timestamp;
    private final Expression unit;

    @FunctionInfo(returnType={"double"}, description="compute the rate of a counter field. Available in METRICS command only", isAggregation=true)
    public Rate(Source source, @Param(name="field", type={"counter_long|counter_integer|counter_double"}, description="counter field") Expression field, Expression timestamp, @Param(optional=true, name="unit", type={"time_duration"}, description="the unit") Expression unit) {
        super(source, field, unit != null ? List.of(timestamp, unit) : List.of(timestamp));
        this.timestamp = timestamp;
        this.unit = unit;
    }

    public Rate(StreamInput in) throws IOException {
        this(Source.readFrom((StreamInput)((PlanStreamInput)in)), (Expression)in.readNamedWriteable(Expression.class), (Expression)in.readNamedWriteable(Expression.class), (Expression)in.readOptionalNamedWriteable(Expression.class));
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        this.source().writeTo(out);
        out.writeNamedWriteable((NamedWriteable)this.field());
        out.writeNamedWriteable((NamedWriteable)this.timestamp);
        out.writeOptionalNamedWriteable((NamedWriteable)this.unit);
    }

    public String getWriteableName() {
        return Rate.ENTRY.name;
    }

    public static Rate withUnresolvedTimestamp(Source source, Expression field, Expression unit) {
        return new Rate(source, field, (Expression)new UnresolvedAttribute(source, "@timestamp"), unit);
    }

    protected NodeInfo<Rate> info() {
        return NodeInfo.create((Node)this, Rate::new, (Object)this.field(), (Object)this.timestamp, (Object)this.unit);
    }

    public Rate replaceChildren(List<Expression> newChildren) {
        if (this.unit != null) {
            if (newChildren.size() == 3) {
                return new Rate(this.source(), newChildren.get(0), newChildren.get(1), newChildren.get(2));
            }
            assert (false) : "expected 3 children for field, @timestamp, and unit; got " + newChildren;
            throw new IllegalArgumentException("expected 3 children for field, @timestamp, and unit; got " + newChildren);
        }
        if (newChildren.size() == 2) {
            return new Rate(this.source(), newChildren.get(0), newChildren.get(1), null);
        }
        assert (false) : "expected 2 children for field and @timestamp; got " + newChildren;
        throw new IllegalArgumentException("expected 2 children for field and @timestamp; got " + newChildren);
    }

    public DataType dataType() {
        return DataType.DOUBLE;
    }

    @Override
    protected Expression.TypeResolution resolveType() {
        Expression.TypeResolution resolution = TypeResolutions.isType((Expression)this.field(), dt -> dt == DataType.COUNTER_LONG || dt == DataType.COUNTER_INTEGER || dt == DataType.COUNTER_DOUBLE, (String)this.sourceText(), (TypeResolutions.ParamOrdinal)TypeResolutions.ParamOrdinal.FIRST, (String[])new String[]{"counter_long", "counter_integer", "counter_double"});
        if (this.unit != null) {
            resolution = resolution.and(TypeResolutions.isType((Expression)this.unit, dt -> dt.isWholeNumber() || EsqlDataTypes.isTemporalAmount(dt), (String)this.sourceText(), (TypeResolutions.ParamOrdinal)TypeResolutions.ParamOrdinal.SECOND, (String[])new String[]{"time_duration"}));
        }
        return resolution;
    }

    long unitInMillis() {
        Object foldValue;
        if (this.unit == null) {
            return DEFAULT_UNIT.millis();
        }
        if (!this.unit.foldable()) {
            throw new IllegalArgumentException("function [" + this.sourceText() + "] has invalid unit [" + this.unit.sourceText() + "]");
        }
        try {
            foldValue = this.unit.fold();
        }
        catch (Exception e) {
            throw new IllegalArgumentException("function [" + this.sourceText() + "] has invalid unit [" + this.unit.sourceText() + "]");
        }
        if (foldValue instanceof Duration) {
            Duration duration = (Duration)foldValue;
            return duration.toMillis();
        }
        throw new IllegalArgumentException("function [" + this.sourceText() + "] has invalid unit [" + this.unit.sourceText() + "]");
    }

    @Override
    public List<Expression> inputExpressions() {
        return List.of(this.field(), this.timestamp);
    }

    @Override
    public AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
        if (inputChannels.size() != 2 && inputChannels.size() != 3) {
            throw new IllegalArgumentException("rate requires two for raw input or three channels for partial input; got " + inputChannels);
        }
        long unitInMillis = this.unitInMillis();
        DataType type = this.field().dataType();
        return switch (type) {
            case DataType.COUNTER_LONG -> new RateLongAggregatorFunctionSupplier(inputChannels, unitInMillis);
            case DataType.COUNTER_INTEGER -> new RateIntAggregatorFunctionSupplier(inputChannels, unitInMillis);
            case DataType.COUNTER_DOUBLE -> new RateDoubleAggregatorFunctionSupplier(inputChannels, unitInMillis);
            default -> throw EsqlIllegalArgumentException.illegalDataType(type);
        };
    }

    public String toString() {
        if (this.unit != null) {
            return "rate(" + this.field() + "," + this.unit + ")";
        }
        return "rate(" + this.field() + ")";
    }

    Expression timestamp() {
        return this.timestamp;
    }

    Expression unit() {
        return this.unit;
    }
}

