""" BW gradient offloading and uploading handler functions
Reference: https://pytorch.org/docs/stable/notes/extending.html#extending-autograd
"""

from typing import Any, Tuple, List

import torch

from utils.pool import AsyncIOPool

def grad_offload(input: torch.Tensor, pool: AsyncIOPool) -> torch.Tensor:
    """ Offload functon wrapper """
    return GradientOffload.apply(input, pool)

def grad_upload(input: torch.Tensor, pool: AsyncIOPool) -> torch.Tensor:
    """ Upload function wrapper """
    return GradientUpload.apply(input, pool)

class LayerwiseConnection(torch.autograd.Function):
    """ Call this function enables the message passing
    of layers. layer (l) -> layer (l+1)
    """
    @staticmethod
    def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
        r""" At FW, save nothing except input
        """
        _ = inputs
    @staticmethod
    def forward(input: List) -> torch.Tensor:
        r"""
        intput: List of the layer (l)'s embeddings
        """
        return input
    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Any:
        return grad_output, None

class GradientOffload(torch.autograd.Function):
    """ Call this function before forward
    This offloads the gradients after the final BW.
    (As The first FW == The last BW)
    """
    @staticmethod
    def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
        r""" At FW, save :class:AsyncIOPool for
        future use to offload the gradients
        """
        _, pool = inputs # (input, pool)
        ctx.pool = pool
    @staticmethod
    def forward(input: torch.Tensor, pool: AsyncIOPool) -> torch.Tensor:
        return input # in fw we just pass input
    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Any:
        pool = ctx.pool
        # some offloading procedure
        return grad_output, None

class GradientUpload(torch.autograd.Function):
    """ Call this function after forward
    This uploads the gradients before the first BW.
    (As the lasf FW == The first BW)
    """
    @staticmethod
    def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
        r""" At FW, save :class:AsyncIOPool for
        future use to upload the gradients
        """
        _, pool = inputs # (input, pool)
        # ctx.pool = pool
    @staticmethod
    def forward(input: torch.Tensor) -> torch.Tensor:
        return input # in fw we just pass input
    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Any:
        # some uploading procedure
        print(grad_output)
        return grad_output, None