import argparse
import torch
import os
import json
from tqdm import tqdm
import sys
import numpy as np

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
EXPERIMENTS_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ROOT_DIR)
sys.path.append(EXPERIMENTS_DIR)
sys.path.insert(0, os.path.join(EXPERIMENTS_DIR, "transformers-4.31.0", "src"))

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

from PIL import Image

from transformers import set_seed
from vcd_utils.vcd_add_noise import add_diffusion_noise
from vcd_utils.vcd_sample import evolve_vcd_sampling

NUM_IMAGE_TOKENS = 576


def recompute_first_logits_from_refined_cache(model, input_ids, cache):
    """Recompute first-token logits using refined KV cache."""
    total_keys = [kv[0] for kv in cache]
    total_values = [kv[1] for kv in cache]

    prefix_keys = [k[:, :, :-1, :].contiguous() for k in total_keys]
    prefix_values = [v[:, :, :-1, :].contiguous() for v in total_values]
    prefix_cache = tuple((k, v) for k, v in zip(prefix_keys, prefix_values))

    last_token_ids = input_ids[:, -1:]
    with torch.inference_mode():
        outputs_last = model(
            last_token_ids,
            images=None,
            use_cache=True,
            past_key_values=prefix_cache,
            return_dict=True,
        )
    return outputs_last.past_key_values, outputs_last.logits[:, -1, :]


def apply_D_to_cache(cache, D_576_576, cut_idx, start_layer_idx, end_layer_idx, key_only, value_only, key_value, dscr_lambda=1.0):
    if D_576_576.dim() == 4:
        D = D_576_576.squeeze(0).squeeze(0)
    elif D_576_576.dim() == 2:
        D = D_576_576
    else:
        raise ValueError(f"Unsupported D shape: {tuple(D_576_576.shape)}")

    start_layer_idx = int(start_layer_idx)
    end_layer_idx = int(end_layer_idx)
    dscr_lambda = float(max(0.0, min(1.0, dscr_lambda)))

    out = []
    D_cache = {}

    for layer_idx, (k, v) in enumerate(cache):
        if start_layer_idx <= layer_idx < end_layer_idx:
            if key_only or key_value:
                tag = (k.device, k.dtype)
                Dk = D_cache.get(tag)
                if Dk is None:
                    Dk = D.to(device=k.device, dtype=k.dtype)
                    D_cache[tag] = Dk

                seg = k[:, :, cut_idx:cut_idx + NUM_IMAGE_TOKENS, :]
                seg_refined = torch.einsum("ij,bhjd->bhid", Dk, seg)
                seg_mixed = (1.0 - dscr_lambda) * seg + dscr_lambda * seg_refined
                k = torch.cat(
                    (k[:, :, :cut_idx, :], seg_mixed, k[:, :, cut_idx + NUM_IMAGE_TOKENS:, :]),
                    dim=2
                )

            if value_only or key_value:
                tag = (v.device, v.dtype)
                Dv = D_cache.get(tag)
                if Dv is None:
                    Dv = D.to(device=v.device, dtype=v.dtype)
                    D_cache[tag] = Dv

                seg = v[:, :, cut_idx:cut_idx + NUM_IMAGE_TOKENS, :]
                seg_refined = torch.einsum("ij,bhjd->bhid", Dv, seg)
                seg_mixed = (1.0 - dscr_lambda) * seg + dscr_lambda * seg_refined
                v = torch.cat(
                    (v[:, :, :cut_idx, :], seg_mixed, v[:, :, cut_idx + NUM_IMAGE_TOKENS:, :]),
                    dim=2
                )

        out.append((k, v))

    return tuple(out)


def str2bool(x):
    if isinstance(x, bool):
        return x
    s = str(x).strip().lower()
    if s in ["1", "true", "t", "yes", "y"]:
        return True
    if s in ["0", "false", "f", "no", "n"]:
        return False
    raise argparse.ArgumentTypeError(f"Invalid bool value: {x}")


def clamp01(x: float) -> float:
    if x < 0.0:
        return 0.0
    if x > 1.0:
        return 1.0
    return x


def eval_model(args):
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)

    with open(args.question_file, "r", encoding="utf-8") as f:
        first = f.read(1)
        f.seek(0)
        if first == "[":
            questions = json.load(f)
        else:
            questions = [json.loads(line) for line in f if line.strip()]

    answers_file = os.path.expanduser(args.answers_file)
    os.makedirs(os.path.dirname(answers_file), exist_ok=True)

    dscr_alpha_eff = float(args.dscr_alpha)
    dscr_beta_eff = float(args.dscr_beta)
    dscr_sigma_eff = float(args.dscr_sigma)
    dscr_start_layer_eff = int(args.dscr_start_layer)
    dscr_end_layer_eff = int(args.dscr_end_layer)
    dscr_key_only_eff = bool(args.dscr_key_only)
    dscr_value_only_eff = bool(args.dscr_value_only)
    dscr_key_value_eff = bool(args.dscr_key_value) or (not dscr_key_only_eff and not dscr_value_only_eff)

    key_position = None

    if args.use_vcd:
        evolve_vcd_sampling()

    responses = []

    for i, line in enumerate(tqdm(questions)):
        idx = line["id"]
        image_file = line["image"]
        qs = line["query"]

        if args.use_dscr:
            depth_file = os.path.join(args.depth_folder, os.path.splitext(image_file)[0] + ".npy")
            if not os.path.isfile(depth_file):
                raise FileNotFoundError(f"Depth file not found: {depth_file}")

        if model.config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs + args.format)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = tokenizer_image_token(
            prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
        ).unsqueeze(0).cuda()

        image_token_pos = torch.where(input_ids[0] == -200)[0]
        if image_token_pos.numel() < 1:
            raise RuntimeError("No image token placeholder (-200) found in input_ids.")
        cut_idx = int(image_token_pos[0].item())

        image = Image.open(os.path.join(args.image_folder, image_file)).convert("RGB")
        image_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]

        if args.use_vcd:
            image_tensor_cd = add_diffusion_noise(image_tensor, args.noise_step).unsqueeze(0).half().cuda()
        else:
            image_tensor_cd = None

        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

        # Pre-fill KV cache only when DSCR or OPERA is active.
        # For pure baseline / pure VCD, let generate() handle the first forward
        # internally to avoid RoPE off-by-one from LLaVA image-token expansion.
        kv_outputs = None
        past_key_values_update = None
        needs_prefill = args.use_dscr or args.use_opera
        if needs_prefill:
            with torch.inference_mode():
                kv_outputs = model(
                    input_ids,
                    images=image_tensor.unsqueeze(0).half().cuda(),
                    use_cache=True,
                    output_attentions=True,
                    return_dict=True,
                )

        if args.use_dscr:
            depth = np.load(depth_file, allow_pickle=True)
            depth_tensor = torch.tensor(depth, dtype=torch.float32).unsqueeze(0)  # (1, H, W)

            depth_tensor = depth_tensor.unsqueeze(1)  # (1, 1, H, W)
            depth_patch = torch.nn.functional.interpolate(
                depth_tensor, size=(24, 24), mode="bilinear", align_corners=False
            )  # (1, 1, 24, 24)

            depth_patch = torch.clamp(1.0e-6 + 1.0 / depth_patch, 0.001, 1000).view(-1)  # (576,)
            depth_patch = (depth_patch - depth_patch.min()) / (depth_patch.max() - depth_patch.min() + 1e-6)

            depth_diff = torch.abs(depth_patch.unsqueeze(0) - depth_patch.unsqueeze(1))  # (576, 576)
            gaussian_weight_depth = torch.exp(-(depth_diff ** 2) / (2 * dscr_sigma_eff ** 2 + 1e-12))

            pixel_positions = torch.tensor(
                [[j // 24, j % 24] for j in range(NUM_IMAGE_TOKENS)],
                dtype=torch.float32
            ) / 23.0  # (576, 2)
            position_diff = torch.cdist(pixel_positions, pixel_positions, p=2)  # (576, 576)
            gaussian_weight_position = torch.exp(- (position_diff ** 2) / (2 * dscr_sigma_eff ** 2 + 1e-12))

            gaussian_weight = (gaussian_weight_depth ** dscr_alpha_eff) + (gaussian_weight_position ** dscr_beta_eff)

            dtype = kv_outputs.past_key_values[0][0].dtype
            device = kv_outputs.past_key_values[0][0].device

            D = gaussian_weight / (gaussian_weight.sum(dim=-1, keepdim=True) + 1e-6)  # (576, 576)
            D = D.to(dtype=dtype, device=device).unsqueeze(0).unsqueeze(0)

            num_hidden_layers = len(kv_outputs.past_key_values)
            start_layer_idx = dscr_start_layer_eff
            end_layer_idx = dscr_end_layer_eff
            end_layer_idx = min(end_layer_idx, num_hidden_layers - 1)
            dscr_key_value = dscr_key_value_eff

            if i == 0:
                print("DSCR ON")
                print("dscr_lambda:", args.dscr_lambda)
                print("dscr_keep:", args.dscr_keep)
                print("layers:", args.dscr_start_layer, args.dscr_end_layer)
            past_key_values_update = apply_D_to_cache(
                cache=kv_outputs.past_key_values,
                D_576_576=D,
                cut_idx=cut_idx,
                start_layer_idx=start_layer_idx,
                end_layer_idx=end_layer_idx,
                key_only=dscr_key_only_eff,
                value_only=dscr_value_only_eff,
                key_value=bool(dscr_key_value),
                dscr_lambda=float(args.dscr_lambda),
            )

            past_key_values_update, _ = recompute_first_logits_from_refined_cache(
                model=model,
                input_ids=input_ids,
                cache=past_key_values_update,
            )

        if args.use_opera and args.use_dscr:
            if key_position is None:
                key_position = {
                    "image_start": cut_idx,
                    "image_end": cut_idx + NUM_IMAGE_TOKENS - 1,
                    "response_start": input_ids.shape[1] + NUM_IMAGE_TOKENS - 1,
                }

            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=image_tensor.unsqueeze(0).half().cuda(),
                    images_cd=(image_tensor_cd.unsqueeze(0).half().cuda() if image_tensor_cd is not None else None),
                    cd_alpha=args.cd_alpha,
                    cd_beta=args.cd_beta,
                    do_sample=False,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    top_k=args.top_k,
                    max_new_tokens=200,
                    use_cache=True,
                    past_key_values=past_key_values_update,
                    key_position=key_position,
                    opera_decoding=args.use_opera,
                    num_beams=args.beam,
                    output_attentions=True,
                    scale_factor=args.scale_factor,
                    threshold=args.threshold,
                    num_attn_candidates=args.num_attn_candidates,
                    penalty_weights=args.penalty_weights,
                )
        elif args.use_opera:
            if key_position is None:
                key_position = {
                    "image_start": cut_idx,
                    "image_end": cut_idx + NUM_IMAGE_TOKENS - 1,
                    "response_start": input_ids.shape[1] + NUM_IMAGE_TOKENS - 1,
                }

            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=image_tensor.unsqueeze(0).half().cuda(),
                    images_cd=(image_tensor_cd.unsqueeze(0).half().cuda() if image_tensor_cd is not None else None),
                    cd_alpha=args.cd_alpha,
                    cd_beta=args.cd_beta,
                    do_sample=False,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    top_k=args.top_k,
                    max_new_tokens=200,
                    use_cache=True,
                    past_key_values=kv_outputs.past_key_values,
                    key_position=key_position,
                    opera_decoding=args.use_opera,
                    num_beams=args.beam,
                    output_attentions=True,
                    scale_factor=args.scale_factor,
                    threshold=args.threshold,
                    num_attn_candidates=args.num_attn_candidates,
                    penalty_weights=args.penalty_weights,
                )
        elif args.use_vcd and args.use_dscr:
            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=image_tensor.unsqueeze(0).half().cuda(),
                    images_cd=(image_tensor_cd.unsqueeze(0).half().cuda() if image_tensor_cd is not None else None),
                    cd_alpha=args.cd_alpha,
                    cd_beta=args.cd_beta,
                    do_sample=False,
                    temperature=0.0,
                    top_p=args.top_p,
                    top_k=args.top_k,
                    max_new_tokens=200,
                    use_cache=True,
                    past_key_values=past_key_values_update,
                )
        elif args.use_dscr:
            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=image_tensor.unsqueeze(0).half().cuda(),
                    images_cd=(image_tensor_cd.unsqueeze(0).half().cuda() if image_tensor_cd is not None else None),
                    cd_alpha=args.cd_alpha,
                    cd_beta=args.cd_beta,
                    do_sample=False,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    top_k=args.top_k,
                    max_new_tokens=200,
                    use_cache=True,
                    past_key_values=past_key_values_update,
                )
        else:
            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=image_tensor.unsqueeze(0).half().cuda(),
                    images_cd=(image_tensor_cd.unsqueeze(0).half().cuda() if image_tensor_cd is not None else None),
                    cd_alpha=args.cd_alpha,
                    cd_beta=args.cd_beta,
                    do_sample=False,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    top_k=args.top_k,
                    max_new_tokens=200,
                    use_cache=True,
                    stopping_criteria=[stopping_criteria],
                )

        input_token_len = input_ids.shape[1]
        n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
        if n_diff_input_output > 0:
            print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids")

        outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0].strip()
        if outputs.endswith(stop_str):
            outputs = outputs[: -len(stop_str)].strip()

        if isinstance(idx, int) and idx >= 1005:
            low = outputs.lower().strip()
            if low.startswith("y"):
                resp = "Yes"
            elif low.startswith("n"):
                resp = "No"
            else:
                resp = outputs
        else:
            resp = outputs

        responses.append({"id": idx, "response": resp})

    with open(answers_file, "w", encoding="utf-8") as f:
        json.dump(responses, f, ensure_ascii=False, indent=2)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--model-path", type=str, default="liuhaotian/llava-v1.5-13b")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="/path/to/data")
    parser.add_argument("--question-file", type=str, default="/path/to/data")
    parser.add_argument("--answers-file", type=str, default="./output/test/time_test.json")
    parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--chunk-idx", type=int, default=0)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--top_k", type=int, default=None)

    # vcd
    parser.add_argument("--use_vcd", action="store_true", default=False)
    parser.add_argument("--noise_step", type=int, default=500)
    parser.add_argument("--cd_alpha", type=float, default=1.0)
    parser.add_argument("--cd_beta", type=float, default=0.5)
    parser.add_argument("--seed", type=int, default=42)

    # prompt format
    parser.add_argument("--format", type=str, default="no_format", choices=["no_format", "ow_format", "yn_format"])

    # opera
    parser.add_argument("--use_opera", action="store_true", default=False)
    parser.add_argument("--gpu-id", type=int, default=0)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--num_workers", type=int, default=1)
    parser.add_argument("--beam", type=int, default=5)
    parser.add_argument("--sample", action="store_true")
    parser.add_argument("--scale_factor", type=float, default=50.0)
    parser.add_argument("--threshold", type=int, default=15)
    parser.add_argument("--num_attn_candidates", type=int, default=5)
    parser.add_argument("--penalty_weights", type=float, default=1.0)

    # dscr
    parser.add_argument("--use-dscr", action="store_true", default=False)
    parser.add_argument("--depth-folder", type=str, default="/path/to/data")
    parser.add_argument("--dscr-alpha", type=float, default=0.6)
    parser.add_argument("--dscr-beta", type=float, default=0.8)
    parser.add_argument("--dscr-sigma", type=float, default=0.6)
    parser.add_argument("--dscr-start-layer", type=int, default=10)
    parser.add_argument("--dscr-end-layer", type=int, default=40)
    parser.add_argument("--dscr-lambda", type=float, default=0.1)
    parser.add_argument("--dscr-keep", type=float, default=1.0)
    parser.add_argument("--dscr-key-only", type=str2bool, default=True)
    parser.add_argument("--dscr-value-only", type=str2bool, default=False)
    parser.add_argument("--dscr-key-value", type=str2bool, default=False)

    args = parser.parse_args()

    set_seed(args.seed)

    if args.format == "yn_format":
        args.format = " Please only answer yes or no."
    elif args.format == "ow_format":
        args.format = " Please answer this question with one word."
    else:
        args.format = ""

    print("amber_eval_llava.py")

    print("opera:T" if args.use_opera else "opera:F")
    print("vcd:T" if args.use_vcd else "vcd:F")
    print("dscr:T" if args.use_dscr else "dscr:F")

    eval_model(args)
