#!/usr/bin/env python
import argparse
import os
import sys
import json
import re
import math

from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
EXPERIMENTS_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ROOT_DIR)
sys.path.append(EXPERIMENTS_DIR)
sys.path.insert(0, os.path.join(EXPERIMENTS_DIR, "transformers-4.49.0", "src"))

from transformers import (
    AutoProcessor,
    Qwen2_5_VLForConditionalGeneration,
    DynamicCache,
    set_seed,
)


def _normalize_vqa_answer(ans: str) -> str:
    ans = ans.strip().lower()
    ans = ans.split("\n")[0]

    if ans.startswith("answer:"):
        ans = ans[len("answer:"):].strip()
    if ans.startswith("the answer is"):
        ans = ans[len("the answer is"):].strip()

    punct = r'[!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~]'
    ans = re.sub(punct, " ", ans)
    ans = re.sub(r"\b(a|an|the)\b", " ", ans)

    number_map = {
        "zero": "0", "one": "1", "two": "2", "three": "3", "four": "4",
        "five": "5", "six": "6", "seven": "7", "eight": "8", "nine": "9", "ten": "10",
    }

    words = ans.split()
    new_words = []
    for w in words:
        if w in number_map:
            new_words.append(number_map[w])
        else:
            new_words.append(w)

    return " ".join(new_words).strip()


def compute_vqa_accuracy(pred: str, gt_answers) -> float:
    if not gt_answers:
        return 0.0

    gt_raw = [a.strip().lower() for a in gt_answers if a is not None]
    gt_set = set(gt_raw)

    yn_set = {"yes", "no"}
    is_yn = all((a in yn_set) for a in gt_set if len(a) > 0)

    number_words = {
        "zero": "0", "one": "1", "two": "2", "three": "3", "four": "4",
        "five": "5", "six": "6", "seven": "7", "eight": "8", "nine": "9", "ten": "10",
    }

    def to_digit(s: str):
        s = s.strip().lower()
        if s.isdigit():
            return s
        return number_words.get(s, None)

    is_num = all((to_digit(a) is not None) for a in gt_set if len(a) > 0)

    raw = pred.strip().lower()

    if is_yn:
        if "yes" in raw:
            pred_core = "yes"
        elif "no" in raw:
            pred_core = "no"
        else:
            pred_core = _normalize_vqa_answer(pred)
    elif is_num:
        digits = re.findall(r"\d+", raw)
        if len(digits) == 1:
            pred_core = digits[0]
        else:
            tokens = raw.split()
            cand = [to_digit(t) for t in tokens if to_digit(t) is not None]
            if len(cand) == 1:
                pred_core = cand[0]
            else:
                pred_core = _normalize_vqa_answer(pred)
    else:
        pred_core = _normalize_vqa_answer(pred)

    if len(pred_core) == 0:
        return 0.0

    norm_gts = [_normalize_vqa_answer(a) for a in gt_answers]

    if is_yn:
        norm_gts = [("yes" if g == "yes" else "no") for g in norm_gts]
    if is_num:
        norm_gts = [to_digit(g) or g for g in norm_gts]

    match_count = sum(1 for g in norm_gts if g == pred_core)
    return min(1.0, match_count / 3.0)


def apply_D_to_cache(
    cache,
    D: torch.Tensor,
    cut_idx: int,
    start_layer_idx: int,
    end_layer_idx: int,
    key_only: bool,
    value_only: bool,
    key_value: bool,
    num_image_tokens: int,
    dscr_lambda: float = 1.0,
):
    if D.dim() == 4:
        D = D.squeeze(0).squeeze(0)
    elif D.dim() == 2:
        pass
    else:
        raise ValueError(f"Unsupported D shape: {tuple(D.shape)}")

    start_layer_idx = int(start_layer_idx)
    end_layer_idx = int(end_layer_idx)
    dscr_lambda = float(max(0.0, min(1.0, dscr_lambda)))

    D_cache = {}
    out = []
    for layer_idx, (k, v) in enumerate(cache):
        if start_layer_idx <= layer_idx < end_layer_idx:
            if key_only or key_value:
                tag = (k.device, k.dtype)
                Dk = D_cache.get(tag)
                if Dk is None:
                    Dk = D.to(device=k.device, dtype=k.dtype)
                    D_cache[tag] = Dk
                seg = k[:, :, cut_idx:cut_idx + num_image_tokens, :]
                seg_ref = torch.einsum("ij,bhjd->bhid", Dk, seg)
                seg_mix = (1.0 - dscr_lambda) * seg + dscr_lambda * seg_ref
                k = torch.cat((k[:, :, :cut_idx, :], seg_mix, k[:, :, cut_idx + num_image_tokens:, :]), dim=2)
            if value_only or key_value:
                tag = (v.device, v.dtype)
                Dv = D_cache.get(tag)
                if Dv is None:
                    Dv = D.to(device=v.device, dtype=v.dtype)
                    D_cache[tag] = Dv
                seg = v[:, :, cut_idx:cut_idx + num_image_tokens, :]
                seg_ref = torch.einsum("ij,bhjd->bhid", Dv, seg)
                seg_mix = (1.0 - dscr_lambda) * seg + dscr_lambda * seg_ref
                v = torch.cat((v[:, :, :cut_idx, :], seg_mix, v[:, :, cut_idx + num_image_tokens:, :]), dim=2)
        out.append((k, v))
    return tuple(out)


def recompute_first_logits_from_refined_cache(model, input_ids, cache, attention_mask=None):
    """Refined cache is a legacy tuple; pass last token as prefix to recompute first logit with DSCR."""
    total_keys = [kv[0] for kv in cache]
    total_values = [kv[1] for kv in cache]

    prefix_keys = [k[:, :, :-1, :].contiguous() for k in total_keys]
    prefix_values = [v[:, :, :-1, :].contiguous() for v in total_values]
    prefix_cache_tuple = tuple((k, v) for k, v in zip(prefix_keys, prefix_values))
    prefix_cache = DynamicCache.from_legacy_cache(prefix_cache_tuple)

    last_token_ids = input_ids[:, -1:]
    last_attention_mask = attention_mask[:, -1:] if attention_mask is not None else None
    with torch.inference_mode():
        outputs_last = model(
            input_ids=last_token_ids,
            attention_mask=last_attention_mask,
            pixel_values=None,
            image_grid_thw=None,
            use_cache=True,
            past_key_values=prefix_cache,
            return_dict=True,
        )
    return outputs_last.past_key_values, outputs_last.logits[:, -1, :]


def build_dscr_matrix(
    depth: np.ndarray,
    H_tokens: int,
    W_tokens: int,
    alpha: float,
    beta: float,
    sigma: float,
    self_keep: float,
    no_depth: bool = False,
    no_spatial: bool = False,
) -> torch.Tensor:
    if no_depth and no_spatial:
        raise ValueError("--dscr-no-depth and --dscr-no-spatial cannot both be set")

    self_keep = float(max(0.0, min(1.0, self_keep)))

    depth_tensor = torch.tensor(depth, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    depth_tensor = F.interpolate(
        depth_tensor, size=(H_tokens, W_tokens), mode="bilinear", align_corners=False,
    )
    depth_patch = torch.clamp(1.0e-6 + 1.0 / depth_tensor, 0.001, 1000).view(1, -1)
    d_min, d_max = depth_patch.min(), depth_patch.max()
    depth_patch = (depth_patch - d_min) / (d_max - d_min + 1e-6)

    depth_diff = depth_patch - depth_patch.transpose(1, 0)
    gw_depth = torch.exp(-(depth_diff ** 2) / (2 * sigma ** 2 + 1e-12))

    xs = torch.arange(H_tokens, dtype=torch.float32)
    ys = torch.arange(W_tokens, dtype=torch.float32)
    yy, xx = torch.meshgrid(xs, ys, indexing="ij")
    pos = torch.stack([yy, xx], dim=-1).view(-1, 2)
    pos[:, 0] /= max(H_tokens - 1, 1)
    pos[:, 1] /= max(W_tokens - 1, 1)
    gw_pos = torch.exp(-(torch.cdist(pos, pos, p=2) ** 2) / (2 * sigma ** 2 + 1e-12))

    if no_depth:
        gw = gw_pos ** beta
    elif no_spatial:
        gw = gw_depth ** alpha
    else:
        gw = (gw_depth ** alpha) + (gw_pos ** beta)

    T = gw.shape[0]
    if self_keep < 1.0:
        eye = torch.eye(T, dtype=gw.dtype, device=gw.device)
        gw = gw * (1.0 - eye) + self_keep * eye

    row_sums = gw.sum(dim=-1, keepdim=True)
    zero_mask = row_sums < 1e-12
    if zero_mask.any():
        gw[zero_mask.squeeze(-1)] = 1.0
        row_sums = gw.sum(dim=-1, keepdim=True)

    return gw / (row_sums + 1e-6)


def eval_model(args):
    print(f"[INFO] Load model from {args.model_path}")

    use_device_map = (
        getattr(args, "device_map", None)
        and str(args.device_map).strip().lower() in ("auto", "balanced", "sequential")
    )
    dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16

    if use_device_map:
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            args.model_path,
            torch_dtype=dtype,
            device_map=args.device_map.strip().lower(),
        )
        device = next(model.parameters()).device
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            args.model_path,
            torch_dtype=dtype,
            device_map="auto",
        )
        device = next(model.parameters()).device

    model.eval()

    num_hidden_layers = getattr(model.config, "num_hidden_layers", 28)

    processor = AutoProcessor.from_pretrained(
        args.model_path,
        max_pixels=args.max_image_pixels,
    )
    print(
        f"[Config] processor max_pixels={args.max_image_pixels} "
        f"(≈{int(args.max_image_pixels**0.5)}x{int(args.max_image_pixels**0.5)} square equiv)"
    )
    image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)

    dscr_end_layer = args.dscr_end_layer if args.dscr_end_layer is not None else num_hidden_layers

    print(f"[INFO] Load questions from {args.question_file}")
    with open(args.question_file, "r", encoding="utf-8") as f:
        first_char = f.read(1)
        f.seek(0)
        if first_char == "[":
            questions = json.load(f)
        else:
            questions = [json.loads(line) for line in f if line.strip()]

    print(f"[INFO] Number of questions loaded: {len(questions)}")

    gt_dict = {}
    if args.gt_file is not None and os.path.exists(args.gt_file):
        print(f"[INFO] Load VQA GT from {args.gt_file}")
        with open(args.gt_file, "r", encoding="utf-8") as f:
            first_char = f.read(1)
            f.seek(0)
            if first_char == "[":
                gt_entries = json.load(f)
            else:
                gt_entries = [json.loads(line) for line in f if line.strip()]

        for ent in gt_entries:
            qid = ent["question_id"]
            raw_answers = ent["answers"]
            if len(raw_answers) > 0 and isinstance(raw_answers[0], dict):
                ans_list = [a["answer"] for a in raw_answers]
            else:
                ans_list = list(raw_answers)
            gt_dict[qid] = ans_list

        print(f"[INFO] Loaded GT for {len(gt_dict)} questions")
    else:
        raise ValueError(f"gt_file not found or invalid path: {args.gt_file}")

    os.makedirs(os.path.dirname(args.answers_file), exist_ok=True)
    ans_file = open(args.answers_file, "w", encoding="utf-8")

    total_cnt = 0
    total_acc = 0.0
    pred_cnt = 0

    if args.format == "yn_format":
        fmt_suffix = " Please only answer yes or no."
    elif args.format == "ow_format":
        fmt_suffix = " Please answer this question with one word."
    else:
        fmt_suffix = ""

    print(f"[INFO] Start inference. Image folder: {args.image_folder}")

    for line in tqdm(questions):
        idx = line["question_id"]
        image_file = line["image"]
        qs = line["text"]
        cur_prompt = qs

        image_path = os.path.join(args.image_folder, image_file)
        if not os.path.exists(image_path):
            print(f"[WARN] image not found: {image_path}")
            continue

        image = Image.open(image_path).convert("RGB")

        MAX_IMAGE_DIM = int(args.max_image_pixels ** 0.5 * 1.2)
        w, h = image.size
        if max(w, h) > MAX_IMAGE_DIM:
            scale = MAX_IMAGE_DIM / max(w, h)
            nw, nh = int(w * scale), int(h * scale)
            image = image.resize((nw, nh), Image.Resampling.LANCZOS)
            print(f"[Resize] {image_file}: {w}x{h} -> {nw}x{nh}")

        if args.max_image_side is not None and args.max_image_side > 0:
            w, h = image.size
            m = max(w, h)
            if m > args.max_image_side:
                scale = args.max_image_side / m
                nw, nh = int(round(w * scale)), int(round(h * scale))
                image = image.resize((nw, nh), Image.Resampling.LANCZOS)

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": qs + fmt_suffix},
                ],
            }
        ]

        text = processor.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )

        inputs = processor(text=[text], images=[image], return_tensors="pt")
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)
        pixel_values = inputs["pixel_values"].to(device, dtype=model.dtype)
        image_grid_thw = inputs["image_grid_thw"].to(device)

        img_positions = torch.nonzero(input_ids[0] == image_token_id, as_tuple=False).squeeze(-1)
        if img_positions.numel() == 0:
            raise RuntimeError("No image tokens found in input_ids.")
        cut_idx = int(img_positions[0].item())
        num_image_tokens = int(img_positions.shape[0])

        if args.use_dscr:
            if args.depth_folder is None:
                raise ValueError("use_dscr is True but depth_folder was not provided")

            depth_file = os.path.join(
                args.depth_folder,
                os.path.splitext(image_file)[0] + ".npy",
            )
            if not os.path.exists(depth_file):
                print(f"[WARN] depth file not found: {depth_file}")
                continue

            depth = np.load(depth_file, allow_pickle=True)

            dscr_key_only_eff = bool(args.dscr_key_only)
            dscr_value_only_eff = bool(args.dscr_value_only)
            dscr_key_value_eff = bool(args.dscr_key_value) or (not dscr_key_only_eff and not dscr_value_only_eff)

            grid_t, grid_h, grid_w = image_grid_thw[0].tolist()
            merge_size = processor.image_processor.merge_size
            H_tokens = int(grid_h) // merge_size
            W_tokens = int(grid_w) // merge_size

            D = build_dscr_matrix(
                depth=depth,
                H_tokens=H_tokens,
                W_tokens=W_tokens,
                alpha=float(args.dscr_alpha),
                beta=float(args.dscr_beta),
                sigma=float(args.dscr_sigma),
                self_keep=float(args.dscr_self_keep),
                no_depth=args.dscr_no_depth,
                no_spatial=args.dscr_no_spatial,
            )

            with torch.inference_mode():
                outputs_clean = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pixel_values=pixel_values,
                    image_grid_thw=image_grid_thw,
                    use_cache=True,
                    output_attentions=False,
                    return_dict=True,
                )

            refined_legacy = apply_D_to_cache(
                cache=outputs_clean.past_key_values.to_legacy_cache(),
                D=D,
                cut_idx=cut_idx,
                start_layer_idx=int(args.dscr_start_layer),
                end_layer_idx=int(dscr_end_layer),
                key_only=dscr_key_only_eff,
                value_only=dscr_value_only_eff,
                key_value=dscr_key_value_eff,
                num_image_tokens=num_image_tokens,
                dscr_lambda=float(args.dscr_lambda),
            )

            prefix_cache, _ = recompute_first_logits_from_refined_cache(
                model=model,
                input_ids=input_ids,
                cache=refined_legacy,
                attention_mask=attention_mask,
            )

            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pixel_values=None,
                    image_grid_thw=None,
                    use_cache=True,
                    past_key_values=prefix_cache,
                    do_sample=args.do_sample,
                    temperature=float(args.temperature),
                    top_p=1.0,
                    top_k=None,
                    max_new_tokens=args.max_new_tokens,
                )
        else:
            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pixel_values=pixel_values,
                    image_grid_thw=image_grid_thw,
                    use_cache=True,
                    do_sample=args.do_sample,
                    temperature=float(args.temperature),
                    top_p=1.0,
                    top_k=None,
                    max_new_tokens=args.max_new_tokens,
                )

        gen_start = input_ids.shape[1]
        outputs_str = processor.decode(
            output_ids[0][gen_start:],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        ).strip()

        pred_answer = outputs_str.split("\n")[0].strip()
        pred_cnt += 1

        if idx in gt_dict:
            gt_answers = gt_dict[idx]
            acc = compute_vqa_accuracy(pred_answer, gt_answers)
            total_acc += acc
            total_cnt += 1
        else:
            print(f"[WARN] GT not found for question_id={idx}")

        ans_file.write(
            json.dumps(
                {
                    "question_id": idx,
                    "prompt": cur_prompt + fmt_suffix,
                    "text": outputs_str,
                    "pred_answer": pred_answer,
                    "model_id": args.model_path,
                    "image": image_file,
                    "metadata": {},
                }
            )
            + "\n"
        )
        ans_file.flush()

    ans_file.close()

    print(f"[INFO] Total predictions written: {pred_cnt}")
    print(f"[INFO] Output saved to: {args.answers_file}")

    if total_cnt == 0:
        raise ValueError(
            "No questions matched with GT. Check gt_file format or question_file path."
        )

    avg_acc = total_acc / total_cnt * 100.0
    print(f"[RESULT] VQA accuracy on {total_cnt} samples: {avg_acc:.2f}%")
    print("[INFO] Done.")


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--model-path", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="/path/to/data")
    parser.add_argument("--question-file", type=str, default="/path/to/data")
    parser.add_argument(
        "--gt-file", type=str, default="/path/to/data",
        help="Answer subset file generated by vqa_v2_subset.py",
    )
    parser.add_argument("--answers-file", type=str, default="./output/VQA_v2/qwen2_5vl_vqa_answers_sample500.jsonl")
    parser.add_argument("--max-new-tokens", type=int, default=8)
    parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16"])
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument(
        "--format", type=str, default="ow_format",
        choices=["no_format", "ow_format", "yn_format"],
    )
    parser.add_argument(
        "--device-map", type=str, default="auto",
        help="e.g. 'auto' or 'balanced'",
    )
    parser.add_argument(
        "--max-image-pixels", type=int, default=28 * 28 * 2048,
        help="Processor max_pixels for smart_resize (default 1605632≈1268^2).",
    )
    parser.add_argument(
        "--max-image-side", type=int, default=None,
        help="If set, resize image so max(w,h) <= this before processing.",
    )
    parser.add_argument("--do-sample", action="store_true", default=False)
    parser.add_argument("--temperature", type=float, default=0.0)

    # DSCR
    parser.add_argument("--use_dscr", action="store_true", default=False)
    parser.add_argument("--depth-folder", type=str, default="/path/to/data")
    parser.add_argument("--dscr-alpha", dest="dscr_alpha", type=float, default=0.6)
    parser.add_argument("--dscr-beta", dest="dscr_beta", type=float, default=0.8)
    parser.add_argument("--dscr-sigma", dest="dscr_sigma", type=float, default=0.6)
    parser.add_argument("--dscr-start-layer", dest="dscr_start_layer", type=int, default=16)
    parser.add_argument(
        "--dscr-end-layer", dest="dscr_end_layer", type=int, default=None,
        help="Last layer to apply DSCR (None = total model layers).",
    )
    parser.add_argument("--dscr-lambda", dest="dscr_lambda", type=float, default=1.0)
    parser.add_argument("--dscr-self-keep", dest="dscr_self_keep", type=float, default=1.0)
    parser.add_argument("--dscr-keep-ratio", dest="dscr_self_keep", type=float)
    parser.add_argument("--dscr-key-only", dest="dscr_key_only", action="store_true", default=False)
    parser.add_argument("--dscr-value-only", dest="dscr_value_only", action="store_true", default=False)
    parser.add_argument("--dscr-key-value", dest="dscr_key_value", action="store_true", default=False)
    parser.add_argument("--dscr-no-depth", dest="dscr_no_depth", action="store_true", default=False,
                        help="Spatial-only: remove depth term from D matrix")
    parser.add_argument("--dscr-no-spatial", dest="dscr_no_spatial", action="store_true", default=False,
                        help="Depth-only: remove spatial term from D matrix")

    # legacy aliases
    parser.add_argument("--alpha", dest="dscr_alpha", type=float)
    parser.add_argument("--beta", dest="dscr_beta", type=float)
    parser.add_argument("--sigma", dest="dscr_sigma", type=float)
    parser.add_argument("--start-layer", dest="dscr_start_layer", type=int)
    parser.add_argument("--end-layer", dest="dscr_end_layer", type=int)

    args = parser.parse_args()

    set_seed(args.seed)
    eval_model(args)


if __name__ == "__main__":
    main()
