#!/usr/bin/env python
import argparse
import os
import json
import re
import sys
from typing import List

from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
EXPERIMENTS_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ROOT_DIR)
sys.path.append(EXPERIMENTS_DIR)
sys.path.insert(0, os.path.join(EXPERIMENTS_DIR, "transformers-4.31.0", "src"))

from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path

from transformers import set_seed

NUM_IMAGE_TOKENS = 576  # 24 x 24


# -------------------------
# -------------------------
def _normalize_vqa_answer(ans: str) -> str:
    ans = ans.strip().lower()
    ans = ans.split("\n")[0]

    if ans.startswith("answer:"):
        ans = ans[len("answer:"):].strip()
    if ans.startswith("the answer is"):
        ans = ans[len("the answer is"):].strip()

    punct = r'[!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~]'
    ans = re.sub(punct, " ", ans)
    ans = re.sub(r"\b(a|an|the)\b", " ", ans)

    number_map = {
        "zero": "0", "one": "1", "two": "2", "three": "3", "four": "4",
        "five": "5", "six": "6", "seven": "7", "eight": "8", "nine": "9", "ten": "10",
    }

    words = ans.split()
    new_words = []
    for w in words:
        if w in number_map:
            new_words.append(number_map[w])
        else:
            new_words.append(w)

    return " ".join(new_words).strip()


def compute_vqa_accuracy(pred: str, gt_answers: List[str]) -> float:
    if not gt_answers:
        return 0.0

    gt_raw = [a.strip().lower() for a in gt_answers if a is not None]
    gt_set = set(gt_raw)

    yn_set = {"yes", "no"}
    is_yn = all((a in yn_set) for a in gt_set if len(a) > 0)

    number_words = {
        "zero": "0", "one": "1", "two": "2", "three": "3", "four": "4",
        "five": "5", "six": "6", "seven": "7", "eight": "8", "nine": "9", "ten": "10",
    }

    def to_digit(s: str):
        s = s.strip().lower()
        if s.isdigit():
            return s
        return number_words.get(s, None)

    is_num = all((to_digit(a) is not None) for a in gt_set if len(a) > 0)

    raw = pred.strip().lower()

    if is_yn:
        if "yes" in raw:
            pred_core = "yes"
        elif "no" in raw:
            pred_core = "no"
        else:
            pred_core = _normalize_vqa_answer(pred)
    elif is_num:
        digits = re.findall(r"\d+", raw)
        if len(digits) == 1:
            pred_core = digits[0]
        else:
            tokens = raw.split()
            cand = [to_digit(t) for t in tokens if to_digit(t) is not None]
            if len(cand) == 1:
                pred_core = cand[0]
            else:
                pred_core = _normalize_vqa_answer(pred)
    else:
        pred_core = _normalize_vqa_answer(pred)

    if len(pred_core) == 0:
        return 0.0

    norm_gts = [_normalize_vqa_answer(a) for a in gt_answers]

    if is_yn:
        norm_gts = [("yes" if g == "yes" else "no") for g in norm_gts]
    if is_num:
        norm_gts = [to_digit(g) or g for g in norm_gts]

    match_count = sum(1 for g in norm_gts if g == pred_core)
    return min(1.0, match_count / 3.0)


# -------------------------
# -------------------------
def apply_D_to_cache(
    cache,
    D,
    cut_idx: int,
    start_layer_idx: int,
    end_layer_idx: int,
    key_only: bool,
    value_only: bool,
    key_value: bool,
    num_image_tokens: int = NUM_IMAGE_TOKENS,
    dscr_lambda: float = 1.0,
):
    """Apply DSCR refinement matrix D to KV cache."""
    if D.dim() == 4:
        D = D.squeeze(0).squeeze(0)
    elif D.dim() == 2:
        pass
    else:
        raise ValueError(f"Unsupported D shape: {tuple(D.shape)}")

    start_layer_idx = int(start_layer_idx)
    end_layer_idx = int(end_layer_idx)
    dscr_lambda = float(max(0.0, min(1.0, dscr_lambda)))

    out = []
    D_cache = {}

    for layer_idx, (k, v) in enumerate(cache):
        if start_layer_idx <= layer_idx < end_layer_idx:
            if key_only or key_value:
                tag = (k.device, k.dtype)
                Dk = D_cache.get(tag)
                if Dk is None:
                    Dk = D.to(device=k.device, dtype=k.dtype)
                    D_cache[tag] = Dk
                seg = k[:, :, cut_idx:cut_idx + num_image_tokens, :]
                seg_refined = torch.einsum("ij,bhjd->bhid", Dk, seg)
                seg_mixed = (1.0 - dscr_lambda) * seg + dscr_lambda * seg_refined
                k = torch.cat(
                    (k[:, :, :cut_idx, :], seg_mixed, k[:, :, cut_idx + num_image_tokens:, :]),
                    dim=2,
                )

            if value_only or key_value:
                tag = (v.device, v.dtype)
                Dv = D_cache.get(tag)
                if Dv is None:
                    Dv = D.to(device=v.device, dtype=v.dtype)
                    D_cache[tag] = Dv
                seg = v[:, :, cut_idx:cut_idx + num_image_tokens, :]
                seg_refined = torch.einsum("ij,bhjd->bhid", Dv, seg)
                seg_mixed = (1.0 - dscr_lambda) * seg + dscr_lambda * seg_refined
                v = torch.cat(
                    (v[:, :, :cut_idx, :], seg_mixed, v[:, :, cut_idx + num_image_tokens:, :]),
                    dim=2,
                )

        out.append((k, v))

    return tuple(out)


def recompute_first_logits_from_refined_cache(model, input_ids, cache):
    """Recompute first-token logits using refined KV cache.
    Re-forwards the last prompt token as a prefix so the first generated
    token also benefits from the DSCR-refined cache.
    """
    total_keys = [kv[0] for kv in cache]
    total_values = [kv[1] for kv in cache]

    prefix_keys = [k[:, :, :-1, :].contiguous() for k in total_keys]
    prefix_values = [v[:, :, :-1, :].contiguous() for v in total_values]
    prefix_cache = tuple((k, v) for k, v in zip(prefix_keys, prefix_values))

    last_token_ids = input_ids[:, -1:]
    with torch.inference_mode():
        outputs_last = model(
            last_token_ids,
            images=None,
            use_cache=True,
            past_key_values=prefix_cache,
            return_dict=True,
        )
    return outputs_last.past_key_values, outputs_last.logits[:, -1, :]


def build_dscr_matrix(
    depth: np.ndarray,
    num_image_tokens: int,
    alpha: float,
    beta: float,
    sigma: float,
    self_keep: float,
    device: torch.device,
    no_depth: bool = False,
    no_spatial: bool = False,
) -> torch.Tensor:
    if no_depth and no_spatial:
        raise ValueError("--dscr-no-depth and --dscr-no-spatial cannot both be set")

    H = W = int(num_image_tokens ** 0.5)
    self_keep = float(max(0.0, min(1.0, self_keep)))

    depth_tensor = torch.tensor(depth, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(1)
    depth_patch = F.interpolate(depth_tensor, size=(H, W), mode="bilinear", align_corners=False)
    depth_patch = torch.clamp(1.0e-6 + 1.0 / depth_patch, 0.001, 1000).view(-1)

    dmin, dmax = depth_patch.min(), depth_patch.max()
    depth_patch = (depth_patch - dmin) / (dmax - dmin + 1e-6)

    depth_diff = torch.abs(depth_patch.unsqueeze(0) - depth_patch.unsqueeze(1))
    gw_depth = torch.exp(-(depth_diff ** 2) / (2 * sigma ** 2 + 1e-12))

    pixel_positions = torch.tensor(
        [[i // W, i % W] for i in range(num_image_tokens)],
        dtype=torch.float32, device=device,
    )
    pixel_positions = pixel_positions / float(max(H - 1, 1))
    position_diff = torch.cdist(pixel_positions, pixel_positions, p=2)
    gw_pos = torch.exp(-(position_diff ** 2) / (2 * sigma ** 2 + 1e-12))

    if no_depth:
        gw = gw_pos ** beta
    elif no_spatial:
        gw = gw_depth ** alpha
    else:
        gw = (gw_depth ** alpha) + (gw_pos ** beta)

    if self_keep < 1.0:
        eye = torch.eye(num_image_tokens, dtype=gw.dtype, device=device)
        gw = gw * (1.0 - eye) + self_keep * eye

    row_sums = gw.sum(dim=-1, keepdim=True)
    zero_mask = row_sums < 1e-12
    if zero_mask.any():
        gw[zero_mask.squeeze(-1)] = 1.0
        row_sums = gw.sum(dim=-1, keepdim=True)

    return gw / (row_sums + 1e-6)


# -------------------------
# -------------------------
def eval_model(args):
    disable_torch_init()

    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)

    print(f"[INFO] Load LLaVA model from {model_path} (device_map='auto')")
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path,
        args.model_base,
        model_name,
        load_8bit=False,
        load_4bit=False,
        device_map="auto",
    )

    model_dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16

    try:
        model_device = model.device
    except Exception:
        model_device = next(model.parameters()).device

    vision_tower = model.get_vision_tower() if hasattr(model, "get_vision_tower") else None
    image_token_len = (
        vision_tower.num_patches
        if vision_tower is not None and hasattr(vision_tower, "num_patches")
        else NUM_IMAGE_TOKENS
    )

    num_hidden_layers = getattr(model.config, "num_hidden_layers", 32)
    dscr_end_layer = args.dscr_end_layer if args.dscr_end_layer is not None else num_hidden_layers

    print(f"[INFO] Load questions from {args.question_file}")
    with open(args.question_file, "r", encoding="utf-8") as f:
        first = f.read(1)
        f.seek(0)
        if first == "[":
            questions = json.load(f)
        else:
            questions = [json.loads(line) for line in f if line.strip()]

    print(f"[INFO] Number of questions loaded: {len(questions)}")

    gt_dict = {}
    if args.gt_file is not None and os.path.exists(args.gt_file):
        print(f"[INFO] Load VQA GT from {args.gt_file}")
        with open(args.gt_file, "r", encoding="utf-8") as f:
            first_char = f.read(1)
            f.seek(0)
            if first_char == "[":
                gt_entries = json.load(f)
            else:
                gt_entries = [json.loads(line) for line in f if line.strip()]

        for ent in gt_entries:
            qid = ent["question_id"]
            raw_answers = ent["answers"]
            if len(raw_answers) > 0 and isinstance(raw_answers[0], dict):
                ans_list = [a["answer"] for a in raw_answers]
            else:
                ans_list = list(raw_answers)
            gt_dict[qid] = ans_list

        print(f"[INFO] Loaded GT for {len(gt_dict)} questions")
    else:
        raise ValueError(f"gt_file not found or invalid path: {args.gt_file}")

    os.makedirs(os.path.dirname(args.answers_file), exist_ok=True)
    ans_file = open(args.answers_file, "w", encoding="utf-8")

    total_cnt = 0
    total_acc = 0.0
    pred_cnt = 0

    if args.format == "yn_format":
        fmt_suffix = " Please only answer yes or no."
    elif args.format == "ow_format":
        fmt_suffix = " Please answer this question with one word."
    else:
        fmt_suffix = ""

    print(f"[INFO] Start inference. Image folder: {args.image_folder}")
    print(f"[INFO] use_dscr = {args.use_dscr}")

    for line in tqdm(questions):
        qid = line["question_id"]
        image_file = line["image"]
        qs = line["text"]
        cur_prompt = qs

        image_path = os.path.join(args.image_folder, image_file)
        if not os.path.exists(image_path):
            print(f"[WARN] image not found: {image_path}")
            continue

        image = Image.open(image_path).convert("RGB")

        if model.config.mm_use_im_start_end:
            qs_full = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs + fmt_suffix
        else:
            qs_full = DEFAULT_IMAGE_TOKEN + "\n" + qs + fmt_suffix

        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs_full)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = tokenizer_image_token(
            prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt",
        ).unsqueeze(0).to(model_device)

        image_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"].to(
            device=model_device,
            dtype=model_dtype,
        )

        pos = torch.where(input_ids[0] == IMAGE_TOKEN_INDEX)[0]
        if pos.numel() == 0:
            raise RuntimeError("IMAGE_TOKEN_INDEX not found in input_ids")
        cut_idx = int(pos[0].item())

        if args.use_dscr:
            depth_file = os.path.join(
                args.depth_folder,
                os.path.splitext(image_file)[0] + ".npy",
            )
            if not os.path.exists(depth_file):
                print(f"[WARN] depth file not found: {depth_file}")
                continue

            depth = np.load(depth_file, allow_pickle=True)
            if isinstance(depth, dict) and "depth" in depth:
                depth = depth["depth"]

            dscr_key_only_eff = bool(args.dscr_key_only)
            dscr_value_only_eff = bool(args.dscr_value_only)
            dscr_key_value_eff = bool(args.dscr_key_value) or (not dscr_key_only_eff and not dscr_value_only_eff)

            D = build_dscr_matrix(
                depth=depth,
                num_image_tokens=image_token_len,
                alpha=float(args.dscr_alpha),
                beta=float(args.dscr_beta),
                sigma=float(args.dscr_sigma),
                self_keep=float(args.dscr_self_keep),
                device=model_device,
                no_depth=args.dscr_no_depth,
                no_spatial=args.dscr_no_spatial,
            )

            # prefill
            attention_mask = torch.ones_like(input_ids, device=model_device)
            with torch.inference_mode():
                prefill_outputs = model(
                    input_ids=input_ids,
                    images=image_tensor,
                    attention_mask=attention_mask,
                    use_cache=True,
                    output_attentions=False,
                    return_dict=True,
                )

            past_key_values = apply_D_to_cache(
                cache=prefill_outputs.past_key_values,
                D=D,
                cut_idx=cut_idx,
                start_layer_idx=int(args.dscr_start_layer),
                end_layer_idx=int(dscr_end_layer),
                key_only=dscr_key_only_eff,
                value_only=dscr_value_only_eff,
                key_value=dscr_key_value_eff,
                num_image_tokens=image_token_len,
                dscr_lambda=float(args.dscr_lambda),
            )

            past_key_values, _ = recompute_first_logits_from_refined_cache(
                model=model,
                input_ids=input_ids,
                cache=past_key_values,
            )

            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids=input_ids,
                    images=image_tensor,
                    use_cache=True,
                    past_key_values=past_key_values,
                    do_sample=args.do_sample,
                    temperature=float(args.temperature),
                    max_new_tokens=args.max_new_tokens,
                )

        else:
            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids=input_ids,
                    images=image_tensor,
                    use_cache=True,
                    do_sample=args.do_sample,
                    temperature=float(args.temperature),
                    max_new_tokens=args.max_new_tokens,
                )

        input_token_len = input_ids.shape[1]
        outputs = tokenizer.batch_decode(
            output_ids[:, input_token_len:],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )[0].strip()

        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        if outputs.endswith(stop_str):
            outputs = outputs[:-len(stop_str)].strip()

        pred_answer = outputs.split("\n")[0].strip()
        pred_cnt += 1

        if qid in gt_dict:
            gt_answers = gt_dict[qid]
            acc = compute_vqa_accuracy(pred_answer, gt_answers)
            total_acc += acc
            total_cnt += 1
        else:
            print(f"[WARN] GT not found for question_id={qid}")

        ans_file.write(
            json.dumps(
                {
                    "question_id": qid,
                    "prompt": cur_prompt + fmt_suffix,
                    "text": outputs,
                    "pred_answer": pred_answer,
                    "model_id": model_name,
                    "image": image_file,
                    "metadata": {},
                }
            )
            + "\n"
        )
        ans_file.flush()

    ans_file.close()
    print(f"[INFO] Total predictions written: {pred_cnt}")
    print(f"[INFO] Output saved to: {args.answers_file}")

    if total_cnt == 0:
        raise ValueError("No questions matched with GT. Check gt_file / question_file paths.")

    avg_acc = total_acc / total_cnt * 100.0
    print(f"[RESULT] VQA accuracy on {total_cnt} samples: {avg_acc:.2f}%")
    print("[INFO] Done.")


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--model-path", type=str, default="liuhaotian/llava-v1.5-13b")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="/path/to/data")
    parser.add_argument("--question-file", type=str, default="/path/to/data")
    parser.add_argument(
        "--gt-file", type=str, default="/path/to/data",
        help="Answer subset file generated by vqa_v2_subset.py",
    )
    parser.add_argument("--answers-file", type=str, default="./output/VQA_v2/llava15_vqa_answers_sample500.jsonl")
    parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
    parser.add_argument("--max-new-tokens", type=int, default=8)
    parser.add_argument("--dtype", type=str, default="fp16", choices=["bf16", "fp16"])
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument(
        "--format", type=str, default="ow_format",
        choices=["no_format", "ow_format", "yn_format"],
    )
    parser.add_argument("--do-sample", action="store_true", default=False)
    parser.add_argument("--temperature", type=float, default=0.0)

    # DSCR
    parser.add_argument("--use_dscr", action="store_true", default=False)
    parser.add_argument("--depth-folder", type=str, default="/path/to/data")
    parser.add_argument("--dscr-alpha", dest="dscr_alpha", type=float, default=0.6)
    parser.add_argument("--dscr-beta", dest="dscr_beta", type=float, default=0.8)
    parser.add_argument("--dscr-sigma", dest="dscr_sigma", type=float, default=0.6)
    parser.add_argument("--dscr-start-layer", dest="dscr_start_layer", type=int, default=10)
    parser.add_argument(
        "--dscr-end-layer", dest="dscr_end_layer", type=int, default=None,
        help="Last layer to apply DSCR (None = total model layers).",
    )
    parser.add_argument("--dscr-lambda", dest="dscr_lambda", type=float, default=1.0)
    parser.add_argument("--dscr-self-keep", dest="dscr_self_keep", type=float, default=1.0)
    parser.add_argument("--dscr-keep-ratio", dest="dscr_self_keep", type=float)
    parser.add_argument("--dscr-key-only", dest="dscr_key_only", action="store_true", default=False)
    parser.add_argument("--dscr-value-only", dest="dscr_value_only", action="store_true", default=False)
    parser.add_argument("--dscr-key-value", dest="dscr_key_value", action="store_true", default=False)
    parser.add_argument("--dscr-no-depth", dest="dscr_no_depth", action="store_true", default=False,
                        help="Spatial-only: remove depth term from D matrix")
    parser.add_argument("--dscr-no-spatial", dest="dscr_no_spatial", action="store_true", default=False,
                        help="Depth-only: remove spatial term from D matrix")

    # legacy aliases
    parser.add_argument("--alpha", dest="dscr_alpha", type=float)
    parser.add_argument("--beta", dest="dscr_beta", type=float)
    parser.add_argument("--sigma", dest="dscr_sigma", type=float)
    parser.add_argument("--start-layer", dest="dscr_start_layer", type=int)
    parser.add_argument("--end-layer", dest="dscr_end_layer", type=int)
    parser.add_argument("--dscr_lambda", dest="dscr_lambda", type=float)
    parser.add_argument("--dscr_self_keep", dest="dscr_self_keep", type=float)

    args = parser.parse_args()

    set_seed(args.seed)

    print("llava15_vqa.py")
    print(f"use_dscr: {args.use_dscr}")

    eval_model(args)


if __name__ == "__main__":
    main()
