from pdb import run
from typing import Optional, List, Literal
import numpy as np
import torch
from torch import Tensor
from torch_sparse import SparseTensor
from utils.gdstensor import GDSTensor
import os

from utils.compiler import load_grinnder_ext

grinnder_ext = load_grinnder_ext()

synchronize = grinnder_ext.synchronize
h2d_synchronize = grinnder_ext.h2d_synchronize
d2h_synchronize = grinnder_ext.d2h_synchronize
gather_async = grinnder_ext.gather_async
fill_async = grinnder_ext.fill_async
upload_async = grinnder_ext.upload_async
scatter_async = grinnder_ext.scatter_async


# scatter_async = grinnder_ext.scatter_async

"""
HostStorageTensor Class
: it is a partition-wise storage tensor
: it is used for the host / SSD storage
There are two types of communication with the accelerator
: Host -> Accelerator & Accelerator -> Host
"""
class HostStorageTensors(object):
    r""" Host Storage Blocks
        Partition-wise storage tensor
        (w/o redundancy)
    """
    def __init__(self, num_parts: int,
                 embedding_dim: int,
                 part_sizes: List[int],
                 device: str = 'cpu',
                 is_grad: bool = False,
                 storage_path: Optional[str] = None):
        self._num_parts: int = num_parts
        self._embedding_dim = embedding_dim
        self._part_sizes: List[int] = part_sizes
        self._device: str = device
        self._is_grad: bool = is_grad # whether it saves the gradients
        self._storage_path = storage_path

        self._ever_storage = False

        if self._device == 'storage':
            assert self._storage_path is not None, "Storage path should be provided"
            self._ever_storage = True

        self._tensors = [None] * num_parts
        for idx in range(num_parts):
            self._tensors[idx] = self._generate_tensor(idx)
    
    def device(self) -> str:
        return self._device
    
    def get_tensors(self) -> List[Tensor]:
        if isinstance(self._tensors[0], Tensor):
            return self._tensors
        elif isinstance(self._tensors[0], GDSTensor):
            return [tensor.tensor() for tensor in self._tensors]
        else:
            raise NotImplementedError("Only Tensor and GDSTensor are supported")

    def storage_to_cpu(self):
        assert self._ever_storage, "This is not a storage tensor"
        assert self._device == 'storage', "tensor should be in storage"
        for idx in range(self._num_parts):
            assert isinstance(self._tensors[idx], GDSTensor), "Only GDSTensor is supported"
            self._tensors[idx].to_inplace('cpu', async_=False)
        self._device = 'cpu'

        # ========================
        # Uncomment the below code for debugging
        # Save the tensors to the other storage for debugging
        # FW CASE
        # if not self._is_grad:
        #     # check whether the same name exists
        #     # if exists increase the idx
        #     major_idx = 0
        #     while os.path.exists(f'/large_data/acts_n_grads/act_{major_idx}'):
        #         major_idx += 1
        #     os.makedirs(f'/large_data/acts_n_grads/act_{major_idx}')
        #     for idx in range(self._num_parts):
        #         # save the tensors to the other storage
        #         torch.save(self._tensors[idx], f'/large_data/acts_n_grads/act_{major_idx}/{idx}.pt')
        # import code; code.interact(local=locals())
        # ========================

    def cpu_to_storage(self, use_duplicate=True):
        assert self._ever_storage, "This is not a storage tensor"
        assert self._device == 'cpu', "tensor should be in cpu"
        if use_duplicate:
            for idx in range(self._num_parts):
                assert isinstance(self._tensors[idx], GDSTensor), "Only GDSTensor is supported"
                self._tensors[idx].to_duplicate('storage')
        else:
            # ========================
            # Uncomment the below code for debugging
            # Save the tensors to the other storage for debugging
            # BW CASE
            # if self._is_grad:
            #     # check whether the same name exists
            #     # if exists decrease the idx
            #     major_idx = -1
            #     while os.path.exists(f'/large_data/acts_n_grads/grad_{major_idx}'):
            #         major_idx -= 1
            #     os.makedirs(f'/large_data/acts_n_grads/grad_{major_idx}')
            #     for idx in range(self._num_parts):
            #         # save the tensors to the other storage
            #         torch.save(self._tensors[idx], f'/large_data/acts_n_grads/grad_{major_idx}/{idx}.pt')
            # import code; code.interact(local=locals())
            # ========================
            for idx in range(self._num_parts):
                assert isinstance(self._tensors[idx], GDSTensor), "Only GDSTensor is supported"
                self._tensors[idx].to_inplace('storage')
            for idx in range(self._num_parts):
                self._tensors[idx].synchronize()
        self._device = 'storage'

    def reset_tensors(self, epoch: int):
        if self._device == 'cpu':
            for idx in range(self._num_parts):
                self._tensors[idx].fill_(0.0)
        elif self._device == 'storage':
            for idx in range(self._num_parts): # upload from storage
                assert isinstance(self._tensors[idx], GDSTensor), "Only GDSTensor is supported"                
                zero_tensor = torch.zeros(self._part_sizes[idx],
                                          self._embedding_dim,
                                          device='cpu',
                                          requires_grad=False)
                self._tensors[idx].overwrite(zero_tensor, async_=False)
        else:
            raise NotImplementedError("Only CPU and Storage are supported")

    def to(self):
        pass # TODO: implement host -> storage, storage -> host

    def __getitem__(self, idx: int) -> Tensor:
        return self._tensors[idx]

    def _generate_tensor(self, idx: int) -> Tensor:
        if self._device == 'cpu':
            return torch.empty(self._part_sizes[idx],
                            self._embedding_dim,
                            device=self._device,
                            requires_grad=False)
        elif self._device == 'storage':
            empty_tensor = torch.empty(self._part_sizes[idx],
                            self._embedding_dim,
                            device='cpu',
                            requires_grad=False)
            gds_tensor = GDSTensor(empty_tensor, self._storage_path)
            gds_tensor.to_inplace('storage', async_=False)
            return gds_tensor
        else:
            raise NotImplementedError("Only CPU and Storage are supported")

    @torch.no_grad()
    def sync_fill(self, idx: int, data: Tensor):
        assert data.shape == self._tensors[idx].shape, "Data shape should be the same"
        if self._device == 'cpu':
            self._tensors[idx].copy_(data)
        elif self._device == 'storage':
            self._tensors[idx].overwrite(data, async_=False)
        else:
            raise NotImplementedError("Only CPU and Storage are supported")

    @torch.no_grad() # device -> host
    def async_fill(self, idx: int, data: Tensor,
                   d2h_stream:torch.cuda.Stream, wait_stream: torch.cuda.Stream):
        # recommended for cuda tensors
        assert data.shape == self._tensors[idx].shape, "Data shape should be the same"
        assert str(data.device)[:4] == 'cuda', "Data should be in cuda device"
        self._async_fill(idx, idx, data, d2h_stream, wait_stream) # use d2h stream

    @torch.no_grad()
    def _async_fill(self, stream_id: int, idx: int, data: Tensor,
                    d2h_stream: torch.cuda.Stream, wait_stream: torch.cuda.Stream):
        assert stream_id == idx, "Stream id should be the same as idx"
        # with torch.cuda.stream(self._d2h_stream(stream_id)):
        d2h_stream.wait_stream(wait_stream)
        with torch.cuda.stream(d2h_stream):
            if self._device == 'cpu':
                fill_async(data, self._tensors[idx])
            elif self._device == 'storage':
                assert isinstance(self._tensors[idx], GDSTensor), "Only GDSTensor is supported"
                self._tensors[idx].overwrite(data, async_=True, cuda_stream=d2h_stream)
            else:
                raise NotImplementedError("Only CPU and Storage are supported")

    # call for scattering gradients
    # it also accumulates the gradients
    # this must call d2h_synchronize

    @torch.no_grad() # device -> host
    def sync_scatter(self, pid: int, data: Tensor, boundaries: List):
        assert str(data.device)[:4] == 'cuda', "Data should be in cuda device"
        
        intra_numel = self._tensors[pid].shape[0]
        self._tensors[pid].copy_(data[0:intra_numel])

        offset = intra_numel
        for i, (boundary, tensor) in enumerate(zip(boundaries, self._tensors)):
            if i == pid: continue
            cur_numel = boundary.shape[0]
            tensor[boundary].copy_(data[offset:offset+cur_numel])
            offset += cur_numel

        assert offset == data.shape[0], "Offset should be the same"

    @torch.no_grad() # device -> host
    def async_scatter(self, pid: int, data: Tensor, boundaries: List,
                      d2h_stream: torch.cuda.Stream, wait_stream: torch.cuda.Stream | None):
        # we handle the gradients
        assert str(data.device)[:4] == 'cuda', "Data should be in cuda device"
        self._async_scatter(pid, pid, data, boundaries, d2h_stream, wait_stream) # use push stream 0

    @torch.no_grad() # device -> host
    def _async_scatter(self, steam_id: int, pid: int, data: Tensor, boundaries: List,
                       d2h_stream: torch.cuda.Stream, wait_stream: torch.cuda.Stream | None):
        assert steam_id == pid, "Stream id should be the same as pid"
        # with torch.cuda.stream(self._d2h_stream(steam_id)):
        if wait_stream is not None: # in case we need to wait
            d2h_stream.wait_stream(wait_stream)
        with torch.cuda.stream(d2h_stream):
            boundaries[pid] = torch.zeros(0, dtype=torch.int32)
            scatter_async(pid, data, self.get_tensors(), boundaries)

    @torch.no_grad()
    def d2h_synchronize(self, d2h_stream: torch.cuda.Stream):
        if self._device == 'cpu':
            d2h_synchronize()
            torch.cuda.synchronize(d2h_stream)
        elif self._device == 'storage':
            for idx in range(self._num_parts):
                assert isinstance(self._tensors[idx], GDSTensor), "Only GDSTensor is supported"
                self._tensors[idx].synchronize()
            torch.cuda.synchronize(d2h_stream)
        else:
            raise NotImplementedError("Only CPU and Storage are supported")

    @torch.no_grad() # host -> device
    def sync_upload(self, idx: int, target: Tensor):
        assert target.shape == self._tensors[idx].shape, "Data shape should be the same"
        assert str(target.device)[:4] == 'cuda', "Target should be in cuda device"
        target.copy_(self._tensors[idx])

    @torch.no_grad() # host -> device
    def async_upload(self, idx: int, target: Tensor,
                     h2d_stream: torch.cuda.Stream, wait_stream: torch.cuda.Stream | None = None):
        assert target.shape == self._tensors[idx].shape, "Data shape should be the same"
        assert str(target.device)[:4] == 'cuda', "Target should be in cuda device"
        self._async_upload(idx, idx, target, h2d_stream, wait_stream) # use pull stream 0

    @torch.no_grad() # host -> device
    def _async_upload(self, stream_id: int, idx: int, target: Tensor,
                      h2d_stream: torch.cuda.Stream, wait_stream: torch.cuda.Stream | None = None):
        assert stream_id == idx, "Stream id should be the same as idx"
        # with torch.cuda.stream(self._h2d_stream(stream_id)):
        if wait_stream is not None: # in case we need to wait
            h2d_stream.wait_stream(wait_stream)
        with torch.cuda.stream(h2d_stream):
            if self._device == 'cpu':
                if isinstance(self._tensors[idx], Tensor):
                    upload_async(self._tensors[idx], target)
                elif isinstance(self._tensors[idx], GDSTensor):
                    upload_async(self._tensors[idx].tensor(), target)
                else:
                    raise NotImplementedError("Only Tensor and GDSTensor are supported")
            elif self._device == 'storage':
                assert isinstance(self._tensors[idx], GDSTensor), "Only GDSTensor is supported"
                self._tensors[idx].upload(target, async_=True, cuda_stream=h2d_stream)
            else:
                raise NotImplementedError("Only CPU and Storage are supported")
    
    @torch.no_grad
    def h2d_synchronize(self, h2d_stream: torch.cuda.Stream):
        if self._device == 'cpu':
            h2d_synchronize()
            torch.cuda.synchronize(h2d_stream)
        elif self._device == 'storage':
            for idx in range(self._num_parts):
                assert isinstance(self._tensors[idx], GDSTensor), "Only GDSTensor is supported"
                self._tensors[idx].synchronize()
            torch.cuda.synchronize(h2d_stream)
        else:
            raise NotImplementedError("Only CPU and Storage are supported")

class AcceleratorTensors(object):
    r""" Push and Pull Buffer
    This buffer gathers the activations from the previous layer among the partitions
    Also, it handles the gradients for device (w/o gradient :))
    """
    def __init__(self, num_parts: int,
                 embedding_dim: int,
                 part_sizes: List[int],
                 device='cuda',
                 for_grad: bool = False):
        self._num_parts = num_parts
        self._embedding_dim = embedding_dim
        self._part_sizes = part_sizes
        self._device = device

        assert self._device == 'cuda', "Only CUDA is supported while init"

        self._for_grad = for_grad
        self._requires_grad = not for_grad

        # cuda buffer
        self._cuda_buffers = [None] * num_parts
        for idx in range(num_parts):
            self._cuda_buffers[idx] = self._generate_tensor(idx)

        self._need_syncs = []
        # stream pool (we use one stream per part)
        # self._pool_size = self._num_parts
        # self._h2d_streams = [None] * self._pool_size
        # self._d2h_streams = [None] * self._pool_size

    def __getitem__(self, idx: int) -> Tensor:
        return self._cuda_buffers[idx]

    def reset_tensors(self):

        for idx in range(self._num_parts):
            # t = self._cuda_buffers[idx]
            # self._cuda_buffers[idx] = self._generate_tensor(idx)
            # del t
            self._cuda_buffers[idx].untyped_storage().resize_(0)
            self._cuda_buffers[idx].grad = None

    # # device -> host stream getter
    # def _d2h_stream(self, idx: int) -> torch.cuda.Stream:
    #     assert idx < self._num_parts, "Stream index should be less than the #parts"
    #     if self._d2h_streams[idx] is None:
    #         self._d2h_streams[idx] = torch.cuda.Stream('cuda:0')
    #     return self._d2h_streams[idx]

    # # host -> device stream getter
    # def _h2d_stream(self, idx: int) -> torch.cuda.Stream:
    #     assert idx < self._num_parts, "Stream index should be less than the #parts"
    #     if self._h2d_streams[idx] is None:
    #         self._h2d_streams[idx] = torch.cuda.Stream('cuda:0')
    #     return self._h2d_streams[idx]

    def _generate_tensor(self, idx: int) -> Tensor:
        gen_t = torch.empty(self._part_sizes[idx],
                           self._embedding_dim,
                           device=self._device,
                           requires_grad=self._requires_grad)
        gen_t.untyped_storage().resize_(0)
        return gen_t

    @torch.no_grad()
    def free_from_device(self, idx: int):
        self._cuda_buffers[idx].untyped_storage().resize_(0)

    @torch.no_grad() # host -> device
    def sync_gather(self, pid: int, storage_tensors: HostStorageTensors, boundaries: List):
        r""" Pull the embeddings from the previous layer
        """
        # chech whether the device tensor was previously empty
        assert self._cuda_buffers[pid].untyped_storage().size() == 0, "Buffer should be empty"
        # resize the cuda buffer
        self._cuda_buffers[pid].untyped_storage().resize_(
                                    self._cuda_buffers[pid].numel()*self._cuda_buffers[pid].element_size())
        offset = storage_tensors[pid].shape[0]
        self._cuda_buffers[pid][0:offset].copy_(storage_tensors[pid]) # intra copy
        for i, (storage_tensor, boundary) in enumerate(zip(storage_tensors, boundaries)):
            if i == pid: continue # boundary for pid is None
            cur_numel = boundary.shape[0]
            # boundary copy
            self._cuda_buffers[pid][offset:offset+cur_numel].copy_(storage_tensor[boundary])
            offset += cur_numel
        assert offset == self._cuda_buffers[pid].shape[0], "Offset should be the same"

    @torch.no_grad() # host -> device
    def async_gather(self, pid: int, storage_tensors: HostStorageTensors, boundaries: List,
                     h2d_stream: torch.cuda.Stream, wait_stream: torch.cuda.Stream | None):
        r""" Pull the embeddings from the previous layer
        """
        self._async_gather(pid, pid, storage_tensors, boundaries, h2d_stream, wait_stream)

    @torch.no_grad() # host -> device
    def _async_gather(self, stream_id: int, pid: int, storage_tensors: HostStorageTensors, boundaries: List,
                      h2d_stream: torch.cuda.Stream, wait_stream: torch.cuda.Stream | None):
        r""" Pull the embeddings from the previous layer
        """
        assert stream_id == pid, "Stream id should be the same as pid"
        if wait_stream is not None: # in case we need to wait
            h2d_stream.wait_stream(wait_stream)
        with torch.cuda.stream(h2d_stream):
            # torch.cuda.nvtx.range_push(f'Gather {pid} (h2d)')
            boundaries[pid] = torch.zeros(0, dtype=torch.int32)
            # resize the cuda buffer
            self._cuda_buffers[pid].untyped_storage().resize_(
                                        self._cuda_buffers[pid].numel()*self._cuda_buffers[pid].element_size())
            if isinstance(storage_tensors[pid], GDSTensor):
                gather_async(pid, storage_tensors.get_tensors(), self._cuda_buffers[pid], boundaries)
            elif isinstance(storage_tensors[pid], Tensor):
                # then gather the embeddings from the previous
                gather_async(pid, storage_tensors.get_tensors() , self._cuda_buffers[pid], boundaries)    
    
    @torch.no_grad()
    def sync_direct_pull(self, pid: int, data: Tensor, is_saved_tensor_hook: bool = False):
        if not is_saved_tensor_hook:
            assert self._for_grad, "This is only for gradients"
        else:
            self._cuda_buffers[pid].untyped_storage().resize_(
                                    self._cuda_buffers[pid].numel()*self._cuda_buffers[pid].element_size())
        assert data.shape == self._cuda_buffers[pid].shape, "Data shape should be the same"
        assert str(data.device)[:3] == 'cpu', "Data should be in cpu device"
        self._cuda_buffers[pid].copy_(data)

    @torch.no_grad() # host -> device
    def async_direct_pull(self, pid: int, data: Tensor,
                          h2d_stream: torch.cuda.Stream, wait_stream: torch.cuda.Stream | None = None,
                          is_gds: bool = False):
        assert self._for_grad, "This is only for gradients"
        assert data.shape == self._cuda_buffers[pid].shape, "Data shape should be the same"
        # assert str(data.device)[:3] == 'cpu', "Data should be in cpu device"
        self._async_direct_pull(pid, pid, data, h2d_stream, wait_stream, is_gds)
    
    @torch.no_grad()
    def _async_direct_pull(self, stream_id: int, pid: int, data: Tensor,
                           h2d_stream: torch.cuda.Stream, wait_stream: torch.cuda.Stream | None = None,
                           is_gds: bool = False):
        assert stream_id == pid, "Stream id should be the same as pid"
        if wait_stream is not None: # in case we need to wait
            h2d_stream.wait_stream(wait_stream)
        with torch.cuda.stream(h2d_stream):
            self._cuda_buffers[pid].untyped_storage().resize_(
                                    self._cuda_buffers[pid].numel()*self._cuda_buffers[pid].element_size())
            if is_gds:
                assert isinstance(data, GDSTensor), "Only GDSTensor is supported"
                data.upload(self._cuda_buffers[pid], async_=True, cuda_stream=h2d_stream)
                self._need_syncs.append([data, h2d_stream])
            else:
                upload_async(data, self._cuda_buffers[pid])
           
    # @torch.no_grad()
    # def synchronize(self, idx: int):
    #     h2d_synchronize()
    #     torch.cuda.synchronize(self._h2d_stream(idx))

    @torch.no_grad()
    def h2d_synchronize(self, h2d_stream: torch.cuda.Stream, is_gds: bool = False):
        if is_gds: # only for async_direct_pull
            for tensor, stream in self._need_syncs:
                assert isinstance(tensor, GDSTensor), "Only GDSTensor is supported"
                tensor.synchronize()
                torch.cuda.synchronize(stream)
            self._need_syncs = []
        else:
            h2d_synchronize()
            torch.cuda.synchronize(h2d_stream)

    @torch.no_grad()
    def flush_cuda_buffers(self):
        for idx in range(self._num_parts):
            self._cuda_buffers[idx].untyped_storage().resize_(0)