import torch

from norse.torch.functional.heaviside import heaviside

class Rectangle(torch.autograd.Function):
    r"""Rectangular surrogate gradient as in
        Wu, Yujie & Deng, Lei & Li, Guoqi & Zhu, Jun & Xie, Yuan & Shi, L.P.. (2019). 
        Direct Training for Spiking Neural Networks: Faster, Larger, Better. 
        Proceedings of the AAAI Conference on Artificial Intelligence. 33. 1311-1318. 
        10.1609/aaai.v33i01.33011311. 
    """

    @staticmethod
    @torch.jit.ignore
    def forward(ctx, x: torch.Tensor, alpha: float) -> torch.Tensor:
        ctx.save_for_backward(x)
        ctx.alpha = alpha
        return heaviside(x)

    @staticmethod
    @torch.jit.ignore
    def backward(ctx, grad_output):
        (x,) = ctx.saved_tensors
        alpha = ctx.alpha
        
        grad_input = grad_output.clone()
        grad = grad_input * torch.div(torch.sign(torch.abs(x) < alpha), alpha)

        return grad, None
    
@torch.jit.ignore
def rectangle_fn(x: torch.Tensor, alpha: float = 0.3) -> torch.Tensor:
    return Rectangle.apply(x, alpha)