import torch
from torch.utils.cpp_extension import load

gds_ops = load(
    name='gds_ops',
    sources=['csrc/gds_kernels.cu'],
    extra_cflags=['-lcufile'],
    verbose=True)

def driver_open():
    gds_ops.gds_driver_open()

class GDSBuffer:
    def __init__(self, size: int):
        self.ptr = torch.empty(size, dtype=torch.uint8, device='cuda')
        self.handle = gds_ops.gds_buf_register(self.ptr.data_ptr(), size)
    
    def __del__(self):
        gds_ops.gds_buf_deregister(self.handle)
