import argparse
import json
import os
import sys
from typing import Dict, List

import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
EXPERIMENTS_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ROOT_DIR)
sys.path.append(EXPERIMENTS_DIR)
sys.path.insert(0, os.path.join(EXPERIMENTS_DIR, "transformers-4.31.0", "src"))

from transformers import set_seed

from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from mplug_owl2.conversation import conv_templates, SeparatorStyle
from mplug_owl2.model.builder import load_pretrained_model
from mplug_owl2.utils import disable_torch_init
from mplug_owl2.mm_utils import tokenizer_image_token, get_model_name_from_path
from vcd_utils.vcd_add_noise import add_diffusion_noise
from vcd_utils.vcd_sample import evolve_vcd_sampling

NUM_IMAGE_TOKENS = 64


def recompute_first_logits_from_refined_cache(model, input_ids, cache):
    """Recompute first-token logits using refined KV cache."""
    total_keys = [kv[0] for kv in cache]
    total_values = [kv[1] for kv in cache]

    prefix_keys = [k[:, :, :-1, :].contiguous() for k in total_keys]
    prefix_values = [v[:, :, :-1, :].contiguous() for v in total_values]
    prefix_cache = tuple((k, v) for k, v in zip(prefix_keys, prefix_values))

    last_token_ids = input_ids[:, -1:]
    with torch.inference_mode():
        outputs_last = model(
            last_token_ids,
            images=None,
            use_cache=True,
            past_key_values=prefix_cache,
            return_dict=True,
        )
    return outputs_last.past_key_values, outputs_last.logits[:, -1, :]


def apply_D_to_cache(cache, D_64_64, cut_idx, start_layer_idx, end_layer_idx, key_only, value_only, key_value, dscr_lambda=1.0):
    if D_64_64.dim() == 4:
        D = D_64_64.squeeze(0).squeeze(0)
    elif D_64_64.dim() == 2:
        D = D_64_64
    else:
        raise ValueError(f"Unsupported D shape: {tuple(D_64_64.shape)}")

    start_layer_idx = int(start_layer_idx)
    end_layer_idx = int(end_layer_idx)
    dscr_lambda = float(max(0.0, min(1.0, dscr_lambda)))

    out = []
    D_cache = {}

    for layer_idx, (k, v) in enumerate(cache):
        if start_layer_idx <= layer_idx < end_layer_idx:
            if key_only or key_value:
                tag = (k.device, k.dtype)
                Dk = D_cache.get(tag)
                if Dk is None:
                    Dk = D.to(device=k.device, dtype=k.dtype)
                    D_cache[tag] = Dk

                seg = k[:, :, cut_idx:cut_idx + NUM_IMAGE_TOKENS, :]
                seg_refined = torch.einsum("ij,bhjd->bhid", Dk, seg)
                seg_mixed = (1.0 - dscr_lambda) * seg + dscr_lambda * seg_refined
                k = torch.cat(
                    (k[:, :, :cut_idx, :], seg_mixed, k[:, :, cut_idx + NUM_IMAGE_TOKENS:, :]),
                    dim=2,
                )

            if value_only or key_value:
                tag = (v.device, v.dtype)
                Dv = D_cache.get(tag)
                if Dv is None:
                    Dv = D.to(device=v.device, dtype=v.dtype)
                    D_cache[tag] = Dv

                seg = v[:, :, cut_idx:cut_idx + NUM_IMAGE_TOKENS, :]
                seg_refined = torch.einsum("ij,bhjd->bhid", Dv, seg)
                seg_mixed = (1.0 - dscr_lambda) * seg + dscr_lambda * seg_refined
                v = torch.cat(
                    (v[:, :, :cut_idx, :], seg_mixed, v[:, :, cut_idx + NUM_IMAGE_TOKENS:, :]),
                    dim=2,
                )

        out.append((k, v))

    return tuple(out)


def load_questions(path: str) -> List[Dict]:
    with open(path, "r", encoding="utf-8") as f:
        first = f.read(1)
        f.seek(0)
        if first == "[":
            return json.load(f)
        return [json.loads(line) for line in f if line.strip()]


def discover_sets(gt_dir: str) -> List[str]:
    sets = []
    for f in os.listdir(gt_dir):
        if f.endswith(".json") or f.endswith(".jsonl"):
            sets.append(f.rsplit(".", 1)[0])
    return sorted(list(set(sets)))


def normalize_format(fmt: str) -> str:
    if fmt == "yn_format":
        return " Please only answer yes or no."
    if fmt == "ow_format":
        return " Please answer this question with one word."
    return ""


def normalize_yn_answer(text: str) -> str:
    cleaned = text.strip().lower()
    if "yes" in cleaned:
        return "yes"
    if "no" in cleaned:
        return "no"
    return cleaned


def build_prompt(qs: str, conv_mode: str, fmt: str) -> str:
    qs_mm = DEFAULT_IMAGE_TOKEN + "\n" + qs
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], qs_mm + fmt)
    conv.append_message(conv.roles[1], None)
    return conv.get_prompt(), conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2


def main() -> None:
    parser = argparse.ArgumentParser("POPE (mPLUG, one-load multi-set runner)")
    parser.add_argument("--model-path", type=str, default="MAGAer13/mplug-owl2-llama2-7b")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--gt-dir", type=str, required=True)
    parser.add_argument("--image-root", type=str, required=True)
    parser.add_argument("--out-root", type=str, required=True)
    parser.add_argument("--depth-root", type=str, default=None)
    parser.add_argument("--datasets", nargs="*", default=None)

    parser.add_argument("--method", type=str, default="baseline", choices=["baseline", "vcd", "opera"])
    parser.add_argument("--conv-mode", type=str, default="mplug_owl2")
    parser.add_argument("--format", type=str, default="yn_format", choices=["no_format", "ow_format", "yn_format"])
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max-new-tokens", type=int, default=2)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--top-p", dest="top_p", type=float, default=1.0)
    parser.add_argument("--top-k", dest="top_k", type=int, default=0)

    # VCD
    parser.add_argument("--noise-step", type=int, default=500)
    parser.add_argument("--cd-alpha", type=float, default=0.5)
    parser.add_argument("--cd-beta", type=float, default=0.1)

    # OPERA
    parser.add_argument("--beam", type=int, default=5)
    parser.add_argument("--scale-factor", type=float, default=50.0)
    parser.add_argument("--threshold", type=int, default=15)
    parser.add_argument("--num-attn-candidates", type=int, default=5)
    parser.add_argument("--penalty-weights", type=float, default=1.0)

    # DSCR
    parser.add_argument("--use-dscr", action="store_true", default=False)
    parser.add_argument("--dscr-alpha", type=float, default=0.6)
    parser.add_argument("--dscr-beta", type=float, default=0.8)
    parser.add_argument("--dscr-sigma", type=float, default=0.6)
    parser.add_argument("--dscr-lambda", type=float, default=1.0)
    parser.add_argument("--dscr-start-layer", type=int, default=8)
    parser.add_argument("--dscr-end-layer", type=int, default=24)
    parser.add_argument("--dscr-key-only", action="store_true", default=False)
    parser.add_argument("--dscr-value-only", action="store_true", default=False)
    parser.add_argument("--dscr-key-value", action="store_true", default=False)

    parser.add_argument("--run-name", type=str, default=None)
    parser.add_argument("--do-sample", action="store_true", default=False,
                        help="Enable sampling (do_sample=True) for baseline method")
    args = parser.parse_args()

    set_seed(args.seed)
    fmt = normalize_format(args.format)

    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, _ = load_pretrained_model(model_path, args.model_base, model_name)
    device = next(model.parameters()).device

    if args.method == "vcd":
        evolve_vcd_sampling()

    if args.use_dscr and not args.depth_root:
        raise ValueError("--depth-root is required when --use-dscr is set")

    if args.datasets:
        datasets = args.datasets
    else:
        datasets = discover_sets(args.gt_dir)

    run_name = args.run_name or f"mplug_pope_{args.method}_seed{args.seed}"
    os.makedirs(args.out_root, exist_ok=True)

    dscr_key_value_eff = args.dscr_key_value or (not args.dscr_key_only and not args.dscr_value_only)

    for dataset in datasets:
        qfile_json = os.path.join(args.gt_dir, f"{dataset}.json")
        qfile_jsonl = os.path.join(args.gt_dir, f"{dataset}.jsonl")
        qfile = qfile_json if os.path.isfile(qfile_json) else qfile_jsonl
        if not os.path.isfile(qfile):
            print(f"[skip] missing questions for {dataset}")
            continue

        questions = load_questions(qfile)
        answers_file = os.path.join(args.out_root, f"{run_name}_{dataset}.jsonl")

        with open(answers_file, "w", encoding="utf-8") as fout:
            for q in tqdm(questions, desc=f"{dataset}"):
                qid = q.get("question_id", q.get("id"))
                image_file = q.get("image", "")
                text = q.get("text", q.get("question", ""))

                prompt, stop_str = build_prompt(text, args.conv_mode, fmt)
                input_ids = tokenizer_image_token(
                    prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
                ).unsqueeze(0).to(device)

                image_token_pos = torch.where(input_ids[0] == IMAGE_TOKEN_INDEX)[0]
                if image_token_pos.numel() == 0:
                    raise RuntimeError("IMAGE_TOKEN_INDEX not found in input_ids")
                cut_idx = int(image_token_pos[0].item())

                image = Image.open(os.path.join(args.image_root, image_file)).convert("RGB")
                image_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
                image_bchw = image_tensor.unsqueeze(0).to(device=device, dtype=model.dtype)

                image_bchw_cd = None
                if args.method == "vcd":
                    image_tensor_cd = add_diffusion_noise(image_tensor, int(args.noise_step))
                    image_bchw_cd = image_tensor_cd.unsqueeze(0).to(device=device, dtype=model.dtype)

                with torch.inference_mode():
                    kv_outputs = model(
                        input_ids=input_ids,
                        images=image_bchw,
                        use_cache=True,
                        return_dict=True,
                    )

                cache_for_gen = kv_outputs.past_key_values

                if args.use_dscr:
                    depth_path = os.path.join(args.depth_root, os.path.splitext(image_file)[0] + ".npy")
                    depth_np = np.load(depth_path, allow_pickle=True)
                    depth_tensor = torch.tensor(depth_np, dtype=torch.float32).unsqueeze(0).unsqueeze(1)
                    depth_patch = torch.nn.functional.interpolate(
                        depth_tensor, size=(8, 8), mode="bilinear", align_corners=False
                    )
                    depth_patch = torch.clamp(1.0e-6 + 1.0 / depth_patch, 0.001, 1000).view(-1)
                    depth_patch = (depth_patch - depth_patch.min()) / (depth_patch.max() - depth_patch.min() + 1e-6)

                    depth_diff = torch.abs(depth_patch.unsqueeze(0) - depth_patch.unsqueeze(1))
                    gaussian_weight_depth = torch.exp(-(depth_diff ** 2) / (2 * args.dscr_sigma ** 2 + 1e-12))

                    pixel_positions = torch.tensor(
                        [[i // 8, i % 8] for i in range(NUM_IMAGE_TOKENS)], dtype=torch.float32
                    ) / 7.0
                    position_diff = torch.cdist(pixel_positions, pixel_positions, p=2)
                    gaussian_weight_position = torch.exp(-(position_diff ** 2) / (2 * args.dscr_sigma ** 2 + 1e-12))

                    gaussian_weight = (gaussian_weight_depth ** args.dscr_alpha) + (gaussian_weight_position ** args.dscr_beta)
                    D = gaussian_weight / (gaussian_weight.sum(dim=-1, keepdim=True) + 1e-6)
                    D = D.to(dtype=kv_outputs.past_key_values[0][0].dtype, device=kv_outputs.past_key_values[0][0].device)
                    D = D.unsqueeze(0).unsqueeze(0)

                    num_hidden_layers = len(kv_outputs.past_key_values)
                    end_layer_idx = min(int(args.dscr_end_layer), num_hidden_layers - 1)
                    cache_for_gen = apply_D_to_cache(
                        cache=kv_outputs.past_key_values,
                        D_64_64=D,
                        cut_idx=cut_idx,
                        start_layer_idx=int(args.dscr_start_layer),
                        end_layer_idx=end_layer_idx,
                        key_only=bool(args.dscr_key_only),
                        value_only=bool(args.dscr_value_only),
                        key_value=bool(dscr_key_value_eff),
                        dscr_lambda=float(args.dscr_lambda),
                    )

                    cache_for_gen, _ = recompute_first_logits_from_refined_cache(
                        model=model,
                        input_ids=input_ids,
                        cache=cache_for_gen,
                    )

                gen_kwargs = dict(
                    input_ids=input_ids,
                    images=image_bchw,
                    do_sample=False,
                    temperature=(0.0 if args.method == "vcd" else float(args.temperature)),
                    top_p=float(args.top_p),
                    top_k=args.top_k,
                    max_new_tokens=args.max_new_tokens,
                    use_cache=True,
                    past_key_values=cache_for_gen,
                )

                if args.method == "vcd":
                    gen_kwargs.update(
                        dict(
                            images_cd=image_bchw_cd,
                            cd_alpha=args.cd_alpha,
                            cd_beta=args.cd_beta,
                        )
                    )
                elif args.method == "opera":
                    key_position = {
                        "image_start": cut_idx,
                        "image_end": cut_idx + NUM_IMAGE_TOKENS - 1,
                        "response_start": input_ids.shape[1] + NUM_IMAGE_TOKENS - 1,
                    }
                    gen_kwargs.update(
                        dict(
                            key_position=key_position,
                            opera_decoding=True,
                            num_beams=args.beam,
                            output_attentions=True,
                            scale_factor=args.scale_factor,
                            threshold=args.threshold,
                            num_attn_candidates=args.num_attn_candidates,
                            penalty_weights=args.penalty_weights,
                            do_sample=False,
                        )
                    )

                with torch.inference_mode():
                    output_ids = model.generate(**gen_kwargs)

                input_len = input_ids.shape[1]
                output = tokenizer.batch_decode(output_ids[:, input_len:], skip_special_tokens=True)[0].strip()
                if args.format == "yn_format":
                    output = normalize_yn_answer(output)
                if output.endswith(stop_str):
                    output = output[:-len(stop_str)].strip()

                fout.write(
                    json.dumps(
                        {
                            "question_id": qid,
                            "prompt": prompt,
                            "text": output,
                            "model_id": model_name,
                            "image": image_file,
                            "metadata": {},
                        }
                    )
                    + "\n"
                )


if __name__ == "__main__":
    main()
