/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.spi;

import com.nvidia.cuvs.BruteForceIndex;
import com.nvidia.cuvs.CagraIndex;
import com.nvidia.cuvs.CagraMergeParams;
import com.nvidia.cuvs.CuVSDeviceMatrix;
import com.nvidia.cuvs.CuVSHostMatrix;
import com.nvidia.cuvs.CuVSMatrix;
import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.GPUInfoProvider;
import com.nvidia.cuvs.HnswIndex;
import com.nvidia.cuvs.TieredIndex;
import com.nvidia.cuvs.internal.BruteForceIndexImpl;
import com.nvidia.cuvs.internal.CagraIndexImpl;
import com.nvidia.cuvs.internal.CuVSDeviceMatrixImpl;
import com.nvidia.cuvs.internal.CuVSDeviceMatrixRMMImpl;
import com.nvidia.cuvs.internal.CuVSHostMatrixArenaImpl;
import com.nvidia.cuvs.internal.CuVSHostMatrixImpl;
import com.nvidia.cuvs.internal.CuVSMatrixInternal;
import com.nvidia.cuvs.internal.CuVSResourcesImpl;
import com.nvidia.cuvs.internal.GPUInfoProviderImpl;
import com.nvidia.cuvs.internal.HnswIndexImpl;
import com.nvidia.cuvs.internal.TieredIndexImpl;
import com.nvidia.cuvs.internal.common.PinnedMemoryBuffer;
import com.nvidia.cuvs.internal.common.Util;
import com.nvidia.cuvs.internal.panama.headers_h;
import com.nvidia.cuvs.internal.panama.headers_h_1;
import com.nvidia.cuvs.spi.CuVSProvider;
import com.nvidia.cuvs.spi.NativeDependencyLoader;
import com.nvidia.cuvs.spi.ProviderInitializationException;
import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.Locale;
import java.util.Objects;
import java.util.jar.JarFile;
import java.util.jar.Manifest;

final class JDKProvider
implements CuVSProvider {
    private static final MethodHandle createNativeDataset$mh;
    private static final MethodHandle createNativeDatasetWithStrides$mh;

    private JDKProvider() {
    }

    static CuVSProvider create() throws ProviderInitializationException {
        NativeDependencyLoader.loadLibraries();
        String mavenVersion = JDKProvider.readCuVSVersionFromManifest();
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment majorPtr = localArena.allocate(headers_h.uint16_t);
            MemorySegment minorPtr = localArena.allocate(headers_h.uint16_t);
            MemorySegment patchPtr = localArena.allocate(headers_h.uint16_t);
            Util.checkCuVSError(headers_h.cuvsVersionGet(majorPtr, minorPtr, patchPtr), "cuvsVersionGet");
            short major = majorPtr.get(headers_h.uint16_t, 0L);
            short minor = minorPtr.get(headers_h.uint16_t, 0L);
            short patch = patchPtr.get(headers_h.uint16_t, 0L);
            String cuvsVersionString = String.format(Locale.ROOT, "%02d.%02d.%d", major, minor, patch);
            if (mavenVersion != null && !cuvsVersionString.equals(mavenVersion)) {
                throw new ProviderInitializationException(String.format(Locale.ROOT, "libcuvs_c version mismatch: expected [%s], found [%s]", mavenVersion, cuvsVersionString));
            }
        }
        return new JDKProvider();
    }

    private static String readCuVSVersionFromManifest() {
        String string;
        JarFile jarFile = new JarFile(JDKProvider.class.getProtectionDomain().getCodeSource().getLocation().getPath());
        try {
            Manifest manifest = jarFile.getManifest();
            string = manifest.getMainAttributes().getValue("Implementation-Version");
        }
        catch (Throwable throwable) {
            try {
                try {
                    jarFile.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (IOException e) {
                return null;
            }
        }
        jarFile.close();
        return string;
    }

    private static CuVSMatrix createNativeDataset(MemorySegment memorySegment, int size, int dimensions, CuVSMatrix.DataType dataType) {
        return new CuVSHostMatrixImpl(memorySegment, size, dimensions, dataType);
    }

    private static CuVSMatrix createNativeDatasetWithStrides(MemorySegment memorySegment, int size, int dimensions, int rowStride, int columnStride, CuVSMatrix.DataType dataType) {
        return new CuVSHostMatrixImpl(memorySegment, size, dimensions, rowStride, columnStride, dataType);
    }

    @Override
    public CuVSResources newCuVSResources(Path tempDirectory) {
        Objects.requireNonNull(tempDirectory);
        if (Files.notExists(tempDirectory, new LinkOption[0])) {
            throw new IllegalArgumentException("does not exist:" + String.valueOf(tempDirectory));
        }
        if (!Files.isDirectory(tempDirectory, new LinkOption[0])) {
            throw new IllegalArgumentException("not a directory:" + String.valueOf(tempDirectory));
        }
        return new CuVSResourcesImpl(tempDirectory);
    }

    @Override
    public BruteForceIndex.Builder newBruteForceIndexBuilder(CuVSResources cuVSResources) {
        return BruteForceIndexImpl.newBuilder(Objects.requireNonNull(cuVSResources));
    }

    @Override
    public CagraIndex.Builder newCagraIndexBuilder(CuVSResources cuVSResources) {
        return CagraIndexImpl.newBuilder(Objects.requireNonNull(cuVSResources));
    }

    @Override
    public HnswIndex.Builder newHnswIndexBuilder(CuVSResources cuVSResources) {
        return HnswIndexImpl.newBuilder(Objects.requireNonNull(cuVSResources));
    }

    @Override
    public TieredIndex.Builder newTieredIndexBuilder(CuVSResources cuVSResources) {
        return TieredIndexImpl.newBuilder(Objects.requireNonNull(cuVSResources));
    }

    @Override
    public CagraIndex mergeCagraIndexes(CagraIndex[] indexes) {
        if (indexes == null || indexes.length == 0) {
            throw new IllegalArgumentException("At least one index must be provided for merging");
        }
        return CagraIndexImpl.merge(indexes);
    }

    @Override
    public CagraIndex mergeCagraIndexes(CagraIndex[] indexes, CagraMergeParams mergeParams) {
        if (indexes == null || indexes.length == 0) {
            throw new IllegalArgumentException("At least one index must be provided for merging");
        }
        return CagraIndexImpl.merge(indexes, mergeParams);
    }

    @Override
    public GPUInfoProvider gpuInfoProvider() {
        return new GPUInfoProviderImpl();
    }

    @Override
    public CuVSMatrix.Builder<CuVSHostMatrix> newHostMatrixBuilder(long size, long columns, CuVSMatrix.DataType dataType) {
        return new HostMatrixBuilder(size, columns, dataType);
    }

    @Override
    public CuVSMatrix.Builder<CuVSHostMatrix> newHostMatrixBuilder(long size, long columns, int rowStride, int columnStride, CuVSMatrix.DataType dataType) {
        return new HostMatrixBuilder(size, columns, rowStride, columnStride, dataType);
    }

    @Override
    public CuVSMatrix.Builder<CuVSDeviceMatrix> newDeviceMatrixBuilder(CuVSResources resources, long size, long columns, CuVSMatrix.DataType dataType) {
        return new DeviceMatrixBuilder(resources, size, columns, dataType);
    }

    @Override
    public CuVSMatrix.Builder<CuVSDeviceMatrix> newDeviceMatrixBuilder(CuVSResources resources, long size, long columns, int rowStride, int columnStride, CuVSMatrix.DataType dataType) {
        return new DeviceMatrixBuilder(resources, size, columns, rowStride, columnStride, dataType);
    }

    @Override
    public MethodHandle newNativeMatrixBuilder() {
        return createNativeDataset$mh;
    }

    @Override
    public MethodHandle newNativeMatrixBuilderWithStrides() {
        return createNativeDatasetWithStrides$mh;
    }

    @Override
    public CuVSMatrix newMatrixFromArray(float[][] vectors) {
        Objects.requireNonNull(vectors);
        if (vectors.length == 0) {
            throw new IllegalArgumentException("vectors should not be empty");
        }
        int size = vectors.length;
        int columns = vectors[0].length;
        CuVSHostMatrixArenaImpl dataset = new CuVSHostMatrixArenaImpl(size, columns, CuVSMatrix.DataType.FLOAT);
        Util.copy(dataset.memorySegment(), vectors);
        return dataset;
    }

    @Override
    public CuVSMatrix newMatrixFromArray(int[][] vectors) {
        Objects.requireNonNull(vectors);
        if (vectors.length == 0) {
            throw new IllegalArgumentException("vectors should not be empty");
        }
        int size = vectors.length;
        int columns = vectors[0].length;
        CuVSHostMatrixArenaImpl dataset = new CuVSHostMatrixArenaImpl(size, columns, CuVSMatrix.DataType.INT);
        Util.copy(dataset.memorySegment(), vectors);
        return dataset;
    }

    @Override
    public CuVSMatrix newMatrixFromArray(byte[][] vectors) {
        Objects.requireNonNull(vectors);
        if (vectors.length == 0) {
            throw new IllegalArgumentException("vectors should not be empty");
        }
        int size = vectors.length;
        int columns = vectors[0].length;
        CuVSHostMatrixArenaImpl dataset = new CuVSHostMatrixArenaImpl(size, columns, CuVSMatrix.DataType.BYTE);
        Util.copy(dataset.memorySegment(), vectors);
        return dataset;
    }

    static {
        try {
            MethodHandles.Lookup lookup = MethodHandles.lookup();
            createNativeDataset$mh = lookup.findStatic(JDKProvider.class, "createNativeDataset", MethodType.methodType(CuVSMatrix.class, MemorySegment.class, Integer.TYPE, Integer.TYPE, CuVSMatrix.DataType.class));
            createNativeDatasetWithStrides$mh = lookup.findStatic(JDKProvider.class, "createNativeDatasetWithStrides", MethodType.methodType(CuVSMatrix.class, MemorySegment.class, Integer.TYPE, Integer.TYPE, Integer.TYPE, Integer.TYPE, CuVSMatrix.DataType.class));
        }
        catch (IllegalAccessException | NoSuchMethodException e) {
            throw new RuntimeException(e);
        }
    }

    private static class HostMatrixBuilder
    extends MatrixBuilder<CuVSHostMatrixImpl>
    implements CuVSMatrix.Builder<CuVSHostMatrix> {
        private HostMatrixBuilder(long size, long columns, CuVSMatrix.DataType dataType) {
            super(new CuVSHostMatrixArenaImpl(size, columns, dataType), size, columns);
        }

        private HostMatrixBuilder(long size, long columns, int rowStride, int columnStride, CuVSMatrix.DataType dataType) {
            super(new CuVSHostMatrixArenaImpl(size, columns, rowStride, columnStride, dataType), size, columns, rowStride);
        }

        @Override
        protected void internalAddVector(MemorySegment vector) {
            if ((long)this.currentRow >= this.size) {
                throw new ArrayIndexOutOfBoundsException();
            }
            MemorySegment.copy(vector, 0L, ((CuVSHostMatrixImpl)this.matrix).memorySegment(), (long)this.currentRow++ * this.rowSize, this.rowBytes);
        }

        @Override
        public CuVSHostMatrix build() {
            return (CuVSHostMatrix)((Object)this.matrix);
        }
    }

    private static final class DeviceMatrixBuilder
    extends MatrixBuilder<CuVSDeviceMatrixImpl>
    implements CuVSMatrix.Builder<CuVSDeviceMatrix> {
        private final MemorySegment stream;
        private final PinnedMemoryBuffer hostBuffer;
        private final long bufferRowCount;
        private int currentBufferRow;

        private DeviceMatrixBuilder(CuVSResources resources, long size, long columns, CuVSMatrix.DataType dataType) {
            super(CuVSDeviceMatrixRMMImpl.create(resources, size, columns, dataType), size, columns);
            this.stream = Util.getStream(resources);
            this.hostBuffer = new PinnedMemoryBuffer(size, columns, ((CuVSDeviceMatrixImpl)this.matrix).valueLayout());
            this.bufferRowCount = Math.min(this.hostBuffer.size() / this.rowBytes, size);
            this.currentBufferRow = 0;
        }

        private DeviceMatrixBuilder(CuVSResources resources, long size, long columns, int rowStride, int columnStride, CuVSMatrix.DataType dataType) {
            super(CuVSDeviceMatrixRMMImpl.create(resources, size, columns, rowStride, columnStride, dataType), size, columns, rowStride);
            this.stream = Util.getStream(resources);
            this.hostBuffer = new PinnedMemoryBuffer(size, columns, ((CuVSDeviceMatrixImpl)this.matrix).valueLayout());
            this.bufferRowCount = Math.min(this.hostBuffer.size() / this.rowBytes, size);
            this.currentBufferRow = 0;
        }

        @Override
        protected void internalAddVector(MemorySegment vector) {
            if ((long)this.currentRow >= this.size) {
                throw new ArrayIndexOutOfBoundsException();
            }
            long hostBufferOffset = (long)this.currentBufferRow * this.rowBytes;
            MemorySegment.copy(vector, 0L, this.hostBuffer.address(), hostBufferOffset, this.rowBytes);
            ++this.currentRow;
            ++this.currentBufferRow;
            if ((long)this.currentBufferRow == this.bufferRowCount) {
                this.flushBuffer();
            }
        }

        private void flushBuffer() {
            if (this.currentBufferRow > 0) {
                long deviceMemoryOffset = (long)(this.currentRow - this.currentBufferRow) * this.rowSize;
                MemorySegment dst = ((CuVSDeviceMatrixImpl)this.matrix).memorySegment().asSlice(deviceMemoryOffset);
                Util.checkCudaError(headers_h.cudaMemcpy2DAsync(dst, this.rowSize, this.hostBuffer.address(), this.rowBytes, this.rowBytes, this.currentBufferRow, Util.CudaMemcpyKind.HOST_TO_DEVICE.kind, this.stream), "cudaMemcpy2DAsync");
                this.currentBufferRow = 0;
                Util.checkCudaError(headers_h_1.cudaStreamSynchronize(this.stream), "cudaStreamSynchronize");
            }
        }

        @Override
        public CuVSDeviceMatrix build() {
            this.flushBuffer();
            this.hostBuffer.close();
            return (CuVSDeviceMatrix)((Object)this.matrix);
        }
    }

    private static abstract class MatrixBuilder<T extends CuVSMatrixInternal> {
        protected final long columns;
        protected final long size;
        protected final T matrix;
        protected final long elementSize;
        protected final long rowSize;
        protected final long rowBytes;
        protected int currentRow;

        protected MatrixBuilder(T matrix, long size, long columns) {
            this.columns = columns;
            this.size = size;
            this.matrix = matrix;
            this.elementSize = matrix.valueLayout().byteSize();
            this.rowBytes = this.rowSize = columns * this.elementSize;
            this.currentRow = 0;
        }

        protected MatrixBuilder(T matrix, long size, long columns, int rowStride) {
            this.columns = columns;
            this.size = size;
            this.matrix = matrix;
            this.elementSize = matrix.valueLayout().byteSize();
            this.rowSize = rowStride > 0 ? (long)rowStride * this.elementSize : columns * this.elementSize;
            this.rowBytes = columns * this.elementSize;
            this.currentRow = 0;
        }

        public void addVector(float[] vector) {
            if ((long)vector.length != this.columns) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Expected a vector of size [%d], got [%d]", this.columns, vector.length));
            }
            this.internalAddVector(MemorySegment.ofArray(vector));
        }

        public void addVector(byte[] vector) {
            if ((long)vector.length != this.columns) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Expected a vector of size [%d], got [%d]", this.columns, vector.length));
            }
            this.internalAddVector(MemorySegment.ofArray(vector));
        }

        public void addVector(int[] vector) {
            if ((long)vector.length != this.columns) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Expected a vector of size [%d], got [%d]", this.columns, vector.length));
            }
            this.internalAddVector(MemorySegment.ofArray(vector));
        }

        protected abstract void internalAddVector(MemorySegment var1);
    }
}

