#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Gaze-driven sparse cropping for Qwen2.5-VL with:
- processor size enforcement (prevent upscaling)
- two-scale inputs (global low-res + ROI high-res)
- segmented timing & memory (prep/prefill/gen; allocated & reserved)
"""

import os, time, math, csv, argparse
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoProcessor, Qwen2_5_VLForConditionalGeneration
from tqdm import tqdm
# ==== 你已有的数据集类（保持不变，确保同目录可 import）====
from voilagaze_dataset import VoilaDataset

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.benchmark = False
# torch.use_deterministic_algorithms(True)  # 可能略降速


MODEL_ID   = "Qwen/Qwen2.5-VL-3B-Instruct"
IMAGE_SIZE = 224
DTYPE      = torch.bfloat16  # A100/H100 建议 bf16

# 还原 VOILA 的 Flamingo 归一化图像
FLAMINGO_MEAN = np.array([0.481, 0.458, 0.408]).reshape(3,1,1)
FLAMINGO_STD  = np.array([0.269, 0.261, 0.276]).reshape(3,1,1)
def tensor_to_pil(img_t3hw):
    x = img_t3hw.detach().cpu().numpy()
    x = (x * FLAMINGO_STD + FLAMINGO_MEAN) * 255.0
    x = np.clip(x, 0, 255).astype(np.uint8).transpose(1,2,0)
    return Image.fromarray(x)

# -------------------------
# 打印/测量工具
# -------------------------
def decode_generated(tokenizer, out_ids, L_in=None):
    if L_in is None or out_ids.shape[1] <= 4:
        return tokenizer.batch_decode(out_ids, skip_special_tokens=True)[0].strip()
    return tokenizer.batch_decode(out_ids[:, L_in:], skip_special_tokens=True)[0].strip()

def measure_section(fn):
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    t0 = time.time()
    out = fn()
    torch.cuda.synchronize()
    dt = time.time() - t0
    mem_alloc = torch.cuda.max_memory_allocated() / 1024**2
    mem_resv  = torch.cuda.max_memory_reserved()  / 1024**2
    return out, dt, mem_alloc, mem_resv

# ========= FLOPs 估算（解析式） =========
def _vision_token_count(pack):
    if "pixel_values" not in pack:
        return 0
    pv = pack["pixel_values"]
    if isinstance(pv, (list, tuple)):
        return sum(int(t.shape[0]) for t in pv)
    elif isinstance(pv, torch.Tensor):
        return int(pv.shape[0])
    return 0

def _text_token_count_total(pack):
    # 用 input_ids 的总长度（含 padding），更贴近实际算子计算量
    return int(pack["input_ids"].shape[1]) if "input_ids" in pack else 0


def _model_dims(model):
    cfg = model.config
    d   = int(getattr(cfg, "hidden_size"))
    L   = int(getattr(cfg, "num_hidden_layers"))
    nh  = int(getattr(cfg, "num_attention_heads"))
    dh  = d // nh
    dff = int(getattr(cfg, "intermediate_size", 4*d))
    return d, L, nh, dh, dff

def _flops_prefill(L_ctx, d, L, nh, dh, dff):
    """
    近似计数（乘加=2 FLOPs）：
    - Attn(QK^T+AV):   ~ 4 * nh * L_ctx^2 * dh
    - Proj(Q,K,V,O):  ~ 4 * L_ctx * d^2
    - FFN:            ~ 2 * L_ctx * d * dff
    """
    attn = 4.0 * nh * (L_ctx ** 2) * dh
    proj = 4.0 * L_ctx * (d ** 2)
    ffn  = 2.0 * L_ctx * d * dff
    return L * (attn + proj + ffn), {"attn": L*attn, "proj": L*proj, "ffn": L*ffn}

def _flops_decode_step(L_ctx, d, L, nh, dh, dff):
    """
    单步解码（KV 已缓存）：
    - Attn:           ~ 4 * nh * L_ctx * dh
    - Proj(Q,K,V,O):  ~ 4 * d^2
    - FFN:            ~ 2 * d * dff
    """
    attn = 4.0 * nh * L_ctx * dh
    proj = 4.0 * (d ** 2)
    ffn  = 2.0 * d * dff
    return L * (attn + proj + ffn), {"attn": L*attn, "proj": L*proj, "ffn": L*ffn}

def _human_flops(x):
    if x >= 1e15: return f"{x/1e15:.2f} PFLOPs"
    if x >= 1e12: return f"{x/1e12:.2f} TFLOPs"
    if x >= 1e9:  return f"{x/1e9:.2f} GFLOPs"
    if x >= 1e6:  return f"{x/1e6:.2f} MFLOPs"
    return f"{x:.0f} FLOPs"

def compute_flops_report(tag, model, pack, out_ids=None, t_gen=None, verbose=True):
    """
    返回 dict: {L_text_total, L_text_eff, L_vis, L_ctx, prefill_flops, decode_step_flops, decode_total_flops, tokens_per_sec}
    并可选打印。
    """
    d, L, nh, dh, dff = _model_dims(model)
    L_text_total = _text_token_count_total(pack)     # 用于计算量
    # L_vis        = _vision_token_count(pack)
    L_ctx        = L_text_total

    prefill_total, prefill_parts = _flops_prefill(L_ctx, d, L, nh, dh, dff)

    gen_new = 0
    decode_step_total = 0.0
    decode_total      = 0.0
    tps = None
    if out_ids is not None and "input_ids" in pack:
        gen_new = int(out_ids.shape[1] - pack["input_ids"].shape[1])
        decode_step_total, _parts = _flops_decode_step(L_ctx, d, L, nh, dh, dff)
        decode_total = gen_new * decode_step_total
        if t_gen and gen_new > 0:
            tps = gen_new / t_gen

    # if verbose:
        # print(f"[FLOPs:{tag}] text_total={L_text_total} (eff={L_text_eff}), vision={L_vis} -> L_ctx={L_ctx}, layers={L}, d={d}, heads={nh}x{dh}")
        # print(f"  prefill ≈ { _human_flops(prefill_total) }")
        # if out_ids is not None:
        #     print(f"  decode/step ≈ { _human_flops(decode_step_total) }")
        #     print(f"  decode total ({gen_new} tok) ≈ { _human_flops(decode_total) }" + (f", tokens/sec≈{tps:.2f}" if tps else ""))
        # print(f" total_flops ≈ {_human_flops(decode_total+prefill_total)}")
    return {
        # "L_vis":        L_vis,
        "L_ctx":        L_ctx,
        "prefill_flops":      prefill_total,
        "decode_step_flops":  decode_step_total,
        "decode_total_flops": decode_total,
        "total_flops":        decode_total+prefill_total,
        "tokens_per_sec":     (tps if tps is not None else 0.0),
    }


def empty_and_gc():
    torch.cuda.empty_cache()
    if hasattr(torch.cuda, "reset_peak_memory_stats"):
        torch.cuda.reset_peak_memory_stats()


def total_context_tokens(pack):
    # print(pack.keys())
    # print(f"input_ids.shape={pack['input_ids'].shape}, attention_mask.shape={pack['attention_mask'].shape}")

    # 文本 token 数
    L_text = pack["input_ids"].shape[1] if "input_ids" in pack else 0
    # 视觉 token 数
    L_vis = 0
    if "pixel_values" in pack:
        pv = pack["pixel_values"]
        if isinstance(pv, (list, tuple)):
            L_vis = sum(int(t.shape[0]) for t in pv)
        elif isinstance(pv, torch.Tensor):
            L_vis = int(pv.shape[0])
    return L_text, L_vis


# -------------------------
# 处理器分辨率强制：避免上采样
# -------------------------
from contextlib import contextmanager

@contextmanager
def enforce_image_budget(proc, shortest_edge=None, max_pixels=None):
    """
    临时把 processor 的图像尺寸/像素上限压到我们希望的范围。
    不同处理器字段名可能不同，这里做尽量鲁棒的设置（不存在就忽略）。
    """
    ip = getattr(proc, "image_processor", None)
    if ip is None:
        yield
        return

    # 备份
    backup = {}
    for name in ["size", "max_pixels", "do_resize", "resample", "keep_ratio", "crop_size", "image_grid_pinpoints"]:
        if hasattr(ip, name):
            backup[name] = getattr(ip, name)

    # 应用
    try:
        if shortest_edge is not None:
            # 常见：size 为 dict 或 int
            try:
                ip.size = {"shortest_edge": int(shortest_edge)}
            except Exception:
                try:
                    ip.size = int(shortest_edge)
                except Exception:
                    pass
            # 某些处理器可能有 min/max 限制，尽量关闭自动上采样
            if hasattr(ip, "do_resize"):
                setattr(ip, "do_resize", True)
            if hasattr(ip, "keep_ratio"):
                setattr(ip, "keep_ratio", True)
        if max_pixels is not None and hasattr(ip, "max_pixels"):
            setattr(ip, "max_pixels", int(max_pixels))
        yield
    finally:
        # 还原
        for k, v in backup.items():
            try:
                setattr(ip, k, v)
            except Exception:
                pass

# -------------------------
# 构造 Qwen 输入（单图 / 双图）
# -------------------------
def build_mm_inputs_single(proc, pil_img, prompt, device):
    messages = [{
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": f"{prompt}"}, # Please answer in one short sentence.
        ],
    }]
    text = proc.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
    )
    # print(f"[DEBUG] template_text:\n{text}")
    return proc(text=[text], images=[pil_img], return_tensors="pt").to(device)

def build_mm_inputs_two(proc, pil_img_global, pil_img_roi, prompt, device):
    # 引导模型优先使用第二张（ROI）
    # guide = ("image 1: global, image 2: is ROI.")
    messages = [{
        "role": "user",
        "content": [
            {"type": "image"},  # Image 1
            {"type": "image"},  # Image 2 (ROI)
            {"type": "text", "text": f"{prompt}"},#\n{guide}
        ],
    }]
    text = proc.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
    )
    return proc(text=[text], images=[pil_img_global, pil_img_roi], return_tensors="pt").to(device)

# -------------------------
# Gaze 最小包围框 + 目标分辨率
# -------------------------
def gaze_bbox(hm, mass=0.5, min_size=28, align=28):
    """
    hm: torch.Tensor [1,1,H,W] in [0,1]
    返回：bbox=(y0,y1,x0,x1), target_size(对齐为 28 的倍数)
    """
    h = hm[0,0]
    flat = h.flatten()
    vals, idx = torch.sort(flat, descending=True)
    total = h.sum()
    cutoff = (vals.cumsum(0) >= (total * mass)).nonzero()
    cutoff = cutoff[0].item() if total > 0 and cutoff.numel() > 0 else (len(vals)-1)
    mask = torch.zeros_like(flat, dtype=torch.bool)
    mask[idx[:cutoff+1]] = True
    mask = mask.view(*h.shape)

    ys, xs = torch.where(mask)
    if ys.numel() == 0:
        y0, x0, y1, x1 = 0, 0, h.shape[0], h.shape[1]
    else:
        y0, y1 = ys.min().item(), ys.max().item()+1
        x0, x1 = xs.min().item(), xs.max().item()+1
        pad = int(0.05 * max(h.shape))
        y0, x0 = max(0,y0-pad), max(0,x0-pad)
        y1, x1 = min(h.shape[0],y1+pad), min(h.shape[1],x1+pad)

    H, W = h.shape
    kept_ratio = ((y1-y0)*(x1-x0)) / (H*W) if H*W>0 else 1.0

    target = max(int((kept_ratio ** 0.5) * 224), min_size)
    target = (target // align) * align
    if target < align: target = align
    return (y0,y1,x0,x1)

def crop_and_resize(pil_img, bbox):
    y0,y1,x0,x1 = bbox
    # print(y0,y1,x0,x1)
    crop = pil_img.crop((x0, y0, x1, y1))
    # if crop.size[0] == 0 or crop.size[1] == 0:
    #     crop = pil_img
    #     target = 224
    return crop.resize((x1-x0, y1-y0), Image.BICUBIC)

# -------------------------
# 模型加载
# -------------------------
def load_model_and_processor(use_flash_attn=True):
    kwargs = dict(torch_dtype=DTYPE, device_map="auto")
    if use_flash_attn:
        kwargs["attn_implementation"] = "flash_attention_2"
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(MODEL_ID, **kwargs).eval()
    proc = AutoProcessor.from_pretrained(MODEL_ID)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.set_float32_matmul_precision("high")
    return model, proc

# -------------------------
# 可选：保存可视化
# -------------------------
def _draw_trace_path_on_bgr(
    canvas_bgr,
    trace_xy_norm,
    line_color=(0, 255, 255),   # 黄线
    start_color=(0, 255, 0),    # 起点绿
    end_color=(0, 0, 255),      # 终点红
    line_thickness=2,
    point_radius=3,
    max_points=None             # 调试用：只画前 N 个点
):
    import numpy as np, cv2
    H, W = canvas_bgr.shape[:2]
    tr = np.asarray(trace_xy_norm, dtype=np.float32)
    if tr.ndim != 2 or tr.shape[1] != 2 or tr.shape[0] == 0:
        return canvas_bgr

    xs = np.clip(tr[:, 0], 0.0, 0.999) * (W - 1)
    ys = np.clip(tr[:, 1], 0.0, 0.999) * (H - 1)
    pts = np.stack([xs, ys], axis=1).astype(np.int32)
    if max_points is not None:
        pts = pts[:max_points]

    for i in range(1, len(pts)):
        cv2.line(canvas_bgr, tuple(pts[i-1]), tuple(pts[i]), line_color, line_thickness)
    if len(pts) > 0:
        cv2.circle(canvas_bgr, tuple(pts[0]),  point_radius+1, start_color, thickness=-1)
    if len(pts) > 1:
        cv2.circle(canvas_bgr, tuple(pts[-1]), point_radius+1, end_color,   thickness=-1)

    return canvas_bgr

def _draw_grid_and_selected_tokens(
    pil_img, 
    hm_2d, 
    out_path, 
    align=28, 
    rho=0.5, 
    bbox=None, 
    grid_color=(128,128,128), 
    sel_color=(255,0,0), 
    roi_color=(255,215,0),  # gold
    thickness=1
):
    """
    在原图上绘制 align×align 的网格；按 gaze 质量累积到 rho 的准则，选出格子并高亮。
    若提供 bbox=(y0,y1,x0,x1)，则同时画出 ROI 外接矩形。
    """
    import cv2
    import numpy as np

    img = np.array(pil_img).copy()
    H, W = img.shape[:2]

    # 1) 将 heatmap resize 到与图像一致大小，计算每个网格格子的总质量
    hm = hm_2d.detach().cpu().numpy()
    if hm.ndim == 4: hm = hm[0,0]
    elif hm.ndim == 3: hm = hm[0]
    hm = np.clip(hm, 0, 1)

    # 归一到 [0,1] 并 resize 到图像大小
    if hm.max() > 0: hm = hm / hm.sum()
    hm_resized = cv2.resize(hm, (W, H), interpolation=cv2.INTER_AREA)

    # 2) 计算每个格子的质量
    cell_h, cell_w = align, align
    ny, nx = H // cell_h, W // cell_w
    masses = []
    for gy in range(ny):
        for gx in range(nx):
            y0, y1 = gy*cell_h, (gy+1)*cell_h
            x0, x1 = gx*cell_w, (gx+1)*cell_w
            mass = hm_resized[y0:y1, x0:x1].sum()
            masses.append(((gy, gx), mass))
    masses.sort(key=lambda t: t[1], reverse=True)

    # 3) 累积质量直到达到 rho
    selected = set()
    total = hm_resized.sum()
    cutoff = rho * total if total > 0 else 0
    cur = 0.0
    for (gy, gx), m in masses:
        if cur >= cutoff: break
        selected.add((gy, gx))
        cur += m

    # 4) 画网格与选中格子
    canvas = img.copy()
    # 先画选中的格子（半透明填充）
    overlay = canvas.copy()
    for (gy, gx) in selected:
        y0, y1 = gy*cell_h, (gy+1)*cell_h
        x0, x1 = gx*cell_w, (gx+1)*cell_w
        cv2.rectangle(overlay, (x0, y0), (x1, y1), sel_color, -1)
    alpha = 0.25  # 选中格子半透明填充
    canvas = cv2.addWeighted(overlay, alpha, canvas, 1 - alpha, 0)

    # 再画网格线
    for gy in range(1, ny):
        y = gy * cell_h
        cv2.line(canvas, (0, y), (nx*cell_w, y), grid_color, thickness)
    for gx in range(1, nx):
        x = gx * cell_w
        cv2.line(canvas, (x, 0), (nx*cell_w, 0 + ny*cell_h), grid_color, thickness)

    # 可选画 ROI 框
    if bbox is not None:
        y0,y1,x0,x1 = bbox
        # 注意：bbox 基于 heatmap 的坐标，已在构造 ROI 时对齐图像尺寸（本代码里 hm 与图像可能维度不同）
        # 为稳妥：把 hm 的 bbox 按比例映射到图像大小
        hH, hW = hm.shape[:2]
        scale_y, scale_x = H / float(hH), W / float(hW)
        Y0, Y1 = int(round(y0*scale_y)), int(round(y1*scale_y))
        X0, X1 = int(round(x0*scale_x)), int(round(x1*scale_x))
        cv2.rectangle(canvas, (X0, Y0), (X1, Y1), roi_color, 2)

    cv2.imwrite(out_path, cv2.cvtColor(canvas, cv2.COLOR_RGB2BGR))

def selected_tokens_from_bbox(bbox, hm_shape, H, W, align=28, rule="intersect"):
    """
    基于 ROI 的 bbox（热图坐标）确定被选中的 token（网格格子）。
    - bbox: (y0, y1, x0, x1) in heatmap coords
    - hm_shape: (hH, hW) 热图尺寸
    - H, W: 原图尺寸
    - align: 网格单元边长（像素）
    - rule: 'intersect'（格子与 ROI 相交即选中），或 'center_in'（格子中心点落在 ROI 内才选中）
    返回: selected(set((gy,gx))), ny, nx
    """
    import numpy as np
    hH, hW = hm_shape
    sy, sx = H / float(hH), W / float(hW)

    y0_h, y1_h, x0_h, x1_h = bbox
    Y0, Y1 = int(round(y0_h * sy)), int(round(y1_h * sy))
    X0, X1 = int(round(x0_h * sx)), int(round(x1_h * sx))

    # 边界规整
    Y0, Y1 = max(0, min(Y0, H)), max(0, min(Y1, H))
    X0, X1 = max(0, min(X0, W)), max(0, min(X1, W))
    if Y0 > Y1: Y0, Y1 = Y1, Y0
    if X0 > X1: X0, X1 = X1, X0

    ny = int(np.ceil(H / float(align)))
    nx = int(np.ceil(W / float(align)))

    selected = set()
    for gy in range(ny):
        cy0, cy1 = gy * align, min((gy + 1) * align, H)
        for gx in range(nx):
            cx0, cx1 = gx * align, min((gx + 1) * align, W)

            if rule == "center_in":
                cx = 0.5 * (cx0 + cx1)
                cy = 0.5 * (cy0 + cy1)
                cond = (X0 <= cx < X1) and (Y0 <= cy < Y1)
            else:
                inter_x0 = max(cx0, X0)
                inter_y0 = max(cy0, Y0)
                inter_x1 = min(cx1, X1)
                inter_y1 = min(cy1, Y1)
                cond = (inter_x1 > inter_x0) and (inter_y1 > inter_y0)

            if cond:
                selected.add((gy, gx))
    return selected, ny, nx


def draw_grid_tokens_colored(
    pil_img,
    selected,
    ny, nx,
    align=28,
    out_path="grid_tokens.png",
    color_selected=(255, 0, 0),     # 红：选中
    color_unselected=(0, 128, 255), # 蓝：未选中
    grid_color=(0, 255, 0),         # 灰色：网格线
    grid_thickness=1,
    fill_alpha=0.28,
    bbox=None,
    hm_shape=None,                  # (hH, hW) 若 bbox 基于热图坐标，需提供以便映射
    roi_color=(255, 215, 0)         # 金色：ROI 框
):
    import cv2, numpy as np
    img = np.array(pil_img).copy()
    H, W = img.shape[:2]

    canvas = img.copy()
    overlay = canvas.copy()

    # 所有格子都着色（选中=红；未选中=蓝）
    for gy in range(ny):
        y0, y1 = gy*align, min((gy+1)*align, H)
        for gx in range(nx):
            x0, x1 = gx*align, min((gx+1)*align, W)
            color = color_selected if (gy, gx) in selected else color_unselected
            cv2.rectangle(overlay, (x0, y0), (x1, y1), color, thickness=-1)

    canvas = cv2.addWeighted(overlay, fill_alpha, canvas, 1 - fill_alpha, 0)

    # 重描网格线
    for gy in range(1, ny):
        y = min(gy*align, H-1)
        cv2.line(canvas, (0, y), (W-1, y), grid_color, grid_thickness)
    for gx in range(1, nx):
        x = min(gx*align, W-1)
        cv2.line(canvas, (x, 0), (x, H-1), grid_color, grid_thickness)

    # 画 ROI 框（将热图坐标映射到图像坐标）
    if bbox is not None:
        y0, y1, x0, x1 = bbox
        if hm_shape is not None:
            hH, hW = hm_shape
            sy, sx = H / float(hH), W / float(hW)
            Y0, Y1 = int(round(y0 * sy)), int(round(y1 * sy))
            X0, X1 = int(round(x0 * sx)), int(round(x1 * sx))
        else:
            Y0, Y1, X0, X1 = y0, y1, x0, x1
        # cv2.rectangle(canvas, (X0, Y0), (X1, Y1), roi_color, thickness=2)

    cv2.imwrite(out_path, cv2.cvtColor(canvas, cv2.COLOR_RGB2BGR))


# def save_image_and_heatmap(
#     pil_img, 
#     hm_2d, 
#     out_prefix, 
#     crop_img=None, 
#     align=28, 
#     rho=0.5, 
#     bbox=None
# ):
#     import cv2, numpy as np

#     img = np.array(pil_img)
#     H, W = img.shape[:2]

#     # ---- 准备 heatmap 到 numpy、灰度与彩色版本 ----
#     hm = hm_2d.detach().cpu().numpy()
#     if hm.ndim == 4: hm = hm[0,0]
#     elif hm.ndim == 3: hm = hm[0]
#     hm = np.clip(hm, 0, 1)
#     hm_u8 = (hm * 255).astype(np.uint8)

#     # 1) 原图
#     Image.fromarray(img).save(f"{out_prefix}_image.png")

#     # 2) 灰度热力图
#     cv2.imwrite(f"{out_prefix}_heatmap_gray.png", hm_u8)

#     # 3) 彩色热力图（单独存一份，便于对比）
#     hm_color = cv2.applyColorMap(hm_u8, cv2.COLORMAP_JET)
#     cv2.imwrite(f"{out_prefix}_heatmap_color.png", hm_color)

#     # 4) 叠加（把热力图 resize 到与原图一致再 overlay）
#     hm_color_resized = cv2.resize(hm_color, (W, H), interpolation=cv2.INTER_LINEAR)
#     overlay = cv2.addWeighted(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), 1.0, hm_color_resized, 0.4, 0)
#     cv2.imwrite(f"{out_prefix}_overlay.png", overlay[:, :, ::-1])  # 保存为 RGB 顺序

#     # 5) Trace（用等高线/轮廓近似眼动轨迹）
#     #    这里我们在多个阈值上提取等高线，叠加到原图上，形成“轨迹”感
#     trace = img.copy()
#     trace_bgr = cv2.cvtColor(trace, cv2.COLOR_RGB2BGR)
#     # 选择若干阈值绘制等高线
#     for thresh in (64, 96, 128, 160, 192):  # 约等于 0.25~0.75
#         ret, bw = cv2.threshold(hm_u8, thresh, 255, cv2.THRESH_BINARY)
#         contours, _ = cv2.findContours(bw, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
#         # 将热力图坐标 resize 到图像坐标
#         scale_y, scale_x = H / float(hm.shape[0]), W / float(hm.shape[1])
#         for cnt in contours:
#             cnt = cnt.squeeze(1).astype(np.float32)
#             cnt[:,0] *= scale_x
#             cnt[:,1] *= scale_y
#             cnt = cnt.astype(np.int32).reshape(-1,1,2)
#             cv2.polylines(trace_bgr, [cnt], isClosed=True, color=(0, 255, 255), thickness=2)
#     cv2.imwrite(f"{out_prefix}_trace.png", trace_bgr[:, :, ::-1])

#     # 6) 可选保存 ROI 裁剪图
#     if crop_img is not None:
#         crop_img.save(f"{out_prefix}_crop.png")

#     # 7) 用 ROI 的 bbox（热图坐标）得到“选中的 tokens 信息”
#     selected, ny, nx = selected_tokens_from_bbox(
#         bbox=bbox,
#         hm_shape=hm.shape,  # (hH, hW)
#         H=H, W=W,
#         align=align,
#         rule="intersect"     # 或 "center_in" 更严格
#     )

#     # 8) 画网格 + 两色着色 + 重描网格线 + ROI 框
#     draw_grid_tokens_colored(
#         pil_img,
#         selected, ny, nx,
#         align=align,
#         out_path=f"{out_prefix}_grid_tokens.png",
#         color_selected=(255, 0, 0),
#         color_unselected=(0, 128, 255),
#         grid_color=(0, 255, 0),
#         grid_thickness=1,
#         fill_alpha=0.28,
#         bbox=bbox,
#         hm_shape=hm.shape,
#         roi_color=(255, 215, 0)
#     )


def save_image_and_heatmap(
    pil_img,
    hm_2d,
    out_prefix,
    crop_img=None,
    img_global=None,
    align=28,
    rho=0.5,
    bbox=None,
    trace_xy_norm=None,     # ✨ 新增：归一化轨迹点数组 [[x,y],...], x,y∈[0,1]
    draw_trace_on="image",  # "image" | "overlay" | "both"
    trace_max_points=None   # 仅画前 N 个点（调试）
):
    import cv2, numpy as np
    img = np.array(pil_img)
    H, W = img.shape[:2]

    # ---- 准备 heatmap 到 numpy、灰度与彩色版本 ----
    hm = hm_2d.detach().cpu().numpy()
    if hm.ndim == 4: hm = hm[0,0]
    elif hm.ndim == 3: hm = hm[0]
    hm = np.clip(hm, 0, 1)
    hm_u8 = (hm * 255).astype(np.uint8)

    # 1) 原图
    Image.fromarray(img).save(f"{out_prefix}_image.png")

    # 2) 灰度热力图
    cv2.imwrite(f"{out_prefix}_heatmap_gray.png", hm_u8)

    # 3) 彩色热力图（单独存一份，便于对比）
    hm_color = cv2.applyColorMap(hm_u8, cv2.COLORMAP_JET)
    cv2.imwrite(f"{out_prefix}_heatmap_color.png", hm_color)

    # 4) 叠加（把热力图 resize 到与原图一致再 overlay）
    hm_color_resized = cv2.resize(hm_color, (W, H), interpolation=cv2.INTER_LINEAR)
    img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    overlay_bgr = cv2.addWeighted(img_bgr.copy(), 1.0, hm_color_resized, 0.4, 0)
    cv2.imwrite(f"{out_prefix}_overlay.png", overlay_bgr)

    # 5) Trace（等高线/轮廓近似眼动强度分布，可保留）
    trace = img.copy()
    trace_bgr = cv2.cvtColor(trace, cv2.COLOR_RGB2BGR)
    for thresh in (64, 96, 128, 160, 192):
        _, bw = cv2.threshold(hm_u8, thresh, 255, cv2.THRESH_BINARY)
        contours, _ = cv2.findContours(bw, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        scale_y, scale_x = H / float(hm.shape[0]), W / float(hm.shape[1])
        for cnt in contours:
            cnt = cnt.squeeze(1).astype(np.float32)
            cnt[:,0] *= scale_x
            cnt[:,1] *= scale_y
            cnt = cnt.astype(np.int32).reshape(-1,1,2)
            cv2.polylines(trace_bgr, [cnt], isClosed=True, color=(0, 255, 255), thickness=2)
    cv2.imwrite(f"{out_prefix}_trace.png", trace_bgr)

    # 6) 可选保存 ROI 裁剪图
    if crop_img is not None:
        crop_img.save(f"{out_prefix}_crop.png")

    if img_global is not None:
        img_global.save(f"{out_prefix}_global.png")

    # 7) 用 ROI 的 bbox（热图坐标）得到“选中的 tokens 信息”，并画网格+两色格子+ROI 框
    selected, ny, nx = selected_tokens_from_bbox(
        bbox=bbox,
        hm_shape=hm.shape,  # (hH, hW)
        H=H, W=W,
        align=align,
        rule="intersect"
    )
    draw_grid_tokens_colored(
        pil_img,
        selected, ny, nx,
        align=align,
        out_path=f"{out_prefix}_grid_tokens.png",
        color_selected=(204, 78, 1),       # 红：选中
        color_unselected=(0, 128, 255),   # 蓝：未选中
        grid_color=(128, 128, 128),
        grid_thickness=1,
        fill_alpha=0.25,
        bbox=bbox,
        hm_shape=hm.shape,                # 映射 bbox → 图像坐标
        roi_color=(255, 215, 0)
    )

    # 8) ✨ 在“原图”或“overlay”上叠加“轨迹连线”（按时间顺序）
    if trace_xy_norm is not None:
        if draw_trace_on in ("image", "both"):
            img_trace_bgr = img_bgr.copy()
            img_trace_bgr = _draw_trace_path_on_bgr(
                img_trace_bgr, trace_xy_norm,
                line_color=(0,255,255), start_color=(0,255,0), end_color=(0,0,255),
                line_thickness=2, point_radius=3, max_points=trace_max_points
            )
            cv2.imwrite(f"{out_prefix}_image_trace.png", img_trace_bgr)

        if draw_trace_on in ("overlay", "both"):
            overlay_trace_bgr = overlay_bgr.copy()
            overlay_trace_bgr = _draw_trace_path_on_bgr(
                overlay_trace_bgr, trace_xy_norm,
                line_color=(0,255,255), start_color=(0,255,0), end_color=(0,0,255),
                line_thickness=2, point_radius=3, max_points=trace_max_points
            )
            cv2.imwrite(f"{out_prefix}_overlay_trace.png", overlay_trace_bgr)

# -------------------------
# 运行：Baseline（单图）
# -------------------------
@torch.no_grad()
def run_baseline_segmented(model, proc, pil_img, prompt, max_new_tokens=24, enforce_size=None):
    device = model.device if hasattr(model, "device") else "cuda"
    # 1) 预处理（processor）
    def _prep():
        # ctx = enforce_image_budget(proc, shortest_edge=enforce_size, max_pixels=(enforce_size**2 if enforce_size else None))
        ctx = enforce_image_budget(proc, shortest_edge=enforce_size, max_pixels=None)
        with ctx:
            return build_mm_inputs_single(proc, pil_img, prompt, device)
    pack, t_prep, memA, resA = measure_section(_prep)
    # print(f"[DEBUG] input_ids_len={pack['input_ids'].shape[1]}, "
    #         f"attn_mask_sum={pack['attention_mask'].sum().item()}")


    # 2) Prefill（一次前向，触发视觉塔+文本前缀）
    def _prefill():
        return model(**pack, use_cache=True)
    _, t_prefill, memB, resB = measure_section(_prefill)

    # 3) 生成
    def _gen():
        return model.generate(
            **pack,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            eos_token_id=proc.tokenizer.eos_token_id,
            pad_token_id=proc.tokenizer.eos_token_id,
            use_cache=True,
        )
    out_ids, t_gen, memC, resC = measure_section(_gen)
    L_text, L_vis = total_context_tokens(pack)
    w, h = pil_img.size
    # print(f"[TOKENS] total={L_text} (vision={L_vis}) {pil_img.size} image_size=({w}x{h})")

    txt = decode_generated(proc.tokenizer, out_ids, L_in=pack["input_ids"].shape[1])
    flops = compute_flops_report("BASE", model, pack, out_ids, t_gen, verbose=True)

    stat = {
        "t_prep_s": t_prep, "t_prefill_s": t_prefill, "t_gen_s": t_gen,
        "memA_alloc_MB": memA, "memA_resv_MB": resA,
        "memB_alloc_MB": memB, "memB_resv_MB": resB,
        "memC_alloc_MB": memC, "memC_resv_MB": resC,
        "total_s": t_prep + t_prefill + t_gen,
        "tokens_total": L_text, "tokens_text": L_text, "tokens_vision": L_vis,
    }
    stat.update(flops)
    return txt, stat

# -------------------------
# 运行：Two-Scale（全局 + ROI）
# -------------------------
@torch.no_grad()
def run_two_scale_segmented(model, proc, pil_img_global, pil_img_roi, prompt,
                            max_new_tokens=24, global_edge=224):
    device = model.device if hasattr(model, "device") else "cuda"

    def _prep():
        # 强制全局与 ROI 的最短边不被上采样
        ctx = enforce_image_budget(proc,
                                   shortest_edge=None,  # 让 two-image 自己大小生效
                                   max_pixels=None)
        with ctx:
            return build_mm_inputs_two(proc, pil_img_global, pil_img_roi, prompt, device)
    pack, t_prep, memA, resA = measure_section(_prep)

    def _prefill():
        return model(**pack, use_cache=True)
    _, t_prefill, memB, resB = measure_section(_prefill)

    def _gen():
        return model.generate(
            **pack,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            eos_token_id=proc.tokenizer.eos_token_id,
            pad_token_id=proc.tokenizer.eos_token_id,
            use_cache=True,
        )
    out_ids, t_gen, memC, resC = measure_section(_gen)
    L_text, L_vis = total_context_tokens(pack)
    # print(f"[TOKENS:Two-scale] total={L_text} (vision={L_vis})")

    txt = decode_generated(proc.tokenizer, out_ids, L_in=pack["input_ids"].shape[1])
    flops = compute_flops_report("BASE", model, pack, out_ids, t_gen, verbose=True)

    stat = {
        "t_prep_s": t_prep, "t_prefill_s": t_prefill, "t_gen_s": t_gen,
        "memA_alloc_MB": memA, "memA_resv_MB": resA,
        "memB_alloc_MB": memB, "memB_resv_MB": resB,
        "memC_alloc_MB": memC, "memC_resv_MB": resC,
        "total_s": t_prep + t_prefill + t_gen,
        "tokens_total": L_text, "tokens_text": L_text-L_vis, "tokens_vision": L_vis,
    }
    stat.update(flops)
    return txt, stat
# -------------------------
# 主程序
# -------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--voila_path",  type=str, required=True)
    ap.add_argument("--images_path", type=str, required=True)
    ap.add_argument("--train_config_path", type=str, required=True)
    ap.add_argument("--max_samples", type=int, default=10)

    # ROI 参数
    ap.add_argument("--mass", type=float, default=0.5, help="gaze 累计质量比例 [0,1]")
    ap.add_argument("--min_size", type=int, default=112, help="裁剪后最小分辨率（28 的倍数附近）")
    ap.add_argument("--align", type=int, default=14)

    # 生成参数
    ap.add_argument("--max_new_tokens", type=int, default=24)

    # 双尺度/全局尺寸
    ap.add_argument("--two_scale", action="store_true", help="启用全局+ROI 双图输入")
    ap.add_argument("--global_edge", type=int, default=224, help="全局图最短边")
    ap.add_argument("--enforce_baseline_edge", type=int, default=None,
                    help="Baseline 单图时强制处理器最短边（避免被上采样）；默认不强制")

    # 输出
    ap.add_argument("--out_csv", type=str, default="results_cropping_sparse.csv")
    ap.add_argument("--save_viz", action="store_true")
    ap.add_argument("--viz_dir", type=str, default="./viz_crop")

    # 模型加速
    ap.add_argument("--flash_attn", action="store_true", help="尝试使用 FlashAttention-2")

    ap.add_argument("--roi_only", action="store_true", help="只喂 ROI 单图（禁用 two-scale）")
    ap.add_argument("--baseline", action="store_true", help="只喂单图（禁用 two-scale）")

    args = ap.parse_args()

    # Dataset（tokenizer 只是让 __getitem__ 顺利通过）
    dummy_tok = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct")
    class Dummy:
        task="voila"; tokenizer=dummy_tok; max_src_length=256; max_tgt_length=256; seed=42; patch_image_size=IMAGE_SIZE
    ds = VoilaDataset(Dummy(), args.voila_path, args.images_path, args.train_config_path)
    anno = ds.dataset
    n = min(args.max_samples, len(ds))
    print(f"[Data] total={len(ds)}; run first {n} samples")

    # Model
    model, proc = load_model_and_processor(use_flash_attn=args.flash_attn)
    print(proc.image_processor)
    # # 直接改属性
    # proc.image_processor.merge_size = 1

    # #（可选）验证一下确实生效
    # print("merge_size =>", proc.image_processor.merge_size)

    proc.image_processor.size = {"shortest_edge": 14}
    proc.image_processor.min_pixels = 1500  # 依然对齐到14的倍数


    # CSV 头
    os.makedirs(os.path.dirname(args.out_csv) or ".", exist_ok=True)
    with open(args.out_csv, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            "idx","id","question","gt_answer", "caption",
            "baseline_txt","two_scale_txt","roi_txt",
            # # times
            # "base_prep_s","base_prefill_s","base_gen_s","base_total_s",
            # "ts_prep_s","ts_prefill_s","ts_gen_s","ts_total_s",
            # roi/global meta
            "roi_bbox", "roi_area", "global_edge", 
            # flops
            "flops", "flops_two_scale", "flops_roi",
            # flops
            "tokens", "tokens_two_scale", "tokens_roi", "tokens_vision"
        ])

    if args.save_viz:
        os.makedirs(args.viz_dir, exist_ok=True)

    # Loop
    for i in tqdm(range(n), desc="Processing samples"):
        ex = ds[i]
        # print(ex.keys())
        # print(ex["trace_tokens"])
        sid = ex["id"]; qa = anno[sid]
        pil_img = tensor_to_pil(ex["patch_images"][0,0])  # 224×224
        hm      = ex["trace_heatmap"]                     # [1,1,H,W]
        prompt  = qa.get("question","")
        gt      = qa.get("answer","")
        caption = qa.get("caption","")

        # Baseline（可选强制处理器最短边，避免被上采样）
        if args.baseline:
            # print("############################## Baseline ############################")
            # print(pil_img.size)
            base_txt, base_stat = run_baseline_segmented(
                model, proc, pil_img, prompt,
                max_new_tokens=args.max_new_tokens,
                enforce_size=args.enforce_baseline_edge
            )

        # ROI 裁剪（一次性算）
        bbox = gaze_bbox(hm, mass=args.mass, min_size=args.min_size, align=args.align)
        # print(f"[ROI] bbox={bbox}")
        img_roi = crop_and_resize(pil_img, bbox)
        y0,y1,x0,x1 = bbox
        roi_area = abs((y1 - y0) * (x1 - x0))


        # Two-Scale（全局低分 + ROI 高分），或单图 ROI（若你想只测 ROI，可以改成只传 img_roi）
        if args.roi_only:
            # 如果不启用 two-scale，就单图跑 ROI（并强制不被上采样）
            # print("############################## ROI only ############################")
            # print(img_roi.size)
            ts_txt, ts_stat = run_baseline_segmented(
                model, proc, img_roi, prompt,
                max_new_tokens=args.max_new_tokens,
                enforce_size=args.enforce_baseline_edge
            )
        
       
        if args.two_scale:
            # print("############################## TWO scale only ############################")
            img_global = pil_img.resize((args.global_edge, args.global_edge), Image.BICUBIC)
            two_txt, two_stat = run_two_scale_segmented(
                model, proc, img_global, img_roi, prompt,
                max_new_tokens=args.max_new_tokens,
                global_edge=args.global_edge
            )

        # 打印 "flops", "flops_two_scale", "flops_roi",
        print(f"\n[{i}] Q: {prompt}")
        print(f"GT: {gt}")
        if args.baseline:
            print(f"BASE   flops={base_stat['total_flops']:.3f} "
                f"-> {base_txt}")

        if args.two_scale:
            print(f"2SCALE flops={two_stat['total_flops']:.3f} "
                  f"-> {two_txt}")
       
        if args.roi_only:
            print(f"ROI    flops={ts_stat['total_flops']:.3f} "
                  f"-> {ts_txt} \n")


        # 写 CSV
        with open(args.out_csv, "a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([
                i, sid, prompt, gt, caption,
                base_txt, two_txt, ts_txt,
                # f"{base_stat['t_prep_s']:.3f}", f"{base_stat['t_prefill_s']:.3f}", f"{base_stat['t_gen_s']:.3f}", f"{base_stat['total_s']:.3f}",
                # f"{ts_stat['t_prep_s']:.3f}", f"{ts_stat['t_prefill_s']:.3f}", f"{ts_stat['t_gen_s']:.3f}", f"{ts_stat['total_s']:.3f}",
                bbox, roi_area, args.global_edge, 
                f"{base_stat['total_flops']:.1f}",  f"{two_stat['total_flops']:.1f}", f"{ts_stat['total_flops']:.1f}",
                f"{base_stat['tokens_total']:.1f}", f"{two_stat['tokens_total']:.1f}", f"{ts_stat['tokens_total']:.1f}", f"{ts_stat['tokens_vision']:.1f}"
            ])

        # 可选可视化
        if args.save_viz:
            prefix = os.path.join(args.viz_dir, f"{i:03d}")
            # print(f"[Viz] Saving to {prefix}")
            # 注意把 mass/align/bbox 传进去，保证 grid 选择与算法一致
            # save_image_and_heatmap(
            #     pil_img, hm, prefix, 
            #     crop_img=img_roi, 
            #     align=args.align, 
            #     rho=args.mass, 
            #     bbox=bbox
            # )
            save_image_and_heatmap(
                pil_img, hm, prefix,
                crop_img=img_roi,
                img_global=img_global,
                align=args.align,
                rho=args.mass,
                bbox=bbox,
                trace_xy_norm=anno[sid]["trace"],  # ✨ 新增：直接传 JSON 里的 normalized trace
                draw_trace_on="image",             # 只画在原图上；想两份就改 "both"
                trace_max_points=None              # 例如 200：只画前 200 个点
            )

    print(f"\n[Done] CSV -> {args.out_csv}")
    if args.save_viz:
        print(f"[Done] Visualizations -> {args.viz_dir}")

if __name__ == "__main__":
    torch.backends.cuda.matmul.allow_tf32 = True
    main()
