import asyncio
from turtle import st
from typing import TYPE_CHECKING, Literal, Callable, Any
import psutil
import time
import os

import torch
from torch import Tensor
from torch.cuda.nvtx import range_push, range_pop

import kvikio
import kvikio.defaults

def get_memory_used():
    proc = psutil.Process()
    mem_info = proc.memory_full_info()
    mem_used = mem_info[0] + mem_info[9]  # Res + Swap
    return mem_used  # in bytes

def get_cuda_memory_used():
    return torch.cuda.memory_allocated()

class Task:
    def __init__(self, callbacks: list[Callable[[], Any]],
                 args: list[Any] | None = None):
        self.callbacks = callbacks
        self.args = args

    def wait(self):
        """Runs each callback with its corresponding arg (if any)."""
        if self.args is not None:
            for callback, arg in zip(self.callbacks, self.args):
                if arg is not None:
                    callback(arg)
                else:
                    callback()
        else:
            for callback in self.callbacks:
                callback()

ZERO_SIZE = 0

class GDSTensor(torch.Tensor):
    r"""
    A subclass of `torch.Tensor` that adds GPU Direct Storage (GDS) capabilities.
    
    - We keep track of a "logical device" which can be one of `'cpu'`, `'cuda'`, or `'storage'`.
    - We store the underlying data in `self` (since we're a Tensor subclass).
    - We store additional metadata (file path, tasks, etc.) as normal Python attributes.
    - Be note that we cannot change the original base device (CPU or CUDA) of the tensor.
    - In other words, we cannot transfer a CPU tensor to CUDA or vice versa.
    - However, we can move the data to/from storage, and to/from CPU/GPU.
    - As an alternative, i have implemented a `to_inplace` method that can move the data to any device.
    """

    @staticmethod
    def __new__(cls, tensor: Tensor, storage_path: str) -> 'GDSTensor':
        """
        Creates a new GDSTensor from an existing CPU or CUDA tensor.
        The new GDSTensor is 'physically' identical to the input tensor (same shape, dtype, etc.),
        but it has extra metadata for GDS.
        """
        if tensor.requires_grad:
            raise ValueError("GDS does not support tensors with gradients.")

        # Create a subclassed Tensor:
        #   - We do .detach() to ensure no gradient,
        #   - .clone() is optional, but sometimes safer to avoid altering the original.
        #   - Then we call as_subclass(cls) to produce a GDSTensor.
        # If you want to share the same storage with the original tensor, skip clone().
        # For safety, let's do .detach().clone().
        obj = tensor.detach().clone().as_subclass(cls)

        obj._base_device = tensor.device.type # we track the original device

        # Now store GDS-specific metadata:
        obj._untyped_size = tensor.untyped_storage().size()
        obj._file_path = os.path.join(storage_path, str(id(obj)))
        obj._n_threads = 32
        # We keep track of a "logical device" which can be 'cpu', 'cuda', or 'storage':
        obj._logical_device: Literal['storage', 'cpu', 'cuda'] = tensor.device.type
        obj._task: Task | None = None
        obj.ready: bool = True

        # to support cpu <-> cuda async operations
        # original torch.tensor does not allow inter-device operations
        # Therefore, we need to use streams to make it work
        # obj.d2h_stream = torch.cuda.Stream()
        # obj.h2d_stream = torch.cuda.Stream()

        return obj

    ########################################################################
    # Basic properties
    ########################################################################

    @property
    def nbytes(self) -> int:
        """Returns the number of bytes occupied by this tensor."""
        return self.numel() * self.element_size()

    @property
    def logical_device(self) -> str:
        """
        Returns the "logical device" of this GDSTensor:
          - 'cpu'
          - 'cuda'
          - 'storage' (on NVMe via GDS)
        """
        return self._logical_device

    def __repr__(self) -> str:
        return str(self)

    def __str__(self) -> str:
        if self._logical_device != 'storage':
            s = "GDST" + super().__repr__()[1:]  # e.g. "GDSTensor(..."
        else:
            s = "GDSTensor(device='storage')"
        s += f"[ready={self.ready}]"
        return s

    ########################################################################
    # Overriding movement methods & custom logic
    ########################################################################

    def synchronize(self):
        """
        Wait for any async tasks to finish if `self.ready` is False.
        Must call before re-using the data if an async operation was triggered.
        """
        if self.ready:
            self._task = None
            return
        if self._task is not None:
            self._task.wait()
        self._task = None
        self.ready = True

    def upload(self, target_tensor: Tensor, async_: bool = True,
               cuda_stream: torch.cuda.Stream | None = None) -> None:
        """
        "Upload" from storage -> GPU.
        Reads from the file at self._file_path into `target_tensor`.
        """
        # We only allow upload if we're logically on 'storage':
        if self._logical_device != 'storage':
            raise RuntimeError("Upload is only allowed for storage-based tensor.")
        assert self.ready
        assert self.shape == target_tensor.shape, "Shapes must match."

        # Identify the target device:
        device_str = str(target_tensor.device)
        if 'cpu' in device_str:
            raise NotImplementedError("storage -> cpu direct upload not implemented here.")
        elif 'cuda' in device_str:
            pass
        elif 'storage' in device_str:
            return  # do nothing
        else:
            raise NotImplementedError()

        self.ready = False

        def _make_ready():
            self.ready = True

        match self._logical_device, 'cuda':
            case 'storage', 'cuda':
                if not os.path.exists(self._file_path):
                    raise FileNotFoundError(f"{self._file_path} does not exist.")
                f = kvikio.CuFile(self._file_path, 'r')
                if not async_:
                    # read synchronously
                    f.read(target_tensor, task_size=self.nbytes // self._n_threads)
                    f.close()
                    _make_ready()
                else:
                    if cuda_stream is None:
                        raise ValueError("cuda_stream must be provided for async upload.")
                    future = f.raw_read_async(
                        buf=target_tensor.data,
                        size=target_tensor.nbytes,
                        file_offset=0,
                        stream=cuda_stream.cuda_stream
                    )
                    self._task = Task([future.check_bytes_done, f.close, _make_ready])
            case _:
                raise NotImplementedError(f"Upload not implemented for {self._logical_device} -> cuda")

    def overwrite(self, source_tensor: Tensor,
                  async_: bool = True,
                  cuda_stream: torch.cuda.Stream | None = None) -> None:
        """
        Overwrite the existing file on disk (when self is 'storage') using `source_tensor`.
        This is the reverse of `upload`: we write from CPU or GPU to disk.
        """
        if self._logical_device != 'storage':
            raise RuntimeError("overwrite is only allowed for storage-based tensor.")
        # For safety:
        if source_tensor.requires_grad:
            source_tensor = source_tensor.detach()
        assert self.ready
        assert self.shape == source_tensor.shape, "Shape mismatch."

        device_str = str(source_tensor.device)
        if 'cpu' in device_str:
            dev = 'cpu'
        elif 'cuda' in device_str:
            dev = 'cuda'
        elif 'storage' in device_str:
            dev = 'storage'
        else:
            raise NotImplementedError()

        self.ready = False

        def _make_ready():
            self.ready = True

        def _resize_zero():
            # Attempt to resize our underlying storage to 0:
            try:
                self.untyped_storage().resize_(ZERO_SIZE)
            except:
                # fallback: reallocate self to empty
                tmp = torch.empty_like(self)
                tmp.untyped_storage().resize_(ZERO_SIZE)
                self.copy_(tmp)

        def _del_tensor(t):
            del t

        match dev, self._logical_device:
            case 'cpu', 'storage':
                # synchronous CPU->storage
                if id(self) == id(source_tensor):
                    _make_ready()
                else:
                    def _cpu_to_storage():
                        # remove old file
                        if os.path.exists(self._file_path):
                            os.remove(self._file_path)
                        else:
                            raise FileNotFoundError(f"{self._file_path} does not exist.")
                        source_tensor.numpy().tofile(self._file_path)

                        # Now we physically share data with source_tensor? 
                        # Because we are "the same shape", but to keep consistency,
                        # let's just do a zero-resize to reflect that the real data is on disk.
                        # This is a bit of a hack, but it works.

                        _resize_zero()
                        _make_ready()
                    loop = asyncio.get_event_loop()
                    future = loop.run_in_executor(None, _cpu_to_storage)
                    if not async_:
                        loop.run_until_complete(future)
                    else:
                        self._task = Task([lambda: loop.run_until_complete(future)])
            case 'cuda', 'storage':
                if id(self) == id(source_tensor):
                    _make_ready()
                else:
                    if os.path.exists(self._file_path):
                        os.remove(self._file_path)
                    else:
                        raise FileNotFoundError(f"{self._file_path} does not exist.")
                    f = kvikio.CuFile(self._file_path, 'w')
                    if not async_:
                        f.write(source_tensor, task_size=self.nbytes // self._n_threads)
                        f.close()
                        _del_tensor(source_tensor)
                        _resize_zero()
                        _make_ready()
                    else:
                        if cuda_stream is None:
                            raise ValueError("cuda_stream must be provided for async overwrite.")
                        future = f.raw_write_async(
                            buf=source_tensor,
                            size=source_tensor.nbytes,
                            file_offset=0,
                            dev_offset=0,
                            stream=cuda_stream.cuda_stream,
                        )
                        self._task = Task(
                            [future.check_bytes_done, f.close, _resize_zero, _del_tensor, _make_ready],
                            [None, None, None, source_tensor, None]
                        )
            case _:
                raise NotImplementedError(f"Overwrite not implemented for {dev} -> {self._logical_device}")

        # We remain logically 'storage'
        self._logical_device = 'storage'

    def to_duplicate(self, device: Literal['storage']) -> None:
        """
        Copies data to 'storage' *without* discarding the CPU/GPU data. 
        In practice, here we only keep the "duplicate" on storage, and 
        forcibly zero out the main memory usage (like caching).
        """
        assert device == 'storage'
        if self._logical_device == device:
            return

        assert not self.requires_grad
        assert self.ready

        self.ready = False

        def _make_ready():
            self.ready = True

        def _resize_zero():
            try:
                self.untyped_storage().resize_(ZERO_SIZE)
            except:
                tmp = torch.empty_like(self)
                tmp.untyped_storage().resize_(ZERO_SIZE)
                self.copy_(tmp)

        match self._logical_device, device:
            case 'cpu', 'storage':
                if os.path.exists(self._file_path):
                    _resize_zero()
                    _make_ready()
                else:
                    raise FileNotFoundError(f"{self._file_path} not found.")
            case 'cuda', 'storage':
                if os.path.exists(self._file_path):
                    _resize_zero()
                    _make_ready()
                else:
                    raise FileNotFoundError(f"{self._file_path} not found.")
            case _:
                raise NotImplementedError()

        self._logical_device = 'storage'

    def to_inplace(self, device: Literal['storage', 'cpu', 'cuda'], async_: bool = False) -> Task | None:
        """
        Move the physical data from self._logical_device -> `device`.
        This modifies self in-place.
        """
        if self._logical_device == device:
            return None

        # GDS doesn't support gradients, so we skip them:
        assert not self.requires_grad
        assert self.ready
        
        self.ready = False

        def _make_ready():
            self.ready = True

        def _resize_zero():
            try:
                self.untyped_storage().resize_(ZERO_SIZE)
            except:
                tmp = torch.empty_like(self)
                tmp.untyped_storage().resize_(ZERO_SIZE)
                self.copy_(tmp)

        def _resize_back():
            self.untyped_storage().resize_(self._untyped_size)

        task = None

        print(f"Moving {self._logical_device} -> {device}")
        match self._logical_device, device:
            case 'storage', 'cpu':
                # Synchronous or async file read into CPU memory
                def _storage_to_cpu():
                    dtype = self.dtype
                    # read the entire file into self:
                    # The shape is the same, so we do from_file + reshape:
                    new_t = torch.from_file(
                        filename=self._file_path,
                        size=self._untyped_size // dtype.itemsize,
                        dtype=dtype
                    ).reshape_as(self)
                    # Copy new_t -> self in-place:
                    self.copy_(new_t)
                    self._logical_device = 'cpu'
                    _make_ready()

                loop = asyncio.get_event_loop()
                future = loop.run_in_executor(None, _storage_to_cpu)
                if not async_:
                    loop.run_until_complete(future)
                else:
                    task = Task([lambda: loop.run_until_complete(future)])
                
            case 'cpu', 'storage':
                def _cpu_to_storage():
                    if os.path.exists(self._file_path):
                        os.remove(self._file_path)
                    # Write CPU data to disk
                    self.numpy().tofile(self._file_path)
                    _resize_zero()
                    self._logical_device = 'storage'
                    _make_ready()
                loop = asyncio.get_event_loop()
                future = loop.run_in_executor(None, _cpu_to_storage)
                if not async_:
                    loop.run_until_complete(future)
                else:
                    task = Task([lambda: loop.run_until_complete(future)])
            
            case 'storage', 'cuda':
                # read from file into a CUDA tensor
                f = kvikio.CuFile(self._file_path, 'r')
                _resize_back()  # restore space
                # move self to cuda:
                self.data = self.cuda()  # or `self = self.cuda()` won't quite work in-place, so do self.copy_(self.cuda()) or so
                
                # we can use a new tensor to read into
                # new_t = torch.empty_like(self, device='cuda')
                # or we can use the existing tensor
                new_t = self.data  # use the existing tensor
                # if error occurs, just use new_t = self
                if not async_:
                    f.read(new_t, task_size=self.nbytes // self._n_threads)
                    f.close()
                    # then copy new_t -> self
                    # self.copy_(new_t)
                    # for the second method, we can just use the existing tensor directly
                    self._logical_device = 'cuda'
                    _make_ready()
                else:
                    future = f.pread(new_t, task_size=self.nbytes // self._n_threads)
                    def finalize():
                        # self.copy_(new_t) 
                        # for the second method, we can just use the existing tensor directly
                        self._logical_device = 'cuda'
                        _make_ready()
                    task = Task([future.get, f.close, finalize])
            
            case 'cuda', 'storage':
                f = kvikio.CuFile(self._file_path, 'w')
                if not async_:
                    f.write(self, task_size=self.nbytes // self._n_threads)
                    f.close()
                    _resize_zero()
                    self._logical_device = 'storage'
                    _make_ready()
                else:
                    future = f.pwrite(self, task_size=self.nbytes // self._n_threads)
                    def finalize():
                        _resize_zero()
                        self._logical_device = 'storage'
                        _make_ready()
                    task = Task([future.get, f.close, finalize])

            case _:
                # Fallback: create a brand-new tensor on the target device.
                if not async_:
                    print(f"Moving {self._logical_device} -> {device} (fallback)")
                    tmp = super().to(device)  # a plain Tensor on `device`
                    
                    # Now wrap `tmp` in a fresh GDSTensor (or re-subclass it).
                    # We'll copy over critical metadata from the old self.
                    new_self = tmp.as_subclass(GDSTensor)
                    new_self._file_path = self._file_path
                    new_self._untyped_size = self._untyped_size
                    new_self._n_threads = self._n_threads
                    new_self._task = None
                    new_self.ready = True
                    new_self._logical_device = device

                    # Return the new object. This means outside references to `self`
                    # won't see the device change, but that is the PyTorch model:
                    return new_self
                else:
                    assert False, "Async fallback not implemented yet."
                    # For an async approach, you can do something like:
                    tmp = super().to(device, non_blocking=True)
                    stream = torch.cuda.current_stream() if device.startswith("cuda") else None

                    def finalize():
                        if stream is not None:
                            stream.synchronize()
                        # Now wrap in GDSTensor
                        new_self = tmp.as_subclass(GDSTensor)
                        new_self._file_path = self._file_path
                        new_self._untyped_size = self._untyped_size
                        new_self._n_threads = self._n_threads
                        new_self._task = None
                        new_self.ready = True
                        new_self._logical_device = device
                        return new_self  # But note we cannot truly "return" here 
                                        # in a callback context. We'll store it.

                    # The callback pattern might need a redesign so you can bubble 
                    # up the `new_self`. For example, store it in `self._moved_version`,
                    # or return a "future" that yields the new object.
                    
                    # We'll do a minimal approach just calling finalize() after stream sync:
                    def callback():
                        finalize()
                    
                    self._task = Task([callback])
                    return self

        self._task = task
        return task

# #############################################################################
# Test
# #############################################################################
if __name__ == '__main__':
    # Set kvikio defaults
    kvikio.defaults.num_threads_reset(32)  # # threads
    kvikio.defaults.compat_mode_reset(True)  # Must use GDS
    
    def get_memory_used_gb():
        return get_memory_used() / 1e9

    a = torch.rand((10, 1000, 1000, 250))  # ~10GB if 32-bit float
    n_bytes = a.numel() * a.element_size()
    kvikio.defaults.set_task_size(n_bytes // 32)

    print(f"n_bytes: {n_bytes}")

    gd_a = GDSTensor(a, '/mnt/fast_nvme/gdstensor/')
    del a

    # print(gd_a)
    # print(f"Device: {gd_a.logical_device}")

    print(f"GDS initialized (CPU)")
    print(f"Device: {gd_a.logical_device}")
    print(f"CPU: {get_memory_used_gb():.3f} GB")
    print(f"GPU: {torch.cuda.memory_allocated() / 1e9:.3f} GB")
    print("---")

    # Move around:
    gd_a.to_inplace('cuda'); gd_a.synchronize()
    print(f"Device: {gd_a.logical_device}")
    print(f"CPU: {get_memory_used_gb():.3f} GB")
    print(f"GPU: {torch.cuda.memory_allocated() / 1e9:.3f} GB")

    assert False, "Stop here"
    gd_a.to_inplace('cpu'); gd_a.synchronize()
    gd_a.to_inplace('cuda'); gd_a.synchronize()
    gd_a.to_inplace('cpu'); gd_a.synchronize()
    gd_a.to_inplace('cuda'); gd_a.synchronize()
    gd_a.to_inplace('cpu'); gd_a.synchronize()

    start = time.perf_counter()
    range_push('CPU->GPU')
    gd_a.to_inplace('cuda')
    gd_a.synchronize()
    range_pop()
    end = time.perf_counter()
    print(f"(CPU->GPU) Time: {end - start:.3f} s")
    print(f"Device: {gd_a.logical_device}")
    print(f"CPU: {get_memory_used_gb():.3f} GB")
    print(f"GPU: {torch.cuda.memory_allocated() / 1e9:.3f} GB")
    print("---")

    start = time.perf_counter()
    range_push('GPU->CPU')
    gd_a.to_inplace('cpu')
    gd_a.synchronize()
    range_pop()
    end = time.perf_counter()
    torch.cuda.empty_cache()
    print(f"(GPU->CPU) Time: {end - start:.3f} s")
    print(f"Device: {gd_a.logical_device}")
    print(f"CPU: {get_memory_used_gb():.3f} GB")
    print(f"GPU: {torch.cuda.memory_allocated() / 1e9:.3f} GB")
    print("---")

    start = time.perf_counter()
    range_push('CPU->Storage')
    gd_a.to_inplace('storage')
    gd_a.synchronize()
    range_pop()
    end = time.perf_counter()
    print(f"(CPU->Storage) Time: {end - start:.3f} s")
    print(f"Device: {gd_a.logical_device}")
    print(f"CPU: {get_memory_used_gb():.3f} GB")
    print(f"GPU: {torch.cuda.memory_allocated() / 1e9:.3f} GB")
    print("---")

    start = time.perf_counter()
    range_push('Storage->GPU')
    gd_a.to_inplace('cuda')
    gd_a.synchronize()
    range_pop()
    end = time.perf_counter()
    print(f"(Storage->GPU) Time: {end - start:.3f} s")
    print(f"Device: {gd_a.logical_device}")
    print(f"CPU: {get_memory_used_gb():.3f} GB")
    print(f"GPU: {torch.cuda.memory_allocated() / 1e9:.3f} GB")
    print("---")

    start = time.perf_counter()
    range_push('GPU->Storage')
    gd_a.to_inplace('storage')
    gd_a.synchronize()
    range_pop()
    end = time.perf_counter()
    print(f"(GPU->Storage) Time: {end - start:.3f} s")
    print(f"Device: {gd_a.logical_device}")
    print(f"CPU: {get_memory_used_gb():.3f} GB")
    print(f"GPU: {torch.cuda.memory_allocated() / 1e9:.3f} GB")
    print("---")

    start = time.perf_counter()
    range_push('Storage->CPU')
    gd_a.to_inplace('cpu')
    gd_a.synchronize()
    range_pop()
    end = time.perf_counter()
    print(f"(Storage->CPU) Time: {end - start:.6f} s")
    print(f"Device: {gd_a.logical_device}")
    print(f"CPU: {get_memory_used_gb():.3f} GB")
    print(f"GPU: {torch.cuda.memory_allocated() / 1e9:.3f} GB")
    print("---")

    start = time.perf_counter()
    range_push('CPU->GPU')
    gd_a.to_inplace('cuda')
    gd_a.synchronize()
    range_pop()
    end = time.perf_counter()
    print(f"(CPU->GPU) Time: {end - start:.3f} s")
    print(f"Device: {gd_a.logical_device}")
    print(f"CPU: {get_memory_used_gb():.3f} GB")
    print(f"GPU: {torch.cuda.memory_allocated() / 1e9:.3f} GB")
    print("---")
