from .Rectangle import rectangle_fn

import torch

from norse.torch.functional.heaviside import heaviside

class AsymRectangle(torch.autograd.Function):
    r"""Asymmetric rectangular surrogate gradient generalization of the rectangular surrogate gradient as in
        Wu, Yujie & Deng, Lei & Li, Guoqi & Zhu, Jun & Xie, Yuan & Shi, L.P.. (2019).
        Direct Training for Spiking Neural Networks: Faster, Larger, Better.    
        Proceedings of the AAAI Conference on Artificial Intelligence. 33. 1311-1318.
        10.1609/aaai.v33i01.33011311.
    """

    @staticmethod
    @torch.jit.ignore
    def forward(ctx, x: torch.Tensor, alpha: float, beta: float) -> torch.Tensor:
        ctx.save_for_backward(x)
        ctx.alpha = alpha
        ctx.beta = beta
        return heaviside(x)

    @staticmethod
    @torch.jit.ignore
    def backward(ctx, grad_output):
        """
        Compute the gradient of the Asymmetric Rectangle function.

        Saved tensor x has shape (B, C, H, W).
        alpha and beta are Tensors of shape (C,).

        Args:
        - ctx: torch.autograd.Function - The context object.
        - grad_output: torch.Tensor - The gradient of the output.

        Returns:
        - grad: torch.Tensor - The gradient of the input.
        """
        (x,) = ctx.saved_tensors
        alpha: torch.Tensor = ctx.alpha
        beta: torch.Tensor = ctx.beta
        grad_input: torch.Tensor = grad_output.clone()

        mask = (alpha.view(1, -1, 1, 1) < x) & (x < beta.view(1, -1, 1, 1))
        grad = grad_input * mask

        return grad, None, None
    
@torch.jit.ignore
def asym_rectangle_fn(x: torch.Tensor, alpha: torch.Tensor = 0.3, beta: torch.Tensor | None = None) -> torch.Tensor:
    if x.device != alpha.device:
        alpha = alpha.to(x.device)
    if beta is None:
        return rectangle_fn(x, alpha)
    if x.device != beta.device:
        beta = beta.to(x.device)
    return AsymRectangle.apply(x, alpha, beta)