import torch
from spikingjelly.clock_driven.neuron import BaseNode
from spikingjelly.clock_driven import surrogate


class LIFNode(BaseNode):
    def __init__(self, tau=100.0, v_threshold=1.0, v_reset=0.0, surrogate_function=surrogate.Sigmoid(),
                 detach_reset=False,
                 monitor_state=False):
        super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, monitor_state)
        self.tau = tau

    def extra_repr(self):
        return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, tau={self.tau}'

    def neuronal_charge(self, dv: torch.Tensor):
        if self.v_reset is None:
            self.v += (dv - self.v) / self.tau
        else:
            self.v += (dv - (self.v - self.v_reset)) / self.tau


class GeneralLIFNode(BaseNode):
    def __init__(self, k, lam, v_threshold=1.0, v_reset=0.0, surrogate_function=surrogate.Sigmoid(), detach_reset=False,
                 monitor_state=False):
        super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, monitor_state)
        self.k = k
        self.lam = lam

    def extra_repr(self):
        return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, k={self.k}, lambda={self.lam}'

    def neuronal_charge(self, dv: torch.Tensor):
        if self.v_reset is None:
            self.v = self.v * self.k + dv * self.lam
        else:
            self.v = dv * self.lam + (self.v - self.v_reset) * self.k
