from torch import nn

class SNNNorm3D(nn.BatchNorm2d):
    """ Implementation of Temporal Difference Based Normalization (TDBN) for 2D data.
    Recreated from the paper: Learnable Surrogate Gradient for Direct Training Spiking Neural Networks
    https://www.ijcai.org/proceedings/2023/0335.pdf
    
    Args:
    - num_features: int - The number of features in the input tensor.
    - v_th: torch.Tensor - The threshold value.
    - eta: float - The eta value.
    - eps: float - The epsilon value to avoid division by zero.
    """
    def __init__(self, num_features: int,  eps: float = 1e-6, affine: bool = True):
        super(SNNNorm3D, self).__init__(num_features, eps=eps, affine=affine)
 