import torch

from norse.torch.functional.heaviside import heaviside

class Triangle(torch.autograd.Function):
    @staticmethod
    @torch.jit.ignore
    def forward(ctx, x: torch.Tensor, alpha: float | tuple[float, float]) -> torch.Tensor:
        ctx.save_for_backward(x)
        ctx.alpha = alpha
        return heaviside(x)
    
    @staticmethod
    @torch.jit.ignore
    def backward(ctx, grad):
        factor = ctx.alpha - ctx.saved_tensors[0].abs()
        grad *= (1 / ctx.alpha) ** 2 * factor.clamp(min=0)
        return grad, None
    
@torch.jit.ignore
def triangle_fn(x: torch.Tensor, alpha: float = 0.3) -> torch.Tensor:
    return Triangle.apply(x, alpha)
