import torch
import torch.nn as nn
import surrogate
import neuron
import json


class MNIST_5layer(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}'):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 14 * 14

            nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2)  # 7 * 7
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            # layer.Dropout(0.5),
            nn.Linear(32 * 7 * 7, 500, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            # layer.Dropout(0.5),
            nn.Linear(500, 100, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Linear(100, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        # print(x[0].shape)
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T


class NMNIST_5layer(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}'):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
        self.conv = nn.Sequential(
            nn.Conv2d(2, 32, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 17 * 17

            nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2)  # 8 * 8

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            # layer.Dropout(0.5),
            nn.Linear(32 * 8 * 8, 500, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            # layer.Dropout(0.5),
            nn.Linear(500, 100, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Linear(100, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        # print(x[0].shape)
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T


class CIFAR10_5layer(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}'):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 16 * 16

            nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2)  # 8 * 8

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            # layer.Dropout(0.5),
            nn.Linear(32 * 8 * 8, 500, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            # layer.Dropout(0.5),
            nn.Linear(500, 100, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Linear(100, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        # print(x[0].shape)
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T


class CIFAR10_128channel(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}'):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 64 * 16 * 16

            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 128 * 8 * 8

            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 128 * 4 * 4

            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 128 * 2 * 2

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            # layer.Dropout(0.5),
            nn.Linear(128 * 2 * 2, 500, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            # layer.Dropout(0.5),
            nn.Linear(500, 100, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Linear(100, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T


class CIFAR10_256channel(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}'):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
        self.conv = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 128 * 16 * 16

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 256 * 8 * 8

            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 256 * 4 * 4

            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 256 * 2 * 2

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            # layer.Dropout(0.5),
            nn.Linear(256 * 2 * 2, 500, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            # layer.Dropout(0.5),
            nn.Linear(500, 100, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Linear(100, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T


class CIFAR10_512channel(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}'):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
        self.conv = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 128 * 16 * 16

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 256 * 8 * 8

            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 512 * 4 * 4

            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 512 * 2 * 2

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            # layer.Dropout(0.5),
            nn.Linear(128 * 2 * 2, 500, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            # layer.Dropout(0.5),
            nn.Linear(500, 100, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Linear(100, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T


class CIFAR10_512channel_L9(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}',
                 batchnorm=False):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
        self.conv = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 128 * 16 * 16

            nn.Conv2d(128, 512, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),

            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 256 * 8 * 8

            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 512 * 4 * 4

            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 512 * 2 * 2
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            #             layer.Dropout(),
            nn.Linear(512 * 2 * 2, 512, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            #             layer.Dropout(),
            nn.Linear(512, 128, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            #             layer.Dropout(),
            nn.Linear(128, 10, bias=True),
            #             neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
            #                                   surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T


class CIFAR10DVS_5layer(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}'):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
        self.conv = nn.Sequential(
            nn.Conv2d(2, 32, kernel_size=3, stride=2, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2),  # 32 * 32

            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.MaxPool2d(2, 2)  # 8 * 8

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            # layer.Dropout(0.5),
            nn.Linear(32 * 8 * 8, 1000, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            # layer.Dropout(0.5),
            nn.Linear(1000, 500, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Linear(500, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        # print(x[0].shape)
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T


if __name__ == "__main__":
    n = CIFAR10_128channel(k=0.2, lam=0.8, T=10)
    print(n(torch.rand(10, 64, 3, 32, 32)).shape)
