/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal;

import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.CuVSResourcesInfo;
import com.nvidia.cuvs.GPUInfo;
import com.nvidia.cuvs.GPUInfoProvider;
import com.nvidia.cuvs.internal.common.LinkerHelper;
import com.nvidia.cuvs.internal.common.Util;
import com.nvidia.cuvs.internal.panama.cudaDeviceProp;
import com.nvidia.cuvs.internal.panama.headers_h;
import com.nvidia.cuvs.internal.panama.headers_h_1;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.util.ArrayList;
import java.util.List;

public class GPUInfoProviderImpl
implements GPUInfoProvider {
    private static boolean hasMinimumCapability(GPUInfo gpuInfo) {
        return gpuInfo.computeCapabilityMajor() > 7 || gpuInfo.computeCapabilityMajor() == 7 && gpuInfo.computeCapabilityMinor() >= 0;
    }

    @Override
    public List<GPUInfo> availableGPUs() {
        return AvailableGpuInitializer.AVAILABLE_GPUS;
    }

    @Override
    public List<GPUInfo> compatibleGPUs() {
        ArrayList<GPUInfo> compatibleGPUs = new ArrayList<GPUInfo>();
        long minDeviceMemoryInBytes = 0x200000000L;
        for (GPUInfo gpuInfo : AvailableGpuInitializer.AVAILABLE_GPUS) {
            if (!GPUInfoProviderImpl.hasMinimumCapability(gpuInfo) || gpuInfo.totalDeviceMemoryInBytes() < minDeviceMemoryInBytes) continue;
            compatibleGPUs.add(gpuInfo);
        }
        return compatibleGPUs;
    }

    @Override
    public CuVSResourcesInfo getCurrentInfo(CuVSResources resources) {
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment deviceIdPtr = localArena.allocate(LinkerHelper.C_INT);
            Util.checkCudaError(headers_h_1.cudaGetDevice(deviceIdPtr), "cudaGetDevice");
            int currentDeviceId = deviceIdPtr.get(LinkerHelper.C_INT, 0L);
            if (resources.deviceId() != currentDeviceId) {
                Util.checkCudaError(headers_h_1.cudaSetDevice(resources.deviceId()), "cudaSetDevice");
            }
            MemorySegment freeMemoryPtr = localArena.allocate(headers_h_1.size_t);
            MemorySegment totalMemoryPtr = localArena.allocate(headers_h_1.size_t);
            Util.checkCudaError(headers_h.cudaMemGetInfo(freeMemoryPtr, totalMemoryPtr), "cudaMemGetInfo");
            if (resources.deviceId() != currentDeviceId) {
                Util.checkCudaError(headers_h_1.cudaSetDevice(currentDeviceId), "cudaSetDevice");
            }
            CuVSResourcesInfo cuVSResourcesInfo = new CuVSResourcesInfo(freeMemoryPtr.get(headers_h_1.size_t, 0L), totalMemoryPtr.get(headers_h_1.size_t, 0L));
            return cuVSResourcesInfo;
        }
    }

    private static class AvailableGpuInitializer {
        static final List<GPUInfo> AVAILABLE_GPUS = AvailableGpuInitializer.getAvailableGpusInfo();

        private AvailableGpuInitializer() {
        }

        private static List<GPUInfo> getAvailableGpusInfo() {
            try (Arena localArena = Arena.ofConfined();){
                MemorySegment numGpus = localArena.allocate(LinkerHelper.C_INT);
                int returnValue = headers_h_1.cudaGetDeviceCount(numGpus);
                Util.checkCudaError(returnValue, "cudaGetDeviceCount");
                int numGpuCount = numGpus.get(LinkerHelper.C_INT, 0L);
                ArrayList<GPUInfo> gpuInfoArr = new ArrayList<GPUInfo>();
                MemorySegment deviceProp = cudaDeviceProp.allocate(localArena);
                for (int i = 0; i < numGpuCount; ++i) {
                    returnValue = Util.cudaGetDeviceProperties(deviceProp, i);
                    Util.checkCudaError(returnValue, "cudaGetDeviceProperties");
                    GPUInfo gpuInfo = new GPUInfo(i, cudaDeviceProp.name(deviceProp).getString(0L), cudaDeviceProp.totalGlobalMem(deviceProp), cudaDeviceProp.major(deviceProp), cudaDeviceProp.minor(deviceProp), cudaDeviceProp.asyncEngineCount(deviceProp) > 0, cudaDeviceProp.concurrentKernels(deviceProp) > 0);
                    gpuInfoArr.add(gpuInfo);
                }
                ArrayList<GPUInfo> arrayList = gpuInfoArr;
                return arrayList;
            }
        }
    }
}

