import torch
import torch.nn as nn
import surrogate
import neuron
import json
from models.basic_model import *
from models.Cifar10Net import *
from models.Cifar10DvsNet import *

IN_CHANNELS = {
    'mnist': 1,
    'cifar10': 3,
    'nmnist': 2,
    'cifar10dvs': 2
}

NUM_CLASSES = {
    'mnist': 10,
    'cifar10': 10,
    'nmnist': 10,
    'cifar10dvs': 10
}


class GeneralLIFModel(nn.Module):
    def __init__(self, model, dataset,
                 k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}'):
        super(GeneralLIFModel, self).__init__()
        self.k = k
        self.lam = lam
        self.T = T
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.grad = grad
        self.grad_kargs = grad_kargs
        self.model = model
        self.dataset = dataset
        self.model = self.convert_to_lif(self.model, k, lam, T, v_threshold, v_reset, grad, grad_kargs)
        self.adapt_first_conv_input_channel(model, dataset)

    def convert_to_lif(self, model, k, lam, T, v_threshold=1.0, v_reset=0.0, grad=surrogate.Sigmoid, grad_kargs='{}'):
        grad_kwargs = json.loads(grad_kargs)
        for name, module in model._modules.items():
            if hasattr(module, "_modules"):
                model._modules[name] = self.convert_to_lif(module, k=k, lam=lam, T=T,
                                                           v_threshold=1.0, v_reset=0.0,
                                                           grad=grad, grad_kargs=grad_kargs)
            if module.__class__.__name__ == 'ReLU':
                model._modules[name] = neuron.GeneralLIFNode(k=k, lam=lam,
                                                             v_threshold=v_threshold,
                                                             v_reset=v_reset,
                                                             surrogate_function=grad(**grad_kwargs))
        return model

    def adapt_first_conv_input_channel(self, model, dataset):
        for name, module in model._modules.items():
            if module.__class__.__name__ == 'Conv2d':
                # print('found!')
                # print(module)
                model._modules[name] = nn.Conv2d(
                    in_channels=IN_CHANNELS[dataset.lower()],
                    out_channels=module.out_channels,
                    kernel_size=module.kernel_size,
                    stride=module.stride,
                    padding=module.padding,
                    dilation=module.dilation,
                    groups=module.groups,
                    bias=(module.bias is not None),
                    padding_mode=module.padding_mode
                )
                return True

            if hasattr(module, "_modules"):
                if self.adapt_first_conv_input_channel(module, dataset):
                    return

    def forward(self, x):
        out_spikes_counter = self.model(x[0])
        for t in range(1, self.T):
            out_spikes_counter += self.model(x[t])
        return out_spikes_counter / self.T
