"""
Spiking Neural Network Components

This module implements core components for spiking neural networks:
- LIF (Leaky Integrate-and-Fire) neuron with surrogate gradient
- Temporal dimension handling layers (tdLayer, tdBatchNorm)
- Sequential to ANN container for temporal processing
"""

import torch
import torch.nn as nn


def add_dimention_distribute(x, T):
    """
    Add temporal dimension to input tensor by repeating across T timesteps.
    
    Args:
        x: Input tensor of shape (batch_size, channels, height, width)
        T: Number of timesteps
    
    Returns:
        Tensor of shape (batch_size, T, channels, height, width)
    """
    x.unsqueeze_(1)
    x = x.repeat(1, T, 1, 1, 1)
    return x


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class SeqToANNContainer(nn.Module):
    """
    Container that applies ANN module to each timestep of sequential input.
    
    This module flattens the temporal dimension, applies the wrapped module,
    and reshapes back to the original temporal structure.
    """
    def __init__(self, *args):
        super().__init__()
        if len(args) == 1:
            self.module = args[0]
        else:
            self.module = nn.Sequential(*args)

    def forward(self, x_seq: torch.Tensor):
        y_shape = [x_seq.shape[0], x_seq.shape[1]]
        y_seq = self.module(x_seq.flatten(0, 1).contiguous())
        y_shape.extend(y_seq.shape[1:])
        return y_seq.view(y_shape)


class tdLayer(nn.Module):
    """
    Temporal dimension layer that applies operations across timesteps.
    
    Args:
        layer: The neural network layer to apply
        bn: Optional batch normalization layer
    """
    def __init__(self, layer, bn=None):
        super(tdLayer, self).__init__()
        self.layer = SeqToANNContainer(layer)
        self.bn = bn

    def forward(self, x):
        x_ = self.layer(x)
        if self.bn is not None:
            x_ = self.bn(x_)
        return x_


class tdBatchNorm(nn.Module):
    """
    Temporal dimension batch normalization.
    
    Applies batch normalization independently to each timestep.
    """
    def __init__(self, out_panel):
        super(tdBatchNorm, self).__init__()
        self.bn = nn.BatchNorm2d(out_panel)
        self.seqbn = SeqToANNContainer(self.bn)

    def forward(self, x):
        y = self.seqbn(x)
        return y


class ZIF(torch.autograd.Function):
    """
    Zero-Inflated Firing (ZIF) function with surrogate gradient.
    
    Forward: Heaviside step function (spike generation)
    Backward: Triangular surrogate gradient for gradient approximation
    
    This enables backpropagation through the non-differentiable spike generation.
    """
    @staticmethod
    def forward(ctx, input, gama):
        out = (input > 0).float()
        L = torch.tensor([gama])
        ctx.save_for_backward(input, out, L)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        (input, out, others) = ctx.saved_tensors
        gama = others[0].item()
        grad_input = grad_output.clone()
        # Triangular surrogate gradient
        tmp = (1 / gama) * (1 / gama) * ((gama - input.abs()).clamp(min=0))
        grad_input = grad_input * tmp
        return grad_input, None


class LIFSpike(nn.Module):
    """
    Leaky Integrate-and-Fire (LIF) neuron model.
    
    Implements the dynamics:
        membrane = membrane * tau + input
        spike = Heaviside(membrane - threshold)
        membrane = (1 - spike) * membrane  (reset after spike)
    
    Args:
        thresh: Firing threshold (default: 1.0)
        tau: Membrane potential decay factor (default: 0.5)
        gama: Surrogate gradient width parameter (default: 1.0)
    """
    def __init__(self, thresh=1.0, tau=0.5, gama=1.0):
        super(LIFSpike, self).__init__()
        self.act = ZIF.apply
        self.thresh = thresh
        self.tau = tau
        self.gama = gama

    def forward(self, x):
        mem = 0
        spike_pot = []
        T = x.shape[1]
        for t in range(T):
            mem = mem * self.tau + x[:, t, ...]
            spike = self.act(mem - self.thresh, self.gama)
            mem = (1 - spike) * mem  # Reset mechanism
            spike_pot.append(spike)
        return torch.stack(spike_pot, dim=1)


class IFSpike(nn.Module):
    """
    Integrate-and-Fire (IF) neuron model without leakage.
    
    Similar to LIF but without membrane potential decay.
    """
    def __init__(self, tau=0.5, gama=1.0):
        super(IFSpike, self).__init__()
        self.act = ZIF.apply
        self.tau = tau
        self.gama = gama

    def forward(self, x):
        mem = 0
        spike_pot = []
        T = x.shape[1]
        for t in range(T):
            mem = mem * self.tau + x[:, t, ...]
            spike = self.act(mem, self.gama)
            spike_pot.append(spike)
        return torch.stack(spike_pot, dim=1)
