from .Rectangle import rectangle_fn

import torch

from norse.torch.functional.heaviside import heaviside

class AsymRectangle(torch.autograd.Function):
    r"""Asymmetric rectangular surrogate gradient generalization of the rectangular surrogate gradient as in
        Wu, Yujie & Deng, Lei & Li, Guoqi & Zhu, Jun & Xie, Yuan & Shi, L.P.. (2019).
        Direct Training for Spiking Neural Networks: Faster, Larger, Better.    
        Proceedings of the AAAI Conference on Artificial Intelligence. 33. 1311-1318.
        10.1609/aaai.v33i01.33011311.
    """

    @staticmethod
    @torch.jit.ignore
    def forward(ctx, x: torch.Tensor, alpha: float, beta: float) -> torch.Tensor:
        ctx.save_for_backward(x)
        ctx.alpha = alpha
        ctx.beta = beta
        return heaviside(x)

    @staticmethod
    @torch.jit.ignore
    def backward(ctx, grad_output):
        """
        Compute the gradient of the Asymmetric Rectangle function.

        Saved tensor x has shape (B, C, H, W).
        alpha and beta are Tensors of shape (C,).

        Args:
        - ctx: torch.autograd.Function - The context object.
        - grad_output: torch.Tensor - The gradient of the output.

        Returns:
        - grad: torch.Tensor - The gradient of the input.
        """
        (x,) = ctx.saved_tensors
        alpha: torch.Tensor = ctx.alpha
        beta: torch.Tensor = ctx.beta
        grad_input: torch.Tensor = grad_output.clone()
        mask = (x > alpha.view(1, -1, 1, 1)) & (x < beta.view(1, -1, 1, 1))
        grad = grad_input * mask# * (1 / torch.abs(beta - alpha).view(1, -1, 1, 1))

        return grad, None, None
    
@torch.jit.ignore
def asym_rectangle_fn(x: torch.Tensor, alpha: torch.Tensor | tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    if not isinstance(alpha, tuple):
        return rectangle_fn(x, alpha)
    else:
        if alpha[0].device != x.device or alpha[1].device != x.device:
            alpha = (alpha[0].to(x.device), alpha[1].to(x.device))
    return AsymRectangle.apply(x, alpha[0], alpha[1])