/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal;

import com.nvidia.cuvs.CuVSDeviceMatrix;
import com.nvidia.cuvs.CuVSMatrix;
import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.internal.CuVSDeviceMatrixImpl;
import com.nvidia.cuvs.internal.common.CloseableRMMAllocation;
import java.lang.foreign.ValueLayout;

public class CuVSDeviceMatrixRMMImpl
extends CuVSDeviceMatrixImpl
implements CuVSDeviceMatrix {
    private final CloseableRMMAllocation rmmAllocation;

    private CuVSDeviceMatrixRMMImpl(CuVSResources resources, CloseableRMMAllocation rmmAllocation, long size, long columns, CuVSMatrix.DataType dataType, ValueLayout valueLayout) {
        super(resources, rmmAllocation.handle(), size, columns, dataType, valueLayout);
        this.rmmAllocation = rmmAllocation;
    }

    private CuVSDeviceMatrixRMMImpl(CuVSResources resources, CloseableRMMAllocation rmmAllocation, long size, long columns, long rowStride, long columnStride, CuVSMatrix.DataType dataType, ValueLayout valueLayout) {
        super(resources, rmmAllocation.handle(), size, columns, rowStride, columnStride, dataType, valueLayout);
        this.rmmAllocation = rmmAllocation;
    }

    public static CuVSDeviceMatrixImpl create(CuVSResources resources, long size, long columns, CuVSMatrix.DataType dataType) {
        try (CuVSResources.ScopedAccess resourcesAccess = resources.access();){
            ValueLayout valueLayout = CuVSDeviceMatrixRMMImpl.valueLayoutFromType(dataType);
            CloseableRMMAllocation rmmAllocation = CloseableRMMAllocation.allocateRMMSegment(resourcesAccess.handle(), size * columns * valueLayout.byteSize());
            CuVSDeviceMatrixRMMImpl cuVSDeviceMatrixRMMImpl = new CuVSDeviceMatrixRMMImpl(resources, rmmAllocation, size, columns, dataType, valueLayout);
            return cuVSDeviceMatrixRMMImpl;
        }
    }

    public static CuVSDeviceMatrixImpl create(CuVSResources resources, long size, long columns, long rowStride, long columnStride, CuVSMatrix.DataType dataType) {
        long rowSize;
        ValueLayout valueLayout = CuVSDeviceMatrixRMMImpl.valueLayoutFromType(dataType);
        long elementSize = valueLayout.byteSize();
        if (rowStride <= 0L) {
            rowSize = columns * elementSize;
        } else if (rowStride >= columns) {
            rowSize = rowStride * elementSize;
        } else {
            throw new IllegalArgumentException("Row stride cannot be less than the number of columns");
        }
        if (columnStride != -1L) {
            throw new UnsupportedOperationException("Stridden columns are currently not supported; columnStride must be equal to -1");
        }
        try (CuVSResources.ScopedAccess resourcesAccess = resources.access();){
            CloseableRMMAllocation rmmAllocation = CloseableRMMAllocation.allocateRMMSegment(resourcesAccess.handle(), size * rowSize);
            CuVSDeviceMatrixRMMImpl cuVSDeviceMatrixRMMImpl = new CuVSDeviceMatrixRMMImpl(resources, rmmAllocation, size, columns, rowStride, columnStride, dataType, valueLayout);
            return cuVSDeviceMatrixRMMImpl;
        }
    }

    @Override
    public void close() {
        super.close();
        this.rmmAllocation.close();
    }
}

