import torch
import torch.nn as nn
import surrogate
import neuron
import json
from spikingjelly.clock_driven import layer


class SensorNetVgg6DP(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),
            nn.AvgPool2d(2, 2),  # 32 * 16 * 16

            nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),
            nn.AvgPool2d(2, 2),  # 16 * 8 * 8

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 4096, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),

            # nn.Linear(500, 100, bias=True),
            # neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
            #                       surrogate_function=self.grad(**self.grad_kargs)),
            # layer.Dropout(0.2),

            nn.Linear(4096, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T


class SensorNetVgg6(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 32 * 16 * 16

            nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 16 * 8 * 8

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 4096, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),

            # nn.Linear(500, 100, bias=True),
            # neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
            #                       surrogate_function=self.grad(**self.grad_kargs)),
            # layer.Dropout(0.2),

            nn.Linear(4096, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T


class SensorNetVgg9(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
        self.input_channels = kwargs['in_c']
        self.conv = nn.Sequential(
            nn.Conv2d(self.input_channels, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 64 * 16 * 16

            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 128 * 8 * 8

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),

            nn.Identity() if self.input_channels == 3 else nn.AvgPool2d(2, 2),

            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 256 * 4 * 4

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1024, bias=True) if self.input_channels == 3 else nn.Linear(256 * 8 * 8, 1024),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),

            # nn.Linear(1024, 1024, bias=True),
            # neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
            #                       surrogate_function=self.grad(**self.grad_kargs)),
            # layer.Dropout(0.2),

            nn.Linear(1024, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T


class SensorNetVgg9_ShiftedLIF(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.k = k
        self.lam = lam
        self.T = T
        self.input_channels = kwargs['in_c']
        self.v_threshold = v_threshold
        self.conv = nn.Sequential(
            nn.Conv2d(self.input_channels, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 64 * 16 * 16

            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 128 * 8 * 8

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),

            nn.Identity() if self.input_channels == 3 else nn.AvgPool2d(2, 2),

            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 256 * 4 * 4

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1024, bias=True) if self.input_channels == 3 else nn.Linear(256 * 8 * 8, 1024),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),

            # nn.Linear(1024, 1024, bias=True),
            # neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
            #                       surrogate_function=self.grad(**self.grad_kargs)),
            # layer.Dropout(0.2),

            nn.Linear(1024, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return (out_spikes_counter / self.T) + ((1 - self.k) / 2.0 - self.lam / self.v_threshold)


class SensorNetVgg11(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 64 * 16 * 16

            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 128 * 8 * 8

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 256 * 4 * 4

            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 512 * 2 * 2
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 2 * 2, 1024, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),

            nn.Linear(1024, 1024, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),

            nn.Linear(1024, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T


class SensorNetVgg9_256c(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
#         self.input_channels = kwargs['in_c']
        self.conv = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 64 * 16 * 16

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 128 * 8 * 8

            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),

            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 256 * 4 * 4

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1024, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),

            nn.Linear(1024, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T

    
class SensorNetL9BN(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
#         self.input_channels = kwargs['in_c']
        self.conv = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(128),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(256),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(256),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            
            nn.MaxPool2d(2, 2),  # 256 * 16 * 16


            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(256),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(256),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(256),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 256 * 8 * 8

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            
            layer.Dropout(0.25),
            
            nn.Linear(256 * 8 * 8, 2048, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.25),

            nn.Linear(2048, 100, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.25),
            
            nn.Linear(100, 10, bias=True),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T