/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.simdvec;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;

public class ESVectorUtil {
    private static final MethodHandle BIT_COUNT_MH;
    private static final ESVectorUtilSupport IMPL;

    public static long ipByteBinByte(byte[] q, byte[] d) {
        if (q.length != d.length * 4) {
            throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= 4 x " + d.length);
        }
        return IMPL.ipByteBinByte(q, d);
    }

    public static int ipByteBit(byte[] q, byte[] d) {
        if (q.length != d.length * 8) {
            throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= 8 x " + d.length);
        }
        int result = 0;
        for (int i = 0; i < d.length; ++i) {
            byte mask = d[i];
            for (int j = 7; j >= 0; --j) {
                if ((mask & 1 << j) == 0) continue;
                result += q[i * 8 + 8 - 1 - j];
            }
        }
        return result;
    }

    public static float ipFloatBit(float[] q, byte[] d) {
        if (q.length != d.length * 8) {
            throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= 8 x " + d.length);
        }
        float result = 0.0f;
        for (int i = 0; i < d.length; ++i) {
            byte mask = d[i];
            for (int j = 7; j >= 0; --j) {
                if ((mask & 1 << j) == 0) continue;
                result += q[i * 8 + 8 - 1 - j];
            }
        }
        return result;
    }

    public static int andBitCount(byte[] a, byte[] b) {
        if (a.length != b.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
        }
        try {
            return BIT_COUNT_MH.invokeExact(a, b);
        }
        catch (Throwable e) {
            if (e instanceof Error) {
                Error err = (Error)e;
                throw err;
            }
            if (e instanceof RuntimeException) {
                RuntimeException re = (RuntimeException)e;
                throw re;
            }
            throw new RuntimeException(e);
        }
    }

    static int andBitCountInt(byte[] a, byte[] b) {
        int i;
        int distance = 0;
        int upperBound = a.length & 0xFFFFFFFC;
        for (i = 0; i < upperBound; i += 4) {
            distance += Integer.bitCount(BitUtil.VH_NATIVE_INT.get(a, i) & BitUtil.VH_NATIVE_INT.get(b, i));
        }
        while (i < a.length) {
            distance += Integer.bitCount(a[i] & b[i] & 0xFF);
            ++i;
        }
        return distance;
    }

    static int andBitCountLong(byte[] a, byte[] b) {
        int i;
        int distance = 0;
        int upperBound = a.length & 0xFFFFFFF8;
        for (i = 0; i < upperBound; i += 8) {
            distance += Long.bitCount(BitUtil.VH_NATIVE_LONG.get(a, i) & BitUtil.VH_NATIVE_LONG.get(b, i));
        }
        while (i < a.length) {
            distance += Integer.bitCount(a[i] & b[i] & 0xFF);
            ++i;
        }
        return distance;
    }

    static {
        try {
            BIT_COUNT_MH = Constants.OS_ARCH.equals("aarch64") ? MethodHandles.lookup().findStatic(ESVectorUtil.class, "andBitCountInt", MethodType.methodType(Integer.TYPE, byte[].class, byte[].class)) : MethodHandles.lookup().findStatic(ESVectorUtil.class, "andBitCountLong", MethodType.methodType(Integer.TYPE, byte[].class, byte[].class));
        }
        catch (IllegalAccessException | NoSuchMethodException e) {
            throw new AssertionError((Object)e);
        }
        IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();
    }
}

