import asyncio
from typing import TYPE_CHECKING, Literal, Callable, Any
import psutil

import torch
from torch import Tensor

import kvikio

def get_memory_used():
    proc = psutil.Process()
    mem_info = proc.memory_full_info()
    mem_used = mem_info[0] + mem_info[9]  # Res + Swap
    return mem_used  # GB

def get_cuda_memory_used():
    torch.cuda.memory_allocated()


class Task:
    def __init__(self, callbacks: list[Callable[[], Any]]):
        self.callbacks = callbacks

    def wait(self):
        for callback in self.callbacks:
            callback()


ZERO_SIZE = 0

class GDTensor:

    def __init__(self, tensor: Tensor, storage_path: str):
        assert not tensor.requires_grad
        self._tensor = tensor
        self._untyped_size = tensor.untyped_storage().size()
        self._file_path = storage_path + '/' + str(id(self))

        self.at: Literal['storage', 'cpu', 'cuda'] = self._tensor.device.type
        self._task: Task | None = None
        self.ready: bool = True

    def clone(self):
        pass

    def to(self, *args, **kwargs):
        raise NotImplementedError()

    def __repr__(self):
        return str(self)
    
    def __str__(self):
        if self.at != 'storage':
            s = "GDT" + repr(self._tensor)[1:]
        else:
            s = "GDTensor(device='storage')"
        s += f"[ready={self.ready}]"
        return s

    # tensor delegation
    def __getattr__(self, name: str):
        return getattr(self._tensor, name)
    
    def synchronize(self):
        if self.ready:
            self._task = None
            return
        if self._task is not None:
            self._task.wait()
        self._task = None

    def to_inplace(self, device: Literal['storage', 'cpu', 'cuda'], async_: bool = True) -> Task | None:
        if self.at == device:
            return
        assert not self._tensor.requires_grad
        assert self.ready
        
        self.ready = False
        def make_ready():
            self.ready = True
        
        def resize_zero():
            # TODO: In torch 2.3, untyped_storage has resizable() function. Replace this.
            # if self._tensor.untyped_storage().resizable():
            try:
                self._tensor.untyped_storage().resize_(ZERO_SIZE)
            except:
                # self._tensor = torch.empty((0,)*self._tensor.ndim, dtype=self._tensor.dtype, layout=self._tensor.layout, device=self._tensor.device)
                self._tensor = torch.empty_like(self._tensor)
                self._tensor.untyped_storage().resize_(ZERO_SIZE)

        def resize_back():
            self._tensor.untyped_storage().resize_(self._untyped_size)
        
        task = None
        match self.at, device:
            case 'storage', 'cpu':
                def storage_to_cpu():
                    dtype = self._tensor.dtype
                    self._tensor = torch.from_file(self._file_path, size=self._untyped_size // dtype.itemsize, dtype=dtype).reshape_as(self._tensor)
                    make_ready()
                loop = asyncio.get_event_loop()
                future = loop.run_in_executor(None, storage_to_cpu)
                if not async_:
                    loop.run_until_complete(future)
                else:
                    task = Task([lambda: loop.run_until_complete(future)])

            case 'cpu', 'storage':
                def cpu_to_storage():
                    self._tensor.numpy().tofile(self._file_path)
                    resize_zero()
                    make_ready()
                loop = asyncio.get_event_loop()
                future = loop.run_in_executor(None, cpu_to_storage)
                if not async_:
                    loop.run_until_complete(future)
                else:
                    task = Task([lambda: loop.run_until_complete(future)])

            # case 'storage', 'cuda':
            #     f = kvikio.CuFile(self._file_path, 'r')
            #     # self._tensor = self._tensor.cuda()  # this exits the program???
            #     resize_back()
            #     self._tensor = self._tensor.cuda()
            #     if not async_:
            #         f.read(self._tensor)
            #         f.close()
            #         make_ready()
            #     else:
            #         future = f.pread(self._tensor)
            #         task = Task([future.get, f.close, make_ready])
            # case 'cuda', 'storage':
            #     f = kvikio.CuFile(self._file_path, 'w')
            #     if not async_:
            #         f.write(self._tensor)
            #         f.close()
            #         resize_zero()
            #         make_ready()
            #     else:
            #         future = f.pwrite(self._tensor)
            #         task = Task([future.get, f.close, resize_zero, make_ready])

            case _:
                if not async_:
                    self._tensor = self._tensor.to(device)
                    make_ready()
                else:
                    self._tensor = self._tensor.to(device, non_blocking=True)
                    stream = torch.cuda.current_stream()
                    task = Task([stream.synchronize, make_ready])

        self._task = task
        self.at = device
        # return task