import math
import torch
import torch.nn.functional as F
import torch.nn as nn

# 使用改进的代理梯度方法重新实现脉冲函数
class SpikeFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, membrane_potential, threshold):
        # 保存上下文以便反向传播
        ctx.save_for_backward(membrane_potential, threshold)
        # 前向传播时使用确定性阶跃函数
        return (membrane_potential > threshold).float()
        
    @staticmethod
    def backward(ctx, grad_output):
        # 获取保存的上下文
        membrane_potential, threshold = ctx.saved_tensors
        
        # 计算pre-activation (u = v - v_th)
        u = membrane_potential - threshold
        
        # 使用三角形窗函数作为替代梯度，比其他替代梯度更稳定
        # 窗口宽度设置为较小值以确保传播
        window_width = 1.0  # 增大窗口宽度以确保梯度传递
        
        # 计算三角形窗替代梯度 - 简单明确且数值稳定
        grad_u = (1.0 - torch.abs(u) / window_width).clamp(min=0.0) / window_width
        
        # 计算梯度并裁剪以避免梯度爆炸
        membrane_grad = grad_output * grad_u
        # 阈值梯度符号相反
        threshold_grad = -grad_output * grad_u
        
        # 检查并替换任何NaN值
        membrane_grad = torch.where(torch.isfinite(membrane_grad), 
                                   membrane_grad, 
                                   torch.zeros_like(membrane_grad))
        threshold_grad = torch.where(torch.isfinite(threshold_grad), 
                                    threshold_grad, 
                                    torch.zeros_like(threshold_grad))
        
        # 应用额外的梯度缩放，避免梯度消失
        scale_factor = 1.5  # 增大梯度缩放因子
        membrane_grad = membrane_grad * scale_factor
        threshold_grad = threshold_grad * scale_factor
        
        return membrane_grad, threshold_grad

# --- 改进版 LIF Neuron --- 
class LIFNeuron(nn.Module):
    def __init__(self, tau_mem: float = 20.0, v_threshold: float = 1.0, v_reset: float | None = 0.0, 
                 surrogate_gradient: bool = True, **kwargs):
        super().__init__()
        # 初始化参数
        self.v_threshold_init = float(v_threshold)
        self.v_reset_val = v_reset if v_reset is not None else 0.0
        self.register_buffer('v', torch.tensor(0.0, dtype=torch.float32)) 
        self.tau_mem = float(tau_mem)
        self.decay_factor = 0.0 if self.tau_mem <= 0 else math.exp(-1.0 / self.tau_mem)
        self.surrogate_gradient = surrogate_gradient  # 控制是否使用替代梯度
        
        # 缓存
        self._reset_val_cache = {}

    def reset(self):
        """重置神经元膜电位到重置值"""
        if not isinstance(self.v, torch.Tensor) or self.v.numel() == 0:
            self.v = torch.tensor(float(self.v_reset_val), dtype=torch.float32, 
                                device=self.v.device if isinstance(self.v, torch.Tensor) else torch.device("cpu"))
            return
            
        # 快速重置路径
        if isinstance(self.v_reset_val, (int, float)) or (isinstance(self.v_reset_val, torch.Tensor) and self.v_reset_val.numel() == 1):
            reset_val = float(self.v_reset_val.item() if isinstance(self.v_reset_val, torch.Tensor) else self.v_reset_val)
            self.v.fill_(reset_val)
            return
            
        # 张量重置值
        if isinstance(self.v_reset_val, torch.Tensor) and self.v_reset_val.shape == self.v.shape:
            self.v.copy_(self.v_reset_val.to(self.v.device))
            return
        
        # 默认情况
        self.v.fill_(0.0)  # 默认重置为0
    
    def _prepare_reset_tensor(self, potential_shape: tuple, current_device: torch.device) -> torch.Tensor:
        """准备重置值张量，使用缓存加速"""
        # 对于标量重置值，使用快速路径
        if isinstance(self.v_reset_val, (int, float)) or (isinstance(self.v_reset_val, torch.Tensor) and self.v_reset_val.numel() == 1):
            reset_val = float(self.v_reset_val.item() if isinstance(self.v_reset_val, torch.Tensor) else self.v_reset_val)
            
            # 查找缓存
            shape_key = potential_shape
            device_key = str(current_device)
            cache_key = (shape_key, device_key, reset_val)
            
            if cache_key in self._reset_val_cache:
                return self._reset_val_cache[cache_key]
                
            # 创建新的重置张量
            reset_tensor = torch.full(potential_shape, reset_val, dtype=torch.float32, device=current_device)
            
            # 缓存结果，但限制缓存大小
            if len(self._reset_val_cache) < 10:  # 限制缓存大小
                self._reset_val_cache[cache_key] = reset_tensor
                
            return reset_tensor
            
        # 对于张量重置值，需要转换和广播
        if isinstance(self.v_reset_val, torch.Tensor):
            reset_on_device = self.v_reset_val.to(current_device)
            
            if self.v_reset_val.ndim == 1 and len(potential_shape) > 1 and self.v_reset_val.shape[0] == potential_shape[1]:
                # 1D重置值扩展到输入形状
                return reset_on_device.unsqueeze(0).expand(potential_shape)
            elif self.v_reset_val.shape == potential_shape:
                # 重置值与输入形状相同
                return reset_on_device
                
        # 默认返回全零张量
        return torch.zeros(potential_shape, dtype=torch.float32, device=current_device)

    def forward(self, x_in: torch.Tensor, v_threshold: torch.Tensor = None): 
        """改进的前向传播函数，确保梯度正确流动
        
        参数:
            x_in: 输入张量
            v_threshold: 可选的阈值张量，如果为None则使用初始阈值
        """
        # 获取当前设备
        current_device = x_in.device
        current_dtype = x_in.dtype
        
        # 初始化膜电位（必要时）
        if not isinstance(self.v, torch.Tensor) or self.v.shape != x_in.shape:
            reset_val = float(self.v_reset_val if isinstance(self.v_reset_val, (int, float)) else 0.0)
            self.v = torch.full_like(x_in, reset_val, device=current_device)
        
        # 确保阈值是正确类型的张量
        if v_threshold is None:
            v_threshold = torch.tensor(self.v_threshold_init, dtype=current_dtype, device=current_device)
        elif not isinstance(v_threshold, torch.Tensor):
            v_threshold = torch.tensor(float(v_threshold), dtype=current_dtype, device=current_device)
        else:
            v_threshold = v_threshold.to(device=current_device, dtype=current_dtype)
        
        # 更新膜电位 - 从状态变量获取当前值，但分离计算图
        v_current = self.v.detach().clone()
        v_next = v_current + x_in  # 累加输入电流
        
        # 广播阈值到输入形状（如果需要）
        if v_threshold.shape != v_next.shape and v_threshold.numel() == 1:
            v_threshold = v_threshold.expand_as(v_next)
        
        # 生成脉冲使用条件替代梯度
        if self.surrogate_gradient and self.training:
            # 训练模式下使用替代梯度
            spikes = SpikeFunction.apply(v_next, v_threshold)
        else:
            # 评估模式使用标准阶跃函数
            spikes = (v_next > v_threshold).float()
        
        # 重置发放神经元的膜电位
        reset_val = self.v_reset_val if isinstance(self.v_reset_val, (int, float)) else 0.0
        reset_tensor = torch.full_like(v_next, reset_val)
        
        # 通过torch.where创建新的膜电位，保持计算图
        new_v = torch.where(spikes.bool(), reset_tensor, v_next)
        
        # 更新内部状态，但分离计算图
        self.v = new_v.detach().clone()
        
        return spikes