import torch
from torch import nn

class TDBN3D(nn.Module):
    """ Implementation of Temporal Difference Based Normalization (TDBN) for 3D data.
    Recreated from the paper: Learnable Surrogate Gradient for Direct Training Spiking Neural Networks
    https://www.ijcai.org/proceedings/2023/0335.pdf
    
    Args:
    - num_features: int - The number of features in the input tensor.
    - v_th: torch.Tensor - The threshold value.
    - eta: float - The eta value used for residual connections
    """
    def __init__(self, num_features: int, v_th: torch.Tensor, eta: float = 1.0):
        super().__init__()
        self.bn = nn.BatchNorm2d(num_features, affine=True)
        self.init_weight = False
        self.factor = v_th * eta

        self.bn.weight.data.fill_(v_th * eta)
        self.bn.bias.data.zero_()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        params:
        - x: torch.Tensor - input tensor of shape (T, B, C, H, W)
        """
        if not self.init_weight:
            with torch.no_grad():
                self.bn.weight.data.fill_(self.factor)
                self.bn.bias.data.zero_()
            self.init_weight = True

        T, B, *spatial_dims = x.shape
        out = self.bn(x.reshape(T * B, *spatial_dims))
        _, *spatial_dims = out.shape
        out = out.view(T, B, *spatial_dims).contiguous()
        return out
    
if __name__ == "__main__":
    v_th = torch.as_tensor(0.5)
    tdbn = TDBN3D(4, torch.as_tensor(0.5), 1.0)
    x = torch.randn(2, 1, 4, 4, 4)
    print("Before TDBN")
    print(x)
    print("After TDBN")
    print(tdbn(x))
    print("Mean")
    print(tdbn(x).mean())
    print("Std")
    print(tdbn(x).std())