import argparse
import os
from typing import List
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

torch.set_grad_enabled(False)


# ---------------------- 工具函数 ----------------------
def count_vision_tokens(processor, input_ids: torch.Tensor) -> int:
    """统计输入中视觉相关特殊 token 数（仅用于自检）"""
    vocab = processor.tokenizer.get_vocab()
    specials = ["<|vision_start|>", "<|image_pad|>", "<|vision_end|>", "<|vision_pad|>"]
    ids = input_ids[0].tolist()
    sids = {vocab[s] for s in specials if s in vocab}
    return sum(1 for t in ids if t in sids)


def build_inputs(processor, image_path: str, question: str, device):
    """严格对齐官方流程：messages -> chat_template(text) -> process_vision_info -> processor(...)"""
    if not image_path.startswith("file://"):
        image_uri = "file://" + os.path.abspath(image_path)
    else:
        image_uri = image_path

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image_uri},
                {"type": "text", "text": question},
            ],
        }
    ]

    # 1) 文本模板里会插入 <|vision_start|><|image_pad|><|vision_end|>
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # 2) 视觉张量
    image_inputs, video_inputs = process_vision_info(messages)

    # 3) 一起喂给 processor
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to(device)

    # 自检：视觉 token 必须 > 0
    n_vis = count_vision_tokens(processor, inputs["input_ids"])
    if n_vis == 0:
        print("[DEBUG] chat template text ↓↓↓")
        print(text)
        raise RuntimeError("No vision tokens in input_ids; check messages/template.")
    return inputs


def decode_new_only(processor, gen_ids, prompt_len):
    """只解码新增生成部分"""
    tail = gen_ids[0, prompt_len:]
    return processor.tokenizer.decode(tail, skip_special_tokens=True).strip()


# ---------------------- 基线：原生 generate ----------------------
@torch.no_grad()
def generate_baseline(model, processor, inputs, max_new_tokens=128, do_sample=False, temperature=0.7, top_p=0.9, min_new_tokens=0):
    prompt_len = inputs["input_ids"].shape[1]
    out_ids = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        min_new_tokens=min_new_tokens,
        do_sample=bool(do_sample),
        temperature=float(temperature),
        top_p=float(top_p),
        use_cache=True,
    )
    new_only = out_ids[:, prompt_len:]
    text = processor.tokenizer.decode(new_only[0], skip_special_tokens=True).strip()
    return text, new_only[0].tolist()


# ---------------------- Replay：无图 + 重放前缀后续写 ----------------------
@torch.no_grad()
def generate_replay_after_n(
    model,
    processor,
    question: str,
    baseline_tokens: List[int],
    n_prefix: int,
    max_new_tokens=128,
    do_sample=False,
    temperature=0.7,
    top_p=0.9,
    min_new_tokens=0
):
    tok = processor.tokenizer
    device = model.device
    # 用 baseline 的前 n_prefix 个 token 还原文本前缀
    prefix_text = tok.decode(torch.tensor(baseline_tokens[:n_prefix], device=device), skip_special_tokens=True)

    # 无图对话
    messages = [{"role": "user", "content": [{"type": "text", "text": question}]}]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # 文字-only 输入：把前缀文本写回 assistant 开始处
    prompt_ids = tok(text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
    in_ids = prompt_ids if not prefix_text else torch.cat(
        [prompt_ids, tok(prefix_text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)], dim=1
    )
    attn = torch.ones_like(in_ids, dtype=torch.long, device=device)

    out_ids = model.generate(
        input_ids=in_ids,
        attention_mask=attn,
        max_new_tokens=max_new_tokens,
        min_new_tokens=min_new_tokens,
        do_sample=bool(do_sample),
        temperature=float(temperature),
        top_p=float(top_p),
        use_cache=True,
    )
    # 1. 获取续写部分
    cont = decode_new_only(processor, out_ids, in_ids.shape[1])

    # 2. 打印“前情提要”（Replay 之前的部分）
    print(f"    [Replay 前缀]: {prefix_text}")
    
    # 3. 打印“续写内容”（Replay 之后的部分）
    print(f"    [Replay 续写]: {cont}")
    
    # 4. 打印一个分隔符，让最终的完整输出更清晰
    print("    ---- 完整结果 ----")
    return (prefix_text + cont).strip()


# ---------------------- 主函数 ----------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model_id", type=str, default="Qwen/Qwen2.5-VL-32B-Instruct")
    ap.add_argument("--image_path", type=str, required=True)
    ap.add_argument("--question", type=str, required=True)
    ap.add_argument("--n_prefix", type=int, default=10, help="用于 Replay 方法的前缀 token 数")
    ap.add_argument("--max_new_tokens", type=int, default=128)
    ap.add_argument("--min_new_tokens", type=int, default=0)
    ap.add_argument("--do_sample", type=int, default=0)
    ap.add_argument("--temperature", type=float, default=0.7)
    ap.add_argument("--top_p", type=float, default=0.9)
    ap.add_argument("--seed", type=int, default=42)
    args = ap.parse_args()

    # 可选：设定随机种子保证可复现
    try:
        import random
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
    except Exception:
        pass

    print(f"[INFO] Loading model: {args.model_id}")
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        args.model_id, torch_dtype="auto", device_map="auto"
    )
    processor = AutoProcessor.from_pretrained(args.model_id)

    # 构造输入（严格照官方流程）
    inputs = build_inputs(processor, args.image_path, args.question, model.device)

    # 基线
    print("\n================ 基线（无干预） ================")
    base_text, base_tok = generate_baseline(
        model, processor, inputs,
        max_new_tokens=args.max_new_tokens,
        min_new_tokens=args.min_new_tokens,
        do_sample=bool(args.do_sample),
        temperature=args.temperature,
        top_p=args.top_p
    )
    print(base_text)
    
    # Replay（无图重放前缀）
    print("\n================ 干预（Replay：无图重放前缀） ================")
    replay = generate_replay_after_n(
        model, processor, args.question, base_tok, args.n_prefix,
        max_new_tokens=args.max_new_tokens,
        min_new_tokens=args.min_new_tokens,
        do_sample=bool(args.do_sample),
        temperature=args.temperature,
        top_p=args.top_p
    )
    print(replay)


if __name__ == "__main__":
    main()