import torch.nn as nn
from copy import deepcopy
from soul.neuron import functional
__all__ = [ 'SpikingLeNet']
def multi_time_forward(x_seq, stateless_module):
    y_shape = [x_seq.shape[0], x_seq.shape[1]] # [T, B]
    y = x_seq.flatten(0, 1)
    if isinstance(stateless_module, (list, tuple, nn.Sequential)):
        for m in stateless_module:
            y = m(y)
    else:
        y = stateless_module(y)
    
    y_shape.extend(y.shape[1:]) # [T, B] + [...] -> [T, B, ...]
    return y.view(y_shape)

class SpikingLeNet(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.num_classes = config['num_classes']
        self.T = config['time_step']

        C, H, W = config['input_channels'], config['input_height'], config['input_width']
        lif = config['neuron']

        self.conv1 = nn.Conv2d(C, 32, kernel_size=5, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(32)
        self.lif1 = deepcopy(lif)

        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        H = (H - 4) // 2
        W = (W - 4) // 2

        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(64)
        self.lif2 = deepcopy(lif)

        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        H = (H - 4) // 2
        W = (W - 4) // 2

        self.conv3 = nn.Conv2d(64, 96, kernel_size=5, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(96)
        self.lif3 = deepcopy(lif)

        H -= 4
        W -= 4

        self.ln1 = nn.Linear(96 * H * W, config['hidden_dim'])
        self.lif4 = deepcopy(lif)

        self.head = nn.Linear(config['hidden_dim'], self.num_classes)

    def forward(self, x):
        functional.reset_net(self)

        x = multi_time_forward(x, [self.conv1, self.bn1])
        x = self.lif1(x)

        x = multi_time_forward(x, self.pool1)

        x = multi_time_forward(x, [self.conv2, self.bn2])
        x = self.lif2(x)

        x = multi_time_forward(x, self.pool2)

        x = multi_time_forward(x, [self.conv3, self.bn3])
        x = self.lif3(x)

        x = x.flatten(2) # (T, B, C, H, W) -> (T, B, CHW)

        x = multi_time_forward(x, self.ln1)
        x = self.lif4(x)

        x = self.head(x.mean(0)) # (T, B, D) -> (B, num_cls)

        return x
   