import torch
from torch import nn
from spikingjelly.activation_based import neuron, surrogate
from typing import Callable


def generate_thresh(num_rows, num_cols):
    assert num_rows % 2 == 0, "num_rows must be a multiple of 4"
    
    pos = torch.arange(num_cols).unsqueeze(1)  # [L, 1]
    i = torch.arange(num_rows).unsqueeze(0)  # [1, D]
    angle_rates = 1.0 / torch.pow(10000, (2 * (i // 2)) / num_rows)  # [1, D]
    angle_rads = pos * angle_rates  # [L, D]
    
    pos_encoding = torch.zeros_like(angle_rads)
    pos_encoding[:, 0::2] = torch.cos(angle_rads[:, 0::2])
    pos_encoding[:, 1::2] = torch.sin(angle_rads[:, 1::2])
    
    return pos_encoding.unsqueeze(0).cuda()  # [1, L, D]


class SimSigmoid(nn.Module):
    def __init__(self, alpha=4.0, spiking=True):
        super().__init__()
        self.spiking = spiking
        self.alpha = alpha
    
    def forward(self, v, thr):
        x = v - thr
        s = (x >= 0).to(x) # [B, L, D]
        _norm = torch.norm(s.mean(dim=0).flatten() - v.mean(0).flatten(), p=2, dim=0)
        x = (x * self.alpha).sigmoid_()
        return x + _norm * ((s - x) / _norm).detach()
    
            
class PEBaseNode(neuron.base.MemoryModule):
    def __init__(self, v_threshold: float = 1.0,
                 surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False,
                 step_mode='s', num_heads=1, embedding_dim=384, token_num=64, k=0.3):
        """
        A simple version of ``BaseNode``. The user can modify this neuron easily.
        """
        super().__init__()
        self.surrogate_function = surrogate_function
        self.detach_reset = detach_reset
        self.step_mode = step_mode
        self.register_memory(name='v', value=0.)

    def single_step_forward(self, x: torch.Tensor):

        self.neuronal_charge(x)
        spike = self.neuronal_fire()
        self.neuronal_reset(spike)
        return spike


    def neuronal_charge(self, x: torch.Tensor):
        raise NotImplementedError


    def neuronal_fire(self):
        return self.surrogate_function(self.v - self.v_threshold[:, 0:self.v.shape[1], :])


    def neuronal_reset(self, spike):
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        # soft reset
        self.v = self.v - self.v_threshold[:, 0:self.v.shape[1], :] * spike_d
        

class PELIFNode(PEBaseNode):
    def __init__(self, v_threshold: float, tau:float, decay_input: bool = True,
                 surrogate_function: Callable = SimSigmoid(), detach_reset: bool = False,
                 step_mode='s', num_heads=1, embedding_dim=384, token_num=512, k=0.3, zo_loss=True):
        super().__init__(v_threshold, surrogate_function, detach_reset, step_mode, num_heads, embedding_dim, token_num, k)
        self.tau = tau
        self.decay_input = decay_input
        self.v_threshold = v_threshold + k * generate_thresh(embedding_dim//num_heads, token_num).repeat(1, 1, num_heads)
        
    def neuronal_charge(self, x: torch.Tensor):
        if self.decay_input:
            self.v = self.v + (x - self.v) / self.tau
        else:
            self.v = self.v + (-self.v) / self.tau + x
        
    def neuronal_fire(self):
        s = self.surrogate_function(self.v, self.v_threshold[:, 0:self.v.shape[1], :])
        
        return s
    