from ..NeuronConfig import NeuronConfig
from ..Neuron import Neuron

from .QIFParameters import QIFParameters
from .QIFState import QIFState

import torch

from SNN.SurrogateGradients import threshold

import os

class ITQIF(Neuron):
    """
    A Quadratic Integrate-and-Fire Neuron
    """
    def __init__(
        self,
        p: QIFParameters = QIFParameters(), 
        config: NeuronConfig = NeuronConfig(),
        return_voltage: bool = False,
        name: str | None = None,
        *args,
        **kwargs
    ):
        super(ITQIF, self).__init__(
            *args,
            **kwargs
        )
        self.p = p
        self.config = config
        self.return_voltage = return_voltage
        self.name = name

        # Pre-compute window for surrogate gradient
        self.left, self.right = compute_window(self.p.a, self.p.v_reset, self.p.v_c, self.p.v_th)
    
    def forward(self, input_spikes: torch.Tensor, state: QIFState | None) -> tuple[torch.Tensor, QIFState]:
        if state is None:
            state = self.initial_state(input_spikes)

        v_in = update(self.p.a, state.v, self.p.v_reset, self.p.v_c, input_spikes)

        # self.record_spikes(v_in)

        z_new = threshold(v_in - self.p.v_th, self.p.method, self.left, self.right)
        v_new = torch.where(z_new > 0, self.p.v_reset, v_in)

        self.updates_spikes(z_new.detach().sum().item(), total = input_spikes.numel())

        if self.return_voltage:
            return v_in, QIFState(v_in)
        return z_new, QIFState(v_new)

    def initial_state(self, input_tensor: torch.Tensor):
        return QIFState(
            v=torch.full_like(
                input_tensor,
                self.p.v_reset.item(),
                device=input_tensor.device,
                dtype=input_tensor.dtype,
                requires_grad=True,
            ),
            i=torch.zeros_like(
                input_tensor,
                device=input_tensor.device,
                dtype=input_tensor.dtype,
            ),
        )
    
    def record_spikes(self, voltage: torch.Tensor):
        if not os.path.exists("./SNN/output/CIFAR10/QIF_ResNet19"):
            os.makedirs("./SNN/output/CIFAR10/QIF_ResNet19")
        torch.save(voltage, f"./SNN/output/CIFAR10/QIF_ResNet19/Voltage_Layer_{self.layer}_Count_{self.count}.pt")
        self.count += 1

# Optimized code for v_reset = 0
@torch.jit.script
def update(a: torch.Tensor, v: torch.Tensor, v_reset: torch.Tensor, v_c: torch.Tensor, input_spikes: torch.Tensor) -> torch.Tensor:
    return a * v * (v - v_c) + input_spikes

@torch.jit.script
def compute_window(a: torch.Tensor, v_reset: torch.Tensor, v_c: torch.Tensor, v_th: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    mu: torch.Tensor = a * v_th**2
    std: torch.Tensor = v_th * torch.sqrt(1 + a**2 * (2*v_th**2 + v_c**2))
    return mu - std, mu + std

# @torch.jit.script
# def update(a: torch.Tensor, v: torch.Tensor, v_reset: torch.Tensor, v_c: torch.Tensor, input_spikes: torch.Tensor) -> torch.Tensor:
#     return a * (v - v_reset) * (v - v_c) + input_spikes

# @torch.jit.script
# def compute_window(a: torch.Tensor, v_reset: torch.Tensor, v_c: torch.Tensor, v_th: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
#     mu: torch.Tensor = a * (v_th**2 + v_reset * v_c)
#     std: torch.Tensor = v_th * torch.sqrt(1 + a**2 * (2*v_th**2 + (v_c + v_reset)**2))
#     return mu - std, mu + std