
# -*- coding: utf-8 -*-
import os
import math
import argparse
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info


# ===================== 通用工具 =====================

def interpolate_safe(x, size, mode, align_corners=False):
    """PyTorch 最近邻不允许带 align_corners，这里做个安全封装。"""
    if mode in ("linear", "bilinear", "bicubic", "trilinear"):
        return F.interpolate(x, size=size, mode=mode, align_corners=align_corners)
    else:
        return F.interpolate(x, size=size, mode=mode)

def find_last_k_visual_blocks(model, k=4):
    """返回视觉塔最后 k 个 block 的模块名列表。"""
    idx = []
    for name, _ in model.named_modules():
        if name.startswith("model.visual.blocks."):
            tail = name.split(".")[-1]
            if tail.isdigit():
                idx.append(int(tail))
    idx = sorted(set(idx))
    if not idx:
        base = max(0, 32 - k)
        return [f"model.visual.blocks.{i}" for i in range(base, base + k)]
    last = idx[-k:]
    return [f"model.visual.blocks.{i}" for i in last]

def register_multi_hooks(model, layer_names):
    """
    在多个层上注册 forward hook，保存激活与梯度。
    版本A：保存副本为 fp32，但 **返回原梯度**，避免 dtype 冲突。
    """
    acts, grads, handles = {}, {}, []

    def make_hook(name):
        def fwd_hook(module, inp, out):
            y = out[0] if isinstance(out, (tuple, list)) else out
            acts[name] = y.detach().to(torch.float32)
            def bwd_hook(g):
                grads[name] = g.detach().to(torch.float32)  # 仅保存
                return g  # 返回原梯度，防止 Half↔Float 报错
            y.register_hook(bwd_hook)
        return fwd_hook

    for n, m in model.named_modules():
        if n in layer_names:
            handles.append(m.register_forward_hook(make_hook(n)))

    if not handles:
        raise RuntimeError(f"未能在模型中找到指定层：{layer_names}")
    return acts, grads, handles

def to_tokens(feat_t, grad_t):
    """
    把 (feat, grad) 统一成 [B, T, C] 的 token 序列，并推断是否有 CLS 以及网格边长。
    兼容：
      - 4D [B, C, H, W] -> [B, H*W, C]（视为无 CLS）
      - 2D [T, C]       -> [1, T, C]
      - 3D [B, T, C]    -> 直接使用
    返回: tokens_feat, tokens_grad, has_cls(bool), side(int)
    """
    if feat_t.dim() == 4:
        B, C, H, W = feat_t.shape
        F2 = feat_t.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)
        G2 = grad_t.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)
        return F2, G2, False, H  # 通常 H==W

    if feat_t.dim() == 2:
        feat_t = feat_t.unsqueeze(0)
    if grad_t.dim() == 2:
        grad_t = grad_t.unsqueeze(0)

    B, T, C = feat_t.shape
    def is_square(x):
        r = int(round(math.sqrt(max(0, x))))
        return (r * r == x), r

    ok_no, s_no = is_square(T)
    ok_has, s_has = is_square(T - 1)
    if ok_has:
        return feat_t, grad_t, True, s_has
    if ok_no:
        return feat_t, grad_t, False, s_no
    s = int(math.sqrt(T))
    keep = max(1, s * s)
    return feat_t[:, :keep, :], grad_t[:, :keep, :], False, s

def layer_cam_single(feat, grad):
    """
    计算单层 Grad-CAM：ReLU(sum_c(grad * act))，返回 [B,1,h,w]。
    内部包含去 CLS、确保正方形网格、3x3 平均池化平滑与每层归一化。
    """
    tokens_f, tokens_g, has_cls, side = to_tokens(feat, grad)
    if has_cls and tokens_f.shape[1] > 1:
        tokens_f = tokens_f[:, 1:, :]
        tokens_g = tokens_g[:, 1:, :]

    B, HW, _ = tokens_f.shape
    side = int(math.sqrt(HW))
    keep = side * side
    tokens_f = tokens_f[:, :keep, :]
    tokens_g = tokens_g[:, :keep, :]

    cam = (tokens_g * tokens_f).sum(dim=-1)  # [B, HW]
    cam = F.relu(cam).view(B, 1, side, side)
    cam = F.avg_pool2d(cam, kernel_size=3, stride=1, padding=1)         # 去斑点
    cam = cam / (cam.amax(dim=(2, 3), keepdim=True) + 1e-6)             # 每层归一化
    return cam  # [B,1,h,w]

def load_image_bgr(local_path):
    """优先用 OpenCV 读，失败则用 PIL 再转 BGR。"""
    img = cv2.imread(local_path, cv2.IMREAD_COLOR)
    if img is None:
        pil = Image.open(local_path).convert("RGB")
        img = cv2.cvtColor(np.array(pil), cv2.COLOR_RGB2BGR)
    return img


# ===================== “按高注意力区域自动聚类并平均高亮” =====================

def cluster_and_feather(
    cam_up_float: np.ndarray,
    percentile: float = 0.85,
    topk: int = 3,
    feather_px: int = 25,
):
    """
    对上采样到原图大小的 CAM (H,W, float[0..1])：
    1) 用分位数阈值选高注意力像素；
    2) 在二值图上做连通域（等价自动“聚类”）；
    3) 选前 topk 个区域，区域内用**该区域 CAM 的均值**填充；
    4) 对填充图做高斯羽化，形成自然过渡。
    返回: feather_map (H,W,float[0..1])
    """
    H, W = cam_up_float.shape
    cam_norm = cam_up_float.copy()
    cam_norm = (cam_norm - cam_norm.min()) / (cam_norm.max() - cam_norm.min() + 1e-6)

    thr = np.quantile(cam_norm, percentile)
    mask = (cam_norm >= thr).astype(np.uint8)

    # 连通域：返回 (num_labels, labels, stats, centroids)
    num, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    if num <= 1:
        # 没有显著区域；退化为整体羽化的 CAM
        soft = cv2.GaussianBlur(cam_norm, (0, 0), sigmaX=max(1, feather_px/3), sigmaY=max(1, feather_px/3))
        soft = (soft - soft.min()) / (soft.max() - soft.min() + 1e-6)
        return soft

    # 计算每个连通域的“重要性”：区域内 CAM 之和
    comps = []
    for label in range(1, num):
        area = stats[label, cv2.CC_STAT_AREA]
        if area <= 4:    # 过滤很小的噪点
            continue
        region_score = cam_norm[labels == label].sum()
        comps.append((region_score, label))
    if not comps:
        soft = cv2.GaussianBlur(cam_norm, (0, 0), sigmaX=max(1, feather_px/3), sigmaY=max(1, feather_px/3))
        soft = (soft - soft.min()) / (soft.max() - soft.min() + 1e-6)
        return soft

    comps.sort(reverse=True)
    keep = comps[:topk]

    # 构造“平均高亮”图：每个区域填充该区域均值，其他为 0
    region_map = np.zeros_like(cam_norm, dtype=np.float32)
    for _, label in keep:
        region_vals = cam_norm[labels == label]
        mean_val = float(region_vals.mean())
        region_map[labels == label] = mean_val

    # 羽化：高斯模糊形成自然过渡
    if feather_px > 0:
        feather_map = cv2.GaussianBlur(region_map, (0, 0),
                                       sigmaX=max(1, feather_px/3),
                                       sigmaY=max(1, feather_px/3))
    else:
        feather_map = region_map

    # 归一化到 [0,1]
    feather_map = (feather_map - feather_map.min()) / (feather_map.max() - feather_map.min() + 1e-6)
    return feather_map


# ===================== 主流程 =====================

def run(
    model_id: str,
    image_path: str,
    question: str,
    save_path: str = "0.jpg",
    k_layers: int = 4,
    agg_mode: str = "sum",            # "sum" | "mean" | "last" | "weights"
    # 聚类/羽化相关参数：
    cluster_percentile: float = 0.85, # 选前多少分位的高注意力像素参与聚类（0~1）
    cluster_topk: int = 3,            # 只保留前 topk 个高注意力连通域
    feather_px: int = 25,             # 羽化（高斯模糊）像素尺度，越大过渡越柔
):
    # 1) 加载模型与处理器
    # model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    #     model_id, torch_dtype="auto", device_map="auto"
    # )
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_id, torch_dtype="auto", device_map="auto"
    )
    processor = AutoProcessor.from_pretrained(model_id)

    # 只对视觉塔求梯度，省显存
    for n, p in model.named_parameters():
        if not n.startswith("model.visual"):
            p.requires_grad_(False)

    # 2) 构造 messages（官方范式）并处理多模态输入
    if not image_path.startswith("file://"):
        image_uri = "file://" + os.path.abspath(image_path)
    else:
        image_uri = image_path

    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": image_uri},
            {"type": "text",  "text": question},
        ],
    }]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, _ = process_vision_info(messages)

    # 3) 推理生成答案文本（与官方 inference 一致）
    gen_inputs = processor(
        text=text, images=image_inputs, videos=None,
        padding=True, return_tensors="pt",
    ).to(model.device)

    with torch.no_grad():
        generated_ids = model.generate(**gen_inputs, max_new_tokens=1024, do_sample=False)

    input_ids_len = gen_inputs.input_ids.shape[1]
    answer_ids = generated_ids[0][input_ids_len:]
    answer_text = processor.decode(
        answer_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    ).strip()
    print("模型答案：", answer_text)

    # 4) 构造“可导前向”输入：把答案接回模板化文本末尾
    ft_text = text + answer_text
    ft_inputs = processor(
        text=ft_text, images=image_inputs, videos=None,
        padding=True, return_tensors="pt",
    ).to(model.device)

    # 5) 注册多层 Hook（最后 K 层）
    target_layers = find_last_k_visual_blocks(model, k=k_layers)
    print("用于聚合的层：", target_layers)
    acts, grads, handles = register_multi_hooks(model, target_layers)

    # 6) 以“答案 token 的对数概率”构造一次性目标并反传
    model.train(False)
    try:
        any_param = next(p for p in model.parameters() if p is not None)
        model_dtype = any_param.dtype
    except StopIteration:
        model_dtype = torch.float32
    use_amp = torch.cuda.is_available() and (model_dtype in (torch.float16, torch.bfloat16))

    with torch.cuda.amp.autocast(enabled=use_amp):
        out = model(**ft_inputs, use_cache=False)

    logits = out.logits
    input_ids = ft_inputs.input_ids
    logits_shift = logits[:, :-1, :]
    labels_shift = input_ids[:, 1:]

    log_probs = torch.log_softmax(logits_shift, dim=-1)
    tok_logp  = log_probs.gather(-1, labels_shift.unsqueeze(-1)).squeeze(-1)

    answer_start = input_ids_len
    answer_len   = input_ids.shape[1] - answer_start
    start_s = max(0, answer_start - 1)
    end_s   = start_s + max(0, answer_len)
    if end_s <= start_s:
        raise RuntimeError("未能正确定位答案段。")

    if agg_mode == "sum":
        objective = - tok_logp[:, start_s:end_s].sum()
    elif agg_mode == "mean":
        objective = - tok_logp[:, start_s:end_s].mean()
    elif agg_mode == "last":
        objective = - tok_logp[:, end_s - 1].sum()
    elif agg_mode == "weights":
        w = torch.linspace(0.5, 1.5, steps=end_s - start_s, device=tok_logp.device)
        objective = - (tok_logp[:, start_s:end_s] * w).sum()
    else:
        raise ValueError(f"未知聚合模式：{agg_mode}")

    model.zero_grad(set_to_none=True)
    objective.backward()
    for h in handles:
        h.remove()

    # 7) 逐层 CAM 计算并聚合（得到 patch 级 cam: [B,1,h,w]）
    cams = [layer_cam_single(acts[n], grads[n]) for n in target_layers]
    cam = torch.stack(cams, dim=0).mean(0)  # [B,1,h,w]
    cam = cam / (cam.amax(dim=(2, 3), keepdim=True) + 1e-6)

    # 8) 上采样到原图大小（用双线性更平滑）
    local_path = image_uri.replace("file://", "")
    orig = load_image_bgr(local_path)
    H_img, W_img = orig.shape[:2]
    cam_up = interpolate_safe(cam, size=(H_img, W_img), mode="bilinear", align_corners=False)[0, 0].cpu().numpy()
    cam_up = np.clip(cam_up, 0.0, 1.0).astype(np.float32)

    # 9) 基于连通域的“聚类+平均高亮+羽化过渡”
    feather_map = cluster_and_feather(
        cam_up_float=cam_up,
        percentile=cluster_percentile,
        topk=cluster_topk,
        feather_px=feather_px,
    )

    # 10) 叠加到原图并保存
    heat = (255 * feather_map).astype(np.uint8)
    heat_color = cv2.applyColorMap(heat, cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(heat_color, 0.45, orig, 0.55, 0)

    os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
    cv2.imwrite(save_path, overlay)
    print(f"✅ 已保存聚类-羽化注意力热力图：{save_path}")

    return answer_text, save_path


# ===================== CLI =====================

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Qwen2.5-VL Grad-CAM（多层聚合 + 连通域聚类平均 + 羽化过渡）")
    parser.add_argument("--model_id", type=str, default="Qwen/Qwen2.5-VL-32B-Instruct")
    parser.add_argument("--image_path", type=str, required=True, help="本地图像路径（会自动加 file://）")
    parser.add_argument("--question", type=str, required=True, help="VQA 提示词（Instruct 风格）")
    parser.add_argument("--save_path", type=str, default="vqa_gradcam_qwen2p5vl_cluster.jpg")
    parser.add_argument("--k_layers", type=int, default=4, help="聚合的视觉层数（从最后往前数）")
    parser.add_argument("--agg_mode", type=str, default="sum", choices=["sum", "mean", "last", "weights"])
    # 聚类/羽化相关
    parser.add_argument("--cluster_percentile", type=float, default=0.85, help="用于选高注意力像素的分位阈值(0~1)")
    parser.add_argument("--cluster_topk", type=int, default=3, help="保留的高注意力连通域数量")
    parser.add_argument("--feather_px", type=int, default=25, help="羽化像素尺度（越大过渡越柔）")
    args = parser.parse_args()

    run(
        model_id=args.model_id,
        image_path=args.image_path,
        question=args.question,
        save_path=args.save_path,
        k_layers=args.k_layers,
        agg_mode=args.agg_mode,
        cluster_percentile=args.cluster_percentile,
        cluster_topk=args.cluster_topk,
        feather_px=args.feather_px,
    )
