

import torch
import torch.nn as nn
from spikingjelly.clock_driven import layer
from spikingjelly.clock_driven import layer,neuron,surrogate



def create_conv_sequential(in_channels, out_channels, number_layer,  use_max_pool):
    # 首层是in_channels-out_channels
    # 剩余number_layer - 1层都是out_channels-out_channels
    conv = [
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True),
        nn.MaxPool2d(2, 2) if use_max_pool else nn.AvgPool2d(2, 2)
    ]

    for i in range(number_layer - 1):
        conv.extend([
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True),
            nn.MaxPool2d(2, 2) if use_max_pool else nn.AvgPool2d(2, 2)
        ])
    return nn.Sequential(*conv)

def create_2fc(channels, h, w, dpp, class_num):
    return nn.Sequential(
        nn.Flatten(),
        layer.Dropout(dpp),
        nn.Linear(channels * h * w, channels * h * w // 4, bias=False),
        neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True),
        layer.Dropout(dpp),
        nn.Linear(channels * h * w // 4, class_num * 10, bias=False),
        neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True),
    )

class NeuromorphicNet(nn.Module):
    def __init__(self, T=20, use_max_pool=True):
        super().__init__()
        self.T = T

        self.use_max_pool = use_max_pool


        self.train_times = 0
        self.max_test_accuracy = 0
        self.epoch = 0
        self.conv = None
        self.fc = None
        self.boost = nn.AvgPool1d(10, 10)

    def forward(self, x):
        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]
        out_spikes_counter = self.boost(self.fc(self.conv(x[0])).unsqueeze(1)).squeeze(1)
        for t in range(1, x.shape[0]):
            out_spikes_counter += self.boost(self.fc(self.conv(x[t])).unsqueeze(1)).squeeze(1)
        return out_spikes_counter


class DVS128GestureNet(NeuromorphicNet):
    def __init__(self, T=20, use_max_pool=True, channels=128, number_layer=5):
        super().__init__(T=T,  use_max_pool=True)
        w = 128
        h = 128
        self.conv = create_conv_sequential(2, channels, number_layer=number_layer,
                                           use_max_pool=use_max_pool)
        self.fc = create_2fc(channels=channels, w=w >> number_layer, h=h >> number_layer, dpp=0.5, class_num=11,
                           )


