from typing import Literal, Callable, Any
import os

import kvikio
import torch

from torch import Tensor
from torch_sparse import SparseTensor

from tensornvme import DiskOffloader

from loguru import logger

from utils.gdstensor import ZERO_SIZE
from utils.gdstensor import GDSTensor


class Task:
    def __init__(self, callbacks: list[Callable[[], Any]],
                 args: list[Any] | None = None):
        self.callbacks = callbacks
        self.args = args

    def wait(self):
        if self.args is not None:
            for callback, arg in zip(self.callbacks, self.args):
                if arg is not None:
                    callback(arg)
                else:
                    callback()
        else:
            for callback in self.callbacks:
                callback()

class device:
    def __init__(self, type: str):
        self.type = type
    def __str__(self) -> str:
        return self.type

ZERO_SIZE = 0

class AdjacencyMatrixWithGDS(SparseTensor):
    r"""
    Adjacency Matrix with GPU Direct Storage (GDS).
    This class references the logic from `AdjacencyMatrixWithOffloader` but uses
    GDSTensor for rowptr, col, and value to move them between CPU, CUDA, and storage.

    Important:
      - We assume that at initialization time, `rowptr`, `col`, and `value` are CPU tensors.
      - We also assume that `rowptr`, `col`, and `value` do NOT require gradients.
    """

    def __init__(
        self,
        rowptr: Tensor | None = None,
        col: Tensor | None = None,
        value: Tensor | None = None,
        sparse_sizes: tuple[int | None, int | None] | None = None,
        is_sorted: bool = False,
        storage_path: str = "/mnt/fast_nvme/gds_adjmat",  # a folder path to store rowptr/col/value
    ):
        """
        By default, we expect the inputs to be CPU tensors.

        :param rowptr: CPU tensor containing rowptr
        :param col: CPU tensor containing col indices
        :param value: CPU tensor containing optional edge values
        :param sparse_sizes: shape of the adjacency matrix
        :param is_sorted: whether col indices are sorted within each row
        :param storage_path: folder path for storing data if device=storage
        """
        super().__init__(
            rowptr=rowptr,
            col=col,
            value=value,
            sparse_sizes=sparse_sizes,
            is_sorted=is_sorted,
        )

        # At init, we are on CPU (assertion from your snippet).
        assert self.storage._rowptr.device.type == "cpu"
        assert self.storage._col.device.type == "cpu"
        if self.storage._value is not None:
            assert self.storage._value.device.type == "cpu"
        
        assert rowptr is not None, "rowptr must be provided"
        assert col is not None, "col must be provided"
        assert value is not None, "value must be provided"

        # Wrap each tensor with GDSTensor for future GDS usage:
        if rowptr is not None:
            # create a GDS wrapper, but still physically on CPU
            self.storage._rowptr = GDSTensor(rowptr, storage_path)
        if col is not None:
            self.storage._col = GDSTensor(col, storage_path)
        if value is not None:
            self.storage._value = GDSTensor(value, storage_path)

        self.at: Literal["storage", "cpu", "cuda"] = "cpu"
        self.ready: bool = True
        self._task: Task | None = None

    @property
    def device(self) -> str:
        """Returns the logical 'device' where the adjacency matrix currently lives:
        'cpu', 'cuda', or 'storage'."""
        return self.at

    def __repr__(self) -> str:
        return str(self)

    def __str__(self) -> str:
        # For simplicity, we can just show that rowptr, col, value are GDS Tensors
        # plus a short info about device/ready state.
        s = (
            "AdjacencyMatrixWithGDS(\n"
            f"  rowptr={self.storage._rowptr},\n"
            f"  col={self.storage._col},\n"
            f"  value={self.storage._value}\n"
            ")"
        )
        s += f"[at={self.at}][ready={self.ready}]"
        return s

    def synchronize(self):
        r"""
        If the tensor is used with async=True,
        this function will wait for any GDS tasks to finish.
        You must call this function before reusing the adjacency matrix if
        you used async transfers.
        """
        if self.ready:
            self._task = None
            return
        if self._task is not None:
            self._task.wait()
        # Also call .synchronize() on each underlying GDSTensor:
        if self.storage._rowptr is not None:
            self.storage._rowptr.synchronize()
        if self.storage._col is not None:
            self.storage._col.synchronize()
        if self.storage._value is not None:
            self.storage._value.synchronize()

        self._task = None
        self.ready = True

    @torch.no_grad()
    def to(self, destination: str, async_: bool = False):
        """
        Move the adjacency matrix to 'cpu', 'cuda', or 'storage'.

        :param destination: must be one of {'cpu', 'cuda', 'storage'}
        :param async_: if True, attempt to do GDS asynchronously (if supported).
        """
        # Basic checks
        assert destination in ["cpu", "cuda", "storage"], f"Invalid destination: {destination}"
        assert self.ready, "You must call synchronize() before calling to() again."

        # If already on that device, do nothing
        if destination == self.at:
            logger.debug(f"Already on {destination}, skipping .to()")
            return

        # We'll become 'unready' until the actual operation finishes (async or sync)
        self.ready = False

        # For convenience, define a local callback to mark readiness:
        def _make_ready():
            self.ready = True

        # -----------
        # Implementation approach:
        # We have 6 possible transitions: CPU->CUDA, CUDA->CPU, CPU->STORAGE,
        # STORAGE->CPU, CUDA->STORAGE, STORAGE->CUDA.
        # Use each GDSTensor's `.to_inplace(destination, async_=True/False)` logic
        # or combine it with synchronous moves (like .to("cpu")) as needed.
        # -----------
        if self.storage._rowptr is not None:
            self._rowptr_move(self.storage._rowptr, destination, async_)
        if self.storage._col is not None:
            self._rowptr_move(self.storage._col, destination, async_)
        if self.storage._value is not None:
            self._rowptr_move(self.storage._value, destination, async_)

        # Because GDS's `.to_inplace` can do async or sync, we finalize by either:
        #   - scheduling a Task that calls .synchronize() at the end, or
        #   - calling .synchronize() immediately if async_ is False
        if async_:
            # We'll just store a small Task that calls .synchronize() on this object
            self._task = Task([self.synchronize])
        else:
            # Do it synchronously (block now)
            self.synchronize()

        # Mark final device
        self.at = destination

    def _rowptr_move(self, gds_tensor: GDSTensor, destination: str, async_: bool):
        """
        Helper to move a single GDSTensor from current self.at -> destination.
        We do an in-place move of that GDSTensor.  If you want partial-async,
        you can set async_=True.  Then call `synchronize()` to wait later.
        """
        # current -> next
        if gds_tensor.device == destination:
            return  # no-op
        # use to_inplace from the `GDSTensor`
        gds_tensor.to_inplace(destination, async_)

    def gpu_to_storage(self, async_: bool = True):
        """
        Explicit method to move from GPU to Storage (if we are on GPU).
        If you're on CPU, this does nothing. If you're already on Storage, does nothing.
        """
        if self.at == 'cuda':
            self.to('storage', async_=async_)

    def storage_to_gpu(self, async_: bool = True):
        """
        Explicit method to move from Storage to GPU (if we are on Storage).
        """
        if self.at == 'storage':
            self.to('cuda', async_=async_)

class AdjacencyMatrixWithSimpleGDS(SparseTensor):
    r""" Adjacency Matrix with GPU Direct Storage (GDS).
    But this class has much simpler logic than `AdjacencyMatrixWithGDS` for compatibility.
    """

    def __init__(self, rowptr: Tensor | None = None, col: Tensor | None = None, value: Tensor | None = None,
                 storage_path: str = "/mnt/fast_nvme/grinnder_storage",
                 sparse_sizes: tuple[int | None, int | None] | None = None, is_sorted: bool = False):
        # as we only use GDS here, we don't need to store the offloader
        self.physical_device: Literal['cpu', 'cuda', 'storage']

        super().__init__(rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, is_sorted=is_sorted)

        # At init, get the initial device
        self.physical_device = rowptr.device.type
        assert self.physical_device == 'cpu', "Initial device must be CPU"

        # For synchronization
        self._task:  Task | None = None
        self.ready: bool = True

        # GDS parameters
        self.storage_path = os.path.join(storage_path, f'adjmat_{id(self)}')
        self._row_ptr_path = self.storage_path + '_rowptr'
        self._col_path = self.storage_path + '_col'
        self._value_path = self.storage_path + '_value'
        self._gds_n_threads = 32

        self._rowptr_untyped_size = self.storage._rowptr.untyped_storage().size()
        self._col_untyped_size = self.storage._col.untyped_storage().size()
        if self.storage._value is not None:
            self._value_untyped_size = self.storage._value.untyped_storage().size()

        # print(f"storage rowptr: {self._row_ptr_path}")
        # print(f"storage col: {self._col_path}")
        # print(f"storage value: {self._value_path}")

        # Move the tensors to cuda for GDS
        # ================================
        self.storage._rowptr = self.storage._rowptr.to('cuda')
        self.storage._col = self.storage._col.to('cuda')
        if self.storage._value is not None:
            self.storage._value = self.storage._value.to('cuda')
        # ================================


        # Force allocation to move to storage
        self.storage._rowptr = self.storage._rowptr.contiguous()
        self.storage._col    = self.storage._col.contiguous()
        if self.storage._value is not None:
            self.storage._value = self.storage._value.contiguous()
        self.device_to_storage(async_=False)


    def synchronize(self):
        """
        Wait for any GDS tasks to finish.
        """
        if self.ready:
            self._task = None
            return
        if self._task is not None:
            self._task.wait()
        self._task = None
        self.ready = True

    def device_to_storage(self, async_: bool = True):
        """
        Explicit method to move from Host/GPU (Device) to Storage.
        """
        if not self.ready: # check the sync status
            self.synchronize()
        if self.physical_device == 'storage':
            return # already on storage
        
        # We'll become 'unready' until the actual operation finishes (async or sync)
        self.ready = False

        # For convenience, define a local callback to mark readiness:
        def _make_ready():
            self.ready = True
        
        # Move the tensors to storage
        # 1) check the existance of the sparse tensors
        assert self.storage._rowptr is not None
        assert self.storage._col is not None

        # 2) move the tensors to storage
        f_row = kvikio.CuFile(self._row_ptr_path, 'w')
        f_col = kvikio.CuFile(self._col_path, 'w')
        if self.storage._value is not None:
            f_val = kvikio.CuFile(self._value_path, 'w')

        def _resize_zero():
            self.storage._rowptr.untyped_storage().resize_(ZERO_SIZE)
            self.storage._col.untyped_storage().resize_(ZERO_SIZE)
            if self.storage._value is not None:
                self.storage._value.untyped_storage().resize_(ZERO_SIZE)

        def finalize():
            _resize_zero()
            self.physical_device = 'storage'
            _make_ready()

        if not async_:
            # 2-1) Sync
            f_row.write(self.storage._rowptr,
                        task_size=self.storage._rowptr.nbytes//self._gds_n_threads)
            f_col.write(self.storage._col,
                        task_size=self.storage._col.nbytes//self._gds_n_threads)
            if self.storage._value is not None:
                f_val.write(self.storage._value,
                        task_size=self.storage._value.nbytes//self._gds_n_threads)
            f_row.close()
            f_col.close()
            if self.storage._value is not None:
                f_val.close()
            finalize()
        else:
            # 2-2) Async
            row_future = f_row.pwrite(self.storage._rowptr)
            col_future = f_col.pwrite(self.storage._col)
            if self.storage._value is not None:
                val_future = f_val.pwrite(self.storage._value)

            if self.storage._value is not None:
                self._task = Task([row_future.get, col_future.get, val_future.get,
                                    f_row.close, f_col.close, f_val.close,
                                    finalize])
            else:
                self._task = Task([row_future.get, col_future.get,
                                    f_row.close, f_col.close,
                                    finalize])

    def storage_to_device(self, async_: bool = True, discard: bool = False, target_device: str = 'cpu'):
        """
        Explicit method to move from Storage to Host/GPU (Device) (if we are on Storage).
        """
        if not self.ready: # check the sync status
            self.synchronize()
        if self.physical_device == target_device:
            return # already on cuda/host
        
        # We'll become 'unready' until the actual operation finishes (async or sync)
        self.ready = False

        # For convenience, define a local callback to mark readiness:
        def _make_ready():
            self.ready = True

        def _remove_storage_files():
            # remove the data from storage by removing
            os.remove(self._row_ptr_path)
            os.remove(self._col_path)
            if self.storage._value is not None:
                os.remove(self._value_path)
        
        # Move the tensors to cuda
        # 1) check the existance of the sparse tensors
        assert self.storage._rowptr is not None
        assert self.storage._col is not None

        # 2) move the tensors to cuda
        f_row = kvikio.CuFile(self._row_ptr_path, 'r')
        f_col = kvikio.CuFile(self._col_path, 'r')
        if self.storage._value is not None:
            f_val = kvikio.CuFile(self._value_path, 'r')
        
        def _resize_back():
            self.storage._rowptr.untyped_storage().resize_(self._rowptr_untyped_size)
            self.storage._col.untyped_storage().resize_(self._col_untyped_size)
            if self.storage._value is not None:
                self.storage._value.untyped_storage().resize_(self._value_untyped_size)

        def finalize():
            self.physical_device = target_device
            _make_ready()
            if discard:
                _remove_storage_files()

        _resize_back()

        if not async_:
            # 2-1) Sync
            f_row.read(self.storage._rowptr,
                        task_size=self.storage._rowptr.nbytes//self._gds_n_threads)
            f_col.read(self.storage._col, 
                        task_size=self.storage._col.nbytes//self._gds_n_threads)
            if self.storage._value is not None:
                f_val.read(self.storage._value,
                        task_size=self.storage._value.nbytes//self._gds_n_threads)
            f_row.close()
            f_col.close()
            if self.storage._value is not None:
                f_val.close()
            if discard:
                _remove_storage_files()
            _make_ready() # we just make ready when sync.
        else:
            # 2-2) Async
            row_future = f_row.pread(self.storage._rowptr)
            col_future = f_col.pread(self.storage._col)
            if self.storage._value is not None:
                val_future = f_val.pread(self.storage._value)
            
            if self.storage._value is not None:
                self._task = Task([row_future.get, col_future.get, val_future.get,
                                    f_row.close, f_col.close, f_val.close,
                                    finalize])
            else:
                self._task = Task([row_future.get, col_future.get,
                                    f_row.close, f_col.close,
                                    finalize])

    def discard_from_device(self, async_: bool = True, target_device: str = 'cpu'):
        """
        Discard the data from Host/GPU memory.
        As we already have the data in storage, we can discard the data from Host/GPU memory.
        For discarding, we just resize the storage tensors to zero size.
        """
        assert self.physical_device == target_device, "Data must be on Host/GPU"
        if not self.ready: # check the sync status
            self.synchronize()

        # Now start discarding
        self.ready = False
        
        # For convenience, define a local callback to mark readiness:
        def _make_ready():
            self.ready = True

        def _resize_zero():
            self.storage._rowptr.untyped_storage().resize_(ZERO_SIZE)
            self.storage._col.untyped_storage().resize_(ZERO_SIZE)
            if self.storage._value is not None:
                self.storage._value.untyped_storage().resize_(ZERO_SIZE)

        def discard_core():
            _resize_zero()
            self.physical_device = 'storage'
            _make_ready()

        if not async_:
            discard_core()
        else:
            self._task = Task([discard_core])
        

class AdjacencyMatrixWithOffloader(SparseTensor):

    def __init__(self, rowptr: Tensor | None = None, col: Tensor | None = None, value: Tensor | None = None,
                 sparse_sizes: tuple[int | None, int | None] | None = None, is_sorted: bool = False):
        self.offloader: DiskOffloader | None
        self.at: Literal['storage', 'cpu', 'cuda']

        super().__init__(rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, is_sorted=is_sorted)
        self.offloader = None
        self.at = rowptr.device.type

    def synchronize(self):
        if self.offloader is not None:
            self.offloader.synchronize()

    def gpu_to_storage(self, async_: bool = True):  # TODO: GDS
        if self.offloader is not None:
            self.offloader.synchronize()
        if self.at == 'storage':
            return
        assert self.storage._rowptr is not None
        assert self.storage._col is not None
        self.storage._rowptr = self.storage._rowptr.to('cpu')
        self.storage._col = self.storage._col.to('cpu')

        if self.offloader is not None:
            self.offloader.async_write(self.storage._rowptr)
            self.offloader.async_write(self.storage._col)

        if self.storage._value is not None:
            self.storage._value = self.storage._value.to('cpu')
            if self.offloader is not None:
                self.offloader.async_write(self.storage._value)

        if self.offloader is not None:
            if not async_:
                self.offloader.synchronize()
        self.at = 'storage'


    def storage_to_gpu(self, async_: bool = True):
        if self.offloader is not None:
            self.offloader.synchronize()
        if self.at == 'cuda':
            return
        assert self.storage._rowptr is not None
        assert self.storage._col is not None
        if self.offloader is not None:
            self.offloader.async_read(self.storage._rowptr)
            self.offloader.async_read(self.storage._col)
            if not async_:
                self.offloader.synchronize()
        self.storage._rowptr = self.storage._rowptr.to('cuda')
        self.storage._col = self.storage._col.to('cuda')

        if self.storage._value is not None:
            if self.offloader is not None:
                self.offloader.sync_read(self.storage._value)
            self.storage._value = self.storage._value.to('cuda')

        self.at = 'cuda'