import torch
import torch.nn as nn
from spikingjelly.activation_based import layer
from spikingjelly.activation_based import surrogate, neuron
from typing import Callable, Any, Dict

import cpp_neuron
import math
import numpy as np

def get_activation(name):
    if "bptt" in name and "one" in name:
        return BPTTLIFOne
    elif "bptt" in name:
        return BPTTLIF
    elif "vanilla" in name:
        return VanillaLIF
    else:
        raise AssertionError("No such activation!!")

def get_thresolds(model):
    threshs = []
    for m in model.modules():
        if isinstance(m, (BPTTLIF, BPTTLIFOne, VanillaLIF)):
            if isinstance(m.thresh, float):
                threshs.append(m.thresh)
            else:
                threshs.append(m.thresh.item())
    return np.array(threshs)

def static_firing_rate(model):
    rates = []
    for m in model.modules():
        if isinstance(m, (BPTTLIF, BPTTLIFOne, VanillaLIF)):
            rates.append(m.r)
    return np.array(rates)


def static_spike_count(model):
    rates = []
    for m in model.modules():
        if isinstance(m, (BPTTLIF, BPTTLIFOne, VanillaLIF)):
            rates.append(m.s)
    return np.array(rates)

def zif_backward(x, thre):
    result = (1. - (x/thre).abs()).clamp_min(0)
    return result

class ZIF(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, gama):
        out = (input >= 0).float()
        L = torch.tensor([gama])
        ctx.save_for_backward(input, out, L)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        (input, out, others) = ctx.saved_tensors
        gama = others[0].item()
        grad_input = grad_output
        tmp = (1 / gama) * (1 / gama) * ((gama - input.abs()).clamp(min=0))
        grad_input = grad_input * tmp
        return grad_input, None
    
class Surrogate(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, thre, sp):
        out = (input >= thre).float()
        ctx.save_for_backward(input, out, thre, torch.tensor(sp))
        return out * thre

    @staticmethod
    def backward(ctx, grad_input):
        (input, out, thre, sp) = ctx.saved_tensors
        input = (input - thre) / thre
        tmp = (1.0 - input.abs()).clamp(min=0)
        grad_input = grad_input * tmp
        grad_thre = - (grad_input * tmp + sp * out * tmp).sum(0).mean()
        return grad_input, grad_thre, None
    
class SurrogateOne(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, thre, sp):
        out = (input >= thre).float()
        ctx.save_for_backward(input, out, thre, torch.tensor(sp))
        return out

    @staticmethod
    def backward(ctx, grad_input):
        (input, out, thre, sp) = ctx.saved_tensors
        input = (input - thre)
        tmp = (1.0 - input.abs()).clamp(min=0)
        grad_input = grad_input * tmp
        grad_thre = - (grad_input * tmp + sp * out * tmp).sum(0).mean()
        return grad_input, grad_thre, None
    
class LIFThreFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, thre, T, tau, sp):
        spikes, mems = cpp_neuron.forward_with_thre_hard(x, thre, T, tau)
        ctx.save_for_backward(spikes, mems, thre, torch.tensor(T), torch.tensor(tau), torch.tensor(sp))
        thre_updates = (sp * zif_backward(mems, thre) * spikes).sum(1).mean()
        return spikes * thre, thre_updates

    @staticmethod
    def backward(ctx, grad_output, grad_output2):
        (spikes, mems, thre, T, tau, sp) = ctx.saved_tensors
        grad_x, grad_thre = cpp_neuron.backward_with_thre_hard(grad_output, mems, spikes, thre, T.item(), tau.item(), 0.)
        return grad_x, grad_thre, None, None, None
    
class LIFVanillaFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, thre, T, tau):
        spikes, mems = cpp_neuron.forward_with_thre_hard(x, thre, T, tau)
        ctx.save_for_backward(spikes, mems, thre, torch.tensor(T), torch.tensor(tau))
        return spikes

    @staticmethod
    def backward(ctx, grad_output):
        (spikes, mems, thre, T, tau) = ctx.saved_tensors
        grad_x, _ = cpp_neuron.backward_with_thre_hard(grad_output, mems, spikes, thre, T.item(), tau.item(), 0.)
        return grad_x, None, None, None
    
class LIFOneFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, thre, T, tau, sp):
        spikes, mems = cpp_neuron.forward_with_thre_hard(x, thre, T, tau)
        ctx.save_for_backward(spikes, mems, thre, torch.tensor(T), torch.tensor(tau), torch.tensor(sp))
        thre_updates = (sp * zif_backward(mems, thre) * spikes).sum(1).mean()
        return spikes, thre_updates

    @staticmethod
    def backward(ctx, grad_output, grad_output2):
        (spikes, mems, thre, T, tau, sp) = ctx.saved_tensors
        grad_x, grad_thre = cpp_neuron.backward_with_one_hard(grad_output, mems, spikes, thre, T.item(), tau.item(), 0.)
        return grad_x, grad_thre, None, None, None
    
# class LIFThreSBPFunction(torch.autograd.Function):
#     @staticmethod
#     def forward(ctx, x, thre, T, sp):
#         spikes, mems = cpp_neuron.forward_with_thre(x, thre, T)
#         ctx.save_for_backward(spikes, mems, thre, torch.tensor(T), torch.tensor(sp))
#         return spikes * thre

#     @staticmethod
#     def backward(ctx, grad_output):
#         (spikes, mems, thre, T, sp) = ctx.saved_tensors
#         grad_x, grad_thre = cpp_neuron.backward_with_thre_sbp(grad_output, mems, spikes, thre, T.item(), sp.item())
#         return grad_x, grad_thre, None, None

# class SBPLIF(nn.Module):
#     def __init__(self, T=4, thresh=1., tau=1., gama=1.0, sp=1e-4):
#         super(SBPLIF, self).__init__()
#         self.act = LIFThreSBPFunction.apply
#         self.thresh = nn.Parameter(torch.tensor(thresh), requires_grad=True)
#         self.tau = tau
#         self.gama = gama
#         self.T = T
#         self.r = 0.
#         self.running_mean = 0.
#         self.running_var = 0.
#         self.sp = sp
#         self.macs = 0.

#     def forward(self, x):
#         x = self.act(x, self.thresh, self.T, self.sp)
#         self.r = (x.mean()/self.thresh).item() # N C H W
#         return x

class BPTTLIF(nn.Module):
    def __init__(self, T, thresh=1., tau=1., gama=1.0, sp=1e-4):
        super(BPTTLIF, self).__init__()
        self.act = LIFThreFunction.apply
        # register buffer
        self.thresh = nn.Parameter(torch.tensor(thresh), requires_grad=True)
        self.tau = tau
        self.gama = gama
        self.T = T
        self.r = 0.
        self.sp = sp
        self.macs = 0.
        self.act_value = 0.
        self.update_value = 0.
        self.s = 0.

    def forward(self, x):
        x, tmp = self.act(x, self.thresh, self.T, self.tau, self.sp)
        self.update_value += tmp
        self.act_value = x.reshape(x.size(0), -1).mean(1).sum()
        self.r = (x.mean()/self.thresh).item() # N C H W
        self.s = (x.mean(1)/self.thresh).sum().item()
        return x

# class LIFThreFunctionNew(torch.autograd.Function):
#     @staticmethod
#     def forward(ctx, x, thre, T, tau, sp):
#         spikes, mems = cpp_neuron.forward_with_thre_hard(x, torch.tensor(1.0), T, tau)
#         ctx.save_for_backward(spikes, mems, thre, torch.tensor(T), torch.tensor(tau), torch.tensor(sp))
#         thre_updates = (sp * zif_backward(mems, thre) * spikes).sum(1).mean()
#         return spikes * thre, thre_updates

#     @staticmethod
#     def backward(ctx, grad_output, grad_output2):
#         (spikes, mems, thre, T, tau, sp) = ctx.saved_tensors
#         grad_x, grad_thre = cpp_neuron.backward_with_thre_hard(grad_output, mems, spikes, thre, T.item(), tau.item(), 0.)
#         return grad_x, grad_thre, None, None, None


# class BPTTLIF(nn.Module):
#     def __init__(self, T, thresh=1.0, tau=1., gama=1.0, sp=0.):
#         super(BPTTLIF, self).__init__()
#         self.act = Surrogate.apply
#         self.thresh = nn.Parameter(torch.tensor(thresh), requires_grad=True)
#         self.tau = tau
#         self.gama = gama
#         self.T = T
#         self.r = 0
#         self.act_value = 0.
#         self.sp = sp

#     def forward(self, x):
#         mem = 0.
#         spike_pot = []
#         for t in range(self.T):
#             mem = mem * self.tau + x[t, ...]
#             spike = self.act(mem, self.thresh, self.sp) / self.thresh.detach()
#             mem = (1 - spike) * mem
#             spike_pot.append(spike)
#         x = torch.stack(spike_pot, dim=0)
#         self.r = x.mean().item()
#         self.act_value = x.reshape(x.size(0), -1).mean(1).sum()
#         return x * self.thresh.detach()
    
# class BPTTLIFOne(nn.Module):
#     def __init__(self, T, thresh=1.0, tau=1., gama=1.0, sp=0.):
#         super(BPTTLIFOne, self).__init__()
#         self.act = SurrogateOne.apply
#         self.thresh = nn.Parameter(torch.tensor(thresh), requires_grad=True)
#         self.tau = tau
#         self.gama = gama
#         self.T = T
#         self.r = 0
#         self.act_value = 0.
#         self.sp = sp

#     def forward(self, x):
#         mem = 0.
#         spike_pot = []
#         for t in range(self.T):
#             mem = mem * self.tau + x[t, ...]
#             spike = self.act(mem, self.thresh, self.sp)
#             mem = (1 - spike) * mem
#             spike_pot.append(spike)
#         x = torch.stack(spike_pot, dim=0)
#         self.r = x.mean().item()
#         self.act_value = x.reshape(x.size(0), -1).mean(1).sum()
#         return x

class VanillaLIF(nn.Module):
    def __init__(self, T, thresh=1.0, tau=1., gama=1.0, sp=0.):
        super(VanillaLIF, self).__init__()
        self.thresh = torch.tensor(thresh)
        self.tau = tau
        self.gama = gama
        self.T = T
        self.r = 0
        self.s = 0
        self.act_value = 0.
        self.act = LIFVanillaFunction.apply
        # self.act = ZIF.apply

    def forward(self, x):
        ############ cpp
        x = self.act(x, self.thresh, self.T, self.tau)
        self.r = (x.mean()).item() # N C H W
        self.s = x.mean(1).sum().item()
        ############ pytorch
        # mem = 0.
        # spike_pot = []
        # for t in range(self.T):
        #     mem = mem * self.tau + x[t, ...]
        #     spike = self.act(mem - self.thresh, self.gama)
        #     mem = (1 - spike) * mem
        #     spike_pot.append(spike)
        # x = torch.stack(spike_pot, dim=0)
        # self.r = x.mean().item()
        ###############
        ### l0/l1/l2
        self.act_value = x.reshape(x.size(0), -1).mean(1).sum()
        ### hoyer
        # self.act_value = x.reshape(x.size(0), -1).mean(1).sum().sqrt()
        # shape_ = x.shape[1]
        # self.act_value = self.act_value.sum().sqrt() / shape_
        return x
    
class BPTTLIFOne(nn.Module):
    def __init__(self, T, thresh=1., tau=1., gama=1.0, sp=1e-4):
        super(BPTTLIFOne, self).__init__()
        self.act = LIFOneFunction.apply
        # register buffer
        self.thresh = nn.Parameter(torch.tensor(thresh), requires_grad=True)
        self.tau = tau
        self.gama = gama
        self.T = T
        self.r = 0.
        self.sp = sp
        self.macs = 0.
        self.act_value = 0.
        self.update_value = 0.

    def forward(self, x):
        x, tmp = self.act(x, self.thresh, self.T, self.tau, self.sp)
        self.update_value += tmp
        self.act_value = x.reshape(x.size(0), -1).mean(1).sum()
        self.r = x.mean().item() # N C H W
        return x

class BN(layer.BatchNorm2d):
    def __init__(self, num_features):
        super().__init__(num_features, eps=0.00001, momentum=0.1, affine=True,
                         track_running_stats=True, step_mode='m')

class ConvBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 1,
        groups: int = 1,
        bias: bool = False,
        norm_layer: Callable[..., Any] = BN,
        norm_layer_kwargs: Dict = {},
        activation: Callable[..., Any] = BPTTLIF,
        activation_kwargs: Dict = {},
    ) -> None:
        super(ConvBlock, self).__init__()
        self.conv = layer.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=1,
            groups=groups,
            bias=bias,
            padding_mode='zeros',
            step_mode='m',
        )
        self.norm_layer = norm_layer(out_channels, **norm_layer_kwargs)
        self.activation = activation(**activation_kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.conv(x)
        out = self.norm_layer(out)
        out = self.activation(out)
        return out