/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue;

import java.math.BigInteger;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.util.ArrayUtil;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.AbstractMultivalueFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedianDoubleEvaluator;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedianIntEvaluator;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedianLongEvaluator;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedianUnsignedLongEvaluator;
import org.elasticsearch.xpack.esql.planner.PlannerUtils;
import org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter;
import org.elasticsearch.xpack.esql.type.EsqlDataTypes;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.TypeResolutions;
import org.elasticsearch.xpack.ql.tree.Node;
import org.elasticsearch.xpack.ql.tree.NodeInfo;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataTypes;

public class MvMedian
extends AbstractMultivalueFunction {
    @FunctionInfo(returnType={"double", "integer", "long", "unsigned_long"}, description="Converts a multivalued field into a single valued field containing the median value.")
    public MvMedian(Source source, @Param(name="number", type={"double", "integer", "long", "unsigned_long"}) Expression field) {
        super(source, field);
    }

    @Override
    protected Expression.TypeResolution resolveFieldType() {
        return TypeResolutions.isType((Expression)this.field(), t -> t.isNumeric() && EsqlDataTypes.isRepresentable(t), (String)this.sourceText(), null, (String[])new String[]{"numeric"});
    }

    @Override
    protected EvalOperator.ExpressionEvaluator.Factory evaluator(EvalOperator.ExpressionEvaluator.Factory fieldEval) {
        return switch (PlannerUtils.toElementType(this.field().dataType())) {
            case ElementType.DOUBLE -> new MvMedianDoubleEvaluator.Factory(fieldEval);
            case ElementType.INT -> new MvMedianIntEvaluator.Factory(fieldEval);
            case ElementType.LONG -> {
                if (this.field().dataType() == DataTypes.UNSIGNED_LONG) {
                    yield new MvMedianUnsignedLongEvaluator.Factory(fieldEval);
                }
                yield new MvMedianLongEvaluator.Factory(fieldEval);
            }
            default -> throw EsqlIllegalArgumentException.illegalDataType(this.field.dataType());
        };
    }

    public Expression replaceChildren(List<Expression> newChildren) {
        return new MvMedian(this.source(), newChildren.get(0));
    }

    protected NodeInfo<? extends Expression> info() {
        return NodeInfo.create((Node)this, MvMedian::new, (Object)this.field());
    }

    static void process(Doubles doubles, double v) {
        if (doubles.values.length < doubles.count + 1) {
            doubles.values = ArrayUtil.grow((double[])doubles.values, (int)(doubles.count + 1));
        }
        doubles.values[doubles.count++] = v;
    }

    static double finish(Doubles doubles) {
        Arrays.sort(doubles.values, 0, doubles.count);
        int middle = doubles.count / 2;
        double median = doubles.count % 2 == 1 ? doubles.values[middle] : (doubles.values[middle - 1] + doubles.values[middle]) / 2.0;
        doubles.count = 0;
        return median;
    }

    static double ascending(DoubleBlock values, int firstValue, int count) {
        int middle = firstValue + count / 2;
        if (count % 2 == 1) {
            return values.getDouble(middle);
        }
        return (values.getDouble(middle - 1) + values.getDouble(middle)) / 2.0;
    }

    static void process(Longs longs, long v) {
        if (longs.values.length < longs.count + 1) {
            longs.values = ArrayUtil.grow((long[])longs.values, (int)(longs.count + 1));
        }
        longs.values[longs.count++] = v;
    }

    static long finish(Longs longs) {
        Arrays.sort(longs.values, 0, longs.count);
        int middle = longs.count / 2;
        if (longs.count % 2 == 1) {
            longs.count = 0;
            return longs.values[middle];
        }
        longs.count = 0;
        return MvMedian.avgWithoutOverflow(longs.values[middle - 1], longs.values[middle]);
    }

    static long ascending(LongBlock values, int firstValue, int count) {
        int middle = firstValue + count / 2;
        if (count % 2 == 1) {
            return values.getLong(middle);
        }
        return MvMedian.avgWithoutOverflow(values.getLong(middle - 1), values.getLong(middle));
    }

    static long avgWithoutOverflow(long a, long b) {
        return (a & b) + ((a ^ b) >> 1);
    }

    static void processUnsignedLong(Longs longs, long v) {
        MvMedian.process(longs, v);
    }

    static long finishUnsignedLong(Longs longs) {
        if (longs.count % 2 == 1) {
            return MvMedian.finish(longs);
        }
        Arrays.sort(longs.values, 0, longs.count);
        int middle = longs.count / 2;
        longs.count = 0;
        BigInteger a = EsqlDataTypeConverter.unsignedLongToBigInteger(longs.values[middle - 1]);
        BigInteger b = EsqlDataTypeConverter.unsignedLongToBigInteger(longs.values[middle]);
        return EsqlDataTypeConverter.bigIntegerToUnsignedLong(a.add(b).shiftRight(1));
    }

    static long ascendingUnsignedLong(LongBlock values, int firstValue, int count) {
        int middle = firstValue + count / 2;
        if (count % 2 == 1) {
            return values.getLong(middle);
        }
        BigInteger a = EsqlDataTypeConverter.unsignedLongToBigInteger(values.getLong(middle - 1));
        BigInteger b = EsqlDataTypeConverter.unsignedLongToBigInteger(values.getLong(middle));
        return EsqlDataTypeConverter.bigIntegerToUnsignedLong(a.add(b).shiftRight(1));
    }

    static void process(Ints ints, int v) {
        if (ints.values.length < ints.count + 1) {
            ints.values = ArrayUtil.grow((int[])ints.values, (int)(ints.count + 1));
        }
        ints.values[ints.count++] = v;
    }

    static int finish(Ints ints) {
        Arrays.sort(ints.values, 0, ints.count);
        int middle = ints.count / 2;
        if (ints.count % 2 == 1) {
            ints.count = 0;
            return ints.values[middle];
        }
        ints.count = 0;
        return MvMedian.avgWithoutOverflow(ints.values[middle - 1], ints.values[middle]);
    }

    static int ascending(IntBlock values, int firstValue, int count) {
        int middle = firstValue + count / 2;
        if (count % 2 == 1) {
            return values.getInt(middle);
        }
        return MvMedian.avgWithoutOverflow(values.getInt(middle - 1), values.getInt(middle));
    }

    static int avgWithoutOverflow(int a, int b) {
        return (a & b) + ((a ^ b) >> 1);
    }

    static class Doubles {
        public double[] values = new double[2];
        public int count;

        Doubles() {
        }
    }

    static class Longs {
        public long[] values = new long[2];
        public int count;

        Longs() {
        }
    }

    static class Ints {
        public int[] values = new int[2];
        public int count;

        Ints() {
        }
    }
}

