import torch
import torch.nn as nn
import surrogate
import neuron
import json
from spikingjelly.clock_driven import layer
# from spikingjelly.clock_driven import functional, neuron
from torchvision.models.utils import load_state_dict_from_url

__all__ = [
    'SpikingVGGLike',
    'spiking_vgglike_5_bn',
    'spiking_vgglike_6_bn',
    'spiking_vgglike_7_bn',
    'spiking_vgglike_8_bn',
    'spiking_vgglike_9_bn',
    'spiking_vgglike_16_bn',
]


class SpikingVGGLike(nn.Module):

    def __init__(self, cfg, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}',
                 batch_norm=False, norm_layer=None, num_classes=10):
        super(SpikingVGGLike, self).__init__()
        self.k = k
        self.lam = lam
        self.T = T
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.grad = grad
        self.grad_kargs = json.loads(grad_kargs)

        self.conv = self.make_layers(cfg=cfg, k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                     batch_norm=batch_norm, norm_layer=norm_layer,
                                     surrogate_function=self.grad(**self.grad_kargs))
        self.avgpool = nn.AdaptiveAvgPool2d((3, 3))
        self.fc = nn.Sequential(
            nn.Flatten(),
            # layer.Dropout(0.5),
            nn.Linear(512 * 3 * 3, 1024, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.5),
            nn.Linear(1024, 1024, bias=True),
            neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                  surrogate_function=self.grad(**self.grad_kargs)),
            layer.Dropout(0.5),
            nn.Linear(1024, num_classes, bias=True),
            # neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
            #                       surrogate_function=self.grad(**self.grad_kargs)),
        )

    def forward(self, x):
        out_spikes_counter = self.fc(self.avgpool(self.conv(x[0])))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.avgpool(self.conv(x[t])))
        return out_spikes_counter / self.T

    @staticmethod
    def make_layers(cfg, k, lam, v_threshold, v_reset, surrogate_function, batch_norm=False, norm_layer=None):
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=True)
                if batch_norm:
                    layers += [conv2d, norm_layer(v),
                               neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                                     surrogate_function=surrogate_function)]
                else:
                    layers += [conv2d, neuron.GeneralLIFNode(k=k, lam=lam, v_threshold=v_threshold, v_reset=v_reset,
                                                             surrogate_function=surrogate_function)]
                in_channels = v
        return nn.Sequential(*layers)


cfgs = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
    'TEN': [64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 'M'],
    'NINE': [64, 'M', 128, 256, 'M', 512, 512, 'M', 512, 'M'],
    'EIGHT': [64, 'M', 128, 256, 'M', 512, 512, 'M'],
    'SEVEN': [64, 'M', 128, 256, 'M', 512, 'M'],
    'SIX': [64, 'M', 256, 'M', 512, 'M'],
    'FIVE': [64, 'M', 512, 'M'],
}


def _spiking_vgglike(cfg, k, lam, T, batch_norm, norm_layer: callable = None,
                     v_threshold=1.0, v_reset=0.0,
                     grad=surrogate.Sigmoid, grad_kargs='{}'):
    if batch_norm:
        norm_layer = norm_layer
    else:
        norm_layer = None
    model = SpikingVGGLike(cfg=cfgs[cfg], k=k, lam=lam, T=T,
                           v_threshold=v_threshold, v_reset=v_reset,
                           grad=grad, grad_kargs=grad_kargs,
                           batch_norm=batch_norm, norm_layer=norm_layer)
    return model


def spiking_vgglike_5_bn(k, lam, T, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
    return _spiking_vgglike('FIVE', k=k, lam=lam, T=T, batch_norm=True, grad=grad, grad_kargs=grad_kargs)


def spiking_vgglike_6_bn(k, lam, T, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
    return _spiking_vgglike('SIX', k=k, lam=lam, T=T, batch_norm=True, grad=grad, grad_kargs=grad_kargs)


def spiking_vgglike_7_bn(k, lam, T, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
    return _spiking_vgglike('SEVEN', k=k, lam=lam, T=T, batch_norm=True, grad=grad, grad_kargs=grad_kargs)


def spiking_vgglike_8_bn(k, lam, T, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
    return _spiking_vgglike('EIGHT', k=k, lam=lam, T=T, batch_norm=True, grad=grad, grad_kargs=grad_kargs)


def spiking_vgglike_9_bn(k, lam, T, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
    return _spiking_vgglike('NINE', k=k, lam=lam, T=T, batch_norm=True, grad=grad, grad_kargs=grad_kargs)


def spiking_vgglike_16_bn(k, lam, T, grad=surrogate.Sigmoid, grad_kargs='{}', **kwargs):
    return _spiking_vgglike('E', k=k, lam=lam, T=T, batch_norm=True, grad=grad, grad_kargs=grad_kargs)


if __name__ == "__main__":
    pass
    # from spikingjelly.clock_driven import neuron
    # x = torch.rand(11,3,32,32)
    # model = multi_step_spiking_vgg11(num_classes=10,multi_step_neuron=neuron.MultiStepIFNode)
    # print(model)
#     print(model(x, T=20).shape)
#     print(isinstance(model.features[1],neuron.BaseNode))
#     print(model.features[11:16])
#     print(model.classifier[:-1])
