import argparse
import json
import os
import re
import sys
from typing import Dict, List

import torch
import numpy as np
from PIL import Image
from tqdm import tqdm

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
EXPERIMENTS_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ROOT_DIR)
sys.path.append(EXPERIMENTS_DIR)
sys.path.append(os.path.join(EXPERIMENTS_DIR, "halc_utils"))
sys.path.insert(0, os.path.join(EXPERIMENTS_DIR, "transformers-4.31.0", "src"))

from transformers import AutoTokenizer, set_seed

from Qwen_VL.modeling_qwen import QWenLMHeadModel
from vcd_utils.vcd_add_noise import add_diffusion_noise
from vcd_utils.vcd_sample import evolve_vcd_sampling


def apply_D_to_cache(cache, D, cut_idx, start_layer_idx, end_layer_idx, key_only, value_only, key_value, num_image_tokens, dscr_lambda=1.0):
    """Apply DSCR refinement matrix D to KV cache. dscr_lambda: 0=no DSCR, 1=full DSCR."""
    if D.dim() == 4:
        D = D.squeeze(0).squeeze(0)
    elif D.dim() == 2:
        pass
    else:
        raise ValueError(f"Unsupported D shape: {tuple(D.shape)}")

    start_layer_idx = int(start_layer_idx)
    end_layer_idx = int(end_layer_idx)
    dscr_lambda = float(max(0.0, min(1.0, dscr_lambda)))

    def _infer_kv_layout_qwen(k: torch.Tensor) -> str:
        """
        Qwen-VL past_key_values layout can be either:
          - bthd: (B, T, H, D)
          - bhtd: (B, H, T, D)
        We infer by comparing dim-1 and dim-2 sizes (T is typically much larger than H).
        """
        if k.dim() != 4:
            return "unknown"
        return "bthd" if k.size(1) > k.size(2) else "bhtd"

    # D is fixed 256x256; we apply D[:n_apply,:n_apply] to first n_apply image tokens, n_apply = min(256, actual_seg_size)
    out = []

    for layer_idx, (k, v) in enumerate(cache):
        if start_layer_idx <= layer_idx < end_layer_idx:
            layout = _infer_kv_layout_qwen(k)

            # Actual image segment length in this layer
            if layout == "bthd":
                t_len = int(k.shape[1])
                actual_seg_size = min(int(num_image_tokens), int(t_len - int(cut_idx)))
            else:
                t_len = int(k.shape[2])
                actual_seg_size = min(int(num_image_tokens), int(t_len - int(cut_idx)))

            if actual_seg_size <= 0:
                out.append((k, v))
                continue

            # qwenvl.py: D is fixed 256x256, apply only to first n_apply tokens (no D resizing)
            n_apply = min(256, actual_seg_size)
            D_sub = D[:n_apply, :n_apply].to(device=k.device, dtype=k.dtype)

            if key_only or key_value:
                if layout == "bthd":
                    seg = k[:, cut_idx:cut_idx + n_apply, :, :]
                    seg_refined = torch.einsum("ij,bjhd->bihd", D_sub, seg)
                    seg_mixed = (1.0 - dscr_lambda) * seg + dscr_lambda * seg_refined
                    k = torch.cat(
                        (k[:, :cut_idx, :, :], seg_mixed, k[:, cut_idx + n_apply:, :, :]),
                        dim=1,
                    )
                else:
                    seg = k[:, :, cut_idx:cut_idx + n_apply, :]
                    seg_refined = torch.einsum("ij,bhjd->bhid", D_sub, seg)
                    seg_mixed = (1.0 - dscr_lambda) * seg + dscr_lambda * seg_refined
                    k = torch.cat(
                        (k[:, :, :cut_idx, :], seg_mixed, k[:, :, cut_idx + n_apply:, :]),
                        dim=2,
                    )

            if value_only or key_value:
                Dv_sub = D[:n_apply, :n_apply].to(device=v.device, dtype=v.dtype)
                if layout == "bthd":
                    seg = v[:, cut_idx:cut_idx + n_apply, :, :]
                    seg_refined = torch.einsum("ij,bjhd->bihd", Dv_sub, seg)
                    seg_mixed = (1.0 - dscr_lambda) * seg + dscr_lambda * seg_refined
                    v = torch.cat(
                        (v[:, :cut_idx, :, :], seg_mixed, v[:, cut_idx + n_apply:, :, :]),
                        dim=1,
                    )
                else:
                    seg = v[:, :, cut_idx:cut_idx + n_apply, :]
                    seg_refined = torch.einsum("ij,bhjd->bhid", Dv_sub, seg)
                    seg_mixed = (1.0 - dscr_lambda) * seg + dscr_lambda * seg_refined
                    v = torch.cat(
                        (v[:, :, :cut_idx, :], seg_mixed, v[:, :, cut_idx + n_apply:, :]),
                        dim=2,
                    )

        out.append((k, v))

    return tuple(out)


def infer_pkv_layout(k: torch.Tensor, prompt_len: int) -> str:
    """
    Return:
      bthd for (B,T,H,D)
      bhtd for (B,H,T,D)
    """
    if k.dim() != 4:
        return "unknown"
    if k.size(1) == prompt_len:
        return "bthd"
    if k.size(2) == prompt_len:
        return "bhtd"
    return "bthd" if k.size(1) > k.size(2) else "bhtd"


def trim_cache_last_token(past_key_values, prompt_len: int):
    """
    Remove last prompt token from cache.
    Works for both layouts.
    """
    if past_key_values is None:
        return None
    trimmed = []
    for k, v in past_key_values:
        layout = infer_pkv_layout(k, prompt_len)
        if layout == "bthd":
            if k.size(1) <= 1:
                return past_key_values
            trimmed.append((k[:, :-1, :, :].contiguous(), v[:, :-1, :, :].contiguous()))
        elif layout == "bhtd":
            if k.size(2) <= 1:
                return past_key_values
            trimmed.append((k[:, :, :-1, :].contiguous(), v[:, :, :-1, :].contiguous()))
        else:
            trimmed.append((k, v))
    return tuple(trimmed)


def recompute_first_logits_from_refined_cache(model, input_ids, cache, attention_mask=None):
    """Adapted for Qwen-VL with proper cache layout handling"""
    prompt_len = input_ids.shape[1]
    prefix_cache = trim_cache_last_token(cache, prompt_len)

    last_token_ids = input_ids[:, -1:]
    last_attention_mask = attention_mask[:, -1:] if attention_mask is not None else None
    with torch.inference_mode():
        outputs_last = model(
            input_ids=last_token_ids,
            attention_mask=last_attention_mask,
            images=None,
            use_cache=True,
            past_key_values=prefix_cache,
            return_dict=True,
        )
    return outputs_last.past_key_values, outputs_last.logits[:, -1, :]


def load_questions(path: str) -> List[Dict]:
    with open(path, "r", encoding="utf-8") as f:
        first = f.read(1)
        f.seek(0)
        if first == "[":
            return json.load(f)
        return [json.loads(line) for line in f if line.strip()]


def discover_sets(gt_dir: str) -> List[str]:
    sets = []
    for f in os.listdir(gt_dir):
        if f.endswith(".json") or f.endswith(".jsonl"):
            sets.append(f.rsplit(".", 1)[0])
    return sorted(list(set(sets)))


def pick_img_dir(img_root: str, dataset: str) -> str:
    if os.path.isdir(os.path.join(img_root, dataset)):
        return os.path.join(img_root, dataset)
    if os.path.isdir(os.path.join(img_root, dataset.lower())):
        return os.path.join(img_root, dataset.lower())
    if dataset.lower() == "posters":
        if os.path.isdir(os.path.join(img_root, "posters")):
            return os.path.join(img_root, "posters")
        if os.path.isdir(os.path.join(img_root, "Posters")):
            return os.path.join(img_root, "Posters")
    return img_root


def pick_depth_dir(depth_root: str, dataset: str) -> str:
    return pick_img_dir(depth_root, dataset)


def normalize_format(fmt: str) -> str:
    if fmt == "yn_format":
        return " Please only answer yes or no."
    if fmt == "ow_format":
        return " Please answer this question with one word."
    return ""


def normalize_yn_answer(text: str) -> str:
    cleaned = text.strip().lower()
    match = re.search(r"\b(yes|no)\b", cleaned)
    return match.group(1) if match else cleaned


def build_yn_prefix_allowed_tokens_fn(tokenizer, input_len: int):
    candidates = ["yes", "no", "Yes", "No", " yes", " no"]
    allowed = set()
    for text in candidates:
        token_ids = tokenizer.encode(text, add_special_tokens=False)
        if token_ids:
            allowed.add(token_ids[0])
    if not allowed:
        return None

    eos_id = getattr(tokenizer, "eos_token_id", None)
    if eos_id is None:
        eos_id = getattr(tokenizer, "eod_id", None)

    def _fn(batch_id, input_ids):
        cur_len = input_ids.shape[-1] if hasattr(input_ids, "shape") else len(input_ids)
        if cur_len == input_len:
            return list(allowed)
        return [eos_id] if eos_id is not None else list(allowed)

    return _fn


def main() -> None:
    parser = argparse.ArgumentParser("MME (Qwen-VL)")
    parser.add_argument("--model-path", type=str, default="Qwen/Qwen-VL")
    parser.add_argument("--gt-dir", type=str, required=True)
    parser.add_argument("--image-root", type=str, required=True)
    parser.add_argument("--out-root", type=str, required=True)
    parser.add_argument("--depth-root", type=str, default=None)
    parser.add_argument("--datasets", nargs="*", default=None)

    parser.add_argument("--method", type=str, default="baseline",
                        choices=["baseline", "vcd", "opera", "halc", "damo", "agla"])
    parser.add_argument("--format", type=str, default="no_format",
                        choices=["no_format", "ow_format", "yn_format"])
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max-new-tokens", type=int, default=2)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top-p", dest="top_p", type=float, default=1.0)
    parser.add_argument("--top-k", dest="top_k", type=int, default=None)

    # VCD
    parser.add_argument("--noise-step", type=int, default=500)
    parser.add_argument("--cd-alpha", type=float, default=0.5)
    parser.add_argument("--cd-beta", type=float, default=0.1)

    # OPERA
    parser.add_argument("--beam", type=int, default=5)
    parser.add_argument("--scale-factor", type=float, default=50.0)
    parser.add_argument("--threshold", type=int, default=15)
    parser.add_argument("--num-attn-candidates", type=int, default=5)
    parser.add_argument("--penalty-weights", type=float, default=0.6)
    parser.add_argument("--opera-attn-layer", type=int, default=-1)

    # HALC
    parser.add_argument("--halc-detector", type=str, default="dino", choices=["dino", "owlv2"])
    parser.add_argument("--halc-box-threshold", type=float, default=0.45)
    parser.add_argument("--halc-k-candidate-num", type=int, default=4)
    parser.add_argument("--halc-context-window", type=int, default=4)
    parser.add_argument("--halc-expand-ratio", type=float, default=0.6)
    parser.add_argument("--halc-context-domain", type=str, default="upper", choices=["upper", "lower"])
    parser.add_argument("--halc-contrast-weight", type=float, default=0.05)
    parser.add_argument("--halc-score-type", type=str, default="BLIP",
                        choices=["CLIP", "BLIP", "Random", "Perplexity", "HPSv2"])
    parser.add_argument("--halc-debugger", type=int, default=0, choices=[0, 1, 2])
    parser.add_argument("--halc-mature-layer", type=int, default=None)
    parser.add_argument("--halc-base-layer", type=int, default=4)
    parser.add_argument("--halc-candidate-layers", type=int, nargs="*", default=None)
    parser.add_argument("--halc-relative-top", type=float, default=0.1)
    parser.add_argument("--halc-beam-search", action="store_true", default=False)
    parser.add_argument("--halc-num-beams", type=int, default=1)
    parser.add_argument("--halc-check-all-entities", action="store_true", default=False,
                        help="Bypass POS filter in HALC so hyperparameters affect yes/no tasks")
    # DAMO
    parser.add_argument("--tau", type=float, default=-0.3)
    parser.add_argument("--beta-1", dest="beta_1", type=float, default=0.05)
    parser.add_argument("--beta-2", dest="beta_2", type=float, default=0.20)
    parser.add_argument("--alpha", type=float, default=0.7)
    parser.add_argument("--damo-start-layer", dest="damo_start_layer", type=int, default=16,
                        help="DAMO momentum starts at this layer (default 16 for QwenVL/32 layers, 50%%)")

    # AGLA
    parser.add_argument("--agla-alpha", type=float, default=2.0)
    parser.add_argument("--agla-beta", type=float, default=0.5)

    # DSCR
    parser.add_argument("--use-dscr", action="store_true", default=False)
    parser.add_argument("--dscr-alpha", type=float, default=0.8)
    parser.add_argument("--dscr-beta", type=float, default=0.6)
    parser.add_argument("--dscr-sigma", type=float, default=0.5)
    parser.add_argument("--dscr-keep-ratio", type=float, default=1.0)
    parser.add_argument("--dscr-lambda", type=float, default=1.0)
    parser.add_argument("--dscr-start-layer", type=int, default=0)
    parser.add_argument("--dscr-end-layer", type=int, default=None)
    parser.add_argument("--dscr-key-only", action="store_true", default=False)
    parser.add_argument("--dscr-value-only", action="store_true", default=False)
    parser.add_argument("--dscr-key-value", action="store_true", default=False)

    parser.add_argument("--run-name", type=str, default=None)
    parser.add_argument(
        "--debug-span",
        action="store_true",
        default=False,
        help="Debug: print token/span info for the first few samples (to verify <img>/<imgpad> span).",
    )
    parser.add_argument(
        "--debug-span-limit",
        type=int,
        default=2,
        help="How many samples to print debug-span logs for (default: 2).",
    )

    args = parser.parse_args()
    set_seed(args.seed)

    # Force yn_format for halc and opera methods
    if args.method in ("halc, opera"):
        args.format = "yn_format"

    fmt = normalize_format(args.format)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    tokenizer.padding_side = "left"
    tokenizer.pad_token_id = tokenizer.eod_id
    model = QWenLMHeadModel.from_pretrained(
        args.model_path, device_map="auto", trust_remote_code=True
    ).eval()
    qwen_visual = getattr(model.transformer, "visual", None)
    qwen_image_tokens = None
    if qwen_visual is not None and hasattr(qwen_visual, "attn_pool"):
        qwen_image_tokens = getattr(qwen_visual.attn_pool, "num_queries", None)

    if args.method == "vcd":
        evolve_vcd_sampling()
    if args.method == "agla":
        from agla_utils.agla_sample import evolve_agla_sampling
        from agla_utils.augmentation import augmentation
        from lavis.models import load_model_and_preprocess
        from torchvision import transforms

        evolve_agla_sampling()
        blip_device = "cuda" if torch.cuda.is_available() else "cpu"
        blip_model, blip_vis_processors, blip_text_processors = load_model_and_preprocess(
            "blip_image_text_matching", "large", device=blip_device, is_eval=True
        )
        blip_loader = transforms.Compose([transforms.ToTensor()])

    if args.use_dscr and not args.depth_root:
        raise ValueError("--depth-root is required when --use-dscr is set")

    if args.datasets:
        datasets = args.datasets
    else:
        datasets = discover_sets(args.gt_dir)

    run_name = args.run_name or f"qwenvl_{args.method}_seed{args.seed}"
    os.makedirs(args.out_root, exist_ok=True)

    debug_printed = False
    debug_span_printed = 0

    def log_once(msg: str) -> None:
        nonlocal debug_printed
        if not debug_printed:
            print(msg)
            debug_printed = True

    def _find_positions_1d(ids_1d: torch.Tensor, token_id: int) -> List[int]:
        if token_id is None or int(token_id) < 0:
            return []
        pos = (ids_1d == int(token_id)).nonzero(as_tuple=True)[0]
        return [int(x) for x in pos.tolist()]

    def _find_first_run_start_end(sorted_pos: List[int]) -> List[List[int]]:
        """Return list of [start,end_exclusive] runs from sorted positions."""
        if not sorted_pos:
            return []
        runs = []
        s = sorted_pos[0]
        e = s + 1
        for p in sorted_pos[1:]:
            if p == e:
                e += 1
            else:
                runs.append([s, e])
                s = p
                e = p + 1
        runs.append([s, e])
        return runs

    def maybe_debug_span(
        *,
        prompt: str,
        enc,
        tokenizer,
        model,
        image_start_id: int,
        image_end_id: int,
        imgpad_id: int,
        start_idx: int,
        cut_idx: int,
        image_start_idx: int,
        image_end_idx: int,
        image_token_len: int,
        qwen_image_tokens,
        dscr_enabled: bool,
    ) -> None:
        nonlocal debug_span_printed
        if not args.debug_span:
            return
        if debug_span_printed >= int(args.debug_span_limit):
            return

        ids = enc.input_ids[0].detach().cpu()
        # Token positions
        pos_img = _find_positions_1d(ids, image_start_id)
        pos_end = _find_positions_1d(ids, image_end_id)
        pos_pad = _find_positions_1d(ids, imgpad_id)
        pad_runs = _find_first_run_start_end(sorted(pos_pad))

        # Config visual ids (if any)
        vis_cfg = getattr(getattr(model, "config", None), "visual", None)
        cfg_start_id = None
        cfg_end_id = None
        if isinstance(vis_cfg, dict) and "image_start_id" in vis_cfg:
            try:
                cfg_start_id = int(vis_cfg["image_start_id"])
                cfg_end_id = int(cfg_start_id + 1)
            except Exception:
                cfg_start_id = None
                cfg_end_id = None

        # Decode a small window around our chosen span
        left = max(0, int(image_start_idx) - 16)
        right = min(int(ids.shape[0]), int(image_end_idx) + 16)
        window_ids = ids[left:right].tolist()
        window_str = tokenizer.decode(window_ids, skip_special_tokens=False)

        print("========== [DEBUG SPAN] ==========")
        print(f"[debug] dscr_enabled={dscr_enabled} qwen_image_tokens(num_queries)={qwen_image_tokens}")
        print(f"[debug] prompt: {prompt}")
        print(f"[debug] token_ids: <img>={image_start_id} </img>={image_end_id} <imgpad>={imgpad_id}")
        if cfg_start_id is not None:
            print(f"[debug] model.config.visual image_start_id={cfg_start_id} image_end_id={cfg_end_id}")
        else:
            print("[debug] model.config.visual image_start_id: (none)")
        print(f"[debug] positions: <img>={pos_img} </img>={pos_end}")
        print(f"[debug] positions: <imgpad> count={len(pos_pad)} runs={pad_runs[:4]}{'...' if len(pad_runs) > 4 else ''}")
        print(
            f"[debug] span_used: start_idx(<img>)={start_idx} cut_idx(</img>)={cut_idx} | "
            f"image_start_idx={image_start_idx} image_end_idx={image_end_idx} image_token_len={image_token_len}"
        )
        print(f"[debug] decoded_window[{left}:{right}]: {window_str}")
        print("==================================")

        debug_span_printed += 1

    def infer_visual_span_qwen(tokenizer, model, input_ids_1d: torch.Tensor) -> List[int]:
        """
        Return [vis_start, vis_end_exclusive] for Qwen-VL.
        Priority:
          1) contiguous <imgpad> run
          2) config.visual image_start_id / image_end_id (end_id = start_id + 1 on Qwen-VL)
          3) empty
        """
        # 1) <imgpad> contiguous run
        imgpad_id = None
        try:
            imgpad_id = tokenizer.convert_tokens_to_ids("<imgpad>")
        except Exception:
            imgpad_id = None
        if imgpad_id is not None and int(imgpad_id) >= 0:
            pos = (input_ids_1d == int(imgpad_id)).nonzero(as_tuple=True)[0]
            if pos.numel() > 0:
                s = int(pos[0].item())
                e = s
                for p in pos.tolist():
                    if int(p) == e:
                        e += 1
                    else:
                        break
                return [s, e]

        # 2) config-based (image_end_id is start_id + 1 on Qwen-VL)
        vis_cfg = getattr(getattr(model, "config", None), "visual", None)
        if isinstance(vis_cfg, dict) and "image_start_id" in vis_cfg:
            try:
                start_id = int(vis_cfg["image_start_id"])
                end_id = start_id + 1
            except Exception:
                return [0, 0]

            s_pos = (input_ids_1d == start_id).nonzero(as_tuple=True)[0]
            if s_pos.numel() == 0:
                return [0, 0]
            s0 = int(s_pos[0].item())

            e_pos = (input_ids_1d == end_id).nonzero(as_tuple=True)[0]
            if e_pos.numel() == 0:
                return [0, 0]

            e0 = None
            for p in e_pos.tolist():
                if int(p) > s0:
                    e0 = int(p)
                    break
            if e0 is None or s0 + 1 >= e0:
                return [0, 0]
            return [s0 + 1, e0]

        return [0, 0]

    for dataset in datasets:
        qfile_json = os.path.join(args.gt_dir, f"{dataset}.json")
        qfile_jsonl = os.path.join(args.gt_dir, f"{dataset}.jsonl")
        qfile = qfile_json if os.path.isfile(qfile_json) else qfile_jsonl
        if not os.path.isfile(qfile):
            print(f"[skip] missing questions for {dataset}")
            continue

        questions = load_questions(qfile)
        img_dir = pick_img_dir(args.image_root, dataset)
        depth_dir = pick_depth_dir(args.depth_root, dataset) if args.use_dscr else None
        answers_file = os.path.join(args.out_root, f"{run_name}_{dataset}.jsonl")

        with open(answers_file, "w", encoding="utf-8") as fout:
            for q in tqdm(questions, desc=f"{dataset}"):
                qid = q.get("question_id", q.get("id"))
                image_file = q.get("image", "")
                question = q.get("text", q.get("question", ""))

                image_path = os.path.join(img_dir, image_file)
                prompt = f"<img>{image_path}</img>{question}{fmt} Answer:"
                input_ids = tokenizer([prompt], return_tensors="pt", padding="longest").to(model.device)

                image_start = tokenizer.convert_tokens_to_ids("<img>")
                image_end = tokenizer.convert_tokens_to_ids("</img>")
                imgpad_id = None
                try:
                    imgpad_id = tokenizer.convert_tokens_to_ids("<imgpad>")
                except Exception:
                    imgpad_id = None
                start_idx = int((input_ids.input_ids[0] == image_start).nonzero(as_tuple=True)[0].item())
                cut_idx = int((input_ids.input_ids[0] == image_end).nonzero(as_tuple=True)[0].item())
                # IMPORTANT (Qwen-VL): visual tokens are usually <imgpad> runs, not the path text between <img>...</img>.
                ids_1d = input_ids.input_ids[0]
                vis_s, vis_e = infer_visual_span_qwen(tokenizer, model, ids_1d)
                if vis_e > vis_s:
                    image_start_idx = int(vis_s)
                    image_end_idx = int(vis_e - 1)
                    image_token_len = int(vis_e - vis_s)
                else:
                    # Fallback: previous heuristic (may include path tokens; keep only as last resort)
                    image_token_len = qwen_image_tokens or (cut_idx - start_idx - 1)
                    image_start_idx = start_idx + 1
                    seq_len = input_ids.input_ids.shape[1]
                    image_end_idx = min(image_start_idx + image_token_len - 1, seq_len - 1)
                    image_token_len = max(0, image_end_idx - image_start_idx + 1)

                maybe_debug_span(
                    prompt=prompt,
                    enc=input_ids,
                    tokenizer=tokenizer,
                    model=model,
                    image_start_id=image_start,
                    image_end_id=image_end,
                    imgpad_id=imgpad_id,
                    start_idx=start_idx,
                    cut_idx=cut_idx,
                    image_start_idx=image_start_idx,
                    image_end_idx=image_end_idx,
                    image_token_len=image_token_len,
                    qwen_image_tokens=qwen_image_tokens,
                    dscr_enabled=bool(args.use_dscr),
                )

                image = Image.open(image_path).convert("RGB")
                image_tensor = model.transformer.visual.image_transform(image).to(model.device)

                image_tensor_cd = None
                if args.method == "vcd":
                    image_tensor_cd = add_diffusion_noise(image_tensor, int(args.noise_step))
                elif args.method == "agla":
                    tensor_image = blip_loader(image.resize((384, 384)))
                    blip_image = blip_vis_processors["eval"](image).unsqueeze(0).to(blip_device)
                    blip_question = blip_text_processors["eval"](question)
                    tokenized_text = blip_model.tokenizer(
                        blip_question, padding="longest", truncation=True, return_tensors="pt"
                    ).to(blip_device)
                    augmented_image = augmentation(
                        blip_image, blip_question, tensor_image, blip_model, tokenized_text, image
                    )
                    image_tensor_cd = model.transformer.visual.image_transform(augmented_image)

                cache_clean = None
                cache_noisy = None
                cache_augmented = None
                first_logits_clean = None
                first_logits_noisy = None
                first_logits_augmented = None
                kv_outputs = None
                
                if args.method == "vcd" and args.use_dscr and image_tensor_cd is not None:
                    attention_mask_clean = torch.ones_like(input_ids.input_ids, device=input_ids.input_ids.device)
                    with torch.inference_mode():
                        out_noisy = model(
                            input_ids=input_ids.input_ids,
                            images=image_tensor_cd.unsqueeze(0).half().to(model.device),
                            attention_mask=attention_mask_clean,
                            use_cache=True,
                            return_dict=True,
                        )
                    cache_noisy = out_noisy.past_key_values
                    first_logits_noisy = out_noisy.logits[:, -1, :]
                
                if args.method == "agla" and args.use_dscr and image_tensor_cd is not None:
                    attention_mask_clean = torch.ones_like(input_ids.input_ids, device=input_ids.input_ids.device)
                    with torch.inference_mode():
                        out_augmented = model(
                            input_ids=input_ids.input_ids,
                            images=image_tensor_cd.unsqueeze(0).half().to(model.device),
                            attention_mask=attention_mask_clean,
                            use_cache=True,
                            return_dict=True,
                        )
                    cache_augmented = out_augmented.past_key_values
                    first_logits_augmented = out_augmented.logits[:, -1, :]
                
                if args.use_dscr:
                    # Build D matrix: fixed 16x16=256 grid (same as object_hallucination_vqa_qwenvl.py).
                    # apply_D_to_cache does NOT resize D; it applies D[:n_apply,:n_apply] with n_apply = min(256, image_token_len).
                    import torch.nn.functional as F
                    depth_path = os.path.join(depth_dir, os.path.splitext(image_file)[0] + ".npy")
                    depth_np = np.load(depth_path, allow_pickle=True)
                    depth_tensor = torch.tensor(depth_np, dtype=torch.float32).unsqueeze(0).unsqueeze(1)

                    depth_patch = F.interpolate(depth_tensor, size=(16, 16), mode="bilinear", align_corners=False)
                    depth_patch = torch.clamp(1.0e-6 + 1.0 / depth_patch, 0.001, 1000).view(1, -1)
                    depth_patch_min = depth_patch.min()
                    depth_patch_max = depth_patch.max()
                    depth_patch = (depth_patch - depth_patch_min) / (depth_patch_max - depth_patch_min + 1e-6)

                    depth_diff = depth_patch - depth_patch.transpose(1, 0)
                    gaussian_weight_depth = torch.exp(-(depth_diff ** 2) / (2 * (args.dscr_sigma ** 2)))

                    pixel_positions = torch.tensor(
                        [[i // 16, i % 16] for i in range(256)],
                        dtype=torch.float32,
                    ) / 15.0
                    position_diff = torch.cdist(pixel_positions, pixel_positions, p=2)
                    gaussian_weight_position = torch.exp(-(position_diff ** 2) / (2 * (args.dscr_sigma ** 2)))

                    gaussian_weight = (gaussian_weight_depth ** args.dscr_alpha) + (gaussian_weight_position ** args.dscr_beta)
                    D = gaussian_weight / (gaussian_weight.sum(dim=-1, keepdim=True) + 1e-6)

                    #######
                    # Forward to get initial cache (always, like llava15)
                    attention_mask_clean = torch.ones_like(input_ids.input_ids, device=input_ids.input_ids.device)
                    with torch.inference_mode():
                        kv_outputs = model(
                            input_ids=input_ids.input_ids,
                            images=image_tensor.unsqueeze(0).half().to(model.device),
                            attention_mask=attention_mask_clean,
                            use_cache=True,
                            return_dict=True,
                        )
                    
                    # Apply D to cache (same as qwenvl: segment 1:cut_idx, length min(256, cut_idx-1))
                    dscr_key_value = args.dscr_key_value or (
                        not args.dscr_key_only and not args.dscr_value_only
                    )
                    dscr_segment_start = 1
                    dscr_segment_len = min(256, max(0, cut_idx - 1))
                    cache_clean = apply_D_to_cache(
                        cache=kv_outputs.past_key_values,
                        D=D,
                        cut_idx=dscr_segment_start,
                        start_layer_idx=int(args.dscr_start_layer),
                        end_layer_idx=int(args.dscr_end_layer),
                        key_only=bool(args.dscr_key_only),
                        value_only=bool(args.dscr_value_only),
                        key_value=bool(dscr_key_value),
                        num_image_tokens=dscr_segment_len,
                        dscr_lambda=args.dscr_lambda,
                    )
                    

                    if args.method != "baseline":
                        attention_mask_clean = torch.ones_like(input_ids.input_ids, device=input_ids.input_ids.device)
                        cache_clean, first_logits_clean = recompute_first_logits_from_refined_cache(
                            model=model,
                            input_ids=input_ids.input_ids,
                            cache=cache_clean,
                            attention_mask=attention_mask_clean,
                        )
                    else:
                        first_logits_clean = kv_outputs.logits[:, -1, :]
                    
                    if args.method == "vcd" and cache_noisy is not None:
                        cache_noisy_dscr = apply_D_to_cache(
                            cache=cache_noisy,
                            D=D,
                            cut_idx=dscr_segment_start,
                            start_layer_idx=int(args.dscr_start_layer),
                            end_layer_idx=int(args.dscr_end_layer),
                            key_only=bool(args.dscr_key_only),
                            value_only=bool(args.dscr_value_only),
                            key_value=bool(dscr_key_value),
                            num_image_tokens=dscr_segment_len,
                            dscr_lambda=args.dscr_lambda,
                        )
                        cache_noisy_dscr, _ = recompute_first_logits_from_refined_cache(
                            model=model,
                            input_ids=input_ids.input_ids,
                            cache=cache_noisy_dscr,
                            attention_mask=attention_mask_clean,
                        )
                        model._dscr_past_key_values_cd = cache_noisy_dscr

                gen_dscr_kwargs = {}
                cache_augmented_for_agla = None
                if cache_clean is not None:
                    gen_dscr_kwargs = {"past_key_values": cache_clean, "use_cache": True}
                    if args.method == "agla" and cache_augmented is not None:
                        cache_augmented_for_agla = cache_augmented
                elif kv_outputs is not None:
                    # Baseline: use kv_outputs.past_key_values like llava15
                    gen_dscr_kwargs = {"past_key_values": kv_outputs.past_key_values, "use_cache": True}

                dscr_note = ""
                if args.use_dscr:
                    dscr_note = (
                        f" | DSCR segment_start=1 segment_len={dscr_segment_len} (qwenvl-style 1:cut_idx) "
                        f"a={args.dscr_alpha} b={args.dscr_beta} s={args.dscr_sigma} "
                        f"k={args.dscr_keep_ratio} l={args.dscr_lambda} "
                        f"layers={args.dscr_start_layer}-{args.dscr_end_layer} "
                        f"key_only={args.dscr_key_only} value_only={args.dscr_value_only} "
                        f"key_value={dscr_key_value}"
                    )

                with torch.inference_mode():
                    if args.method == "opera":
                        key_position = {
                            "image_start": image_start_idx,
                            "image_end": image_end_idx,
                            "response_start": input_ids.input_ids.shape[1] - 1,
                        }
                        # OPERA: keep original settings even with DSCR
                        do_sample_opera = False
                        temp_opera = float(args.temperature)
                        log_once(
                            f"[OPERA] do_sample={do_sample_opera} temperature={temp_opera} "
                            f"key_position={key_position} beam={args.beam} "
                            f"scale_factor={args.scale_factor} threshold={args.threshold} "
                            f"num_attn_candidates={args.num_attn_candidates} penalty_weights={args.penalty_weights}"
                            f"{dscr_note}"
                        )
                        output_ids = model.generate(
                            input_ids=input_ids.input_ids,
                            images=image_tensor.unsqueeze(0).half().to(model.device),
                            do_sample=do_sample_opera,
                            temperature=temp_opera,
                            top_p=float(args.top_p),
                            top_k=args.top_k,
                            max_new_tokens=args.max_new_tokens,
                            key_position=key_position,
                            opera_decoding=True,
                            num_beams=args.beam,
                            output_attentions=True,
                            scale_factor=args.scale_factor,
                            threshold=args.threshold,
                            num_attn_candidates=args.num_attn_candidates,
                            penalty_weights=args.penalty_weights,
                            opera_attn_layer=args.opera_attn_layer,
                            **gen_dscr_kwargs,
                        )
                    elif args.method == "halc":
                        from decoder_zoo.HALC.context_density.halc import halc_assistant

                        halc_params = {
                            "detector": args.halc_detector,
                            "box_threshold": args.halc_box_threshold,
                            "debugger": args.halc_debugger,
                            "k_candidate_num": args.halc_k_candidate_num,
                            "LVLM_backbone": "qwen-vl",
                            "score_type": args.halc_score_type,
                            "context_window": args.halc_context_window,
                            "expand_ratio": args.halc_expand_ratio,
                            "context_domain": args.halc_context_domain,
                            "contrast_weight": args.halc_contrast_weight,
                            "tokenizer": tokenizer,
                            "check_all_entities": args.halc_check_all_entities,
                            "dscr_active": bool(args.use_dscr),
                        }
                        vis_processor = lambda img: model.transformer.visual.image_transform(img).to(model.device)
                        halc = halc_assistant(
                            model=model,
                            vis_processor=vis_processor,
                            device=model.device,
                            halc_params=halc_params,
                            max_new_tokens=args.max_new_tokens,
                        )
                        halc.update_input(image_path, prompt)
                        mature_layer = (
                            args.halc_mature_layer
                            if args.halc_mature_layer is not None
                            else int(getattr(model.config, "num_hidden_layers", 32)) - 1
                        )
                        candidate_layers = args.halc_candidate_layers
                        base_layer = None if candidate_layers else args.halc_base_layer
                        # HALC: keep original settings even with DSCR
                        do_sample_halc = False
                        temp_halc = float(args.temperature)
                        log_once(
                            f"[HALC] do_sample={do_sample_halc} temperature={temp_halc} "
                            f"mature_layer={mature_layer} base_layer={base_layer} "
                            f"candidate_layers={candidate_layers} relative_top={args.halc_relative_top} "
                            f"detector={args.halc_detector} expand_ratio={args.halc_expand_ratio} "
                            f"context_window={args.halc_context_window} context_domain={args.halc_context_domain} "
                            f"contrast_weight={args.halc_contrast_weight} score_type={args.halc_score_type} "
                            f"box_threshold={args.halc_box_threshold} beam_search={args.halc_beam_search} "
                            f"num_beams={args.halc_num_beams} "
                            f"check_all_entities={args.halc_check_all_entities}"
                            f"{dscr_note}"
                        )
                        input_len = input_ids.input_ids.shape[1]
                        yn_prefix_fn = (
                            build_yn_prefix_allowed_tokens_fn(tokenizer, input_len)
                            if args.format == "yn_format"
                            else None
                        )
                        output_ids = model.generate(
                            input_ids=input_ids.input_ids,
                            images=image_tensor.unsqueeze(0).half().to(model.device),
                            do_sample=do_sample_halc,
                            temperature=temp_halc,
                            top_p=float(args.top_p),
                            top_k=args.top_k,
                            max_new_tokens=args.max_new_tokens,
                            prefix_allowed_tokens_fn=yn_prefix_fn,
                            halc_decoding=True,
                            dola_decoding=False,
                            beam_search=bool(args.halc_beam_search),
                            num_beams=args.halc_num_beams,
                            halc_assistant=halc,
                            mature_layer=mature_layer,
                            base_layer=base_layer,
                            candidate_premature_layers=candidate_layers,
                            relative_top=args.halc_relative_top,
                            output_attentions=True,
                            output_hidden_states=True,
                            **gen_dscr_kwargs,
                        )
                        if isinstance(output_ids, tuple):
                            output_ids = output_ids[0]
                    elif args.method == "vcd":
                        # VCD: keep original settings even with DSCR
                        do_sample_vcd = True
                        temp_vcd = float(args.temperature)
                        if image_tensor_cd is not None:
                            log_once(
                                f"[VCD] do_sample={do_sample_vcd} temperature={temp_vcd} "
                                f"noise_step={args.noise_step} cd_alpha={args.cd_alpha} "
                                f"cd_beta={args.cd_beta} images_cd_mean={image_tensor_cd.mean().item():.4f} "
                                f"std={image_tensor_cd.std().item():.4f}{dscr_note}"
                            )
                        output_ids = model.generate(
                            input_ids=input_ids.input_ids,
                            images=image_tensor.unsqueeze(0).half().to(model.device),
                            images_cd=image_tensor_cd.unsqueeze(0).half().to(model.device),
                            cd_alpha=args.cd_alpha,
                            cd_beta=args.cd_beta,
                            do_sample=do_sample_vcd,
                            temperature=temp_vcd,
                            top_p=float(args.top_p),
                            top_k=args.top_k,
                            max_new_tokens=args.max_new_tokens,
                            pad_token_id=tokenizer.eod_id,
                            eos_token_id=tokenizer.eod_id,
                            **gen_dscr_kwargs,
                        )
                    elif args.method == "damo":
                        # DAMO: keep original settings even with DSCR
                        do_sample_damo = (args.temperature > 0)
                        temp_damo = float(args.temperature)
                        log_once(
                            f"[DAMO] do_sample={do_sample_damo} temperature={temp_damo} "
                            f"tau={args.tau} beta_1={args.beta_1} beta_2={args.beta_2} "
                            f"alpha={args.alpha} damo_start_layer={args.damo_start_layer}{dscr_note}"
                        )
                        output_ids = model.generate(
                            input_ids=input_ids.input_ids,
                            images=image_tensor.unsqueeze(0).half().to(model.device),
                            do_sample=do_sample_damo,
                            temperature=temp_damo,
                            top_p=float(args.top_p),
                            top_k=args.top_k,
                            max_new_tokens=args.max_new_tokens,
                            output_hidden_states=True,
                            tau=args.tau,
                            beta_1=args.beta_1,
                            beta_2=args.beta_2,
                            alpha=args.alpha,
                            damo_start_layer=args.damo_start_layer,
                            pad_token_id=tokenizer.eod_id,
                            eos_token_id=tokenizer.eod_id,
                            **gen_dscr_kwargs,
                        )
                    elif args.method == "agla":
                        # AGLA: keep original settings even with DSCR
                        do_sample_agla = True
                        temp_agla = float(args.temperature)
                        if image_tensor_cd is not None:
                            log_once(
                                f"[AGLA] do_sample={do_sample_agla} temperature={temp_agla} "
                                f"cd_alpha={args.agla_alpha} cd_beta={args.agla_beta} "
                                f"aug_shape={tuple(image_tensor_cd.shape)}{dscr_note}"
                            )
                        if cache_augmented_for_agla is not None:
                            model._agla_cache_augmented = cache_augmented_for_agla
                        try:
                            output_ids = model.generate(
                                input_ids=input_ids.input_ids,
                                images=image_tensor.unsqueeze(0).half().to(model.device),
                                images_cd=image_tensor_cd.unsqueeze(0).half().to(model.device),
                                cd_alpha=args.agla_alpha,
                                cd_beta=args.agla_beta,
                                do_sample=do_sample_agla,
                                temperature=temp_agla,
                                top_p=float(args.top_p),
                                top_k=args.top_k,
                                max_new_tokens=args.max_new_tokens,
                                pad_token_id=tokenizer.eod_id,
                                eos_token_id=tokenizer.eod_id,
                                **gen_dscr_kwargs,
                            )
                        finally:
                            # Clean up temporary attributes
                            if hasattr(model, '_agla_cache_augmented'):
                                delattr(model, '_agla_cache_augmented')
                    else:

                        if cache_clean is not None:
                            # DSCR-only: always greedy (do_sample=False, temperature=0.0)
                            do_sample_dscr = False
                            temp_dscr = 0.0
                            log_once(f"[BASELINE+DSCR] do_sample={do_sample_dscr} temperature={temp_dscr} (fixed for DSCR){dscr_note}")
                            images_b1 = image_tensor.unsqueeze(0).half().to(model.device)
                            image_tensor_cd_for_gen = image_tensor_cd.unsqueeze(0).half().to(model.device) if image_tensor_cd is not None else None
                            output_ids = model.generate(
                                input_ids=input_ids.input_ids.to(model.device),
                                attention_mask=input_ids.attention_mask.to(model.device),
                                images=images_b1,
                                images_cd=image_tensor_cd_for_gen,
                                cd_alpha=getattr(args, "cd_alpha", 0.5),
                                cd_beta=getattr(args, "cd_beta", 0.1),
                                do_sample=do_sample_dscr,
                                temperature=temp_dscr,
                                top_p=float(args.top_p),
                                top_k=args.top_k,
                                max_new_tokens=args.max_new_tokens,
                                min_new_tokens=1,
                                length_penalty=1.0,
                                num_return_sequences=1,
                                output_hidden_states=True,
                                pad_token_id=tokenizer.eod_id,
                                eos_token_id=tokenizer.eod_id,
                                **gen_dscr_kwargs,
                            )
                        else:
                            # Baseline (no DSCR): model.generate with images, no past_key_values
                            temp_baseline = float(args.temperature)
                            do_sample_baseline = temp_baseline > 0.0
                            if not do_sample_baseline:
                                temp_baseline = 0.0
                            log_once(f"[BASELINE] do_sample={do_sample_baseline} temperature={temp_baseline}{dscr_note}")
                            images_b1 = image_tensor.unsqueeze(0).half().to(model.device)
                            output_ids = model.generate(
                                input_ids=input_ids.input_ids,
                                images=images_b1,
                                do_sample=do_sample_baseline,
                                temperature=temp_baseline,
                                top_p=float(args.top_p),
                                top_k=args.top_k,
                                max_new_tokens=args.max_new_tokens,
                                pad_token_id=tokenizer.eod_id,
                                eos_token_id=tokenizer.eod_id,
                            )

                input_len = input_ids.input_ids.size(1)
                decode_ids = output_ids[0]
                if decode_ids.shape[0] > input_len:
                    decode_ids = decode_ids[input_len:]
                decoded = tokenizer.decode(
                    decode_ids.cpu(),
                    skip_special_tokens=True,
                ).strip()
                if args.format == "yn_format":
                    decoded = normalize_yn_answer(decoded)

                fout.write(
                    json.dumps(
                        {
                            "question_id": qid,
                            "prompt": prompt,
                            "text": decoded,
                            "model_id": "qwen-vl",
                            "image": image_file,
                            "metadata": {},
                        }
                    )
                    + "\n"
                )


if __name__ == "__main__":
    main()
