#!/usr/bin/env python
import argparse
import json
import math
import os
import re
from typing import List, Optional

import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

import sys

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
EXPERIMENTS_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ROOT_DIR)
sys.path.append(EXPERIMENTS_DIR)
sys.path.append(os.path.join(EXPERIMENTS_DIR, "halc_utils"))
sys.path.insert(0, os.path.join(EXPERIMENTS_DIR, "transformers-4.49.0", "src"))

from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, set_seed
from transformers.cache_utils import DynamicCache


def apply_D_to_cache(cache, D, cut_idx, start_layer_idx, end_layer_idx, key_only, value_only, key_value, num_image_tokens, dscr_lambda=1.0):
    """Apply DSCR refinement matrix D to KV cache."""
    if D.dim() == 4:
        D = D.squeeze(0).squeeze(0)
    elif D.dim() == 2:
        pass
    else:
        raise ValueError(f"Unsupported D shape: {tuple(D.shape)}")

    start_layer_idx = int(start_layer_idx)
    end_layer_idx = int(end_layer_idx)
    dscr_lambda = float(max(0.0, min(1.0, dscr_lambda)))

    out = []
    D_cache = {}

    for layer_idx, (k, v) in enumerate(cache):
        if start_layer_idx <= layer_idx < end_layer_idx:
            if key_only or key_value:
                tag = (k.device, k.dtype)
                Dk = D_cache.get(tag)
                if Dk is None:
                    Dk = D.to(device=k.device, dtype=k.dtype)
                    D_cache[tag] = Dk

                seg = k[:, :, cut_idx:cut_idx + num_image_tokens, :]
                seg_refined = torch.einsum("ij,bhjd->bhid", Dk, seg)
                seg_mixed = (1.0 - dscr_lambda) * seg + dscr_lambda * seg_refined
                k = torch.cat(
                    (k[:, :, :cut_idx, :], seg_mixed, k[:, :, cut_idx + num_image_tokens:, :]),
                    dim=2
                )

            if value_only or key_value:
                tag = (v.device, v.dtype)
                Dv = D_cache.get(tag)
                if Dv is None:
                    Dv = D.to(device=v.device, dtype=v.dtype)
                    D_cache[tag] = Dv

                seg = v[:, :, cut_idx:cut_idx + num_image_tokens, :]
                seg_refined = torch.einsum("ij,bhjd->bhid", Dv, seg)
                seg_mixed = (1.0 - dscr_lambda) * seg + dscr_lambda * seg_refined
                v = torch.cat(
                    (v[:, :, :cut_idx, :], seg_mixed, v[:, :, cut_idx + num_image_tokens:, :]),
                    dim=2
                )

        out.append((k, v))

    return tuple(out)


def recompute_first_logits_from_refined_cache(model, input_ids, cache, attention_mask=None):
    """Refined cache is tuple (legacy); Qwen2.5-VL forward expects Cache with .get_seq_length()."""
    total_keys = [kv[0] for kv in cache]
    total_values = [kv[1] for kv in cache]

    prefix_keys = [k[:, :, :-1, :].contiguous() for k in total_keys]
    prefix_values = [v[:, :, :-1, :].contiguous() for v in total_values]
    prefix_cache_tuple = tuple((k, v) for k, v in zip(prefix_keys, prefix_values))
    prefix_cache = DynamicCache.from_legacy_cache(prefix_cache_tuple)

    last_token_ids = input_ids[:, -1:]
    last_attention_mask = attention_mask[:, -1:] if attention_mask is not None else None
    with torch.inference_mode():
        outputs_last = model(
            input_ids=last_token_ids,
            attention_mask=last_attention_mask,
            pixel_values=None,
            image_grid_thw=None,
            use_cache=True,
            past_key_values=prefix_cache,
            return_dict=True,
        )
    return outputs_last.past_key_values, outputs_last.logits[:, -1, :]


def load_questions(path: str) -> List[dict]:
    with open(path, "r", encoding="utf-8") as f:
        first = f.read(1)
        f.seek(0)
        if first == "[":
            return json.load(f)
        return [json.loads(line) for line in f if line.strip()]


def discover_sets(gt_dir: str) -> List[str]:
    sets = []
    for f in os.listdir(gt_dir):
        if f.endswith(".json") or f.endswith(".jsonl"):
            sets.append(f.rsplit(".", 1)[0])
    return sorted(list(set(sets)))


def pick_img_dir(img_root: str, dataset: str) -> str:
    if os.path.isdir(os.path.join(img_root, dataset)):
        return os.path.join(img_root, dataset)
    if os.path.isdir(os.path.join(img_root, dataset.lower())):
        return os.path.join(img_root, dataset.lower())
    if dataset.lower() == "posters":
        if os.path.isdir(os.path.join(img_root, "posters")):
            return os.path.join(img_root, "posters")
        if os.path.isdir(os.path.join(img_root, "Posters")):
            return os.path.join(img_root, "Posters")
    return img_root


def pick_depth_dir(depth_root: str, dataset: str) -> str:
    return pick_img_dir(depth_root, dataset)


def normalize_format(fmt: str) -> str:
    if fmt == "yn_format":
        return " Please only answer yes or no."
    if fmt == "ow_format":
        return " Please answer this question with one word."
    return ""


def normalize_yn_answer(text: str) -> str:
    cleaned = text.strip().lower()
    match = re.search(r"\b(yes|no)\b", cleaned)
    return match.group(1) if match else cleaned


def build_yn_prefix_allowed_tokens_fn(tokenizer, input_len: int):
    candidates = ["yes", "no", "Yes", "No", " yes", " no"]
    allowed = set()
    for text in candidates:
        token_ids = tokenizer.encode(text, add_special_tokens=False)
        if token_ids:
            allowed.add(token_ids[0])
    if not allowed:
        return None

    eos_id = getattr(tokenizer, "eos_token_id", None)
    if eos_id is None:
        eos_id = getattr(tokenizer, "eod_id", None)

    def _fn(batch_id, input_ids):
        cur_len = input_ids.shape[-1] if hasattr(input_ids, "shape") else len(input_ids)
        if cur_len == input_len:
            return list(allowed)
        return [eos_id] if eos_id is not None else list(allowed)

    return _fn



def main() -> None:
    parser = argparse.ArgumentParser("MME (Qwen2.5-VL)")
    parser.add_argument("--model-path", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct")
    parser.add_argument("--gt-dir", type=str, required=True)
    parser.add_argument("--image-root", type=str, required=True)
    parser.add_argument("--out-root", type=str, required=True)
    parser.add_argument("--depth-root", type=str, default=None)
    parser.add_argument("--datasets", nargs="*", default=None)

    parser.add_argument("--method", type=str, default="baseline",
                        choices=["baseline", "vcd", "opera", "halc", "damo", "agla"])
    parser.add_argument("--format", type=str, default="no_format",
                        choices=["no_format", "ow_format", "yn_format"])
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max-new-tokens", type=int, default=2)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top-p", dest="top_p", type=float, default=1.0)
    parser.add_argument("--top-k", dest="top_k", type=int, default=None)
    parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16"])
    parser.add_argument("--gpu-id", type=int, default=0)
    parser.add_argument("--device-map", type=str, default=None,
                        help="e.g. 'auto' or 'balanced' to load model across multiple GPUs")
    parser.add_argument("--max-image-side", type=int, default=None,
                        help="If set, resize image so max(w,h) <= this (avoids OOM with device_map on large images)")
    # --opera-max-image-side removed: OPERA now pre-computes KV cache
    # from full-resolution image (no resize needed)
    parser.add_argument("--max-image-pixels", type=int, default=28*28*2048,
                        help="Processor max_pixels for smart_resize (default 1605632≈1268^2). "
                             "Increase for higher quality, decrease to avoid OOM.")
    parser.add_argument("--rank", type=int, default=0,
                        help="Rank of this process (0..world_size-1) for multi-GPU data parallelism")
    parser.add_argument("--world-size", type=int, default=1,
                        help="Total number of processes for multi-GPU data parallelism")

    # DSCR
    parser.add_argument("--use-dscr", action="store_true", default=False)
    parser.add_argument("--dscr-alpha", type=float, default=0.8)
    parser.add_argument("--dscr-beta", type=float, default=0.6)
    parser.add_argument("--dscr-sigma", type=float, default=0.5)
    parser.add_argument("--dscr-keep-ratio", type=float, default=1.0)
    parser.add_argument("--dscr-lambda", type=float, default=1.0)
    parser.add_argument("--dscr-self-keep", type=float, default=1.0)
    parser.add_argument("--dscr-start-layer", type=int, default=0)
    parser.add_argument("--dscr-end-layer", type=int, default=None)
    parser.add_argument("--dscr-key-only", action="store_true", default=False)
    parser.add_argument("--dscr-value-only", action="store_true", default=False)
    parser.add_argument("--dscr-key-value", action="store_true", default=False)

    # VCD
    parser.add_argument("--noise-step", type=int, default=500)
    parser.add_argument("--cd-alpha", type=float, default=0.5)
    parser.add_argument("--cd-beta", type=float, default=0.1)

    # OPERA
    parser.add_argument("--beam", type=int, default=5)
    parser.add_argument("--scale-factor", type=float, default=50.0)
    parser.add_argument("--threshold", type=int, default=15)
    parser.add_argument("--num-attn-candidates", type=int, default=5)
    parser.add_argument("--penalty-weights", type=float, default=0.6)

    # HALC
    parser.add_argument("--halc-detector", type=str, default="dino", choices=["dino", "owlv2"])
    parser.add_argument("--halc-box-threshold", type=float, default=0.45)
    parser.add_argument("--halc-k-candidate-num", type=int, default=4)
    parser.add_argument("--halc-context-window", type=int, default=4)
    parser.add_argument("--halc-expand-ratio", type=float, default=0.6)
    parser.add_argument("--halc-context-domain", type=str, default="upper", choices=["upper", "lower"])
    parser.add_argument("--halc-contrast-weight", type=float, default=0.05)
    parser.add_argument(
        "--halc-score-type",
        type=str,
        default="BLIP",
        choices=["CLIP", "BLIP", "Random", "Perplexity", "HPSv2"],
    )
    parser.add_argument("--halc-debugger", type=int, default=0, choices=[0, 1, 2])
    parser.add_argument("--halc-mature-layer", type=int, default=None)
    parser.add_argument("--halc-base-layer", type=int, default=0)
    parser.add_argument("--halc-candidate-layers", type=int, nargs="*", default=None)
    parser.add_argument("--halc-relative-top", type=float, default=0.1)
    parser.add_argument("--halc-beam-search", action="store_true", default=False)
    parser.add_argument("--halc-num-beams", type=int, default=1)
    parser.add_argument("--halc-check-all-entities", action="store_true", default=False,
                        help="Bypass POS filter in HALC so hyperparameters affect yes/no tasks")
    # DAMO
    parser.add_argument("--tau", type=float, default=-0.3)
    parser.add_argument("--beta-1", dest="beta_1", type=float, default=0.05)
    parser.add_argument("--beta-2", dest="beta_2", type=float, default=0.20)
    parser.add_argument("--alpha", type=float, default=0.7)
    parser.add_argument("--damo-start-layer", dest="damo_start_layer", type=int, default=14,
                        help="DAMO momentum starts at this layer (default 14 for Qwen2.5VL/28 layers, 50%%)")

    # AGLA
    parser.add_argument("--agla-alpha", type=float, default=2.0)
    parser.add_argument("--agla-beta", type=float, default=0.5)
    parser.add_argument("--run-name", type=str, default=None)
    parser.add_argument("--debug-first-step", action="store_true", default=False,
                        help="Debug: print first step logits comparison for AGLA")

    args = parser.parse_args()
    set_seed(args.seed)

    if args.method in ("halc, opera"):
        args.format = "yn_format"

    if args.use_dscr and not args.depth_root:
        raise ValueError("--depth-root is required when --use-dscr is set")

    fmt = normalize_format(args.format)
    use_device_map = getattr(args, "device_map", None) and str(args.device_map).strip().lower() in ("auto", "balanced", "sequential")
    if use_device_map:
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            args.model_path,
            torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float16,
            device_map=args.device_map.strip().lower(),
        )
        device = next(model.parameters()).device
    else:
        device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            args.model_path,
            torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float16,
        ).to(device)
    model.eval()

    # Set processor max_pixels for image resolution control
    # OPERA now pre-computes KV cache before beam expansion, so no extra VRAM penalty.
    use_opera = args.method == "opera"
    effective_max_pixels = args.max_image_pixels
    processor = AutoProcessor.from_pretrained(
        args.model_path,
        max_pixels=effective_max_pixels,
    )
    print(f"[Config] processor max_pixels={effective_max_pixels} "
          f"(≈{int(effective_max_pixels**0.5)}x{int(effective_max_pixels**0.5)} square equiv)")
    image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)

    if args.dscr_end_layer is None:
        args.dscr_end_layer = getattr(model.config, "num_hidden_layers", 28)

    use_halc = args.method == "halc"
    use_vcd = args.method == "vcd"
    use_opera = args.method == "opera"
    use_damo = args.method == "damo"
    use_agla = args.method == "agla"

    add_diffusion_noise = None
    if use_vcd:
        from vcd_utils.vcd_add_noise import add_diffusion_noise
        from vcd_utils.vcd_sample import evolve_vcd_sampling_qwen25

        evolve_vcd_sampling_qwen25()  # Qwen2.5-VL uses transformers>=4.49 _sample signature

    if use_agla:
        from agla_utils.agla_sample import evolve_agla_sampling_qwen25
        from agla_utils.augmentation import augmentation
        from lavis.models import load_model_and_preprocess
        from torchvision import transforms

        evolve_agla_sampling_qwen25()
        blip_device = "cuda" if torch.cuda.is_available() else "cpu"
        blip_model, blip_vis_processors, blip_text_processors = load_model_and_preprocess(
            "blip_image_text_matching", "large", device=blip_device, is_eval=True
        )
        blip_loader = transforms.Compose([transforms.ToTensor()])

    halc = None
    if use_halc:
        from decoder_zoo.HALC.context_density.halc import halc_assistant

        halc_params = {
            "detector": args.halc_detector,
            "box_threshold": args.halc_box_threshold,
            "debugger": args.halc_debugger,
            "k_candidate_num": args.halc_k_candidate_num,
            "LVLM_backbone": "qwen2.5-vl",
            "score_type": args.halc_score_type,
            "context_window": args.halc_context_window,
            "expand_ratio": args.halc_expand_ratio,
            "context_domain": args.halc_context_domain,
            "contrast_weight": args.halc_contrast_weight,
            "tokenizer": processor.tokenizer,
            "max_image_pixels": effective_max_pixels,
            "check_all_entities": args.halc_check_all_entities,
        }

        def _vis_processor(img):
            # Use image_processor directly (not the full processor which requires text)
            return processor.image_processor(images=[img], return_tensors="pt")

        halc = halc_assistant(
            model=model,
            vis_processor=_vis_processor,
            device=device,
            halc_params=halc_params,
            max_new_tokens=args.max_new_tokens,
        )

    if args.datasets:
        datasets = args.datasets
    else:
        datasets = discover_sets(args.gt_dir)

    # Multi-GPU data parallelism: each process handles a subset of datasets
    world_size = max(1, getattr(args, "world_size", 1))
    rank = max(0, min(getattr(args, "rank", 0), world_size - 1))
    if world_size > 1:
        datasets = [d for i, d in enumerate(datasets) if i % world_size == rank]
        if not datasets:
            return
        print(f"[rank {rank}/{world_size}] processing {len(datasets)} dataset(s): {datasets}")

    run_name = args.run_name or f"qwen25_{args.method}_seed{args.seed}"
    os.makedirs(args.out_root, exist_ok=True)

    debug_printed = False

    def log_once(msg: str) -> None:
        nonlocal debug_printed
        if not debug_printed:
            print(msg)
            debug_printed = True

    for dataset in datasets:
        qfile_json = os.path.join(args.gt_dir, f"{dataset}.json")
        qfile_jsonl = os.path.join(args.gt_dir, f"{dataset}.jsonl")
        qfile = qfile_json if os.path.isfile(qfile_json) else qfile_jsonl
        if not os.path.isfile(qfile):
            print(f"[skip] missing questions for {dataset}")
            continue

        questions = load_questions(qfile)
        img_dir = pick_img_dir(args.image_root, dataset)
        depth_dir = pick_depth_dir(args.depth_root, dataset) if args.use_dscr else None
        answers_file = os.path.join(args.out_root, f"{run_name}_{dataset}.jsonl")

        with open(answers_file, "w", encoding="utf-8") as ans_file:
            for q in tqdm(questions, desc=f"{dataset}"):
                qid = q.get("question_id", q.get("id"))
                image_file = q.get("image", "")
                question = q.get("text", q.get("question", ""))
                cur_prompt = question

                image_path = os.path.join(img_dir, image_file)
                image = Image.open(image_path).convert("RGB")
                
                # Pre-resize safety net: cap max dimension so PIL doesn't load
                # huge images into CPU RAM. The processor's smart_resize handles
                # the actual resolution constraint via max_pixels.
                # Set safety cap ~20% above what smart_resize would produce.
                MAX_IMAGE_DIM = int(effective_max_pixels ** 0.5 * 1.2)
                w, h = image.size
                max_dim = max(w, h)
                if max_dim > MAX_IMAGE_DIM:
                    scale = MAX_IMAGE_DIM / max_dim
                    nw, nh = int(w * scale), int(h * scale)
                    image = image.resize((nw, nh), Image.Resampling.LANCZOS)
                    print(f"[Resize] {image_file}: {w}x{h} -> {nw}x{nh} (pre-resize safety)")
                
                max_side = getattr(args, "max_image_side", None)
                if max_side is not None and max_side > 0:
                    w, h = image.size
                    m = max(w, h)
                    if m > max_side:
                        scale = max_side / m
                        nw, nh = int(round(w * scale)), int(round(h * scale))
                        image = image.resize((nw, nh), Image.Resampling.LANCZOS)

                messages = [
                    {
                        "role": "user",
                        "content": [
                            {"type": "image"},
                            {"type": "text", "text": question + fmt},
                        ],
                    }
                ]
                text = processor.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                )

                inputs = processor(text=[text], images=[image], return_tensors="pt")
                input_ids = inputs["input_ids"].to(device)
                attention_mask = inputs["attention_mask"].to(device)
                pixel_values = inputs["pixel_values"].to(device, dtype=model.dtype)
                image_grid_thw = inputs["image_grid_thw"].to(device)
                image_positions = torch.nonzero(input_ids[0] == image_token_id, as_tuple=False).squeeze(-1)
                image_start_idx = int(image_positions[0].item()) if image_positions.numel() > 0 else -1
                image_token_len = int(image_positions.numel())

                cache_clean = None
                if args.use_dscr:
                    # Build D matrix (copied from llava4, adapted for variable image token length)
                    import torch.nn.functional as F
                    import math
                    depth_path = os.path.join(depth_dir, os.path.splitext(image_file)[0] + ".npy")
                    depth = np.load(depth_path, allow_pickle=True)
                    depth_tensor = torch.tensor(depth, dtype=torch.float32).unsqueeze(0).unsqueeze(1)
                    
                    grid_size = int(math.sqrt(image_token_len))
                    if grid_size * grid_size == image_token_len:
                        depth_patch = F.interpolate(depth_tensor, size=(grid_size, grid_size), mode="bilinear", align_corners=False)
                    else:
                        depth_patch = depth_tensor.reshape(depth_tensor.shape[0], depth_tensor.shape[1], -1)
                        depth_patch = F.interpolate(depth_patch, size=image_token_len, mode="linear", align_corners=False)
                        depth_patch = depth_patch.view(1, 1, -1, 1)
                    
                    depth_patch = torch.clamp(1.0e-6 + 1.0 / depth_patch, 0.001, 1000).view(-1)
                    
                    dmin = depth_patch.min()
                    dmax = depth_patch.max()
                    depth_patch = (depth_patch - dmin) / (dmax - dmin + 1e-6)
                    
                    depth_diff = torch.abs(depth_patch.unsqueeze(0) - depth_patch.unsqueeze(1))
                    gaussian_weight_depth = torch.exp(-(depth_diff ** 2) / (2 * (args.dscr_sigma ** 2) + 1e-12))
                    
                    if grid_size * grid_size == image_token_len:
                        pixel_positions = torch.tensor([[i // grid_size, i % grid_size] for i in range(image_token_len)], dtype=torch.float32)
                        pixel_positions = pixel_positions / float(grid_size - 1)
                    else:
                        pixel_positions = torch.linspace(0, 1, image_token_len, dtype=torch.float32).unsqueeze(1)
                    position_diff = torch.cdist(pixel_positions, pixel_positions, p=2)
                    gaussian_weight_position = torch.exp(-(position_diff ** 2) / (2 * (args.dscr_sigma ** 2) + 1e-12))
                    
                    gaussian_weight = (gaussian_weight_depth ** args.dscr_alpha) + (gaussian_weight_position ** args.dscr_beta)
                    
                    # Apply keep_ratio (self_keep) if needed
                    keep_ratio = getattr(args, "dscr_keep_ratio", None) or getattr(args, "dscr_self_keep", 1.0)
                    if keep_ratio is not None and keep_ratio < 1.0:
                        eye = torch.eye(image_token_len, dtype=gaussian_weight.dtype, device=gaussian_weight.device)
                        gaussian_weight = gaussian_weight * (1.0 - eye) + keep_ratio * eye
                    
                    D = gaussian_weight / (gaussian_weight.sum(dim=-1, keepdim=True) + 1e-6)
                    
                    # Forward to get initial cache
                    with torch.inference_mode():
                        kv_out = model(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                            pixel_values=pixel_values,
                            image_grid_thw=image_grid_thw,
                            use_cache=True,
                            return_dict=True,
                        )
                    
                    # Apply D to cache
                    dscr_key_value = args.dscr_key_value or (
                        not args.dscr_key_only and not args.dscr_value_only
                    )
                    # OPERA+DSCR: skip last layer so OPERA sees original attention
                    num_hidden_layers = getattr(model.config, "num_hidden_layers", 28)
                    dscr_end_eff = int(args.dscr_end_layer) if args.dscr_end_layer is not None else num_hidden_layers
                    if use_opera:
                        dscr_end_eff = min(dscr_end_eff, num_hidden_layers - 1)
                    cache_clean = apply_D_to_cache(
                        cache=kv_out.past_key_values,
                        D=D,
                        cut_idx=image_start_idx,
                        start_layer_idx=int(args.dscr_start_layer),
                        end_layer_idx=dscr_end_eff,
                        key_only=bool(args.dscr_key_only),
                        value_only=bool(args.dscr_value_only),
                        key_value=bool(dscr_key_value),
                        num_image_tokens=image_token_len,
                        dscr_lambda=args.dscr_lambda,
                    )
                    
                    # Recompute first logits
                    cache_clean, first_logits_clean = recompute_first_logits_from_refined_cache(
                        model=model,
                        input_ids=input_ids,
                        cache=cache_clean,
                        attention_mask=attention_mask,
                    )

                gen_dscr_kwargs = {}
                if cache_clean is not None:
                    gen_dscr_kwargs = {"past_key_values": cache_clean, "use_cache": True}

                # AGLA augmentation (BLIP GradCAM) must run outside inference_mode so hooks can be registered
                images_cd_agla = None
                image_grid_thw_cd_agla = None
                input_ids_cd_agla = None
                if use_agla:
                    tensor_image = blip_loader(image.resize((384, 384)))
                    blip_image = blip_vis_processors["eval"](image).unsqueeze(0).to(blip_device)
                    blip_question = blip_text_processors["eval"](question + fmt)
                    tokenized_text = blip_model.tokenizer(
                        blip_question, padding="longest", truncation=True, return_tensors="pt"
                    ).to(blip_device)
                    augmented_image = augmentation(
                        blip_image, blip_question, tensor_image, blip_model, tokenized_text, image
                    )
                    # Qwen2.5-VL dynamic resolution: resize augmented image to match
                    # the original image size so the processor produces the same
                    # image_grid_thw (and hence the same number of image tokens).
                    # Without this, different token counts between clean and augmented
                    # break the generation loop's cache_position tracking after prefill.
                    augmented_image = augmented_image.resize(image.size)
                    aug_inputs = processor(text=[text], images=[augmented_image], return_tensors="pt")
                    images_cd_agla = aug_inputs["pixel_values"].to(device, dtype=model.dtype)
                    image_grid_thw_cd_agla = aug_inputs["image_grid_thw"].to(device)
                    input_ids_cd_agla = aug_inputs["input_ids"].to(device)

                dscr_note = ""
                if args.use_dscr:
                    dscr_key_value = args.dscr_key_value or (
                        not args.dscr_key_only and not args.dscr_value_only
                    )
                    keep_ratio = getattr(args, "dscr_keep_ratio", 1.0)
                    dscr_note = (
                        f" | DSCR img_start={image_start_idx} img_len={image_token_len} "
                        f"a={args.dscr_alpha} b={args.dscr_beta} s={args.dscr_sigma} "
                        f"k={keep_ratio} l={args.dscr_lambda} "
                        f"layers={args.dscr_start_layer}-{args.dscr_end_layer} "
                        f"key_only={args.dscr_key_only} value_only={args.dscr_value_only} "
                        f"key_value={dscr_key_value}"
                    )

                with torch.inference_mode():
                    if use_opera:
                        positions = image_positions
                        if positions.numel() == 0:
                            raise ValueError("No image tokens found for OPERA key_position.")
                        # response_start = end of prompt; in the full-attention matrix
                        # this is where generated tokens start (matching LLaVA).
                        key_position = {
                            "image_start": int(positions[0].item()),
                            "image_end": int(positions[-1].item()),
                            "response_start": input_ids.shape[1],
                        }
                        # OPERA: keep original settings even with DSCR
                        do_sample_opera = False
                        temp_opera = float(args.temperature)
                        log_once(
                            f"[OPERA] do_sample={do_sample_opera} temperature={temp_opera} "
                            f"key_position={key_position} beam={args.beam} "
                            f"scale_factor={args.scale_factor} threshold={args.threshold} "
                            f"num_attn_candidates={args.num_attn_candidates} penalty_weights={args.penalty_weights}"
                            f"{dscr_note}"
                        )
                        input_len = input_ids.shape[1]
                        yn_prefix_fn = (
                            build_yn_prefix_allowed_tokens_fn(processor.tokenizer, input_len)
                            if args.format == "yn_format"
                            else None
                        )

                        from transformers.cache_utils import DynamicCache as _OPERA_DC
                        if cache_clean is not None:
                            # DSCR mode: use DSCR-refined cache
                            opera_cache = cache_clean
                        else:
                            # Non-DSCR mode: forward full image once to get KV cache
                            kv_opera = model(
                                input_ids=input_ids,
                                attention_mask=attention_mask,
                                pixel_values=pixel_values,
                                image_grid_thw=image_grid_thw,
                                use_cache=True,
                                return_dict=True,
                            )
                            opera_cache = kv_opera.past_key_values

                        trimmed = _OPERA_DC()
                        for _li in range(len(opera_cache)):
                            _k, _v = opera_cache[_li]
                            trimmed.key_cache.append(_k[:, :, :-1, :].contiguous())
                            trimmed.value_cache.append(_v[:, :, :-1, :].contiguous())
                        trimmed._seen_tokens = trimmed.key_cache[0].shape[2]
                        model._opera_dscr_cache = trimmed
                        output_ids = model.generate(
                            input_ids=input_ids,
                            pixel_values=None,
                            image_grid_thw=None,
                            do_sample=do_sample_opera,
                            temperature=temp_opera,
                            top_p=float(args.top_p),
                            top_k=args.top_k,
                            max_new_tokens=args.max_new_tokens,
                            prefix_allowed_tokens_fn=yn_prefix_fn,
                            key_position=key_position,
                            opera_decoding=True,
                            num_beams=args.beam,
                            output_attentions=True,
                            scale_factor=args.scale_factor,
                            threshold=args.threshold,
                            num_attn_candidates=args.num_attn_candidates,
                            penalty_weights=args.penalty_weights,
                            use_cache=True,
                        )
                        if hasattr(model, "_opera_dscr_cache"):
                            delattr(model, "_opera_dscr_cache")
                    elif use_halc:
                        mature_layer = (
                            args.halc_mature_layer
                            if args.halc_mature_layer is not None
                            else int(getattr(model.config, "num_hidden_layers", 32)) - 1
                        )
                        candidate_layers = args.halc_candidate_layers
                        base_layer = None if candidate_layers else args.halc_base_layer
                        # HALC: keep original settings even with DSCR
                        do_sample_halc = False
                        temp_halc = float(args.temperature)
                        log_once(
                            f"[HALC] do_sample={do_sample_halc} temperature={temp_halc} "
                            f"mature_layer={mature_layer} base_layer={base_layer} "
                            f"candidate_layers={candidate_layers} relative_top={args.halc_relative_top} "
                            f"detector={args.halc_detector} expand_ratio={args.halc_expand_ratio} "
                            f"context_window={args.halc_context_window} context_domain={args.halc_context_domain} "
                            f"contrast_weight={args.halc_contrast_weight} score_type={args.halc_score_type} "
                            f"box_threshold={args.halc_box_threshold} beam_search={args.halc_beam_search} "
                            f"num_beams={args.halc_num_beams} "
                            f"check_all_entities={args.halc_check_all_entities}"
                            f"{dscr_note}"
                        )

                        dscr_params_for_halc = None
                        depth_np_for_halc = None
                        if args.use_dscr:
                            depth_path_halc = os.path.join(depth_dir, os.path.splitext(image_file)[0] + ".npy")
                            depth_np_for_halc = np.load(depth_path_halc, allow_pickle=True)
                            dscr_key_value_halc = args.dscr_key_value or (
                                not args.dscr_key_only and not args.dscr_value_only
                            )
                            dscr_params_for_halc = {
                                "alpha": args.dscr_alpha,
                                "beta": args.dscr_beta,
                                "sigma": args.dscr_sigma,
                                "keep_ratio": args.dscr_keep_ratio,
                                "lambda": args.dscr_lambda,
                                "start_layer": args.dscr_start_layer,
                                "end_layer": args.dscr_end_layer,
                                "key_only": args.dscr_key_only,
                                "value_only": args.dscr_value_only,
                                "key_value": dscr_key_value_halc,
                            }
                        halc.update_input(
                            image_path, text,
                            depth_np=depth_np_for_halc,
                            dscr_params=dscr_params_for_halc,
                            original_grid_thw=image_grid_thw,
                        )
                        # Build yn_prefix_fn for HALC (same as OPERA)
                        input_len_halc = input_ids.shape[1]
                        yn_prefix_fn_halc = (
                            build_yn_prefix_allowed_tokens_fn(processor.tokenizer, input_len_halc)
                            if args.format == "yn_format"
                            else None
                        )
                        halc_dscr_kwargs = {}
                        if args.use_dscr and depth_np_for_halc is not None:
                            dscr_kv_halc = args.dscr_key_value or (
                                not args.dscr_key_only and not args.dscr_value_only
                            )
                            halc_dscr_kwargs = {
                                "dscr_depth": torch.tensor(
                                    depth_np_for_halc, dtype=torch.float32
                                ).to(device),
                                "dscr_alpha": args.dscr_alpha,
                                "dscr_beta": args.dscr_beta,
                                "dscr_sigma": args.dscr_sigma,
                                "dscr_start_layer": args.dscr_start_layer,
                                "dscr_end_layer": args.dscr_end_layer,
                                "dscr_lambda": args.dscr_lambda,
                                "dscr_self_keep": getattr(args, "dscr_keep_ratio", 1.0),
                                "dscr_key_only": args.dscr_key_only,
                                "dscr_value_only": args.dscr_value_only,
                                "dscr_key_value": dscr_kv_halc,
                            }
                        output_ids = model.generate(
                            input_ids=input_ids,
                            # attention_mask removed: Qwen2.5-VL expands sequence with image tokens
                            pixel_values=pixel_values,
                            image_grid_thw=image_grid_thw,
                            do_sample=do_sample_halc,
                            temperature=temp_halc,
                            top_p=float(args.top_p),
                            top_k=args.top_k,
                            max_new_tokens=args.max_new_tokens,
                            prefix_allowed_tokens_fn=yn_prefix_fn_halc,
                            halc_decoding=True,
                            dola_decoding=False,
                            beam_search=bool(args.halc_beam_search),
                            num_beams=args.halc_num_beams,
                            halc_assistant=halc,
                            mature_layer=mature_layer,
                            base_layer=base_layer,
                            candidate_premature_layers=candidate_layers,
                            relative_top=args.halc_relative_top,
                            output_attentions=True,
                            output_hidden_states=True,
                            **halc_dscr_kwargs,
                        )
                    elif use_vcd:
                        images_cd = add_diffusion_noise(pixel_values.cpu(), int(args.noise_step)).to(
                            device, dtype=model.dtype
                        )

                        cache_clean_cd = None
                        if args.use_dscr and cache_clean is not None:
                            with torch.inference_mode():
                                kv_out_cd = model(
                                    input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    pixel_values=images_cd,  # noised image
                                    image_grid_thw=image_grid_thw,
                                    use_cache=True,
                                    return_dict=True,
                                )
                            dscr_key_value = args.dscr_key_value or (
                                not args.dscr_key_only and not args.dscr_value_only
                            )
                            _num_layers = getattr(model.config, "num_hidden_layers", 28)
                            _vcd_end = int(args.dscr_end_layer) if args.dscr_end_layer is not None else _num_layers
                            cache_clean_cd = apply_D_to_cache(
                                cache=kv_out_cd.past_key_values,
                                D=D,  # same D from clean image
                                cut_idx=image_start_idx,
                                start_layer_idx=int(args.dscr_start_layer),
                                end_layer_idx=_vcd_end,
                                key_only=bool(args.dscr_key_only),
                                value_only=bool(args.dscr_value_only),
                                key_value=bool(dscr_key_value),
                                num_image_tokens=image_token_len,
                                dscr_lambda=args.dscr_lambda,
                            )
                            cache_clean_cd, _ = recompute_first_logits_from_refined_cache(
                                model=model,
                                input_ids=input_ids,
                                cache=cache_clean_cd,
                                attention_mask=attention_mask,
                            )
                        
                        # VCD: keep original settings even with DSCR
                        do_sample_vcd = True
                        temp_vcd = float(args.temperature)
                        log_once(
                            f"[VCD] do_sample={do_sample_vcd} temperature={temp_vcd} "
                            f"noise_step={args.noise_step} cd_alpha={args.cd_alpha} "
                            f"cd_beta={args.cd_beta} images_cd_mean={images_cd.mean().item():.4f} "
                            f"std={images_cd.std().item():.4f}{dscr_note}"
                        )
                        
                        # Build gen_vcd_kwargs - pass CD cache via model attribute (bypasses validation)
                        gen_vcd_kwargs = {}
                        if cache_clean is not None:
                            gen_vcd_kwargs["past_key_values"] = cache_clean
                            gen_vcd_kwargs["use_cache"] = True
                        if cache_clean_cd is not None:
                            model._dscr_past_key_values_cd = cache_clean_cd
                        
                        # DSCR cache already contains image embeddings - don't pass pixel_values
                        vcd_pixel_values = None if cache_clean is not None else pixel_values
                        vcd_image_grid_thw = None if cache_clean is not None else image_grid_thw
                        vcd_images_cd = None if cache_clean_cd is not None else images_cd
                        vcd_image_grid_thw_cd = None if cache_clean_cd is not None else image_grid_thw
                        
                        try:
                            output_ids = model.generate(
                                input_ids=input_ids,
                                # attention_mask removed: Qwen2.5-VL expands sequence with image tokens
                                pixel_values=vcd_pixel_values,
                                image_grid_thw=vcd_image_grid_thw,
                                images_cd=vcd_images_cd,
                                image_grid_thw_cd=vcd_image_grid_thw_cd,
                                cd_alpha=args.cd_alpha,
                                cd_beta=args.cd_beta,
                                do_sample=do_sample_vcd,
                                temperature=temp_vcd,
                                top_p=float(args.top_p),
                                top_k=args.top_k,
                                max_new_tokens=args.max_new_tokens,
                                **gen_vcd_kwargs,
                            )
                        finally:
                            # Clean up CD cache attribute
                            if hasattr(model, "_dscr_past_key_values_cd"):
                                delattr(model, "_dscr_past_key_values_cd")
                    elif use_agla:
                        images_cd = images_cd_agla
                        image_grid_thw_cd = image_grid_thw_cd_agla
                        # AGLA+DSCR: keep augmented cache raw (no DSCR) to complement DSCR-smoothed clean.
                        cache_augmented_for_agla = None
                        # Use actual post-merge token count from input_ids_cd_agla
                        # (image_grid_thw_cd.prod() gives PRE-merge count which is
                        # merge_size^2 times larger in Qwen2.5-VL).
                        if input_ids_cd_agla is not None:
                            agla_image_token_len = int((input_ids_cd_agla[0] == image_token_id).sum().item())
                        elif image_grid_thw_cd is not None:
                            agla_image_token_len = int(image_grid_thw_cd.prod(dim=1).sum().item())
                        else:
                            agla_image_token_len = 0
                        if args.use_dscr and cache_clean is not None and agla_image_token_len == image_token_len:
                            # Compute augmented cache but do NOT apply DSCR (keep raw).
                            with torch.inference_mode():
                                kv_out_cd = model(
                                    input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    pixel_values=images_cd,  # augmented image
                                    image_grid_thw=image_grid_thw_cd,
                                    use_cache=True,
                                    return_dict=True,
                                )
                            cache_augmented_for_agla = kv_out_cd.past_key_values
                            # Recompute first logits from raw augmented cache
                            cache_augmented_for_agla, _ = recompute_first_logits_from_refined_cache(
                                model=model,
                                input_ids=input_ids,
                                cache=cache_augmented_for_agla,
                                attention_mask=attention_mask,
                            )
                        elif args.use_dscr and cache_clean is not None and agla_image_token_len != image_token_len:
                            log_once(
                                f"[AGLA+DSCR] Skipping augmented cache: "
                                f"image_token_len mismatch ({agla_image_token_len} vs {image_token_len})"
                            )
                        
                        # AGLA: keep original settings even with DSCR
                        do_sample_agla = True
                        temp_agla = float(args.temperature)
                        log_once(
                            f"[AGLA] do_sample={do_sample_agla} temperature={temp_agla} "
                            f"cd_alpha={args.agla_alpha} cd_beta={args.agla_beta} "
                            f"aug_shape={tuple(images_cd.shape)}{dscr_note}"
                        )
                        
                        # Build gen_agla_kwargs
                        # - clean cache (DSCR applied) goes as past_key_values
                        # - augmented cache (RAW, no DSCR) goes via model attribute
                        gen_agla_kwargs = {}
                        if cache_clean is not None:
                            gen_agla_kwargs["past_key_values"] = cache_clean
                            gen_agla_kwargs["use_cache"] = True
                        if cache_augmented_for_agla is not None:
                            model._agla_cache_augmented = cache_augmented_for_agla
                        
                        # DSCR cache already contains image embeddings - don't pass pixel_values
                        agla_pixel_values = None if cache_clean is not None else pixel_values
                        agla_image_grid_thw = None if cache_clean is not None else image_grid_thw
                        agla_images_cd = None if cache_augmented_for_agla is not None else images_cd
                        agla_image_grid_thw_cd = None if cache_augmented_for_agla is not None else image_grid_thw_cd
                        
                        if args.debug_first_step:
                            model._agla_debug_first_step = True
                            model._agla_debug_tokenizer = processor.tokenizer
                        # Pass augmented input_ids for the CD path (Qwen2.5-VL
                        # dynamic resolution means augmented image may have a
                        # different number of image tokens).
                        agla_input_ids_cd = None if cache_augmented_for_agla is not None else input_ids_cd_agla
                        try:
                            output_ids = model.generate(
                                input_ids=input_ids,
                                # attention_mask removed: Qwen2.5-VL expands sequence with image tokens
                                pixel_values=agla_pixel_values,
                                image_grid_thw=agla_image_grid_thw,
                                images_cd=agla_images_cd,
                                image_grid_thw_cd=agla_image_grid_thw_cd,
                                input_ids_cd=agla_input_ids_cd,
                                cd_alpha=args.agla_alpha,
                                cd_beta=args.agla_beta,
                                do_sample=do_sample_agla,
                                temperature=temp_agla,
                                top_p=float(args.top_p),
                                top_k=args.top_k,
                                max_new_tokens=args.max_new_tokens,
                                **gen_agla_kwargs,
                            )
                        finally:
                            if hasattr(model, "_agla_debug_first_step"):
                                delattr(model, "_agla_debug_first_step")
                            if hasattr(model, "_agla_debug_tokenizer"):
                                delattr(model, "_agla_debug_tokenizer")
                            # Clean up augmented cache attribute
                            if hasattr(model, "_agla_cache_augmented"):
                                delattr(model, "_agla_cache_augmented")
                    elif use_damo:
                        # DAMO: keep original settings even with DSCR
                        do_sample_damo = (args.temperature > 0)
                        temp_damo = float(args.temperature)
                        log_once(
                            f"[DAMO] do_sample={do_sample_damo} temperature={temp_damo} "
                            f"tau={args.tau} beta_1={args.beta_1} beta_2={args.beta_2} "
                            f"alpha={args.alpha} damo_start_layer={args.damo_start_layer}{dscr_note}"
                        )
                        # DSCR cache already contains image embeddings - don't pass pixel_values
                        damo_pixel_values = None if cache_clean is not None else pixel_values
                        damo_image_grid_thw = None if cache_clean is not None else image_grid_thw
                        output_ids = model.generate(
                            input_ids=input_ids,
                            # attention_mask removed: Qwen2.5-VL expands sequence with image tokens
                            pixel_values=damo_pixel_values,
                            image_grid_thw=damo_image_grid_thw,
                            do_sample=do_sample_damo,
                            temperature=temp_damo,
                            top_p=float(args.top_p),
                            top_k=args.top_k,
                            max_new_tokens=args.max_new_tokens,
                            output_hidden_states=True,
                            tau=args.tau,
                            beta_1=args.beta_1,
                            beta_2=args.beta_2,
                            alpha=args.alpha,
                            damo_start_layer=args.damo_start_layer,
                            **gen_dscr_kwargs,
                        )
                    else:
                        # Baseline: always sampling. DSCR: always greedy.
                        if cache_clean is not None:
                            log_once(f"[BASELINE+DSCR] do_sample=False temperature=0.0{dscr_note}")
                            # DSCR cache already contains image embeddings - don't pass pixel_values
                            output_ids = model.generate(
                                input_ids=input_ids,
                                # attention_mask removed: Qwen2.5-VL expands sequence with image tokens
                                pixel_values=None,
                                image_grid_thw=None,
                                do_sample=False,
                                temperature=0.0,
                                top_p=1.0,
                                top_k=None,
                                max_new_tokens=args.max_new_tokens,
                                **gen_dscr_kwargs,
                            )
                        else:
                            temp_baseline = float(args.temperature)
                            do_sample_baseline = temp_baseline > 0.0
                            if not do_sample_baseline:
                                temp_baseline = 0.0
                            log_once(f"[BASELINE] do_sample={do_sample_baseline} temperature={temp_baseline}{dscr_note}")
                            output_ids = model.generate(
                                input_ids=input_ids,
                                # attention_mask removed: Qwen2.5-VL expands sequence with image tokens
                                pixel_values=pixel_values,
                                image_grid_thw=image_grid_thw,
                                do_sample=do_sample_baseline,
                                temperature=temp_baseline,
                                top_p=float(args.top_p),
                                top_k=args.top_k,
                                max_new_tokens=args.max_new_tokens,
                            )

                input_token_len = input_ids.shape[1]
                decode_ids = output_ids[0]
                if decode_ids.shape[0] > input_token_len:
                    decode_ids = decode_ids[input_token_len:]
                outputs_str = processor.decode(
                    decode_ids,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=False,
                ).strip()
                if args.format == "yn_format":
                    outputs_str = normalize_yn_answer(outputs_str)

                model_id = os.path.basename(args.model_path.rstrip("/")) or args.model_path
                ans_file.write(
                    json.dumps(
                        {
                            "question_id": qid,
                            "prompt": text,
                            "text": outputs_str,
                            "model_id": model_id,
                            "image": image_file,
                            "metadata": {},
                        }
                    )
                    + "\n"
                )
                ans_file.flush()


if __name__ == "__main__":
    main()
