""" BW gradient offloading and uploading handler functions
Reference: https://pytorch.org/docs/stable/notes/extending.html#extending-autograd
"""

from tabnanny import check
from typing import Any, Tuple, List, Optional, Literal

import warnings

import torch
from torch import Tensor
from torch.autograd.graph import saved_tensors_hooks
# from tensornvme import DiskOffloader
from loguru import logger

from utils.buffer import Buffer
from utils.new_buffer import HostStorageTensors, AcceleratorTensors


# JY: I'm here!!!!
def grad_offload(input: torch.Tensor, lid: int, pid: int,
                 host_storage_gradients: HostStorageTensors, boundaries: List[torch.Tensor],
                 d2h_stream: torch.cuda.Stream, wait_stream: torch.cuda.Stream) -> torch.Tensor:
    """ Offload functon wrapper """
    return GradientOffload.apply(input, lid, pid, host_storage_gradients, boundaries,
                                 d2h_stream, wait_stream)

class GradientOffload(torch.autograd.Function):
    """ Call this function before forward
    This offloads the gradients after the final BW.
    (As The first FW == The last BW)
    """
    @staticmethod
    def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
        r""" At FW, save :class:AsyncIOPool for
        future use to offload the gradients
        """
        _, lid, pid, host_storage_gradients, boundaries, d2h_stream, wait_stream = inputs # (input, lid, pid, grad_buffer, boundaries)
        ctx.lid = lid
        ctx.pid = pid
        ctx.host_storage_gradients = host_storage_gradients
        ctx.boundaries = boundaries
        ctx.d2h_stream = d2h_stream
        ctx.wait_stream = wait_stream
    @staticmethod
    def forward(input: torch.Tensor, lid: int, pid: int, host_storage_gradients: HostStorageTensors, boundaries: List[torch.Tensor],
                d2h_stream: torch.cuda.Stream, wait_stream: torch.cuda.Stream) -> torch.Tensor:
            # with torch.no_grad():
            # gradient offload is actually
            # the same as the activation upload
            # TODO!! we need caching mechanism here! (Activation Caching)
            # Priority 1
            # Method 1) if cache has the activation, we just use it.
            # Method 2) if previous partition has the activation of the current
            #           we just use it.
            # input = input.to('cuda', non_blocking=True) # embedding is already exists on CPU, so just upload it
        return input # already in cuda...
    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Any:
        lid = ctx.lid
        pid = ctx.pid
        host_storage_gradients = ctx.host_storage_gradients
        boundaries = ctx.boundaries
        d2h_stream = ctx.d2h_stream
        wait_stream = ctx.wait_stream
        
        if lid != 0:
            host_storage_gradients.async_scatter(pid, grad_output, boundaries, d2h_stream, wait_stream)
            # host_storage_gradients.sync_scatter(pid, grad_output, boundaries)
            # host_storage_gradients.d2h_synchronize() # TODO: overlapping

        # todo - some offloading procedure
        return None, None, None, None, None, None, None # we do not need to return anything

def grad_upload(pid: int, input: torch.Tensor, next_layer_tensor: HostStorageTensors) -> torch.Tensor:
    """ Offload functon wrapper """
    return GradientUpload.apply(input, pid, next_layer_tensor)

class GradientUpload(torch.autograd.Function):
    """ Call this function after forward
    This uploads the gradients before the first BW.
    (As the last FW == The first BW)
    """
    @staticmethod
    def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
        r""" At FW, save :class:AsyncIOPool for
        future use to upload the gradients
        """
        _, pid, next_layer_tensor = inputs # (input, pool)
        ctx.pid = pid
        ctx.next_layer_tensor = next_layer_tensor
    @staticmethod
    def forward(input: torch.Tensor, pid: int, next_layer_tensor: HostStorageTensors) -> torch.Tensor:
        # now directly copy cuda tensor to the host memory
        # with async I/O
        assert str(input.device)[:4] == 'cuda', "Data should be in cuda device"

        # next_layer_tensor.sync_fill(pid, input)
        # next_layer_tensor.async_fill(pid, input, 
        #                              d2h_stream=d2h_stream, wait_event=wait_event)
        # logger.info(f"host tensor fill: {pid}")
        
        """
        (WARNING) Async fill in here does not working!!
        """

        # WARNING: after next_layer_tensor.async_fill
        # We need to call synchronize
        # then resize the input as zero
        # i.e., input.untyped_storage().resize_(0)
        return input
    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Any:
        # some uploading procedure
        # therefore, if gradient is already exists on GPU
        # we do not need to upload it again
        # TODO!! we need caching mechanism here! (Gradient Caching)
        # Priority 3

        # for storage offloading...
        # it was uploaded to host memory from the base model: `backward' function

        assert str(grad_output.device)[:4] == 'cuda', 'The gradient should be in cuda device'
        return grad_output, None, None

class save_on_cpu(saved_tensors_hooks):
    def __init__(self, lid: int, pid: int, # pid = pid-1 for overlapping
                 accelerator_tensors: AcceleratorTensors,
                 checkpointing_strategy: Literal['scattered', 'cpu', 'storage'] = 'cpu',
                 h2d_stream: torch.cuda.Stream | None = None):
        def pack_to_cpu(tensor):
            # print(f'shape of tensor: {tensor.shape}')
            if checkpointing_strategy == 'scattered': # Ours
                assert tensor is accelerator_tensors[lid][pid], 'They are not the same object'
                original_device = tensor.device
                return (original_device, accelerator_tensors[lid], h2d_stream)
            elif checkpointing_strategy == 'cpu': # HongTu
                return (tensor.device, tensor.cpu(), accelerator_tensors[lid])
            else:
                # raise not implemented error
                # showing the message that it is not implemented yet
                raise NotImplementedError(f'{checkpointing_strategy} is not implemented yet')
        def unpack_from_cpu(packed):
            if checkpointing_strategy == 'scattered': # Ours
                device, accelerator_tensors, h2d_stream = packed
                assert device == accelerator_tensors[pid].device, f'{device} != {accelerator_tensors[pid].device}'
                accelerator_tensors.h2d_synchronize(h2d_stream)
                return accelerator_tensors[pid]
            elif checkpointing_strategy == 'cpu': # HongTu
                device, tensor, accelerator_tensors = packed
                assert str(tensor.device)[:3] == 'cpu', 'The tensor should be in cpu device'
                accelerator_tensors.sync_direct_pull(pid, tensor, is_saved_tensor_hook=True)
                return accelerator_tensors[pid]
            else:
                raise NotImplementedError(f'{checkpointing_strategy} is not implemented yet')
        super().__init__(pack_to_cpu, unpack_from_cpu)
