from ..NeuronConfig import NeuronConfig
from ..Neuron import Neuron

from .LIFParameters import LIFParameters
from .LIFState import LIFState

import torch

from SNN.SurrogateGradients import threshold

import os

class ITLIF(Neuron):
    """
    An Iterative Leaky Integrate-and-Fire neuron model as described in the work:
    S. Lian, J. Shen, Q. Liu, Z. Wang, R. Yan, and H. Tang, “Learnable Surrogate Gradient for Direct Training Spiking Neural Networks,” 
    in Proceedings of the Thirty-Second International Joint Conference on Artificial Intelligence, Macau, SAR China: International 
    Joint Conferences on Artificial Intelligence Organization, Aug. 2023, pp. 3002–3010. doi: 10.24963/ijcai.2023/335.
    """
    def __init__(
        self,
        p: LIFParameters = LIFParameters(), 
        config: NeuronConfig = NeuronConfig(),
        return_voltage: bool = False,
        name: str | None = None,
        *args,
        **kwargs,
    ):
        super(ITLIF, self).__init__(
            *args,
            **kwargs
        )
        self.p = p
        self.config = config
        self.return_voltage = return_voltage
        self.name = name
        
    def forward(self, input_spikes: torch.Tensor, state: LIFState | None) -> tuple[torch.Tensor, LIFState]:
        if state is None:
            state = self.initial_state(input_spikes)

        v_in = self.p.tau_mem_inv * state.v + input_spikes

        # self.record_spikes(v_in)

        z_new = threshold(v_in - self.p.v_th, self.p.method, -self.p.alpha, self.p.alpha)
        v_new = torch.where(z_new > 0, self.p.v_reset, v_in)

        self.updates_spikes(z_new.detach().sum().item(), z_new.numel())

        if self.return_voltage:
            return v_in, LIFState(v_in)

        return z_new, LIFState(v_new)
    
    def record_spikes(self, voltage: torch.Tensor):
        if not os.path.exists("./SNN/output/CIFAR10/LIF_ResNet19"):
            os.makedirs("./SNN/output/CIFAR10/LIF_ResNet19")
        torch.save(voltage, f"./SNN/output/CIFAR10/LIF_ResNet19/Voltage_Layer_{self.layer}_Count_{self.count}.pt")
        self.count += 1

    def initial_state(self, input_tensor: torch.Tensor):
        return LIFState(
            v=torch.full_like(
                input_tensor,
                self.p.v_leak.item() if self.return_voltage else self.p.v_reset.item(),
                device=input_tensor.device,
                dtype=input_tensor.dtype,
                requires_grad=True,
            )
        )
