import math
import torch
import torch.nn.functional as F
from .neuron import LIFNeuron
import torch.nn as nn

# --- 优化的Synapse类实现 ---
class Synapse(nn.Module): 
    def __init__(self, trainable: bool = False):
        super().__init__()
        # 添加trainable属性控制是否使阈值可训练
        self.trainable = trainable
        # 使用支持替代梯度的LIFNeuron
        self.if_neurons = LIFNeuron(tau_mem=20.0, v_reset=0.0, surrogate_gradient=True)
        # 存储各数据类型的阈值参数
        self.threshold_params = nn.ParameterDict()
        # 保留缓存以提高性能
        self.g_n_tensor_cache = {} 
        self.special_pattern_cache = {}  # 缓存特殊模式
        self.FLOAT_FORMAT_PARAMS = {
            torch.float8_e4m3fn: {"name": "E4M3FN", "exp_bits": 4, "man_bits": 3, "bias": 7, "min_norm_coded_exp": 1, "max_norm_coded_exp": 14, "total_bits": 8, "has_subnormals": False, "epsilon": 1e-9},
            torch.float8_e5m2: {"name": "E5M2", "exp_bits": 5, "man_bits": 2, "bias": 15, "min_norm_coded_exp": 1, "max_norm_coded_exp": 30, "total_bits": 8, "has_subnormals": False, "epsilon": 1e-9},
            torch.bfloat16: {"name": "BF16", "exp_bits": 8, "man_bits": 7, "bias": 127, "min_norm_coded_exp": 1, "max_norm_coded_exp": 254, "total_bits": 16, "has_subnormals": True, "epsilon": 1e-9},
            torch.float16: {"name": "FP16", "exp_bits": 5, "man_bits": 10, "bias": 15, "min_norm_coded_exp": 1, "max_norm_coded_exp": 30, "total_bits": 16, "has_subnormals": True, "epsilon": 1e-9},
            torch.float32: {"name": "FP32", "exp_bits": 8, "man_bits": 23, "bias": 127, "min_norm_coded_exp": 1, "max_norm_coded_exp": 254, "total_bits": 32, "has_subnormals": True, "epsilon": 1e-9}
        }
        # 预计算常用掩码
        self._precompute_masks()
        # 存储当前正在使用的阈值参数
        self.current_threshold_param = None
        
        # 强制初始化所有可能的阈值参数
        if self.trainable:
            for dtype, params in self.FLOAT_FORMAT_PARAMS.items():
                self._initialize_threshold_param(params)
        
    def _precompute_masks(self):
        """预计算常用位掩码以加速操作"""
        self.bit_masks = {}
        for dtype_name, params in self.FLOAT_FORMAT_PARAMS.items():
            self.bit_masks[params["name"]] = {
                "exp_masks": [(1 << bit) for bit in range(params["exp_bits"])],
                "man_masks": [(1 << bit) for bit in range(params["man_bits"])]
            }
            
    def get_float_format_params(self, dtype: torch.dtype) -> dict:
        """获取浮点格式参数，加入缓存机制"""
        params = self.FLOAT_FORMAT_PARAMS.get(dtype)
        if params is None: raise ValueError(f"Synapse: Unsupported float dtype: {dtype}")
        return params.copy() 
    
    def get_float_special_pattern(self, float_params: dict, device: torch.device, pattern_type: str = "nan", sign_bit: int = 0) -> torch.Tensor:
        """生成特殊模式的缓存版本"""
        cache_key = (float_params["name"], device.type, pattern_type, sign_bit)
        if cache_key in self.special_pattern_cache:
            return self.special_pattern_cache[cache_key]
            
        if sign_bit not in [0, 1]: raise ValueError("Sign bit must be 0 or 1.")
        bits = [sign_bit]; bits.extend([1] * float_params["exp_bits"]) 
        if pattern_type.lower() == "nan":
            if float_params["man_bits"] > 0: bits.extend([1] + [0] * (float_params["man_bits"] - 1))
        elif pattern_type.lower() == "inf": bits.extend([0] * float_params["man_bits"])
        else: raise ValueError(f"Unknown special pattern type: {pattern_type}.")
        if len(bits) != float_params["total_bits"]: bits = (bits + [0]*float_params["total_bits"])[:float_params["total_bits"]]
        
        pattern = torch.tensor(bits, dtype=torch.int8, device=device)
        self.special_pattern_cache[cache_key] = pattern
        return pattern
    
    def _initialize_threshold_param(self, float_params: dict) -> nn.Parameter:
        """初始化特定数据类型的阈值参数"""
        format_name = float_params["name"]
        param_name = f'threshold_{format_name}'
        
        if param_name not in self.threshold_params:
            # 创建初始阈值张量 [0.0, 0.5, 0.5, ...]
            init_values = torch.full((float_params["total_bits"],), 0.5, dtype=torch.float32)
            init_values[0] = 0.0  # 符号位阈值为0
            
            # 将参数注册到ParameterDict中
            self.threshold_params[param_name] = nn.Parameter(init_values, requires_grad=True)
        
        return self.threshold_params[param_name]
    
    def _get_g_n_tensor_for_dtype(self, float_params: dict, device: torch.device) -> torch.Tensor:
        """获取阈值张量，根据trainable属性决定是否返回可学习参数"""
        format_name = float_params["name"] 
        
        # 如果不是训练模式，使用缓存的固定值
        if not self.trainable:
            cache_key = (format_name, device.type)
            if format_name in self.g_n_tensor_cache:
                cached_tensor = self.g_n_tensor_cache[format_name]
                if cached_tensor.device == device:
                    return cached_tensor
                return cached_tensor.to(device)
                
            # 创建阈值张量 [0.0, 0.5, 0.5, ...]
            g_n_tensor = torch.full((float_params["total_bits"],), 0.5, dtype=torch.float32, device=device)
            g_n_tensor[0] = 0.0  # 符号位阈值为0
            
            self.g_n_tensor_cache[format_name] = g_n_tensor
            return g_n_tensor
        
        # 训练模式下，创建或返回可学习参数
        param_name = f'threshold_{format_name}'
        
        # 检查参数是否已经在ParameterDict中注册
        if param_name not in self.threshold_params:
            # 创建初始阈值张量 [0.0, 0.5, 0.5, ...]
            self._initialize_threshold_param(float_params)
            
        # 获取参数并确保它在正确的设备上
        threshold_param = self.threshold_params[param_name]
        if threshold_param.device != device:
            # 使用to()将参数移动到正确的设备，但保留其requires_grad属性
            self.threshold_params[param_name] = nn.Parameter(
                threshold_param.to(device), 
                requires_grad=threshold_param.requires_grad
            )
            threshold_param = self.threshold_params[param_name]
            
        # 缓存当前使用的阈值参数（这是一个引用，不会创建新的张量）
        self.current_threshold_param = threshold_param
            
        return threshold_param
    
    def _fast_path_zeros(self, input_tensor: torch.Tensor, float_params: dict, device: torch.device, 
                           input_dtype: torch.dtype, original_numerical_shape: tuple) -> torch.Tensor:
        """零值的快速处理路径"""
        n_bits = float_params["total_bits"]
        
        # 检查负零
        has_negative = torch.signbit(input_tensor).any()
        
        if not has_negative:
            # 全为正零，直接返回全零结果
            return torch.zeros((n_bits,) + original_numerical_shape, dtype=input_dtype, device=device)
        else:
            # 处理负零情况
            sign_bits = torch.signbit(input_tensor).to(input_dtype)
            result = torch.zeros((n_bits,) + original_numerical_shape, dtype=input_dtype, device=device)
            result[0] = sign_bits  # 只设置符号位
            return result
    
    def _calculate_f_x_n_batch(self, x_input_flat: torch.Tensor, float_params: dict) -> tuple[torch.Tensor, torch.Tensor]:
        """计算每个值的位表示形式"""
        if x_input_flat.numel() == 0: 
            return torch.empty((0, float_params["total_bits"]), device=x_input_flat.device), torch.empty((0,), dtype=torch.bool, device=x_input_flat.device)
        
        # 快速路径: 检查是否全为零
        if torch.all(x_input_flat == 0):
            zeros = torch.zeros((x_input_flat.numel(), float_params["total_bits"]), 
                               dtype=torch.float32, device=x_input_flat.device)
            if torch.any(torch.signbit(x_input_flat)):
                # 处理负零
                sign_bits = torch.signbit(x_input_flat).to(torch.float32)
                zeros[:, 0] = sign_bits
            return zeros, torch.zeros(x_input_flat.numel(), dtype=torch.bool, device=x_input_flat.device)
        
        # 预处理输入
        x_f32_batch = x_input_flat.to(torch.float32) if x_input_flat.dtype != torch.float32 else x_input_flat
        batch_size = x_f32_batch.numel()
        device = x_f32_batch.device
        
        # 创建结果tensor - 预分配一次而不是多次
        f_x_n_batch_tensors = torch.zeros((batch_size, float_params["total_bits"]), dtype=torch.float32, device=device)
        final_representable_mask = torch.zeros(batch_size, dtype=torch.bool, device=device)
        
        # 检查哪些值是有限和非零的 - 结合两个操作减少内存使用
        finite_and_nonzero_mask = torch.isfinite(x_f32_batch) & (x_f32_batch != 0.0)
        
        # 如果没有需要处理的值，直接返回
        if not torch.any(finite_and_nonzero_mask):
            return f_x_n_batch_tensors, final_representable_mask
        
        # 提取需要处理的值
        x_to_process = x_f32_batch[finite_and_nonzero_mask]
        abs_x_to_process = torch.abs(x_to_process)  # 可以复用
        idx_initial_processing = torch.nonzero(finite_and_nonzero_mask).squeeze(-1)
        
        # 设置符号位 (使用正值表示，符号作为额外编码) - 重要改动：不反转符号，保持梯度流
        f_x_n_batch_tensors[idx_initial_processing, 0] = (x_to_process < 0).float()
        
        # 计算指数 - 使用快速路径避免log2
        tiny_val = torch.finfo(torch.float32).tiny
        x_safe = abs_x_to_process.clamp(min=tiny_val)
        exp_actual_for_x_to_process = torch.floor(torch.log2(x_safe))
        coded_exp_for_x_to_process = (exp_actual_for_x_to_process + float_params["bias"]).int()
        
        # 检查哪些值在可表示范围内
        is_normal_range = (coded_exp_for_x_to_process >= float_params["min_norm_coded_exp"]) & (coded_exp_for_x_to_process <= float_params["max_norm_coded_exp"])
        is_subnormal = (coded_exp_for_x_to_process == 0) & float_params.get("has_subnormals", False)
        can_extract_bits = is_normal_range | is_subnormal
        
        # 如果没有可表示的值，直接返回
        if not torch.any(can_extract_bits):
            return f_x_n_batch_tensors, final_representable_mask
        
        # 获取可表示值的索引
        valid_indices = torch.nonzero(can_extract_bits).squeeze(-1)
        valid_indices_in_original_batch = idx_initial_processing[valid_indices]
        final_representable_mask[valid_indices_in_original_batch] = True
        
        # 提取有效值的指数和尾数
        x_for_bits_abs = abs_x_to_process[valid_indices]
        E_coded_for_bits = coded_exp_for_x_to_process[valid_indices]
        E_actual_for_bits = exp_actual_for_x_to_process[valid_indices]
        
        # 处理指数位 - 位操作优化
        exp_bits = float_params["exp_bits"]
        for i in range(exp_bits):
            bit_pos = exp_bits - 1 - i
            bit_val = ((E_coded_for_bits >> bit_pos) & 1).float()
            f_x_n_batch_tensors[valid_indices_in_original_batch, i + 1] = bit_val
        
        # 处理尾数位
        num_mant_bits = float_params["man_bits"]
        if num_mant_bits > 0:
            # 预分配一次
            mantissa_frac_values = torch.zeros_like(x_for_bits_abs)
            
            # 处理正规化值 - 批量处理优化
            is_normal_in_final = is_normal_range[valid_indices]
            if torch.any(is_normal_in_final):
                normal_indices = torch.nonzero(is_normal_in_final).squeeze(-1)
                # 优化2**x计算
                divisor = torch.pow(2.0, E_actual_for_bits[normal_indices])
                mantissa_frac_values[normal_indices] = x_for_bits_abs[normal_indices] / divisor - 1.0
            
            # 处理次正规化值
            is_subnormal_in_final = is_subnormal[valid_indices]
            if torch.any(is_subnormal_in_final):
                subnormal_indices = torch.nonzero(is_subnormal_in_final).squeeze(-1)
                min_normal_exp_actual = 1 - float_params["bias"]
                # 预计算除数
                subnormal_divisor = torch.pow(2.0, min_normal_exp_actual)
                mantissa_frac_values[subnormal_indices] = x_for_bits_abs[subnormal_indices] / subnormal_divisor
            
            # 计算尾数的整数表示 - 避免精度损失但使用float32
            scaled_mantissa = mantissa_frac_values * (2.0**num_mant_bits)
            mantissa_int_repr = torch.floor(scaled_mantissa).to(torch.int64)
            mantissa_int_repr = mantissa_int_repr.clamp(min=0, max=(1 << num_mant_bits) - 1)
            
            # 设置尾数位 - 按批量处理
            man_start_idx = 1 + exp_bits
            for i in range(num_mant_bits):
                bit_pos = num_mant_bits - 1 - i
                bit_val = ((mantissa_int_repr >> bit_pos) & 1).float()
                f_x_n_batch_tensors[valid_indices_in_original_batch, man_start_idx + i] = bit_val
        
        return f_x_n_batch_tensors, final_representable_mask
    
    def _prepare_special_patterns(self, x_flat: torch.Tensor, representable_mask: torch.Tensor, 
                                 float_params: dict, device: torch.device, input_dtype: torch.dtype) -> torch.Tensor:
        """准备特殊模式，如零、NaN和Inf的位表示"""
        # 创建输出张量
        n_bits = float_params["total_bits"]
        batch_size = x_flat.numel()
        output_tensor = torch.zeros((batch_size, n_bits), dtype=input_dtype, device=device)
        
        # 处理特殊情况
        non_representable_mask = ~representable_mask
        if not torch.any(non_representable_mask):
            return output_tensor
            
        non_representable_indices = torch.nonzero(non_representable_mask).squeeze(-1)
        non_representable_values = x_flat[non_representable_indices]
        sign_bits = torch.signbit(non_representable_values).to(input_dtype)
        zero_mask = (non_representable_values == 0.0)
        
        # 处理零值 - 批量处理
        if torch.any(zero_mask):
            zero_indices_local = torch.nonzero(zero_mask).squeeze(-1)
            zero_indices_global = non_representable_indices[zero_indices_local]
            output_tensor[zero_indices_global, 0] = sign_bits[zero_indices_local]
        
        # 处理非零、非有限值（NaN/Inf）
        nan_mask = ~zero_mask
        if torch.any(nan_mask):
            nan_indices_local = torch.nonzero(nan_mask).squeeze(-1)
            nan_indices_global = non_representable_indices[nan_indices_local]
            
            # 获取缓存的模式
            pattern_type = "nan"  # 简化：所有非零特殊值视为NaN
            
            # 批量设置指数位和尾数位 - 使用切片代替循环
            output_tensor[nan_indices_global, 1:1+float_params["exp_bits"]] = 1
            
            # 如果是NaN，设置尾数首位为1
            if float_params["man_bits"] > 0:
                output_tensor[nan_indices_global, 1+float_params["exp_bits"]] = 1
            
            # 设置符号位
            output_tensor[nan_indices_global, 0] = sign_bits[nan_indices_local]
        
        return output_tensor
    
    def _process_batch(self, x_batch: torch.Tensor, float_params: dict, 
                       device: torch.device, input_dtype: torch.dtype) -> torch.Tensor:
        """处理单个批次的数据"""
        n_bits = float_params["total_bits"]
        batch_size = x_batch.numel()
        
        # 重置神经元状态
        self.if_neurons.reset()
        
        # 计算f(x,n)批处理
        f_x_n_batch, representable_mask = self._calculate_f_x_n_batch(x_batch, float_params)
        
        # 获取G(n)张量，可能是可学习参数
        threshold_param = self._get_g_n_tensor_for_dtype(float_params, device)
        
        # 确保f_x_n_batch和threshold_param在同一设备上且数据类型匹配
        # 避免不必要的类型转换，保持梯度流
        if f_x_n_batch.dtype != threshold_param.dtype:
            # 把阈值转换为输入数据类型，而不是相反，这样保留梯度流
            threshold_param_converted = threshold_param.to(f_x_n_batch.dtype)
        else:
            threshold_param_converted = threshold_param
            
        # 使用神经元生成脉冲，直接传递阈值参数
        spikes = self.if_neurons(f_x_n_batch, threshold_param_converted)
        
        # 创建结果时保持原始数据类型，避免不必要的转换
        result = torch.zeros((batch_size, n_bits), dtype=input_dtype, device=device)
        
        # 处理可表示的值
        if torch.any(representable_mask):
            representable_indices = torch.nonzero(representable_mask).squeeze(-1)
            # 转换到目标数据类型
            result[representable_indices] = spikes[representable_indices].to(input_dtype)
        
        # 处理特殊情况（零、NaN、无穷大）
        if torch.any(~representable_mask):
            special_patterns = self._prepare_special_patterns(
                x_batch, representable_mask, float_params, device, input_dtype)
            non_representable_indices = torch.nonzero(~representable_mask).squeeze(-1)
            result[non_representable_indices] = special_patterns[non_representable_indices]
        
        # 转置结果
        return result.transpose(0, 1)
    
    def _split_batch_processing(self, x_input_tensor: torch.Tensor, batch_size: int = 10000) -> torch.Tensor:
        """用于处理超大输入的分批处理方法"""
        original_shape = x_input_tensor.shape
        device = x_input_tensor.device
        input_dtype = x_input_tensor.dtype
        float_params = self.get_float_format_params(input_dtype)
        n_bits = float_params["total_bits"]
        
        # 展平输入
        x_flat = x_input_tensor.contiguous().view(-1)
        total_elements = x_flat.numel()
        
        if total_elements <= batch_size:
            # 如果数据量小，直接处理
            return self.forward(x_input_tensor)
        
        # 创建结果张量
        result = torch.zeros((n_bits,) + original_shape, dtype=input_dtype, device=device)
        num_batches = (total_elements + batch_size - 1) // batch_size
        
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, total_elements)
            
            # 处理一个批次
            batch_tensor = x_flat[start_idx:end_idx].view(-1)
            batch_bits = self._process_batch(batch_tensor, float_params, device, input_dtype)
            
            # 将结果复制到相应位置
            flat_result = result.view(n_bits, -1)
            flat_result[:, start_idx:end_idx] = batch_bits
        
        return result
    
    def forward(self, x_input_tensor: torch.Tensor) -> torch.Tensor: 
        """将输入张量转换为位表示形式，确保梯度正确流动"""
        original_numerical_shape = x_input_tensor.shape
        device = x_input_tensor.device 
        num_elements = x_input_tensor.numel()
        input_dtype = x_input_tensor.dtype
        float_params = self.get_float_format_params(input_dtype) 
        n_bits = float_params["total_bits"]
        
        # 处理空输入
        if num_elements == 0: 
            return torch.empty((n_bits,) + original_numerical_shape, dtype=input_dtype, device=device)
            
        # 快速路径: 如果全是零值
        if torch.all(x_input_tensor == 0.0):
            return self._fast_path_zeros(x_input_tensor, float_params, device, input_dtype, original_numerical_shape)
            
        # 大数据集分批处理
        MAX_BATCH_SIZE = 16384  # 根据实际测试调整
        if num_elements > MAX_BATCH_SIZE:
            return self._split_batch_processing(x_input_tensor, MAX_BATCH_SIZE)
        
        # 展平输入张量并确保连续的内存布局
        x_flat = x_input_tensor.contiguous().view(-1)
        
        # 重置神经元状态
        self.if_neurons.reset()
        
        # 计算浮点数的位表示
        f_x_n_batch, representable_mask = self._calculate_f_x_n_batch(x_flat, float_params)
        
        # 获取阈值参数，确保数据类型匹配，这里使用float32以提高精度
        threshold_param = self._get_g_n_tensor_for_dtype(float_params, device)
        threshold_param = threshold_param.to(f_x_n_batch.dtype)
        
        # 重要：将输入符号位直接设置为输出符号位，完全跳过脉冲神经元处理
        # 这避免了符号位信息的丢失，并允许梯度流过
        spikes = self.if_neurons(f_x_n_batch, threshold_param)
        
        # 创建结果时保持与输入相同的设备和数据类型
        result = torch.zeros((num_elements, n_bits), dtype=input_dtype, device=device)
        
        # 处理可表示的值
        if torch.any(representable_mask):
            representable_indices = torch.nonzero(representable_mask).squeeze(-1)
            # 特别处理符号位：保留原始输入的符号位，确保梯度流
            sign_bit_indices = representable_indices
            result[sign_bit_indices, 0] = f_x_n_batch[sign_bit_indices, 0]
            
            # 处理非符号位 - 使用神经元脉冲结果
            result[representable_indices, 1:] = spikes[representable_indices, 1:].to(input_dtype)
        
        # 处理特殊情况（零、NaN、无穷大）
        if torch.any(~representable_mask):
            special_patterns = self._prepare_special_patterns(
                x_flat, representable_mask, float_params, device, input_dtype)
            non_representable_indices = torch.nonzero(~representable_mask).squeeze(-1)
            result[non_representable_indices] = special_patterns[non_representable_indices]
        
        # 转置并重塑结果
        bits_flat_transposed = result.transpose(0, 1)
        output = bits_flat_transposed.reshape(n_bits, *original_numerical_shape)
        
        return output