/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal;

import com.nvidia.cuvs.CuVSDeviceMatrix;
import com.nvidia.cuvs.CuVSHostMatrix;
import com.nvidia.cuvs.CuVSMatrix;
import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.RowView;
import com.nvidia.cuvs.internal.CuVSMatrixBaseImpl;
import com.nvidia.cuvs.internal.CuVSMatrixInternal;
import com.nvidia.cuvs.internal.SliceRowView;
import com.nvidia.cuvs.internal.common.PinnedMemoryBuffer;
import com.nvidia.cuvs.internal.common.Util;
import com.nvidia.cuvs.internal.panama.DLManagedTensor;
import com.nvidia.cuvs.internal.panama.DLTensor;
import com.nvidia.cuvs.internal.panama.headers_h;
import java.lang.foreign.Arena;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.lang.foreign.ValueLayout;

public class CuVSDeviceMatrixImpl
extends CuVSMatrixBaseImpl
implements CuVSDeviceMatrix {
    private long bufferedMatrixRowStart = 0L;
    private long bufferedMatrixRowEnd = 0L;
    private final CuVSResources resources;
    private final long rowStride;
    private final long columnStride;
    private final PinnedMemoryBuffer hostBuffer;

    protected CuVSDeviceMatrixImpl(CuVSResources resources, MemorySegment deviceMemorySegment, long size, long columns, CuVSMatrix.DataType dataType, ValueLayout valueLayout) {
        this(resources, deviceMemorySegment, size, columns, -1L, -1L, dataType, valueLayout);
    }

    protected CuVSDeviceMatrixImpl(CuVSResources resources, MemorySegment deviceMemorySegment, long size, long columns, long rowStride, long columnStride, CuVSMatrix.DataType dataType, ValueLayout valueLayout) {
        super(deviceMemorySegment, dataType, valueLayout, size, columns);
        this.resources = resources;
        this.rowStride = rowStride;
        this.columnStride = columnStride;
        this.hostBuffer = new PinnedMemoryBuffer(size, columns, valueLayout);
    }

    @Override
    public long rowStride() {
        return this.rowStride;
    }

    @Override
    public MemorySegment toTensor(Arena arena) {
        long[] lArray;
        if (this.rowStride >= 0L) {
            long[] lArray2 = new long[2];
            lArray2[0] = this.rowStride;
            lArray = lArray2;
            lArray2[1] = this.columnStride;
        } else {
            lArray = null;
        }
        long[] strides = lArray;
        return Util.prepareTensor(arena, this.memorySegment, new long[]{this.size, this.columns}, strides, this.code(), this.bits(), headers_h.kDLCUDA());
    }

    private void populateBuffer(long startRow) {
        try (Arena localArena = Arena.ofConfined();){
            long rowBytes = this.columns * this.valueLayout.byteSize();
            long endRow = Math.min(startRow + this.hostBuffer.size() / rowBytes, this.size);
            long rowCount = endRow - startRow;
            MemorySegment sliceManagedTensor = DLManagedTensor.allocate(localArena);
            DLManagedTensor.dl_tensor(sliceManagedTensor, DLTensor.allocate(localArena));
            Util.checkCuVSError(headers_h.cuvsMatrixSliceRows(0L, this.toTensor(localArena), startRow, endRow, sliceManagedTensor), "cuvsMatrixSliceRows");
            assert (DLTensor.shape(DLManagedTensor.dl_tensor(sliceManagedTensor)).get(headers_h.C_LONG, 0L) == rowCount);
            assert (DLTensor.shape(DLManagedTensor.dl_tensor(sliceManagedTensor)).getAtIndex(headers_h.C_LONG, 1L) == this.columns);
            MemorySegment bufferTensor = Util.prepareTensor(localArena, this.hostBuffer.address(), new long[]{rowCount, this.columns}, this.code(), this.bits(), headers_h.kDLCPU());
            try (CuVSResources.ScopedAccess resourceAccess = this.resources.access();){
                Util.checkCuVSError(headers_h.cuvsMatrixCopy(resourceAccess.handle(), sliceManagedTensor, bufferTensor), "cuvsMatrixCopy");
                Util.checkCuVSError(headers_h.cuvsStreamSync(resourceAccess.handle()), "cuvsStreamSync");
                this.bufferedMatrixRowStart = startRow;
                this.bufferedMatrixRowEnd = endRow;
            }
        }
    }

    @Override
    public RowView getRow(long row) {
        if (row < this.bufferedMatrixRowStart || row >= this.bufferedMatrixRowEnd) {
            this.populateBuffer(row);
        }
        long valueByteSize = this.valueLayout.byteSize();
        long startRow = row - this.bufferedMatrixRowStart;
        return new SliceRowView(this.hostBuffer.address().asSlice(startRow * this.columns * valueByteSize, this.columns * valueByteSize), this.columns, this.valueLayout, this.dataType, valueByteSize);
    }

    @Override
    public void toArray(int[][] array) {
        assert (this.dataType == CuVSMatrix.DataType.INT || this.dataType == CuVSMatrix.DataType.UINT);
        assert ((long)array.length >= this.size) : "Input array is not large enough";
        assert (array.length == 0 || (long)array[0].length >= this.columns) : "Input array is not large enough";
        try (Arena localArena = Arena.ofConfined();){
            long rowBytes = this.columns * this.valueLayout.byteSize();
            MemorySegment tmpRowSegment = localArena.allocate(rowBytes);
            int r = 0;
            while ((long)r < this.size) {
                this.copyRow(array[r], localArena, r, tmpRowSegment);
                ++r;
            }
        }
    }

    @Override
    public void toArray(float[][] array) {
        assert (this.dataType == CuVSMatrix.DataType.FLOAT);
        assert ((long)array.length >= this.size) : "Input array is not large enough";
        assert (array.length == 0 || (long)array[0].length >= this.columns) : "Input array is not large enough";
        try (Arena localArena = Arena.ofConfined();){
            long rowBytes = this.columns * this.valueLayout.byteSize();
            MemorySegment tmpRowSegment = localArena.allocate(rowBytes);
            int r = 0;
            while ((long)r < this.size) {
                this.copyRow(array[r], localArena, r, tmpRowSegment);
                ++r;
            }
        }
    }

    @Override
    public void toArray(byte[][] array) {
        assert (this.dataType == CuVSMatrix.DataType.BYTE);
        assert ((long)array.length >= this.size) : "Input array is not large enough";
        assert (array.length == 0 || (long)array[0].length >= this.columns) : "Input array is not large enough";
        try (Arena localArena = Arena.ofConfined();){
            SequenceLayout rowSegmentLayout = MemoryLayout.sequenceLayout(this.columns, this.valueLayout);
            MemorySegment tmpRowSegment = localArena.allocate(rowSegmentLayout);
            int r = 0;
            while ((long)r < this.size) {
                this.copyRow(array[r], localArena, r, tmpRowSegment);
                ++r;
            }
        }
    }

    private void copyRow(Object array, Arena localArena, int r, MemorySegment tmpRowSegment) {
        MemorySegment sliceManagedTensor = DLManagedTensor.allocate(localArena);
        DLManagedTensor.dl_tensor(sliceManagedTensor, DLTensor.allocate(localArena));
        Util.checkCuVSError(headers_h.cuvsMatrixSliceRows(0L, this.toTensor(localArena), r, r + 1, sliceManagedTensor), "cuvsMatrixSliceRows");
        MemorySegment bufferTensor = Util.prepareTensor(localArena, tmpRowSegment, new long[]{1L, this.columns}, this.code(), this.bits(), headers_h.kDLCUDA());
        try (CuVSResources.ScopedAccess resourceAccess = this.resources.access();){
            Util.checkCuVSError(headers_h.cuvsMatrixCopy(resourceAccess.handle(), sliceManagedTensor, bufferTensor), "cuvsMatrixCopy");
            Util.checkCuVSError(headers_h.cuvsStreamSync(resourceAccess.handle()), "cuvsStreamSync");
        }
        MemorySegment.copy(tmpRowSegment, this.valueLayout, 0L, array, 0, (int)this.columns);
    }

    @Override
    public void toHost(CuVSHostMatrix targetMatrix) {
        CuVSDeviceMatrixImpl.copyMatrix(this, (CuVSMatrixInternal)((Object)targetMatrix), this.resources);
    }

    @Override
    public CuVSDeviceMatrix toDevice(CuVSResources resources) {
        return new CuVSDeviceMatrixDelegate(this);
    }

    @Override
    public void toDevice(CuVSDeviceMatrix targetMatrix, CuVSResources cuVSResources) {
        CuVSDeviceMatrixImpl.copyMatrix(this, (CuVSMatrixInternal)((Object)targetMatrix), cuVSResources);
    }

    @Override
    public void close() {
        this.hostBuffer.close();
    }

    private static class CuVSDeviceMatrixDelegate
    implements CuVSDeviceMatrix,
    CuVSMatrixInternal {
        private final CuVSDeviceMatrixImpl deviceMatrix;

        private CuVSDeviceMatrixDelegate(CuVSDeviceMatrixImpl deviceMatrix) {
            this.deviceMatrix = deviceMatrix;
        }

        @Override
        public long size() {
            return this.deviceMatrix.size();
        }

        @Override
        public long columns() {
            return this.deviceMatrix.columns();
        }

        @Override
        public CuVSMatrix.DataType dataType() {
            return this.deviceMatrix.dataType();
        }

        @Override
        public RowView getRow(long row) {
            return this.deviceMatrix.getRow(row);
        }

        @Override
        public void toArray(int[][] array) {
            this.deviceMatrix.toArray(array);
        }

        @Override
        public void toArray(float[][] array) {
            this.deviceMatrix.toArray(array);
        }

        @Override
        public void toArray(byte[][] array) {
            this.deviceMatrix.toArray(array);
        }

        @Override
        public void toHost(CuVSHostMatrix hostMatrix) {
            this.deviceMatrix.toHost(hostMatrix);
        }

        @Override
        public void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources) {
            deviceMatrix.toDevice(deviceMatrix, cuVSResources);
        }

        @Override
        public CuVSDeviceMatrix toDevice(CuVSResources cuVSResources) {
            return this;
        }

        @Override
        public MemorySegment memorySegment() {
            return this.deviceMatrix.memorySegment();
        }

        @Override
        public ValueLayout valueLayout() {
            return this.deviceMatrix.valueLayout();
        }

        @Override
        public long rowStride() {
            return this.deviceMatrix.rowStride();
        }

        @Override
        public MemorySegment toTensor(Arena arena) {
            return this.deviceMatrix.toTensor(arena);
        }

        @Override
        public void close() {
        }
    }
}

