import torch
import torch.nn as nn
import surrogate
import neuron
import json
from spikingjelly.clock_driven import layer


class Cifar10NetVgg9(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
        self.conv = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 64 * 16 * 16

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 128 * 8 * 8

            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),

            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 256 * 4 * 4

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1024, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),

            nn.Linear(1024, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T


class Cifar10NetVgg11(nn.Module):
    def __init__(self, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
        super().__init__()
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)
        self.T = T
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 64 * 16 * 16

            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 128 * 8 * 8

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 256 * 4 * 4

            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            nn.AvgPool2d(2, 2),  # 512 * 2 * 2
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 2 * 2, 1024, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),

            nn.Linear(1024, 1024, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.2),

            nn.Linear(1024, 10, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.conv(x[0]))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x[t]))
        return out_spikes_counter / self.T
