import torch
import numpy as np
import matplotlib.pyplot as plt
import PIL.Image
from torchvision.transforms.functional import resize, to_pil_image
import torch.nn.functional as F
import math
import os
from torch import nn
from diffusers.models.attention_processor import Attention, AttnProcessor2_0


class StoreAttnProcessor2_0(AttnProcessor2_0):
    def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs):
        # 首先使用父类方法计算原始输出
        # print(f"[StoreAttnProcessor] Running for module ID: {id(attn)}") # 调试：确认处理器运行
        output = super().__call__(attn, hidden_states,
                                  encoder_hidden_states, attention_mask, **kwargs)

        # 仅在需要时重新计算注意力分数用于存储
        # print(f"[StoreAttnProcessor] encoder_hidden_states is None: {encoder_hidden_states is None}") # 调试：检查是否为交叉注意力
        if encoder_hidden_states is not None:  # 只处理交叉注意力
            try:
                batch_size, sequence_length, _ = hidden_states.shape

                query = attn.to_q(hidden_states)
                key = attn.to_k(encoder_hidden_states)

                query = attn.head_to_batch_dim(query)
                key = attn.head_to_batch_dim(key)

                # print(f"[StoreAttnProcessor] Query shape: {query.shape}, Key shape: {key.shape}") # 调试：打印形状
                attention_probs = attn.get_attention_scores(
                    query, key, attention_mask)
                # print(f"[StoreAttnProcessor] Calculated attention_probs shape: {attention_probs.shape}") # 调试：打印计算出的分数形状

                # 存储注意力分数供后续可视化使用
                attn.attn_probs = attention_probs.detach()
                # print(f"[StoreAttnProcessor] Stored attn_probs for module ID: {id(attn)}") # 调试：确认存储

            except Exception as e:
                import traceback
                print(
                    f"[StoreAttnProcessor] Error calculating/storing scores for ID {id(attn)}: {e}")
                # traceback.print_exc() # 可选：打印详细堆栈

        # 返回原始输出，不影响模型功能
        return output


# 全局变量保持不变
text_attention_maps = {}
image_attention_maps = {}
hook_handles = []
module_id_to_name = {}


def attention_hook(module, input, output):
    """更精确地捕获 Attention 模块的注意力权重"""
    # 确保这个钩子只处理 Attention 类型的模块
    from diffusers.models.attention_processor import Attention
    if not isinstance(module, Attention):
        # print(f"[Hook] Skipping non-Attention module: {type(module)}") # 调试：打印跳过的模块
        return  # 忽略其他类型的模块

    module_id = str(id(module))
    module_name = module_id_to_name.get(module_id, f"unknown_{module_id[-6:]}")
    # print(f"[Hook] Running for Attention module ID: {module_id}") # 调试：确认钩子运行

    is_image_attn = hasattr(module, '_is_image_attn') and module._is_image_attn
    attn_score_source = None

    # 打印模块的属性，帮助查找attn_probs
    # print(f"[Hook] Attributes for module ID {module_id}: {dir(module)}")

    # --- 修改开始 ---
    # 主要方式：检查 Attention 模块自身是否有名为 attn_probs 的属性
    # 这是 diffusers >= 0.20.0 的常见做法
    if hasattr(module, 'attn_probs') and module.attn_probs is not None:
        # print(f"[Hook] Found attn_probs attribute for ID: {module_id}") # 调试：确认找到属性
        attn_score_source = module.attn_probs
        # print(f"[Hook] attn_probs type: {type(attn_score_source)}, shape: {attn_score_source.shape if isinstance(attn_score_source, torch.Tensor) else 'N/A'}") # 调试：打印类型和形状

    # 备用方式 (可能适用于旧版本或特定 Attention 实现，但不太常见):
    # 检查输出元组。注意：Attention 模块的标准输出通常只是 hidden_states 张量。
    # 只有在特定的 AttentionProcessor 修改了输出格式时，这才会起作用。
    # elif isinstance(output, tuple) and len(output) > 1 and isinstance(output[1], torch.Tensor):
    #     attn_score_source = output[1] # 可能性较低，暂时注释
    # else:
    #     # 调试：未找到属性
    #     print(f"[Hook] Did NOT find attn_probs attribute for ID: {module_id}")

    # --- 修改结束 ---

    if attn_score_source is not None:
        # 确保分数是浮点类型，避免可能的整数类型错误
        if attn_score_source.dtype != torch.float32 and attn_score_source.dtype != torch.float16 and attn_score_source.dtype != torch.bfloat16:
            print(
                f"Warning: Captured attention scores are not float type ({attn_score_source.dtype}) from {module.__class__.__name__} (ID: {module_id}). Skipping visualization for this layer.")
            return

        scores = attn_score_source.detach().cpu().to(torch.float32)  # 转为CPU上的float32处理
        # print(f"Hook captured scores from {module.__class__.__name__} (ID: {module_id}), Image Attn: {is_image_attn}, Shape: {scores.shape}")
        if is_image_attn:
            image_attention_maps[module_id] = scores
        else:
            text_attention_maps[module_id] = scores
    # else:
    #     # 可以取消注释以调试哪个 Attention 模块没有找到分数
    #     print(f"Hook did NOT capture scores from Attention module (ID: {module_id}), Image Attn: {is_image_attn}")


# register_attention_hooks, clear_hooks_and_maps, visualize_text_and_image_attention, _visualize_single_attention 函数保持不变
def register_attention_hooks(model):
    """注册钩子到模型的注意力模块"""
    global hook_handles, module_id_to_name, text_attention_maps, image_attention_maps
    from diffusers.models.attention_processor import Attention
    from diffusers.models.transformer_2d import Transformer2DModel
    from torch import nn

    module_id_to_name = {}

    # 1. 预先识别图像注意力模块 ID
    image_attn_module_ids = set()
    for _, parent_module in model.named_modules():
        # 检查模块是否有 'image_attentions' 属性并且它是一个 ModuleList
        if hasattr(parent_module, 'image_attentions') and isinstance(parent_module.image_attentions, nn.ModuleList):
            # 遍历 ModuleList 中的每个模块 (通常是 Transformer2DModel)
            for image_transformer_or_attn in parent_module.image_attentions:
                # 递归查找该模块及其子模块中的所有 Attention 层
                for _, sub_module in image_transformer_or_attn.named_modules():
                    if isinstance(sub_module, Attention):
                        image_attn_module_ids.add(id(sub_module))

    # 2. 注册钩子 ONLY on Attention modules and mark them
    hook_registered_ids = set()
    for name, module in model.named_modules():
        # 只在 Attention 实例上附加钩子
        if isinstance(module, Attention):
            module_id = id(module)
            module_id_str = str(module_id)
            if module_id not in hook_registered_ids:  # 避免重复注册
                # 标记是否为图像注意力模块
                module_id_to_name[module_id_str] = name
                if module_id in image_attn_module_ids:
                    setattr(module, '_is_image_attn', True)
                    # print(f"Marking {name} (ID: {module_id}) as Image Attention")
                else:
                    setattr(module, '_is_image_attn', False)
                    # print(f"Marking {name} (ID: {module_id}) as Text/Self Attention")

                handle = module.register_forward_hook(attention_hook)
                hook_handles.append(handle)
                hook_registered_ids.add(module_id)

    print(f"Registered {len(hook_handles)} hooks on Attention modules.")
    print(
        f"Identified {len(image_attn_module_ids)} potential image attention module IDs.")
    print(f"module_id_to_name: {module_id_to_name}")


def clear_hooks_and_maps():
    """移除所有钩子并清空存储的注意力图"""
    global text_attention_maps, image_attention_maps, hook_handles, module_id_to_name
    # print(f"Removing {len(hook_handles)} hooks.")
    for handle in hook_handles:
        handle.remove()
    hook_handles = []
    text_attention_maps = {}
    image_attention_maps = {}
    module_id_to_name = {}


def visualize_text_and_image_attention(image, height, width, text_layer_ids=None, image_layer_ids=None,
                                       head_idx=0, save_path_prefix="attention_map"):
    """同时可视化文本和图像注意力图 (使用 module ID)"""
    os.makedirs(os.path.dirname(save_path_prefix), exist_ok=True)

    # 可视化文本注意力图
    if text_attention_maps and text_layer_ids:
        for text_id in text_layer_ids:
            if text_id in text_attention_maps:
                layer_short_name = module_id_to_name.get(
                    str(text_id), f"unknown_{text_id[-6:]}")
                _visualize_single_attention(
                    image, height, width, text_attention_maps[text_id],
                    f"文本注意力 ({layer_short_name}, 头: {head_idx})",
                    head_idx, f"{save_path_prefix}_text_id{layer_short_name}.png"
                )

    # 可视化图像注意力图
    if image_attention_maps and image_layer_ids:
        for image_id in image_layer_ids:
            if image_id in image_attention_maps:
                layer_short_name = module_id_to_name.get(
                    str(image_id), f"unknown_{image_id[-6:]}")
                _visualize_single_attention(
                    image, height, width, image_attention_maps[image_id],
                    f"图像注意力 ({layer_short_name}, 头: {head_idx})",
                    head_idx, f"{save_path_prefix}_image_id{layer_short_name}.png"
                )


def _visualize_single_attention(image, height, width, attn_scores, title, head_idx=0, save_path="attention_map.png"):
    """可视化单个注意力图"""
    # print(f"Visualizing: {title}, Raw score shape: {attn_scores.shape}")

    try:
        if attn_scores.dim() == 3:
            # 形状 [batch * num_heads, seq_len_q, seq_len_k]
            batch_x_heads, seq_len_q, seq_len_k = attn_scores.shape
            # 假设推理时 batch_size = 1
            num_heads = batch_x_heads
            if head_idx >= num_heads:
                print(
                    f"错误: head_idx {head_idx} 超出范围 (共 {num_heads} 个头)。回退到头 0。")
                head_idx = 0
            # Shape: [seq_len_q, seq_len_k]
            head_attention = attn_scores[head_idx]

        elif attn_scores.dim() == 4:
            # 形状 [batch, num_heads, seq_len_q, seq_len_k]
            batch, num_heads, seq_len_q, seq_len_k = attn_scores.shape
            if batch != 1:
                print(
                    f"Warning: 可视化假设 batch_size=1，但得到 batch_size={batch}。将只使用第一个 batch item。")
            if head_idx >= num_heads:
                print(
                    f"错误: head_idx {head_idx} 超出范围 (共 {num_heads} 个头)。回退到头 0。")
                head_idx = 0
            # Shape: [seq_len_q, seq_len_k]
            head_attention = attn_scores[0, head_idx]
        else:
            print(f"错误: 不支持的注意力分数维度 {attn_scores.dim()}。")
            return

        processed_attention = head_attention.mean(
            dim=0)  # 对 Query 取平均，Shape: [seq_len_k]
        # --- 修改结束 ---

        processed_attention = head_attention.mean(
            dim=1)  # 对 Key 取平均，Shape: [seq_len_q]
        map_size = len(processed_attention)

        # 检查 map_size 是否等于 height * width
        if map_size == height * width:
            side_len_h, side_len_w = height, width
        else:
            # 尝试作为正方形重塑 (最后的手段)
            side_len = math.isqrt(map_size)
            if side_len * side_len == map_size:
                side_len_h, side_len_w = side_len, side_len
                print(
                    f"Warning: map_size {map_size} != height*width {height*width}. 将注意力图重塑为 {side_len}x{side_len}。")
            else:
                print(
                    f"错误: 无法将 map_size {map_size} (seq_len_q) 重塑为 {height}x{width} 或近似正方形。")
                return

        # 重塑为2D地图
        attention_map_2d = processed_attention.view(
            side_len_h, side_len_w).numpy()

        # 上采样到图像尺寸
        img_h, img_w = np.array(image).shape[:2]
        # 确保 attention_map_2d 是有效的 numpy 数组
        if not isinstance(attention_map_2d, np.ndarray):
            print(f"错误: attention_map_2d 不是有效的 numpy 数组。")
            return

        attention_map_tensor = torch.from_numpy(attention_map_2d).unsqueeze(
            0).unsqueeze(0)  # 添加 batch 和 channel 维度
        # 确保数据类型是浮点
        attention_map_tensor = attention_map_tensor.float()

        attention_map_resized = F.interpolate(
            attention_map_tensor, size=(img_h, img_w), mode='bilinear', align_corners=False
        )
        attention_map_resized = attention_map_resized.squeeze().numpy()

        # 可视化
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(image)
        plt.title("Input Image")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(image)
        # 归一化注意力图以便更好地可视化
        map_min, map_max = attention_map_resized.min(), attention_map_resized.max()
        if map_max > map_min:
            normalized_map = (attention_map_resized -
                              map_min) / (map_max - map_min)
        else:
            normalized_map = np.zeros_like(attention_map_resized)

        plt.imshow(normalized_map, cmap='jet',
                   alpha=0.6)  # 使用 'viridis' colormap
        plt.title(title)
        plt.axis('off')

        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()
        print(f"Attention map saved to {save_path}")

    except Exception as e:
        import traceback
        print(f"可视化注意力图时出错 for '{title}': {e}")
        traceback.print_exc()  # 打印详细错误堆栈
