import asyncio
from typing import TYPE_CHECKING, Literal, Callable, Any
import psutil
import time
import os

import torch
from torch import Tensor
from torch.cuda.nvtx import range_push, range_pop

import kvikio
import kvikio.defaults

def get_memory_used():
    proc = psutil.Process()
    mem_info = proc.memory_full_info()
    mem_used = mem_info[0] + mem_info[9]  # Res + Swap
    return mem_used  # GB

def get_cuda_memory_used():
    torch.cuda.memory_allocated()


class Task:
    def __init__(self, callbacks: list[Callable[[], Any]],
                 args: list[Any] | None = None):
        self.callbacks = callbacks
        self.args = args

    def wait(self):
        if self.args is not None:
            for callback, arg in zip(self.callbacks, self.args):
                if arg is not None:
                    callback(arg)
                else:
                    callback()
        else:
            for callback in self.callbacks:
                callback()

class device:
    def __init__(self, type: str):
        self.type = type
    def __str__(self) -> str:
        return self.type


ZERO_SIZE = 0

class GDSTensor(): # GPU Direct Storage Tensor

    def __init__(self, tensor: Tensor, storage_path: str):
        assert not tensor.requires_grad
        self._tensor = tensor
        self._untyped_size = tensor.untyped_storage().size()
        self._file_path = storage_path + '/' + str(id(self))
        self._n_threads = 32

        self.at: Literal['storage', 'cpu', 'cuda'] = self._tensor.device.type
        self._task: Task | None = None
        self.ready: bool = True

    def clone(self):
        pass

    def to(self, *args, **kwargs):
        raise NotImplementedError()

    @property
    def device(self):
        return self.at

    def __repr__(self):
        return str(self)
    
    def __str__(self):
        if self.at != 'storage':
            s = "GDST" + repr(self._tensor)[1:]
        else:
            s = "GDSTensor(device='storage')"
        s += f"[ready={self.ready}]"
        return s

    # tensor delegation
    def __getattr__(self, name: str):
        return getattr(self._tensor, name)
    
    # def __getitem__(self) -> Tensor:
    #     return self._tensor # return tensor directly
    
    def tensor(self) -> Tensor:
        return self._tensor

    def synchronize(self):
        if self.ready:
            self._task = None
            return
        if self._task is not None:
            self._task.wait()
        self._task = None

    @torch.no_grad()
    def upload(self, target_tensor: Tensor, async_: bool = True, cuda_stream: torch.cuda.Stream | None = None) -> None:
        assert self.at == 'storage', "Upload is only allowed for storage tensor."
        assert self.ready
        assert self._tensor.shape == target_tensor.shape, 'The shape of the tensor should be the same.'

        if 'cpu' in str(target_tensor.device):
            raise NotImplementedError()
        elif 'cuda' in str(target_tensor.device):
            device = 'cuda'
        elif 'storage' in str(target_tensor.device):
            return # do nothing
        else:
            raise NotImplementedError()

        self.ready = False

        def _make_ready():
            self.ready = True

        task = None
        match self.at, device:
            case 'storage', 'cuda':
                if not os.path.exists(self._file_path):
                    assert False, "The file does not exist."
                f = kvikio.CuFile(self._file_path, 'r')
                if not async_:
                    f.read(target_tensor, task_size=self._tensor.nbytes//self._n_threads)
                    f.close()
                    _make_ready()
                else:
                    assert cuda_stream is not None, "cuda_stream must be provided."
                    future = f.raw_read_async(
                        buf=target_tensor.data,
                        size=target_tensor.nbytes,
                        file_offset=0,
                        stream=cuda_stream.cuda_stream
                    )
                    # future = f.pwrite(self._tensor, task_size=self._tensor.nbytes//self._n_threads)
                    task = Task([future.check_bytes_done, f.close,
                                    _make_ready])
            case _:
                assert False, "Not implemented."
        
        self._task = task

    @torch.no_grad()
    def overwrite(self, tensor: Tensor, async_: bool = True, cuda_stream: torch.cuda.Stream | None = None) -> None:
        assert self.at == 'storage', "Overwrite is only allowed for storage tensor."
        if tensor.requires_grad:
            tensor = tensor.detach()
        assert self.ready
        assert self._tensor.shape == tensor.shape, 'The shape of the tensor should be the same.'

        if 'cpu' in str(tensor.device):
            device = 'cpu'
        elif 'cuda' in str(tensor.device):
            device = 'cuda'
        elif 'storage' in str(tensor.device):
            device = 'storage'
        else:
            raise NotImplementedError()

        self.ready = False

        def _make_ready():
            self.ready = True
        def _resize_zero(): # resize zero for moving tensor to another device
            # TODO: In torch 2.3, untyped_storage has resizable() function. Replace this.
            # if self._tensor.untyped_storage().resizable():
            try:
                self._tensor.untyped_storage().resize_(ZERO_SIZE)
            except:
                # self._tensor = torch.empty((0,)*self._tensor.ndim, dtype=self._tensor.dtype, layout=self._tensor.layout, device=self._tensor.device)
                self._tensor = torch.empty_like(self._tensor)
                self._tensor.untyped_storage().resize_(ZERO_SIZE)

        def _del_tensor(tensor):
            del tensor

        task = None
        match device, self.at:
            case 'cpu', 'storage':
                if id(self._tensor) == id(tensor):
                    _make_ready()
                else:
                    def _cpu_to_storage():
                        del self._tensor
                        if os.path.exists(self._file_path):
                            os.remove(self._file_path)
                        else:
                            assert False, "The file does not exist."
                        tensor.numpy().tofile(self._file_path)
                        self._tensor = tensor
                        _resize_zero()
                        self.at = 'storage'
                        _make_ready()
                    loop = asyncio.get_event_loop()
                    future = loop.run_in_executor(None, _cpu_to_storage)
                    if not async_:
                        loop.run_until_complete(future)
                    else:
                        task = Task([lambda: loop.run_until_complete(future)])
            case 'cuda', 'storage':
                if id(self._tensor) == id(tensor):
                    _make_ready()
                else:
                    if os.path.exists(self._file_path):
                        os.remove(self._file_path)
                    else:
                        assert False, "The file does not exist."
                    f = kvikio.CuFile(self._file_path, 'w')
                    if not async_:
                        f.write(tensor, task_size=self._tensor.nbytes//self._n_threads)
                        f.close()
                        del tensor
                        _resize_zero()
                        self.at = 'storage'
                        _make_ready()
                    else:
                        assert cuda_stream is not None, "cuda_stream must be provided."
                        future = f.raw_write_async(
                            buf=tensor,
                            size=tensor.nbytes,
                            file_offset=0,
                            dev_offset=0,
                            stream=cuda_stream.cuda_stream,
                        )
                        # future = f.pwrite(self._tensor, task_size=self._tensor.nbytes//self._n_threads)
                        task = Task([future.check_bytes_done, f.close,
                                     _resize_zero, _del_tensor, _make_ready],
                                     [None, None, None, tensor, None]
                                     )
                        # # future = f.pwrite(self._tensor, task_size=self._tensor.nbytes//self._n_threads)
                        # task = Task([future.check_bytes_done, f.close,
                        #              _resize_zero, _make_ready],
                        #              [None, None, None, None]
                        #              )
            case _:
                assert False, "Not implemented."
        
        self._task = task
        self.at = 'storage' # must be storage (overwrite)

    @torch.no_grad()
    def to_duplicate(self, device: Literal['storage']) -> None:
        # we keep the duplicate data only on storage!
        assert device == 'storage', "To duplicate, the device should be 'storage'."
        if self.at == device:
            return

        assert not self._tensor.requires_grad
        assert self.ready

        self.ready = False
        def _make_ready():
            self.ready = True

        def _resize_zero(): # resize zero for moving tensor to another device
            # TODO: In torch 2.3, untyped_storage has resizable() function. Replace this.
            # if self._tensor.untyped_storage().resizable():
            try:
                self._tensor.untyped_storage().resize_(ZERO_SIZE)
            except:
                # self._tensor = torch.empty((0,)*self._tensor.ndim, dtype=self._tensor.dtype, layout=self._tensor.layout, device=self._tensor.device)
                self._tensor = torch.empty_like(self._tensor)
                self._tensor.untyped_storage().resize_(ZERO_SIZE)

        def _resize_back():
            self._tensor.untyped_storage().resize_(self._untyped_size)

        match self.at, device:
            case 'cpu', 'storage':
                if os.path.exists(self._file_path):
                    _resize_zero()
                    _make_ready()
                else:
                    assert False, "The file does not exist."
            case 'cuda', 'storage':
                if os.path.exists(self._file_path):
                    _resize_zero()
                    _make_ready()
                else:
                    assert False, "The file does not exist."
            case _:
                assert False, "Not implemented."
        
        self.at = 'storage'
        
    @torch.no_grad()
    def to_inplace(self, device: Literal['storage', 'cpu', 'cuda'], async_: bool = True) -> Task | None:

        if self.at == device:
            return

        # We do not support gradient for GDS tensor
        assert not self._tensor.requires_grad
        assert self.ready
        
        self.ready = False
        def _make_ready():
            self.ready = True
        
        def _resize_zero(): # resize zero for moving tensor to another device
            # TODO: In torch 2.3, untyped_storage has resizable() function. Replace this.
            # if self._tensor.untyped_storage().resizable():
            try:
                self._tensor.untyped_storage().resize_(ZERO_SIZE)
            except:
                # self._tensor = torch.empty((0,)*self._tensor.ndim, dtype=self._tensor.dtype, layout=self._tensor.layout, device=self._tensor.device)
                self._tensor = torch.empty_like(self._tensor)
                self._tensor.untyped_storage().resize_(ZERO_SIZE)

        def _resize_back():
            self._tensor.untyped_storage().resize_(self._untyped_size)
        
        task = None
        match self.at, device:
            case 'storage', 'cpu': # storage -> cpu
                def _storage_to_cpu():
                    dtype = self._tensor.dtype
                    self._tensor = torch.from_file(self._file_path, size=self._untyped_size // dtype.itemsize, dtype=dtype).reshape_as(self._tensor)
                    self.at = 'cpu'
                    _make_ready()
                loop = asyncio.get_event_loop() 
                future = loop.run_in_executor(None, _storage_to_cpu)
                if not async_:
                    loop.run_until_complete(future)
                else:
                    task = Task([lambda: loop.run_until_complete(future)])
            case 'cpu', 'storage': # cpu -> storage
                def cpu_to_storage():
                    if os.path.exists(self._file_path):
                        os.remove(self._file_path)
                    self._tensor.numpy().tofile(self._file_path)
                    _resize_zero()
                    self.at = 'storage'
                    _make_ready()
                loop = asyncio.get_event_loop()
                future = loop.run_in_executor(None, cpu_to_storage)
                if not async_:
                    loop.run_until_complete(future)
                else:
                    task = Task([lambda: loop.run_until_complete(future)])
            case 'storage', 'cuda':
                f = kvikio.CuFile(self._file_path, 'r')
                # self._tensor = self._tensor.cuda()  # this exits the program???
                _resize_back()
                self._tensor = self._tensor.cuda()
                if not async_:
                    f.read(self._tensor, task_size=self._tensor.nbytes//self._n_threads)
                    f.close()
                    self.at = 'cuda'
                    _make_ready()
                else:
                    future = f.pread(self._tensor, task_size=self._tensor.nbytes//self._n_threads)
                    task = Task([future.get, f.close, _make_ready])
            case 'cuda', 'storage':
                f = kvikio.CuFile(self._file_path, 'w')
                if not async_:
                    f.write(self._tensor, task_size=self._tensor.nbytes//self._n_threads)
                    f.close()
                    _resize_zero()
                    self.at = 'storage'
                    _make_ready()
                else:
                    future = f.pwrite(self._tensor, task_size=self._tensor.nbytes//self._n_threads)
                    task = Task([future.get, f.close, _resize_zero, _make_ready])
            case _:
                if not async_:
                    self._tensor = self._tensor.to(device)
                    _make_ready()
                else:
                    self._tensor = self._tensor.to(device, non_blocking=True)
                    stream = torch.cuda.current_stream()
                    task = Task([stream.synchronize, _make_ready])

        self._task = task
        self.at = device
        # return task


# Test
if __name__ == '__main__':

    # set kvikio-related settings
    kvikio.defaults.num_threads_reset(32) # n_threads 32
    kvikio.defaults.compat_mode_reset(True) # must use GDS
    def get_memory_used():
        proc = psutil.Process()
        mem_info = proc.memory_full_info()
        mem_used = mem_info[0] + mem_info[9]  # Res + Swap
        return mem_used / 10**9

    def get_memory_used_gb():
        return get_memory_used() / 1e9

    a = torch.rand((10, 1000, 1000, 250))  # 10 GB
    print(f"CPU: {get_memory_used_gb():.3f} GB")

    n_bytes = a.numel() * a.element_size()
    kvikio.defaults.set_task_size(n_bytes//32)

    print(f"n_bytes: {n_bytes}")

    gd_a = GDSTensor(a, '/mnt/fast_nvme/gdstensor/')
    del a
    # print(gd_a)
    # print(gd_a[:2, :2, 0, 0])
    # print(gd_a[:2, :2, -1, -1])
    print(f"GDS initialized (CPU)")
    print(f"Device: {gd_a.device}")
    print(f"CPU: {get_memory_used_gb():.3f} GB")
    print(f"GPU: {torch.cuda.memory_allocated() / 1e9:.3f} GB")
    print(f"---------------------")

    # gd_a.to_inplace('cuda', False)
    # gd_a.to_inplace('cpu', False)
    # gd_a.to_inplace('cuda', False)
    # gd_a.to_inplace('cpu', False)
    # gd_a.to_inplace('cuda', False)
    # gd_a.to_inplace('cpu', False)
    gd_a.to_inplace('cuda'); gd_a.synchronize()
    print(f"Device: {gd_a.device}")
    print(f"CPU: {get_memory_used_gb():.3f} GB")
    print(f"GPU: {torch.cuda.memory_allocated() / 1e9:.3f} GB")

    gd_a.to_inplace('cpu'); gd_a.synchronize()
    gd_a.to_inplace('cuda'); gd_a.synchronize()
    gd_a.to_inplace('cpu'); gd_a.synchronize()
    gd_a.to_inplace('cuda'); gd_a.synchronize()
    gd_a.to_inplace('cpu'); gd_a.synchronize()


    start = time.perf_counter()
    range_push('CPU->GPU')
    gd_a.to_inplace('cuda')
    gd_a.synchronize()
    range_pop()
    end = time.perf_counter()
    # print(gd_a)
    print(f"(CPU->GPU) Time: {end - start} s")  # C -> G: 1.523
    print(f"Device: {gd_a.device}")
    print(f"CPU: {get_memory_used_gb():.3f} GB")
    print(f"GPU: {torch.cuda.memory_allocated() / 1e9:.3f} GB")
    print(f"---------------------")

    # assert False

    start = time.perf_counter()
    range_push('GPU->CPU')
    gd_a.to_inplace('cpu')
    gd_a.synchronize()
    range_pop()
    end = time.perf_counter()
    torch.cuda.empty_cache()
    # print(gd_a)
    print(f"(GPU->CPU) Time: {end - start} s")  # G -> C: 3.200
    print(f"Device: {gd_a.device}")
    print(f"CPU: {get_memory_used() / 10**9} GB")  # 17.84
    print(f"GPU: {torch.cuda.memory_allocated() / 10**9} GB")  # 0
    print(f"---------------------")

    start = time.perf_counter()
    range_push('CPU->Storage')
    gd_a.to_inplace('storage')
    gd_a.synchronize()
    range_pop()
    end = time.perf_counter()
    # print(gd_a)
    print(f"(CPU->Storage) Time: {end - start} s")  # C -> S: 3.094
    print(f"Device: {gd_a.device}")
    print(f"CPU: {get_memory_used() / 10**9} GB")  # 17.84
    print(f"GPU: {torch.cuda.memory_allocated() / 10**9} GB")  # 0
    print(f"---------------------")

    start = time.perf_counter()
    range_push('Storage->GPU')
    gd_a.to_inplace('cuda')
    gd_a.synchronize()
    range_pop()
    end = time.perf_counter()
    print(f"(Storage->GPU) Time: {end - start} s")  # S -> G: 2.002
    print(f"Device: {gd_a.device}")
    print(f"CPU: {get_memory_used() / 10**9} GB")  # 17.85
    print(f"GPU: {torch.cuda.memory_allocated() / 10**9} GB")  # 10
    print(f"---------------------")

    start = time.perf_counter()
    range_push('GPU->Storage')
    gd_a.to_inplace('storage')
    gd_a.synchronize()
    range_pop()
    end = time.perf_counter()
    # print(gd_a)
    print(f"(GPU->Storage) Time: {end - start} s")  # G -> S: 7.862
    print(f"Device: {gd_a.device}")
    print(f"CPU: {get_memory_used() / 10**9} GB")  # 17.85
    print(f"GPU: {torch.cuda.memory_allocated() / 10**9} GB")  # 0
    print(f"---------------------")

    start = time.perf_counter()
    range_push('Storage->CPU')
    gd_a.to_inplace('cpu')
    gd_a.synchronize()
    range_pop()
    end = time.perf_counter()
    # print(gd_a)
    print(f"(Storage->CPU) Time: {end - start} s")  # S -> C: 0.0002
    print(f"Device: {gd_a.device}")
    print(f"CPU: {get_memory_used() / 10**9} GB")  # 17.85PI
    print(f"GPU: {torch.cuda.memory_allocated() / 10**9} GB")  # 0
    print(f"---------------------")

    start = time.perf_counter()
    range_push('CPU->GPU')
    gd_a.to_inplace('cuda')
    gd_a.synchronize()
    range_pop()
    end = time.perf_counter()
    # print(gd_a)
    print(f"(CPU->GPU) Time: {end - start} s")  # C -> G: 0.875
    print(f"Device: {gd_a.device}")
    print(f"CPU: {get_memory_used() / 10**9} GB")  # 17.85
    print(f"GPU: {torch.cuda.memory_allocated() / 10**9} GB")  # 10
    print(f"---------------------")
