from ..NeuronConfig import NeuronConfig
from ..Neuron import Neuron

from .LNMParameters import LNMParameters
from .LNMState import LNMState

from SNN.LearnableMembrane import LearnablePolynomialMembrane

from SNN.SurrogateGradients import threshold

import torch
from torch import nn

class LNM(Neuron):
    """
    Learnable Neuron Model (LNM) based on the LIF model with polynomial membrane potential update.
    Learnable Neuron Model (LNM) based on the LIF model with polynomial membrane potential update.
    """
    def __init__(
        self,
        p: LNMParameters = LNMParameters(), 
        config: NeuronConfig = NeuronConfig()
    ):
        super(LNM, self).__init__()
        self.p = p
        self.config = config
        
        assert config.poly_degree > 0, "Polynomial degree must be greater than 0"   

        self.update = LearnablePolynomialMembrane(config.poly_degree, bias=False)
        self.window = compute_window(self.p.tau_mem_inv, self.p.v_th)

    def forward(self, input_spikes: torch.Tensor, state: LNMState | None) -> tuple[torch.Tensor, LNMState]:
        if state is None:
            state = self.initial_state(input_spikes)

        # Membrane potential update
        v_in: torch.Tensor = state.v + self.update(state.v) + input_spikes

        # Compute the surrogate spike with gradient
        z_new = threshold(v_in - self.p.v_th, self.p.method, self.window)
        reset_gate = (z_new > 0).detach()
        v_new = torch.where(reset_gate, self.p.v_reset, v_in)

        self.updates_spikes(z_new.sum().item(), total = input_spikes.numel())

        return z_new, LNMState(v_new)

    def initial_state(self, input_tensor: torch.Tensor):
        return LNMState(
            v=torch.full_like(
                input_tensor,
                self.p.v_reset.item(),
                device=input_tensor.device,
                dtype=input_tensor.dtype,
                requires_grad=True,
            )
        )

@torch.jit.script
def compute_window(beta: torch.Tensor, v_th: torch.Tensor) -> torch.Tensor:
    return v_th*torch.sqrt(1 + torch.pow(beta, 2))