import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .soma import Soma # 假设这些是您更新后的、保持形状的Soma
from .synapse import Synapse # 和 Synapse

# --- 改进的基类，添加初始化检查 ---
class SNNLayerBase(nn.Module):
    def __init__(self, float_dtype: torch.dtype, device_str: str = None, trainable_synapse: bool = False):
        super().__init__()
        self.float_dtype = float_dtype
        self.device = torch.device("cuda" if torch.cuda.is_available() and device_str is None else device_str or "cpu")
        self.soma = Soma()
        # 创建时传入trainable参数
        self.synapse = Synapse(trainable=trainable_synapse)
        float_params = self.soma.get_float_format_params(float_dtype) 
        self.num_bits_in_float = float_params["total_bits"]
        
        # 确保Synapse的训练状态与层匹配
        if trainable_synapse and not self.synapse.trainable:
            print("Warning: SNNLayerBase detected mismatch - trainable_synapse=True but Synapse is not trainable")
            self.synapse.trainable = True
            # 初始化阈值参数
            for dtype, params in self.synapse.FLOAT_FORMAT_PARAMS.items():
                if hasattr(self.synapse, '_initialize_threshold_param'):
                    self.synapse._initialize_threshold_param(params)
        
    @property
    def weight(self):
        # 尝试获取pytorch层的weight参数
        if hasattr(self, 'pytorch_layer'):
            return self.pytorch_layer.weight
        return None

class LinearLayerSNN(SNNLayerBase):
    def __init__(self, pytorch_linear_layer: torch.nn.Linear, float_dtype: torch.dtype, 
                 device: str = None, trainable_synapse: bool = False):
        super().__init__(float_dtype, device, trainable_synapse)
        # 处理meta tensor的情况
        if hasattr(pytorch_linear_layer.weight, 'is_meta') and pytorch_linear_layer.weight.is_meta:
            # 对于meta tensor，使用to_empty()而不是to()
            self.pytorch_layer = pytorch_linear_layer.to_empty(device=self.device)
        else:
            # 正常情况下使用to()
            self.pytorch_layer = pytorch_linear_layer.to(device=self.device) 
        self.in_features = pytorch_linear_layer.in_features
        self.out_features = pytorch_linear_layer.out_features
        
    def forward(self, x_bit_tensor_input: torch.Tensor) -> torch.Tensor:
        if x_bit_tensor_input.shape[0] != self.num_bits_in_float:
            raise ValueError(f"Input bit tensor's first dim must be {self.num_bits_in_float}.")
        # Input bit tensor: (N_bits, *BatchDims, InFeatures)
        # Last dim of numerical shape is InFeatures
        if x_bit_tensor_input.shape[-1] != self.in_features: 
            raise ValueError(f"Input bit tensor's last effective numerical feature dimension "
                             f"({x_bit_tensor_input.shape[-1]}) must match layer in_features ({self.in_features}).")

        # 获取输入张量的设备
        input_device = x_bit_tensor_input.device
        
        # Soma takes (N_bits, *NumericalShape) and returns (*NumericalShape)
        x_numerical_structured = self.soma.forward(x_bit_tensor_input, self.float_dtype, input_device)
        
        # Reshape for nn.Linear: (Batch_Combined, InFeatures)
        x_for_linear_op = x_numerical_structured.reshape(-1, self.in_features)
        
        # 使用手动实现的线性层，确保梯度流正确传递
        weight = self.pytorch_layer.weight
        bias = self.pytorch_layer.bias
        
        # 确保数据类型匹配，避免不必要的自动转换
        if x_for_linear_op.dtype != weight.dtype:
            weight_converted = weight.to(x_for_linear_op.dtype)
            bias_converted = None if bias is None else bias.to(x_for_linear_op.dtype)
        else:
            weight_converted = weight
            bias_converted = bias
        
        # 使用F.linear直接计算线性变换，保持梯度流
        y_op_output = F.linear(x_for_linear_op, weight_converted, bias_converted)
        
        # 转换为目标数据类型，但只在必要时执行
        if y_op_output.dtype != self.float_dtype:
            y_quantized = y_op_output.to(self.float_dtype)
        else:
            y_quantized = y_op_output
        
        # Reshape numerical output to its conceptual multi-dimensional shape
        original_batch_star_shape = x_bit_tensor_input.shape[1:-1]
        target_numerical_output_shape = original_batch_star_shape + (self.out_features,)
        y_quantized_structured = y_quantized.reshape(target_numerical_output_shape)
        
        # 通过突触生成位表示
        return self.synapse.forward(y_quantized_structured)

class ConvLayerSNN(SNNLayerBase):
    def __init__(self, pytorch_conv_layer: torch.nn.Conv2d, float_dtype: torch.dtype, 
                 device: str = None, trainable_synapse: bool = False):
        super().__init__(float_dtype, device, trainable_synapse)
        # 处理meta tensor的情况
        if hasattr(pytorch_conv_layer.weight, 'is_meta') and pytorch_conv_layer.weight.is_meta:
            # 对于meta tensor，使用to_empty()而不是to()
            self.pytorch_layer = pytorch_conv_layer.to_empty(device=self.device)
        else:
            # 正常情况下使用to()
            self.pytorch_layer = pytorch_conv_layer.to(device=self.device)
        self.in_channels = pytorch_conv_layer.in_channels
        # Out_channels not strictly needed here if pytorch_layer output shape is used by Synapse

    def forward(self, x_bit_tensor_input: torch.Tensor) -> torch.Tensor:
        # x_bit_tensor_input shape: (N_bits, N_batch, C_in, H_in, W_in)
        if x_bit_tensor_input.shape[0] != self.num_bits_in_float:
            raise ValueError(f"Input bit tensor's first dim must be {self.num_bits_in_float}.")
        if x_bit_tensor_input.shape[2] != self.in_channels: # C_in is at index 2
            raise ValueError(f"Input bit tensor's C_in dim (index 2) must match layer in_channels.")

        # 获取输入张量的设备
        input_device = x_bit_tensor_input.device
        
        # Soma input: (N_bits, N, Cin, Hin, Win), output: (N, Cin, Hin, Win)
        x_numerical_structured = self.soma.forward(x_bit_tensor_input, self.float_dtype, input_device)
        
        # 确保数据类型匹配
        weight = self.pytorch_layer.weight
        if x_numerical_structured.dtype != weight.dtype:
            # 转换权重数据类型以匹配输入，保持梯度流
            weight_converted = weight.to(x_numerical_structured.dtype)
            bias_converted = None if self.pytorch_layer.bias is None else self.pytorch_layer.bias.to(x_numerical_structured.dtype)
            
            # 手动执行卷积操作，避免可能的梯度断开
            y_op_output = F.conv2d(
                x_numerical_structured,
                weight_converted,
                bias_converted,
                self.pytorch_layer.stride,
                self.pytorch_layer.padding,
                self.pytorch_layer.dilation,
                self.pytorch_layer.groups
            )
        else:
            # 直接使用pytorch_layer
            y_op_output = self.pytorch_layer(x_numerical_structured)
        
        # 如果需要，转换为目标数据类型
        if y_op_output.dtype != self.float_dtype:
            y_quantized = y_op_output.to(self.float_dtype)
        else:
            y_quantized = y_op_output
        
        # Synapse input: (N, Cout, Hout, Wout), output: (N_bits, N, Cout, Hout, Wout)
        return self.synapse.forward(y_quantized)

class ActivationLayerSNN(SNNLayerBase):
    def __init__(self, activation_module_instance: torch.nn.Module, 
                 float_dtype: torch.dtype, device: str = None, trainable_synapse: bool = False):
        super().__init__(float_dtype, device, trainable_synapse)
        # 处理meta tensor的情况（如果激活函数有权重）
        if hasattr(activation_module_instance, 'weight') and hasattr(activation_module_instance.weight, 'is_meta') and activation_module_instance.weight.is_meta:
            # 对于meta tensor，使用to_empty()而不是to()
            self.activation_module = activation_module_instance.to_empty(device=self.device)
        else:
            # 正常情况下使用to()
            self.activation_module = activation_module_instance.to(self.device) 

    def forward(self, x_bit_tensor_input: torch.Tensor) -> torch.Tensor:
        # x_bit_tensor_input 形状: (N_bits, *NumericalDims)
        if x_bit_tensor_input.shape[0] != self.num_bits_in_float:
            raise ValueError(f"Input bit tensor's first dim must be {self.num_bits_in_float}.")

        # 获取输入张量的设备
        input_device = x_bit_tensor_input.device
        
        # Soma input: (N_bits, *NumDims), output: (*NumDims)
        x_numerical_structured = self.soma.forward(x_bit_tensor_input, self.float_dtype, input_device)
        
        # 大多数激活函数期望float32输入以获得更好的兼容性
        # 但避免不必要的转换，只在必要时转换
        input_dtype = x_numerical_structured.dtype
        if input_dtype != torch.float32 and (
            isinstance(self.activation_module, (nn.ReLU, nn.Sigmoid, nn.Tanh, nn.LeakyReLU, nn.SiLU))
        ):
            # 这些激活函数在float32上表现更好
            activated = self.activation_module(x_numerical_structured.to(torch.float32))
        else:
            # 对于其他激活函数，尝试保持原始数据类型
            activated = self.activation_module(x_numerical_structured)
        
        # 确保输出类型正确
        if activated.dtype != self.float_dtype:
            activated_quantized = activated.to(self.float_dtype)
        else:
            activated_quantized = activated
        
        # Synapse input: (*NumDims), output: (N_bits, *NumDims)
        return self.synapse.forward(activated_quantized)

class EmbeddingLayerSNN(SNNLayerBase):
    def __init__(self, pytorch_embedding_layer: torch.nn.Embedding, 
                 float_dtype: torch.dtype, device: str = None, trainable_synapse: bool = False):
        super().__init__(float_dtype, device, trainable_synapse) 
        # 处理meta tensor的情况
        if hasattr(pytorch_embedding_layer.weight, 'is_meta') and pytorch_embedding_layer.weight.is_meta:
            # 对于meta tensor，使用to_empty()而不是to()
            self.pytorch_layer = pytorch_embedding_layer.to_empty(device=self.device)
        else:
            # 正常情况下使用to()
            self.pytorch_layer = pytorch_embedding_layer.to(device=self.device)
            
        # Embedding weights should be in self.float_dtype for consistency of output precision
        if not hasattr(self.pytorch_layer.weight, 'is_meta') or not self.pytorch_layer.weight.is_meta:
            self.pytorch_layer.weight.data = self.pytorch_layer.weight.data.to(float_dtype) # Cast weights
        
        self.embedding_dim = pytorch_embedding_layer.embedding_dim
        
    def forward(self, x_indices: torch.Tensor) -> torch.Tensor: 
        # Input x_indices is integer indices, NOT a bit tensor
        if x_indices.dtype not in [torch.long, torch.int, torch.int64, torch.int32]:
             print(f"警告: EmbeddingLayerSNN输入x_indices类型为{x_indices.dtype}，通常应为Long或Int类型。")
        
        # 将输入移动到正确的设备
        x_indices = x_indices.to(self.device)
        
        # Embedding lookup uses its own weights (now self.float_dtype)
        # Output will be self.float_dtype
        embedding_output = self.pytorch_layer(x_indices)
        
        # Synapse input: (*x_indices.shape, embedding_dim) of self.float_dtype
        # Output: (N_bits, *x_indices.shape, embedding_dim)
        return self.synapse.forward(embedding_output)