from typing import Optional, List

import numpy as np
import cupy
import torch
from torch import Tensor
from torch.utils.dlpack import to_dlpack, from_dlpack
from tensornvme import DiskOffloader

from utils.compiler import load_grinnder_ext
# from utils.gdtensor import GDTensor

# load grinnder external modules
grinnder_ext = load_grinnder_ext()

synchronize = grinnder_ext.synchronize
read_async = grinnder_ext.read_async
write_async = grinnder_ext.write_async
contiguous_write_async = grinnder_ext.contiguous_write_async
write_with_reduction = grinnder_ext.write_with_reduction_async
conti_write_with_reduction = grinnder_ext.conti_write_with_reduction_async

class Buffer(object):
    r"""Buffer for GriNNder. """
    def __init__(self, num_parts: int,
                 embedding_dim: Optional[int] = -1,
                 device=None,
                 for_grad: Optional[bool] = False,
                 part_sizes: Optional[List] = None,
                 layer_wise_cache: Optional[bool] = False,
                 storage_offload: Optional[bool] = False,
                 storage_path: Optional[str] = None):

        # for initial feature init
        self.ever_pulled = False

        self.pool_size = 4

        self._push_cache: list[Tensor | None] = [None] * self.pool_size
        self._push_streams: list[torch.cuda.Stream | None] = [None] * self.pool_size
        self._push_index: int = -1

        self.num_parts = num_parts
        self.for_grad = for_grad
        self.part_sizes: list[int] = part_sizes
        assert len(part_sizes) == num_parts, '#part sizes must match num_parts'

        self.layer_wise_cache = layer_wise_cache
        self.storage_offload = storage_offload
        self.storage_path = storage_path
        if self.storage_offload:
            self.offloader = DiskOffloader(f'{self.storage_path}/{id(self)}', backend='aio')
        else:
            self.offloader = None

        self.embs = list[Tensor]()

        self._hook_registered = [False for _ in range(num_parts)]

        self._generate_buffer(embedding_dim)

        self._device = device

    def _generate_buffer(self, embedding_dim: int):
        r""" Generate buffer for the embedding """
        for i in range(self.num_parts):
            emb = torch.zeros(self.part_sizes[i], embedding_dim,
                              device='cpu',
                              pin_memory=False)
            if self.storage_offload:
                # emb.storage().resize_(0)
                if self.layer_wise_cache:
                    emb.storage().resize_(0)
                else:
                    # offload buffer
                    self.offloader.sync_write(emb)
            self.embs.append(emb)

    def layer_wise_upload_to_host(self):
        r""" upload all the buffers to the host """
        assert self.storage_offload, 'only call for storage offloading'
        if self.storage_offload:
            for emb in self.embs:
                self.offloader.async_read_only(emb)
            self.offloader.synchronize()

    def layer_wise_offload_to_storage(self):
        r""" offload all the buffers to the storage """
        assert self.storage_offload, 'only call for storage offloading'
        if self.storage_offload:
            for emb in self.embs:
                self.offloader.async_write(emb)
            self.offloader.synchronize()

    @torch.no_grad()
    def reset_buffer(self):
        """ reset buffers after 1 iteration """
        if self.storage_offload and not self.layer_wise_cache:
        # if self.storage_offload:
            for i, emb in enumerate(self.embs):
                shape = emb.shape
                self.offloader.del_from_stroage(emb)
                new_emb = torch.zeros(shape, device='cpu', pin_memory=False)
                self.offloader.async_write(new_emb)
                self.embs[i] = new_emb
            self.offloader.synchronize()
        elif self.storage_offload and self.layer_wise_cache:
            for emb in self.embs:
                # emb.zero_()
                emb.storage().resize_(0)
        else:
            for emb in self.embs:
                emb.zero_()

    def enable_grad_and_register_hook(self, hook, pid: int) -> None:
        """ enable grad and register hook """
        if not self._hook_registered[pid]:
            assert not self.for_grad, 'only call for fw'
            self.embs[pid].requires_grad_()
            self.embs[pid].register_hook(hook)
            self._hook_registered[pid] = True

    def _push_stream(self, idx: int) -> torch.cuda.Stream:
        if self._push_streams[idx] is None:
            assert str(self._device)[:4] == 'cuda'
            self._push_streams[idx] = torch.cuda.Stream(self._device)
        return self._push_streams[idx]

    @torch.no_grad()
    def _async_push_add(self, src: Tensor, dst: Tensor, boundaries: Optional[Tensor] = None) -> None:
        # Start pushing `src` to ([offset, count] and index positions to `dst`:
        self._push_index = (self._push_index + 1) % self.pool_size
        self._synchronize_push(self._push_index)
        self._push_cache[self._push_index] = src
        with torch.cuda.stream(self._push_stream(self._push_index)):
            if src.device == dst.device:
                print('[GriNNder: Warning] Unexpected device copy')
                if boundaries is None:
                    dst.data[:src.shape[0]] += src.data[:]
                else:
                    dst[boundaries] += src
            elif self.storage_offload:
                # src is on GPU, dst is on Storage
                self.offloader.sync_read(dst)
                if boundaries is None:
                    contiguous_write_async(src, dst)
                else:
                    write_with_reduction(src, dst, boundaries)
                self.offloader.async_write(dst)
            else:
                if boundaries is None:
                    conti_write_with_reduction(src, dst)
                else:
                    write_with_reduction(src, dst, boundaries)

    @torch.no_grad()
    def _synchronize_push(self, idx: Optional[int] = None) -> None:
        # Synchronize the push command of stream `idx` or all commands:
        if idx is None:
            for idx in range(self.pool_size):
                self._synchronize_push(idx)
            self._push_index = -1
        else:
            torch.cuda.synchronize(self._push_stream(idx))
            self._push_cache[idx] = None

    def to(self, device:str, pid: int) -> Tensor:
        """ move buffer to the device """
        assert not self.for_grad, 'only call for fw'
        if self.embs[pid] is not None:
            self.embs[pid] = self.embs[pid].to(device, non_blocking=True)
        return self.embs[pid]

    def register_buffer(self, buffer: Tensor, part_id: int):
        """ register buffer to the buffer """
        assert not self.for_grad, 'only call for fw'
        assert self.embs[part_id] is not None, 'buffer must be initialized'
        if self.storage_offload: # upload
            self.offloader.sync_read(self.embs[part_id])
            # self.embs[part_id].storage().resize_(self.embs[part_id].numel())
        assert buffer.shape == self.embs[part_id].shape, 'input shape must match the buffer shape'
        self.embs[part_id].copy_(buffer)
        if self.storage_offload: # reoffload after change values
            self.offloader.sync_write(self.embs[part_id])
            self.offloader.synchronize()

    @torch.no_grad()
    def pull_from_prev_buffer_cpp(self, part_id: int, prev_embs: List, boundaries: List) -> None:
        # now only support for the host memory case ....
        pass

    @torch.no_grad()
    def pull_from_prev_buffer(self, part_id: int, prev_embs: List, boundaries: List,
                              prev_offloader: Optional[DiskOffloader] = None,
                              init_feat: Optional[bool] = False) -> None:
        r""" pull from previous buffer in FW
        """
        assert not self.for_grad, 'only call for fw'
        assert self.embs[part_id] is not None, 'buffer must be initialized'

        # if initial feature is already pulled, we do not need to pull again
        # if self.ever_pulled and init_feat:
        #     if self.storage_offload:
        #         self.offloader.sync_read(self.embs[part_id])
        #     return True

        # if self.layer_wise_cache: then we do not need to read from storage

        torch.cuda.nvtx.range_push(f'Pull {part_id}...') # set layer profile

        # At first, copy the inner part from the previous buffer
        offset = 0
        self.embs[part_id].storage().resize_(self.embs[part_id].numel())
        if self.storage_offload and not self.layer_wise_cache: # upload the target buffer
            prev_offloader.sync_read_only(prev_embs[part_id])

        self.embs[part_id][offset:prev_embs[part_id].shape[0]].copy_(prev_embs[part_id])
        offset += prev_embs[part_id].shape[0]
        if self.storage_offload and not self.layer_wise_cache:
            prev_offloader.read_only_callback(prev_embs[part_id]) # we just reoffload them async

        if self.storage_offload and not self.layer_wise_cache: # prologue of overlapping (2-stage pipelining)
            prev_offloader.async_read_only(prev_embs[0])

        for i, (prev_out, boundary) in enumerate(zip(prev_embs, boundaries)):
            if self.storage_offload and not self.layer_wise_cache:
                prev_offloader.sync_read_events()
            if i + 1 < len(prev_embs) and self.storage_offload and not self.layer_wise_cache:
                prev_offloader.async_read_only(prev_embs[i+1])
            if i == part_id:
                if self.storage_offload and not self.layer_wise_cache:
                    prev_offloader.async_read_only_callback(prev_out)
                continue
            cur_numel = boundary.shape[0]
            self.embs[part_id][offset:offset+cur_numel].copy_(prev_out[boundary])
            offset += cur_numel
            if self.storage_offload and not self.layer_wise_cache:
                prev_offloader.async_read_only_callback(prev_out)
        torch.cuda.nvtx.range_pop() # now self.embs[part_id] is ready to use on CPU
        # self.ever_pulled = True
        return True


    # TODO: change adj_t into gdtensor
    # @torch.no_grad()
    # def pull_from_init_gdtensor(self, part_id: int, prev_embs: List[GDTensor], boundaries: List[Tensor]) -> None:
    #     assert not self.for_grad, 'only call for fw'
    #     assert self.embs[part_id] is not None, 'buffer must be initialized'

    #     torch.cuda.nvtx.range_push(f'Pull {part_id}...') # set layer profile

    #     # At first, copy the inner part from the previous buffer
    #     offset = 0
    #     if self.storage_offload: # upload the target buffer
    #         self.offloader.sync_read(self.embs[part_id])
    #         prev_embs[part_id].to_inplace('cpu')
    #         prev_offloader.sync_read_only(prev_embs[part_id])

    #     self.embs[part_id][offset:prev_embs[part_id].shape[0]].copy_(prev_embs[part_id])
    #     offset += prev_embs[part_id].shape[0]
    #     if self.storage_offload:
    #         prev_offloader.read_only_callback(prev_embs[part_id]) # we just reoffload them async

    #     if self.storage_offload: # prologue of overlapping (2-stage pipelining)
    #         prev_offloader.async_read_only(prev_embs[0])

    #     for i, (prev_out, boundary) in enumerate(zip(prev_embs, boundaries)):
    #         if self.storage_offload:
    #             prev_offloader.sync_read_events()
    #         if i + 1 < len(prev_embs) and self.storage_offload:
    #             prev_offloader.async_read_only(prev_embs[i+1])
    #         if i == part_id:
    #             if self.storage_offload:
    #                 prev_offloader.async_read_only_callback(prev_out)
    #             continue
    #         cur_numel = boundary.shape[0]
    #         self.embs[part_id][offset:offset+cur_numel].copy_(prev_out[boundary])
    #         offset += cur_numel
    #         if self.storage_offload:
    #             prev_offloader.async_read_only_callback(prev_out)
    #     torch.cuda.nvtx.range_pop() # now self.embs[part_id] is ready to use on CPU

    #     return True


    @torch.no_grad() # it must be the torch.no_grad cuz it's in BW procedure.
    def update_grads_to_other_parts(self, grad, part_id, boundaries):
        r""" update the gradients to other parts
        """
        torch.cuda.nvtx.range_push(f'Push {part_id}...')
        # same as the below commented one.
        if not self.layer_wise_cache:
            self._async_push_add(grad[:self.part_sizes[part_id]], self.embs[part_id])
        else:
            self.embs[part_id][:self.part_sizes[part_id]] += grad[:self.part_sizes[part_id]].to('cpu')
        # emb: storage / grad: cuda / boundaries: cpu
        offset = self.part_sizes[part_id] # start offset
        for i in range(self.num_parts):
            if i == part_id:
                continue # already updated the self
            num_cur_updates = boundaries[i].shape[0]
            # same as the below commented one.
            if not self.layer_wise_cache:
                self._async_push_add(grad[offset:offset+num_cur_updates], self.embs[i], boundaries[i])
            else:
                self.embs[i][boundaries[i]] += grad[offset:offset+num_cur_updates].to('cpu')
            offset += num_cur_updates
        if self.storage_offload:
            self.offloader.synchronize() # must synchronize
        torch.cuda.nvtx.range_pop()