import math
import torch
import torch.nn.functional as F
import torch.nn as nn
import os
from .soma import Soma
from .synapse import Synapse
from .layer import *
from .neuron import LIFNeuron

class SNNModel(nn.Module):
    def __init__(self, ann_sequential_model: nn.Sequential, float_dtype: torch.dtype, device: str = None, 
                 trainable_entry_synapse: bool = False, trainable_layers_synapse: bool = False, 
                 trainable_layers_indices: list = None):
        """
        初始化SNNModel
        
        参数:
            ann_sequential_model: 原始的PyTorch Sequential模型
            float_dtype: 浮点数类型，如torch.float32
            device: 设备字符串
            trainable_entry_synapse: 是否使入口突触的阈值可训练
            trainable_layers_synapse: 是否使所有层的突触阈值可训练
            trainable_layers_indices: 如果不是所有层都需要可训练突触，则指定需要训练的层索引列表
        """
        super().__init__()
        self.float_dtype = float_dtype
        self.device = torch.device("cuda" if torch.cuda.is_available() and device is None else device or "cpu")

        # 创建可训练的入口突触
        self.entry_synapse = Synapse(trainable=trainable_entry_synapse)
        self.exit_soma = Soma()

        self.snn_layers = nn.ModuleList()
        self.first_layer_is_embedding = False
        
        # 检查第一层是否为Embedding层
        if len(ann_sequential_model) > 0 and isinstance(ann_sequential_model[0], nn.Embedding):
            self.first_layer_is_embedding = True
            
        # 如果未提供trainable_layers_indices，则根据trainable_layers_synapse设置所有层
        if trainable_layers_indices is None:
            trainable_layers_indices = list(range(len(ann_sequential_model))) if trainable_layers_synapse else []
            
        for i, module in enumerate(ann_sequential_model): 
            # 确定当前层是否需要可训练突触
            is_trainable_layer = i in trainable_layers_indices
            
            if isinstance(module, nn.Linear):
                self.snn_layers.append(LinearLayerSNN(module, float_dtype, device, trainable_synapse=is_trainable_layer))
            elif isinstance(module, (nn.ReLU, nn.Sigmoid, nn.Tanh, nn.SiLU)): # 等激活函数
                self.snn_layers.append(ActivationLayerSNN(module, float_dtype, device, trainable_synapse=is_trainable_layer))
            elif isinstance(module, nn.Embedding):
                self.snn_layers.append(EmbeddingLayerSNN(module, float_dtype, device, trainable_synapse=is_trainable_layer))
            elif isinstance(module, nn.Conv2d):
                self.snn_layers.append(ConvLayerSNN(module, float_dtype, device, trainable_synapse=is_trainable_layer))

    def forward(self, x_numerical_input: torch.Tensor) -> torch.Tensor:
        # x_numerical_input 形状 e.g. (BatchSize, InFeatures)
        # 获取输入张量的设备
        input_device = x_numerical_input.device

        if self.first_layer_is_embedding:
            # 如果第一层是Embedding层，直接将输入传递给第一层
            current_data_bits = self.snn_layers[0](x_numerical_input)
            # 从第二层开始处理
            for layer_module in self.snn_layers[1:]:
                current_data_bits = layer_module(current_data_bits)
        else:
            # 1. 初始编码: 数值 -> 比特 (Synapse 输出结构化比特)
            current_data_bits = self.entry_synapse.forward(x_numerical_input) 
            # current_data_bits 形状: (N_bits, BatchSize, InFeatures)

            # 2. 通过SNN层 (它们现在也期望和输出结构化的比特张量)
            for layer_module in self.snn_layers:
                current_data_bits = layer_module(current_data_bits)

        # current_data_bits 是最后一个SNN层的输出比特张量
        # 形状 e.g. (N_bits, BatchSize, FeaturesOut_Model)

        # 3. 最终解码: 比特 (结构化) -> 数值 (结构化)
        output_numerical = self.exit_soma.forward(current_data_bits, self.float_dtype, str(input_device))

        return output_numerical
        
    def load_pretrained_weights(self, state_dict, strict=True, format_type="sequential"):
        """
        从预训练模型的state_dict加载参数到BSE模型
        
        参数:
            state_dict (dict): 预训练模型的状态字典
            strict (bool): 是否要求所有参数名称严格匹配
            format_type (str): 参数格式类型，支持:
                - "sequential": nn.Sequential格式，例如"0.weight", "1.bias"等
                - "named": 命名模块，例如"encoder.layer1.weight"等
                - "flat": 平铺格式，例如直接是"weight", "bias"（适用于单层模型）
            
        返回:
            成功加载的参数数量
        """
        # 创建新的state_dict，将预训练参数映射到SNN层的pytorch_layer
        new_state_dict = {}
        loaded_params = 0
        
        # 根据不同的格式类型处理参数映射
        if format_type == "sequential":
            # 遍历当前模型的所有SNN层
            for i, layer in enumerate(self.snn_layers):
                # 检查是否是包含pytorch_layer的SNN层
                if hasattr(layer, 'pytorch_layer'):
                    # 原始预训练模型中的层参数前缀，例如："0.weight", "0.bias" 等
                    original_prefix = f"{i}."
                    
                    # 遍历预训练的state_dict寻找匹配参数
                    for key, value in state_dict.items():
                        if key.startswith(original_prefix):
                            # 提取参数名（weight, bias等）
                            param_name = key[len(original_prefix):]
                            # 构建新的参数路径，指向SNN层中的pytorch_layer
                            new_key = f"snn_layers.{i}.pytorch_layer.{param_name}"
                            new_state_dict[new_key] = value.to(device=self.device, dtype=self.float_dtype)
                            loaded_params += 1
        
        elif format_type == "named":
            # 构建BSE模型各层的参数映射表
            param_mapping = {}
            
            # 首先构建BSE模型参数映射表
            for i, layer in enumerate(self.snn_layers):
                if hasattr(layer, 'pytorch_layer'):
                    # 首先尝试获取最后一个组件的名称作为识别标识
                    layer_type = type(layer).__name__.lower().replace('layersnn', '')  # 如"linearlayersnn" -> "linear"
                    
                    # 假设每一层都有特定类型特征可以匹配
                    if isinstance(layer, LinearLayerSNN):
                        type_signature = (layer.in_features, layer.out_features)
                        param_mapping[f"snn_layers.{i}.pytorch_layer"] = {
                            'type': 'linear',
                            'signature': type_signature,
                            'index': i
                        }
                    elif isinstance(layer, ConvLayerSNN):
                        # 获取卷积层特征
                        type_signature = (layer.pytorch_layer.in_channels, 
                                         layer.pytorch_layer.out_channels, 
                                         layer.pytorch_layer.kernel_size)
                        param_mapping[f"snn_layers.{i}.pytorch_layer"] = {
                            'type': 'conv',
                            'signature': type_signature,
                            'index': i
                        }
                    elif isinstance(layer, EmbeddingLayerSNN):
                        type_signature = (layer.pytorch_layer.num_embeddings, 
                                         layer.pytorch_layer.embedding_dim)
                        param_mapping[f"snn_layers.{i}.pytorch_layer"] = {
                            'type': 'embedding',
                            'signature': type_signature,
                            'index': i
                        }
            
            # 对预训练模型参数进行启发式匹配
            for key, value in state_dict.items():
                # 提取预训练参数中的层类型和特征信息
                matched = False
                
                # 遍历BSE模型参数映射表
                for bse_path, mapping in param_mapping.items():
                    # 基于层类型和特征匹配
                    if mapping['type'] in key.lower():
                        # 进一步检查张量形状是否一致
                        for bse_param_name, bse_param in self.state_dict().items():
                            if bse_param_name.startswith(bse_path) and bse_param.shape == value.shape:
                                new_state_dict[bse_param_name] = value.to(device=self.device, dtype=self.float_dtype)
                                loaded_params += 1
                                matched = True
                                break
                    
                    if matched:
                        break
        
        elif format_type == "flat":
            # 适用于只有单个层或简单模型的情况
            if len(self.snn_layers) == 1 and hasattr(self.snn_layers[0], 'pytorch_layer'):
                layer = self.snn_layers[0]
                for key, value in state_dict.items():
                    # 直接映射参数，如"weight" -> "snn_layers.0.pytorch_layer.weight"
                    if key in ["weight", "bias"]:
                        new_key = f"snn_layers.0.pytorch_layer.{key}"
                        new_state_dict[new_key] = value.to(device=self.device, dtype=self.float_dtype)
                        loaded_params += 1
        
        else:
            raise ValueError(f"不支持的参数格式类型: {format_type}")
        
        if loaded_params == 0:
            print("警告: 没有找到任何可加载的参数！请检查预训练模型格式。")
            return 0
            
        # 加载参数到模型
        missing_keys, unexpected_keys = self.load_state_dict(new_state_dict, strict=strict)
        
        if len(missing_keys) > 0 and strict:
            print(f"警告: 无法加载的参数: {missing_keys}")
        if len(unexpected_keys) > 0:
            print(f"警告: 预训练模型中存在但当前模型中没有的参数: {unexpected_keys}")
            
        return loaded_params

    @classmethod
    def from_pretrained(cls, ann_sequential_model, pretrained_path, float_dtype=torch.float32, 
                        device=None, strict=True, format_type="sequential", 
                        trainable_entry_synapse=False, trainable_layers_synapse=False,
                        trainable_layers_indices=None):
        """
        从预训练模型文件创建并加载参数的BSE模型
        
        参数:
            ann_sequential_model: 原始PyTorch Sequential模型结构（不需要包含参数）
            pretrained_path: 预训练模型权重文件的路径
            float_dtype: 浮点数精度类型
            device: 设备名称
            strict: 是否要求严格匹配所有参数
            format_type: 参数格式类型，支持"sequential", "named", "flat"
            trainable_entry_synapse: 是否使入口突触的阈值可训练
            trainable_layers_synapse: 是否使所有层的突触阈值可训练
            trainable_layers_indices: 如果不是所有层都需要可训练突触，则指定需要训练的层索引列表
            
        返回:
            加载了预训练参数的SNNModel实例
        """
        # 检查预训练模型文件是否存在
        if not os.path.exists(pretrained_path):
            raise FileNotFoundError(f"预训练模型文件不存在: {pretrained_path}")
            
        # 创建SNNModel实例
        model = cls(ann_sequential_model, float_dtype, device, 
                   trainable_entry_synapse, trainable_layers_synapse, trainable_layers_indices)
        
        # 加载预训练模型参数
        checkpoint = torch.load(pretrained_path, map_location=model.device)
        
        # 处理不同类型的checkpoint格式
        if isinstance(checkpoint, dict):
            # 如果是字典格式，可能包含'state_dict'或其他键
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            elif 'model_state_dict' in checkpoint:
                state_dict = checkpoint['model_state_dict']
            elif 'net' in checkpoint:
                state_dict = checkpoint['net']
            else:
                # 假设整个字典就是state_dict
                state_dict = checkpoint
        else:
            # 不是字典格式的情况
            raise ValueError(f"不支持的checkpoint格式: {type(checkpoint)}")
        
        # 加载参数
        loaded_params = model.load_pretrained_weights(state_dict, strict, format_type)
        print(f"从 {pretrained_path} 加载了 {loaded_params} 个参数")
        
        return model