import argparse
import torch
import torch.nn as nn  # 添加对nn模块的导入
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
import random
import seaborn as sns

import models
from util.datasets import build_dataset

# 全局字典，用于存储每个步骤的分布数据
activation_dict = {}


# 创建用于捕获每个步骤数据的钩子函数
def make_hook(step_name):
    def hook(module, input, output):
        # 捕获完整的tensor输出
        if isinstance(output, torch.Tensor):
            # 存储完整批次用于可视化
            activation_dict[step_name] = output.detach().cpu()

    return hook


# 为模型中的所有模块注册钩子函数
def register_all_module_hooks(model):
    """为模型中的所有模块注册钩子函数，以捕获中间激活值"""
    layer_counter = {}

    for name, module in model.named_modules():
        # 跳过容器类型的模块
        if isinstance(module, nn.Sequential) or isinstance(module, nn.ModuleList) or isinstance(module,
                                                                                                models.Spiking_vit_MetaFormer_Spike_SepConv):
            continue

        # 获取模块类型名称
        module_type = module.__class__.__name__

        # 更新该类型的计数器
        if module_type not in layer_counter:
            layer_counter[module_type] = 0
        layer_counter[module_type] += 1

        # 创建层名称：类型_编号
        layer_name = f"{module_type}_{layer_counter[module_type]}"

        # 注册前向钩子
        module.register_forward_hook(make_hook(layer_name))

    print(f"已为模型注册 {sum(layer_counter.values())} 个钩子函数，覆盖 {len(layer_counter)} 种不同类型的模块")
    print(f"模块类型计数: {layer_counter}")


# 为MS_Attention_linear层注册详细的钩子函数
def register_attention_linear_hooks(model):
    """为MS_Attention_linear层注册详细的钩子，以捕获所有中间计算步骤"""
    from models import MS_Attention_linear

    layer_count = 0
    for name, module in model.named_modules():
        if isinstance(module, MS_Attention_linear):
            layer_count += 1
            layer_prefix = f"attn_layer{layer_count}_"

            # 注册钩子用于q, k, v的卷积输出
            if hasattr(module, "q_conv") and len(module.q_conv) > 0:
                module.q_conv[-1].register_forward_hook(make_hook(f"{layer_prefix}q_conv"))

            if hasattr(module, "k_conv") and len(module.k_conv) > 0:
                module.k_conv[-1].register_forward_hook(make_hook(f"{layer_prefix}k_conv"))

            if hasattr(module, "v_conv") and len(module.v_conv) > 0:
                module.v_conv[-1].register_forward_hook(make_hook(f"{layer_prefix}v_conv"))

            # 注册钩子用于spike函数输出
            if hasattr(module, "head_spike"):
                module.head_spike.register_forward_hook(make_hook(f"{layer_prefix}head_spike"))

            if hasattr(module, "q_spike"):
                module.q_spike.register_forward_hook(make_hook(f"{layer_prefix}q_spike"))

            if hasattr(module, "k_spike"):
                module.k_spike.register_forward_hook(make_hook(f"{layer_prefix}k_spike"))

            if hasattr(module, "v_spike"):
                module.v_spike.register_forward_hook(make_hook(f"{layer_prefix}v_spike"))

            if hasattr(module, "attn_spike"):
                module.attn_spike.register_forward_hook(make_hook(f"{layer_prefix}attn_spike"))

            # 注册钩子用于最终投影输出
            if hasattr(module, "proj_conv") and len(module.proj_conv) > 0:
                module.proj_conv[-1].register_forward_hook(make_hook(f"{layer_prefix}proj_output"))

    print(f"已为 {layer_count} 个 MS_Attention_linear 层注册详细钩子")


# 修改MS_Attention_linear类的前向传播方法，以捕获更多中间步骤
def modify_ms_attention_linear():
    from models import MS_Attention_linear

    # 保存原始的前向传播方法
    original_forward = MS_Attention_linear.forward

    # 定义新的前向传播方法，添加额外钩子
    def new_forward(self, x):
        B, C, H, W = x.shape
        N = H * W
        C_v = int(C * self.lamda_ratio)

        # 获取层ID用于存储中间结果
        layer_id = id(self) % 1000  # 使用对象ID的一部分作为唯一标识符
        layer_prefix = f"attn_details{layer_id}_"

        # 应用head_spike
        x = self.head_spike(x)

        # 应用卷积
        q = self.q_conv(x)
        k = self.k_conv(x)
        v = self.v_conv(x)

        # 存储卷积输出
        activation_dict[f"{layer_prefix}q_conv_out"] = q.detach().cpu()
        activation_dict[f"{layer_prefix}k_conv_out"] = k.detach().cpu()
        activation_dict[f"{layer_prefix}v_conv_out"] = v.detach().cpu()

        # 应用spike函数
        q = self.q_spike(q)
        k = self.k_spike(k)

        # 存储spike输出
        activation_dict[f"{layer_prefix}q_spike_out"] = q.detach().cpu()
        activation_dict[f"{layer_prefix}k_spike_out"] = k.detach().cpu()

        # 重塑操作
        q = q.flatten(2)
        q = (
            q.transpose(-1, -2)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
            .contiguous()
        )

        k = k.flatten(2)
        k = (
            k.transpose(-1, -2)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
            .contiguous()
        )

        v = v.flatten(2)
        v = (
            v.transpose(-1, -2)
            .reshape(B, N, self.num_heads, C_v // self.num_heads)
            .permute(0, 2, 1, 3)
            .contiguous()
        )

        # 存储重塑后的q, k, v
        activation_dict[f"{layer_prefix}q_reshaped"] = q.detach().cpu()
        activation_dict[f"{layer_prefix}k_reshaped"] = k.detach().cpu()
        activation_dict[f"{layer_prefix}v_reshaped"] = v.detach().cpu()

        # 缩放q和k
        q_scaled = q * 4.0
        k_scaled = k * 4.0

        # 存储缩放后的q和k
        activation_dict[f"{layer_prefix}q_scaled"] = q_scaled.detach().cpu()
        activation_dict[f"{layer_prefix}k_scaled"] = k_scaled.detach().cpu()

        # 计算注意力权重
        attn = q_scaled @ k_scaled.transpose(-2, -1)

        # 存储原始注意力权重
        activation_dict[f"{layer_prefix}attn_raw"] = attn.detach().cpu()

        # 应用softmax
        from models import fp16_optimized_exp2_softmax
        x = fp16_optimized_exp2_softmax(attn)

        # 存储softmax输出
        activation_dict[f"{layer_prefix}attn_softmax"] = x.detach().cpu()

        # 应用注意力到值向量
        x = x @ v

        # 存储注意力输出
        activation_dict[f"{layer_prefix}attn_output"] = x.detach().cpu()

        # 重塑并应用后续处理
        x = x.transpose(2, 3).reshape(B, C_v, N).contiguous()

        # 存储转置后的结果
        activation_dict[f"{layer_prefix}reshaped_before_spike"] = x.detach().cpu()

        x = self.attn_spike(x)

        # 存储spike后的结果
        activation_dict[f"{layer_prefix}after_attn_spike"] = x.detach().cpu()

        x = x.reshape(B, C_v, H, W)

        # 存储重塑回空间维度的结果
        activation_dict[f"{layer_prefix}reshaped_spatial"] = x.detach().cpu()

        x = self.proj_conv(x).reshape(B, C, H, W)

        # 存储最终输出
        activation_dict[f"{layer_prefix}final_output"] = x.detach().cpu()

        return x

    # 替换前向传播方法
    MS_Attention_linear.forward = new_forward
    print("已修改 MS_Attention_linear.forward 以捕获更多中间步骤")


# 为SepConv_Spike层注册详细钩子
def register_sepconv_hooks(model):
    from models import SepConv_Spike

    layer_count = 0
    for name, module in model.named_modules():
        if isinstance(module, SepConv_Spike):
            layer_count += 1
            layer_prefix = f"sepconv{layer_count}_"

            # 注册spike函数钩子
            if hasattr(module, "spike1"):
                module.spike1.register_forward_hook(make_hook(f"{layer_prefix}spike1"))

            if hasattr(module, "spike2"):
                module.spike2.register_forward_hook(make_hook(f"{layer_prefix}spike2"))

            if hasattr(module, "spike3"):
                module.spike3.register_forward_hook(make_hook(f"{layer_prefix}spike3"))

            # 注册卷积输出钩子
            if hasattr(module, "pwconv1"):
                if isinstance(module.pwconv1, nn.Sequential):
                    module.pwconv1.register_forward_hook(make_hook(f"{layer_prefix}pwconv1"))

            if hasattr(module, "dwconv"):
                if isinstance(module.dwconv, nn.Sequential):
                    module.dwconv.register_forward_hook(make_hook(f"{layer_prefix}dwconv"))

            if hasattr(module, "pwconv2"):
                if isinstance(module.pwconv2, nn.Sequential):
                    module.pwconv2.register_forward_hook(make_hook(f"{layer_prefix}pwconv2"))

    print(f"已为 {layer_count} 个 SepConv_Spike 层注册详细钩子")


# 为MS_ConvBlock_spike_SepConv层注册钩子
def register_convblock_hooks(model):
    from models import MS_ConvBlock_spike_SepConv

    layer_count = 0
    for name, module in model.named_modules():
        if isinstance(module, MS_ConvBlock_spike_SepConv):
            layer_count += 1
            layer_prefix = f"convblock{layer_count}_"

            # 注册MLP部分的钩子
            if hasattr(module, "spike1"):
                module.spike1.register_forward_hook(make_hook(f"{layer_prefix}spike1"))

            if hasattr(module, "spike2"):
                module.spike2.register_forward_hook(make_hook(f"{layer_prefix}spike2"))

            if hasattr(module, "conv1"):
                module.conv1.register_forward_hook(make_hook(f"{layer_prefix}conv1"))

            if hasattr(module, "conv2"):
                module.conv2.register_forward_hook(make_hook(f"{layer_prefix}conv2"))

            if hasattr(module, "bn1"):
                module.bn1.register_forward_hook(make_hook(f"{layer_prefix}bn1"))

            if hasattr(module, "bn2"):
                module.bn2.register_forward_hook(make_hook(f"{layer_prefix}bn2"))

    print(f"已为 {layer_count} 个 MS_ConvBlock_spike_SepConv 层注册详细钩子")


# 为MS_MLP层注册钩子
def register_mlp_hooks(model):
    from models import MS_MLP

    layer_count = 0
    for name, module in model.named_modules():
        if isinstance(module, MS_MLP):
            layer_count += 1
            layer_prefix = f"mlp{layer_count}_"

            # 注册各组件的钩子
            if hasattr(module, "fc1_spike"):
                module.fc1_spike.register_forward_hook(make_hook(f"{layer_prefix}fc1_spike"))

            if hasattr(module, "fc1_conv"):
                module.fc1_conv.register_forward_hook(make_hook(f"{layer_prefix}fc1_conv"))

            if hasattr(module, "fc1_bn"):
                module.fc1_bn.register_forward_hook(make_hook(f"{layer_prefix}fc1_bn"))

            if hasattr(module, "fc2_spike"):
                module.fc2_spike.register_forward_hook(make_hook(f"{layer_prefix}fc2_spike"))

            if hasattr(module, "fc2_conv"):
                module.fc2_conv.register_forward_hook(make_hook(f"{layer_prefix}fc2_conv"))

            if hasattr(module, "fc2_bn"):
                module.fc2_bn.register_forward_hook(make_hook(f"{layer_prefix}fc2_bn"))

    print(f"已为 {layer_count} 个 MS_MLP 层注册详细钩子")


# 为下采样层注册钩子
def register_downsampling_hooks(model):
    from models import MS_DownSampling

    layer_count = 0
    for name, module in model.named_modules():
        if isinstance(module, MS_DownSampling):
            layer_count += 1
            layer_prefix = f"downsample{layer_count}_"

            # 注册编码器组件的钩子
            if hasattr(module, "encode_conv"):
                module.encode_conv.register_forward_hook(make_hook(f"{layer_prefix}encode_conv"))

            if hasattr(module, "encode_bn"):
                module.encode_bn.register_forward_hook(make_hook(f"{layer_prefix}encode_bn"))

            if hasattr(module, "encode_spike"):
                module.encode_spike.register_forward_hook(make_hook(f"{layer_prefix}encode_spike"))

    print(f"已为 {layer_count} 个 MS_DownSampling 层注册详细钩子")


# 绘制分布图
def plot_distributions(output_dir):
    os.makedirs(output_dir, exist_ok=True)

    # 按层排序键名
    sorted_keys = sorted(activation_dict.keys())

    # 创建各类型目录
    module_types = ["attn_layer", "attn_details", "sepconv", "convblock", "mlp", "downsample",
                    "MultiSpike", "Conv2d", "BatchNorm2d", "other"]
    for module_type in module_types:
        os.makedirs(os.path.join(output_dir, module_type), exist_ok=True)

    # 创建分布索引文件
    index_file = os.path.join(output_dir, "distribution_index.txt")
    with open(index_file, 'w', encoding='utf-8') as f:
        f.write("分布可视化索引\n")
        f.write("=" * 50 + "\n\n")
        f.write("按模块类型分类:\n\n")

        # 为每种模块类型创建计数器
        type_counters = {t: 0 for t in module_types}

        # 提前计算每种类型的数量
        for key in sorted_keys:
            data = activation_dict[key]
            if data is None or data.numel() == 0:
                continue

            module_type = "other"
            for type_prefix in module_types[:-1]:  # 排除"other"
                if type_prefix in key:
                    module_type = type_prefix
                    break

            type_counters[module_type] += 1

        # 写入索引文件的类型统计
        for module_type in module_types:
            f.write(f"{module_type}: {type_counters[module_type]}个分布\n")

        f.write("\n完整分布列表:\n\n")

    # 统计不同值范围的分布
    total_distributions = 0
    zero_heavy_distributions = 0
    negative_distributions = 0

    # 处理每个激活分布
    for key in sorted_keys:
        data = activation_dict[key]
        if data is None or data.numel() == 0:
            print(f"跳过 {key} - 无数据")
            continue

        total_distributions += 1

        # 根据键名确定模块类型目录
        module_type = "other"
        for type_prefix in module_types[:-1]:  # 排除"other"
            if type_prefix in key:
                module_type = type_prefix
                break

        # 创建干净的文件名
        filename = key.replace('.', '_').replace('/', '_')

        # 创建图表
        plt.figure(figsize=(12, 8))

        # 创建直方图子图
        plt.subplot(2, 1, 1)

        # 将tensor展平用于直方图
        flat_data = data.flatten().numpy()

        # 检查分布特性
        if np.sum(flat_data == 0) / len(flat_data) > 0.5:
            zero_heavy_distributions += 1

        if np.min(flat_data) < 0:
            negative_distributions += 1

        # 对于非常大的数据集进行采样
        if len(flat_data) > 10000000:  # 如果超过1000万个值
            print(f"采样 {key} 数据用于可视化（非常大的tensor）")
            indices = np.random.choice(len(flat_data), size=10000000, replace=False)
            flat_data = flat_data[indices]

        # 绘制直方图
        sns.histplot(flat_data, bins=100, kde=True)
        plt.title(f'单批次分布 - {key}')
        plt.xlabel('值')
        plt.ylabel('频率')

        # 创建统计信息子图
        plt.subplot(2, 1, 2)
        plt.axis('off')

        # 计算统计信息
        mean_val = np.mean(flat_data)
        median_val = np.median(flat_data)
        min_val = np.min(flat_data)
        max_val = np.max(flat_data)
        std_val = np.std(flat_data)

        # 计算特殊值数量
        zero_count = np.sum(flat_data == 0)
        zero_percent = (zero_count / len(flat_data)) * 100

        # 计算是否包含NaN或Inf
        nan_count = np.sum(np.isnan(flat_data))
        inf_count = np.sum(np.isinf(flat_data))

        # 计算百分位数
        percentiles = [1, 5, 25, 50, 75, 95, 99]
        percentile_values = np.percentile(flat_data, percentiles)

        # 创建统计信息文本
        stats_text = f"""
        统计信息 - {key}:

        形状: {data.shape}
        均值: {mean_val:.6f}
        中位数: {median_val:.6f}
        最小值: {min_val:.6f}
        最大值: {max_val:.6f}
        标准差: {std_val:.6f}

        零值数量: {zero_count} ({zero_percent:.2f}%)
        NaN值: {nan_count}
        Inf值: {inf_count}

        百分位数:
        """

        for p, val in zip(percentiles, percentile_values):
            stats_text += f"    {p}%: {val:.6f}\n"

        plt.text(0.1, 0.5, stats_text, fontsize=12, family='monospace')

        # 保存图表
        save_path = os.path.join(output_dir, module_type, f"{filename}_distribution.png")
        plt.tight_layout()
        plt.savefig(save_path, dpi=300)
        plt.close()

        # 追加到索引文件
        with open(index_file, 'a', encoding='utf-8') as f:
            rel_path = os.path.join(module_type, f"{filename}_distribution.png")
            f.write(f"{key}: {rel_path}\n")

        print(f"保存分布图: {key} -> {module_type}目录")

        # 清除数据以释放内存
        activation_dict[key] = None

    # 添加分析摘要到索引文件
    with open(index_file, 'a', encoding='utf-8') as f:
        f.write("\n\n分布分析摘要\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"总分布数量: {total_distributions}\n")
        f.write(
            f"零值占比超过50%的分布数量: {zero_heavy_distributions} ({zero_heavy_distributions / total_distributions * 100:.2f}%)\n")
        f.write(
            f"包含负值的分布数量: {negative_distributions} ({negative_distributions / total_distributions * 100:.2f}%)\n")

    print(f"\n分布统计摘要:")
    print(f"总分布数量: {total_distributions}")
    print(
        f"零值占比超过50%的分布: {zero_heavy_distributions} ({zero_heavy_distributions / total_distributions * 100:.2f}%)")
    print(f"包含负值的分布: {negative_distributions} ({negative_distributions / total_distributions * 100:.2f}%)")


def get_args():
    parser = argparse.ArgumentParser('分布可视化脚本', add_help=False)

    # 模型参数
    parser.add_argument('--model', default='Efficient_Spiking_Transformer_l', type=str, metavar='MODEL',
                        help='要训练的模型名称')
    parser.add_argument('--model_mode', default='ms', type=str, help='要训练的模型模式')
    parser.add_argument('--input_size', default=224, type=int, help='输入图像大小')
    parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path率')

    # 数据集参数
    parser.add_argument('--data_path', default='/dev/shm/imagenet-zdh/ImageNet-1K', type=str, help='数据集路径')
    parser.add_argument('--nb_classes', default=1000, type=int, help='类别数量')
    parser.add_argument('--batch_size', default=1, type=int, help='批次大小 - 仅处理一个批次')
    parser.add_argument('--num_workers', default=4, type=int)
    parser.add_argument('--pin_mem', action='store_true', help='在DataLoader中锁定CPU内存')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

    # 输出参数
    parser.add_argument('--output_dir', default='./single_batch_distribution', help='保存输出的路径')

    # 检查点参数
    parser.add_argument('--resume', default='', help='从检查点恢复', required=True)

    # 设备参数
    parser.add_argument('--device', default='cuda', help='训练/测试使用的设备')
    parser.add_argument('--seed', default=0, type=int)

    # 分布式参数
    parser.add_argument('--world_size', default=1, type=int, help='分布式进程数量')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://', help='设置分布式训练的url')

    # 时间步长参数
    parser.add_argument('--time_steps', default=1, type=int)

    return parser.parse_args()


def main():
    args = get_args()

    print("\n" + "=" * 50)
    print("单批次分布可视化工具 - Efficient_Spiking_Transformer_l")
    print("=" * 50 + "\n")

    print("工作目录: {}".format(os.path.dirname(os.path.realpath(__file__))))
    print("参数:\n{}".format(args).replace(", ", ",\n"))

    device = torch.device(args.device)

    # 为了可重现性固定随机种子
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # 导入必要的torch模块
    import torch.nn as nn

    # 检查CUDA是否可用
    if args.device == 'cuda' and not torch.cuda.is_available():
        print("警告: 要求使用CUDA，但CUDA不可用。将切换到CPU。")
        device = torch.device('cpu')
        args.device = 'cpu'

    # 创建输出目录
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
        print(f"将保存分布图到: {os.path.abspath(args.output_dir)}")

    # 构建数据集
    print("构建数据集...")
    dataset_val = build_dataset(is_train=False, args=args)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    # 使用适当的批次大小
    actual_batch_size = min(4, args.batch_size)  # 限制批次大小以避免CUDA内存错误
    print(f"使用批次大小 {actual_batch_size} 进行可视化 (原始: {args.batch_size})")

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        sampler=sampler_val,
        batch_size=actual_batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=False,
    )

    # 创建模型
    print(f"创建模型: {args.model}")
    model = models.__dict__[args.model]()
    model.T = args.time_steps

    # 加载检查点
    if args.resume:
        if os.path.isfile(args.resume):
            print(f"加载检查点: {args.resume}")
            checkpoint = torch.load(args.resume, map_location='cpu')

            # 处理不同的检查点格式
            if 'model' in checkpoint:
                checkpoint_model = checkpoint['model']
            elif 'state_dict' in checkpoint:
                checkpoint_model = checkpoint['state_dict']
            else:
                checkpoint_model = checkpoint

            # 加载模型权重
            msg = model.load_state_dict(checkpoint_model, strict=False)
            print(f"检查点加载完成: {msg}")
        else:
            print(f"未找到检查点: {args.resume}")
            return

    model.to(device)
    model.eval()

    # 修改MS_Attention_linear以捕获额外步骤
    modify_ms_attention_linear()

    # 注册钩子以捕获分布
    register_all_module_hooks(model)
    register_attention_linear_hooks(model)
    register_sepconv_hooks(model)
    register_convblock_hooks(model)
    register_mlp_hooks(model)
    register_downsampling_hooks(model)

    # 处理一个批次
    try:
        print("获取单个批次的图像...")
        data_iter = iter(data_loader_val)
        images, _ = next(data_iter)

        # 将图像移至设备
        images = images.to(device)

        print(f"执行前向传播分析单个批次（{len(images)}张图像）...")
        # 前向传播以捕获分布
        with torch.no_grad():
            # 清空CUDA缓存以确保有足够内存
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # 执行模型推理
            print("开始模型推理...")
            model(images)
            print("模型推理完成，已捕获所有中间激活值")

            # 打印捕获的激活层数量
            print(f"总共捕获了 {len(activation_dict)} 个不同的激活分布")

        print("前向传播完成。正在绘制分布图...")
        # 绘制并保存分布可视化
        plot_distributions(args.output_dir)

        print(f"分布分析完成! 所有图表已保存到 {args.output_dir} 目录")

    except RuntimeError as e:
        if "CUDA out of memory" in str(e) or "CUBLAS_STATUS_ALLOC_FAILED" in str(e):
            print("CUDA内存不足！尝试减小批次大小并重新运行")
            print(f"错误详情: {e}")
        else:
            print(f"运行时错误: {e}")
        import traceback
        traceback.print_exc()
    except Exception as e:
        print(f"处理过程中发生错误: {e}")
        import traceback
        traceback.print_exc()


if __name__ == '__main__':
    main()