from .SNNNorm3D import SNNNorm3D

import torch
from torch import nn

class TDBN3D(SNNNorm3D):
    """ Implementation of Temporal Difference Based Normalization (TDBN) for 3D data.
    Recreated from the paper: Learnable Surrogate Gradient for Direct Training Spiking Neural Networks
    https://www.ijcai.org/proceedings/2023/0335.pdf
    
    Args:
    - num_features: int - The number of features in the input tensor.
    - v_th: torch.Tensor - The threshold value.
    - eta: float - The eta value used for residual connections
    """
    def __init__(self, num_features: int, v_th: torch.Tensor, eta: float = 1.0):
        self.v_th = v_th
        self.eta = eta
        super().__init__(num_features)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        params:
        - x: torch.Tensor - input tensor of shape (T, B, C, H, W)
        """
        T, B, *spatial_dims = x.shape
        out = super().forward(x.reshape(T * B, *spatial_dims))
        _, *spatial_dims = out.shape
        out = out.view(T, B, *spatial_dims).contiguous()
        return out
    
    def reset_parameters(self) -> None:
        self.reset_running_stats()
        if self.affine:
            nn.init.constant_(self.weight, self.v_th * self.eta)
            nn.init.zeros_(self.bias)
    
if __name__ == "__main__":
    v_th = torch.as_tensor(0.5)
    tdbn = TDBN3D(4, torch.as_tensor(0.5), 1.0)
    x = torch.randn(2, 1, 4, 4, 4)
    print("Before TDBN")
    print(x)
    print("After TDBN")
    print(tdbn(x))
    print("Mean")
    print(tdbn(x).mean())
    print("Std")
    print(tdbn(x).std())