/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.nativeaccess.jdk;

import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.Linker;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import org.elasticsearch.nativeaccess.jdk.LinkerHelper;
import org.elasticsearch.nativeaccess.lib.VectorLibrary;

public final class JdkVectorLibrary
implements VectorLibrary {
    static final MethodHandle dot8stride$mh;
    static final MethodHandle sqr8stride$mh;
    static final MethodHandle dot8s$mh;
    static final MethodHandle sqr8s$mh;
    static final int DOT_STRIDE = 32;
    static final int SQR_STRIDE = 16;
    static final MethodHandle DOT_HANDLE;
    static final MethodHandle SQR_HANDLE;

    static int dotProduct(MemorySegment a, MemorySegment b, int length) {
        assert (length >= 0);
        if (a.byteSize() != b.byteSize()) {
            throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
        }
        if ((long)length > a.byteSize()) {
            throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
        }
        int i = 0;
        int res = 0;
        if (length >= 32) {
            res = JdkVectorLibrary.dot8s(a, b, i += length & 0xFFFFFFE0);
        }
        while (i < length) {
            res += a.get(ValueLayout.JAVA_BYTE, (long)i) * b.get(ValueLayout.JAVA_BYTE, (long)i);
            ++i;
        }
        assert (i == length);
        return res;
    }

    static int squareDistance(MemorySegment a, MemorySegment b, int length) {
        assert (length >= 0);
        if (a.byteSize() != b.byteSize()) {
            throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
        }
        if ((long)length > a.byteSize()) {
            throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
        }
        int i = 0;
        int res = 0;
        if (length >= 16) {
            res = JdkVectorLibrary.sqr8s(a, b, i += length & 0xFFFFFFF0);
        }
        while (i < length) {
            int dist = a.get(ValueLayout.JAVA_BYTE, (long)i) - b.get(ValueLayout.JAVA_BYTE, (long)i);
            res += dist * dist;
            ++i;
        }
        assert (i == length);
        return res;
    }

    private static int dot8Stride() {
        try {
            return dot8stride$mh.invokeExact();
        }
        catch (Throwable t) {
            throw new AssertionError((Object)t);
        }
    }

    private static int sqr8Stride() {
        try {
            return sqr8stride$mh.invokeExact();
        }
        catch (Throwable t) {
            throw new AssertionError((Object)t);
        }
    }

    private static int dot8s(MemorySegment a, MemorySegment b, int length) {
        try {
            return dot8s$mh.invokeExact(a, b, length);
        }
        catch (Throwable t) {
            throw new AssertionError((Object)t);
        }
    }

    private static int sqr8s(MemorySegment a, MemorySegment b, int length) {
        try {
            return sqr8s$mh.invokeExact(a, b, length);
        }
        catch (Throwable t) {
            throw new AssertionError((Object)t);
        }
    }

    @Override
    public MethodHandle dotProductHandle() {
        return DOT_HANDLE;
    }

    @Override
    public MethodHandle squareDistanceHandle() {
        return SQR_HANDLE;
    }

    static {
        System.loadLibrary("vec");
        dot8stride$mh = LinkerHelper.downcallHandle("dot8s_stride", FunctionDescriptor.of(ValueLayout.JAVA_INT, new MemoryLayout[0]), new Linker.Option[0]);
        sqr8stride$mh = LinkerHelper.downcallHandle("sqr8s_stride", FunctionDescriptor.of(ValueLayout.JAVA_INT, new MemoryLayout[0]), new Linker.Option[0]);
        dot8s$mh = LinkerHelper.downcallHandle("dot8s", FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT), new Linker.Option[0]);
        sqr8s$mh = LinkerHelper.downcallHandle("sqr8s", FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT), new Linker.Option[0]);
        assert (JdkVectorLibrary.dot8Stride() == 32) : JdkVectorLibrary.dot8Stride() + " != 32";
        assert (JdkVectorLibrary.sqr8Stride() == 16) : JdkVectorLibrary.sqr8Stride() + " != 16";
        try {
            MethodHandles.Lookup lookup = MethodHandles.lookup();
            MethodType mt = MethodType.methodType(Integer.TYPE, MemorySegment.class, MemorySegment.class, Integer.TYPE);
            DOT_HANDLE = lookup.findStatic(JdkVectorLibrary.class, "dotProduct", mt);
            SQR_HANDLE = lookup.findStatic(JdkVectorLibrary.class, "squareDistance", mt);
        }
        catch (IllegalAccessException | NoSuchMethodException e) {
            throw new RuntimeException(e);
        }
    }
}

