import torch
from torch import nn

class TDBN2D(nn.BatchNorm2d):
    """ Implementation of Temporal Difference Based Normalization (TDBN) for 1D data.
    Recreated from the paper: Learnable Surrogate Gradient for Direct Training Spiking Neural Networks
    https://www.ijcai.org/proceedings/2023/0335.pdf
    
    This class is specifically for usage with the SpikeFormer architectures and we assume the model has reshaped the input

    Args:
    - num_features: int - The number of features in the input tensor.
    - v_th: torch.Tensor - The threshold value.
    - eta: float - The eta value used for residual connections
    """
    def __init__(self, num_features: int, v_th: torch.Tensor, eta: float = 1.0, *args, **kwargs):
        self.v_th = v_th
        self.eta = eta
        super().__init__(num_features, *args, **kwargs)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        params:
        - x: torch.Tensor - input tensor of shape (TB, C, H, W)
        """
        return super().forward(x)
    
    def reset_parameters(self) -> None:
        self.reset_running_stats()
        if self.affine:
            nn.init.constant_(self.weight, self.v_th * self.eta)
            nn.init.zeros_(self.bias)
    