import torch
import math
from torch import nn
from spikingjelly.activation_based import neuron, surrogate
from typing import Callable, Optional
from matplotlib import pyplot as plt


def generate_thresh(num_rows, num_cols):
    assert num_rows % 2 == 0, "num_rows must be a multiple of 4"
    
    pos = torch.arange(num_cols).unsqueeze(1)  # [L, 1]
    i = torch.arange(num_rows).unsqueeze(0)  # [1, D]
    angle_rates = 1.0 / torch.pow(10000, (2 * (i // 2)) / num_rows)  # [1, D]
    angle_rads = pos * angle_rates  # [L, D]
    
    pos_encoding = torch.zeros_like(angle_rads)
    pos_encoding[:, 0::2] = torch.cos(angle_rads[:, 0::2])
    pos_encoding[:, 1::2] = torch.sin(angle_rads[:, 1::2])
    
    return pos_encoding.unsqueeze(0).cuda()  # [1, L, D]
            
class PEBaseNode(neuron.base.MemoryModule):
    def __init__(self, v_threshold: float = 1.0,
                 surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False,
                 step_mode='s', num_heads=1, embedding_dim=384, token_num=64, k=0.3):
        """
        A simple version of ``BaseNode``. The user can modify this neuron easily.
        """
        super().__init__()
        self.v_threshold = v_threshold + k * generate_thresh(embedding_dim, token_num)
        self.surrogate_function = surrogate_function
        self.detach_reset = detach_reset
        self.step_mode = step_mode
        self.register_memory(name='v', value=0.)

    def single_step_forward(self, x: torch.Tensor):

        self.neuronal_charge(x)
        spike = self.neuronal_fire()
        self.neuronal_reset(spike)
        return spike


    def neuronal_charge(self, x: torch.Tensor):
        raise NotImplementedError


    def neuronal_fire(self):
        return self.surrogate_function(self.v - self.v_threshold[:, 0:self.v.shape[1], :])


    def neuronal_reset(self, spike):
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        # soft reset
        self.v = self.v - self.v_threshold[:, 0:self.v.shape[1], :] * spike_d

class PELIFNode(PEBaseNode):
    def __init__(self, v_threshold: float, tau:float, decay_input: bool = True,
                 surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False,
                 step_mode='s', num_heads=1, embedding_dim=384, token_num=64, k=0.3, zo_loss=True):
        super().__init__(v_threshold, surrogate_function, detach_reset, step_mode, num_heads, embedding_dim, token_num, k)
        self.tau = tau
        self.decay_input = decay_input
        self.zo_loss = zo_loss
        self.zo_loss_value = 0
        self.loss_fn = nn.MSELoss()
        
    def neuronal_charge(self, x: torch.Tensor):
        if self.decay_input:
            self.v = self.v + (x - self.v) / self.tau
        else:
            self.v = self.v + (-self.v) / self.tau + x
        
    def neuronal_fire(self):
        s = self.surrogate_function(self.v - self.v_threshold[:, 0:self.v.shape[1], :])
        
        if self.zo_loss and self.training:
            self.zo_loss_value += self.loss_fn(self.v.mean(dim=0), s.mean(dim=0).detach())
        
        return s
    
    
class APELIFNode(PEBaseNode):
    def __init__(self, v_threshold: float, tau:float, decay_input: bool = True,
                 surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False,
                 step_mode='s', num_heads=1, embedding_dim=384, token_num=64, k=0.3):
        super().__init__(v_threshold, surrogate_function, detach_reset, step_mode, num_heads, embedding_dim, token_num, k)
        self.tau = tau
        self.decay_input = decay_input
        
    def neuronal_charge(self, x: torch.Tensor):
        if self.decay_input:
            self.v = self.v + (x - self.v) / self.tau
        else:
            self.v = self.v + (-self.v) / self.tau + x
        
    def neuronal_fire(self):
        s = self.surrogate_function(self.v - self.v_threshold[:, 0:self.v.shape[1], :])

        return s
    