import torch
import torch.nn as nn
import surrogate
import neuron
import json
from spikingjelly.clock_driven import layer


class CifarDvsNetVgg9(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
#         self.input_channels = kwargs['in_c']
        self.conv = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 64 * 64 * 64

            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 128 * 32 * 32

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            
#             nn.AvgPool2d(2, 2),  # 256 * 16 * 16
            
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),

            nn.AvgPool2d(2, 2),  # 256 * 16 * 16

            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 256 * 8 * 8

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 1024, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),

#             nn.Linear(1024, 1024, bias=True),
#             neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
#                                   surrogate_function=self.grad(**self.grad_kargs)),
#             layer.Dropout(0.2),

            nn.Linear(1024, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T
    

class CifarDvsNetVgg10(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
#         self.input_channels = kwargs['in_c']
        self.conv = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 64 * 64 * 64

            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 128 * 32 * 32

            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            
            nn.AvgPool2d(2, 2),  # 256 * 16 * 16
            
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),

            nn.AvgPool2d(2, 2),  # 256 * 8 * 8

            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 256 * 4 * 4

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 1024, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),

            nn.Linear(1024, 1024, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),

            nn.Linear(1024, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T
    
    
class CifarDvsNetL7(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
#         self.input_channels = kwargs['in_c']
        self.conv = nn.Sequential(
            nn.Conv2d(2, 128, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(128),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 128 * 64 * 64
            
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(128),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 128 * 32 * 32
            
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(128),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 128 * 16 * 16
            
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(128),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 128 * 8 * 8

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            
            layer.Dropout(0.25),
            
            nn.Linear(128 * 8 * 8, 512, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.25),

            nn.Linear(512, 100, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.25),

            nn.Linear(100, 10, bias=True),
#             neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
#                                   surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T