import argparse
import json
import os
import random
import sys
from typing import List, Dict

import torch
import numpy as np
from PIL import Image
from tqdm import tqdm

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
EXPERIMENTS_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ROOT_DIR)
sys.path.append(EXPERIMENTS_DIR)
sys.path.insert(0, os.path.join(EXPERIMENTS_DIR, "transformers-4.31.0", "src"))

from transformers import set_seed

from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from mplug_owl2.conversation import conv_templates
from mplug_owl2.model.builder import load_pretrained_model
from mplug_owl2.utils import disable_torch_init
from mplug_owl2.mm_utils import tokenizer_image_token, get_model_name_from_path

from vcd_utils.vcd_add_noise import add_diffusion_noise
from vcd_utils.vcd_sample import evolve_vcd_sampling

NUM_IMAGE_TOKENS = 576


def apply_D_to_cache(cache, D_576_576, cut_idx, start_layer_idx, end_layer_idx, key_only, value_only, key_value):
    """Copied from object_hallucination_vqa_llava4.py, adapted for variable image token length"""
    if D_576_576.dim() == 4:
        D = D_576_576.squeeze(0).squeeze(0)
    elif D_576_576.dim() == 2:
        D = D_576_576
    else:
        raise ValueError(f"Unsupported D shape: {tuple(D_576_576.shape)}")

    start_layer_idx = int(start_layer_idx)
    end_layer_idx = int(end_layer_idx)

    # Get actual image token length from cache
    actual_image_tokens = NUM_IMAGE_TOKENS
    if len(cache) > 0:
        k_first, v_first = cache[0]
        actual_seq_len = k_first.shape[2]
        actual_image_tokens = min(NUM_IMAGE_TOKENS, actual_seq_len - cut_idx)
        
        # Check if D matrix size matches
        if D.shape[0] != actual_image_tokens or D.shape[1] != actual_image_tokens:
            # Resize D matrix to match actual image tokens
            import torch.nn.functional as F
            if D.shape[0] != actual_image_tokens:
                D = F.interpolate(
                    D.unsqueeze(0).unsqueeze(0),
                    size=(actual_image_tokens, actual_image_tokens),
                    mode="bilinear",
                    align_corners=False
                ).squeeze(0).squeeze(0)

    out = []
    D_cache = {}

    for layer_idx, (k, v) in enumerate(cache):
        if start_layer_idx <= layer_idx < end_layer_idx:
            if key_only or key_value:
                tag = (k.device, k.dtype)
                Dk = D_cache.get(tag)
                if Dk is None:
                    Dk = D.to(device=k.device, dtype=k.dtype)
                    D_cache[tag] = Dk

                # Get actual segment size
                actual_seg_size = min(actual_image_tokens, k.shape[2] - cut_idx)
                seg = k[:, :, cut_idx:cut_idx + actual_seg_size, :]
                
                # Ensure Dk matches seg size
                if Dk.shape[0] != actual_seg_size or Dk.shape[1] != actual_seg_size:
                    import torch.nn.functional as F
                    Dk = F.interpolate(
                        Dk.unsqueeze(0).unsqueeze(0),
                        size=(actual_seg_size, actual_seg_size),
                        mode="bilinear",
                        align_corners=False
                    ).squeeze(0).squeeze(0)
                    D_cache[tag] = Dk
                
                seg2 = torch.einsum("ij,bhjd->bhid", Dk, seg)
                k = torch.cat(
                    (k[:, :, :cut_idx, :], seg2, k[:, :, cut_idx + actual_seg_size:, :]),
                    dim=2
                )

            if value_only or key_value:
                tag = (v.device, v.dtype)
                Dv = D_cache.get(tag)
                if Dv is None:
                    Dv = D.to(device=v.device, dtype=v.dtype)
                    D_cache[tag] = Dv

                # Get actual segment size
                actual_seg_size = min(actual_image_tokens, v.shape[2] - cut_idx)
                seg = v[:, :, cut_idx:cut_idx + actual_seg_size, :]
                
                # Ensure Dv matches seg size
                if Dv.shape[0] != actual_seg_size or Dv.shape[1] != actual_seg_size:
                    import torch.nn.functional as F
                    Dv = F.interpolate(
                        Dv.unsqueeze(0).unsqueeze(0),
                        size=(actual_seg_size, actual_seg_size),
                        mode="bilinear",
                        align_corners=False
                    ).squeeze(0).squeeze(0)
                    D_cache[tag] = Dv
                
                seg2 = torch.einsum("ij,bhjd->bhid", Dv, seg)
                v = torch.cat(
                    (v[:, :, :cut_idx, :], seg2, v[:, :, cut_idx + actual_seg_size:, :]),
                    dim=2
                )

        out.append((k, v))

    return tuple(out)


def recompute_first_logits_from_refined_cache(model, input_ids, cache):
    """Copied from object_hallucination_vqa_llava4.py"""
    total_keys = [kv[0] for kv in cache]
    total_values = [kv[1] for kv in cache]

    prefix_keys = [k[:, :, :-1, :].contiguous() for k in total_keys]
    prefix_values = [v[:, :, :-1, :].contiguous() for v in total_values]
    prefix_cache = tuple((k, v) for k, v in zip(prefix_keys, prefix_values))

    last_token_ids = input_ids[:, -1:]
    with torch.inference_mode():
        outputs_last = model(
            last_token_ids,
            images=None,
            use_cache=True,
            past_key_values=prefix_cache,
            return_dict=True,
        )
    return outputs_last.past_key_values, outputs_last.logits[:, -1, :]


def build_prompt(conv_mode: str) -> str:
    qs = DEFAULT_IMAGE_TOKEN + "\n"
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    return conv.get_prompt()


def load_coco_images(ann_file: str, seed: int, num_samples: int) -> List[Dict]:
    with open(ann_file, "r") as f:
        anns = json.load(f)
    images = anns.get("images", [])
    rng = random.Random(seed)
    rng.shuffle(images)
    if num_samples > 0:
        images = images[:num_samples]
    return images


def main() -> None:
    parser = argparse.ArgumentParser("CHAIR (mPLUG) VCD/OPERA")
    parser.add_argument("--model-path", type=str, default="MAGAer13/mplug-owl2-llama2-7b")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--data-path", type=str, default="/path/to/data")
    parser.add_argument("--coco-ann-file", type=str, required=True)
    parser.add_argument("--answers-file", type=str, required=True)
    parser.add_argument("--conv-mode", type=str, default="mplug_owl2")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num-samples", type=int, default=500)
    parser.add_argument("--max-new-tokens", type=int, default=256)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top-p", dest="top_p", type=float, default=1.0)
    parser.add_argument("--top-k", dest="top_k", type=int, default=None)

    # VCD
    parser.add_argument("--use-vcd", action="store_true", default=False)
    parser.add_argument("--noise-step", type=int, default=500)
    parser.add_argument("--cd-alpha", type=float, default=1.0)
    parser.add_argument("--cd-beta", type=float, default=0.1)

    # OPERA
    parser.add_argument("--use-opera", action="store_true", default=False)
    parser.add_argument("--beam", type=int, default=5)
    parser.add_argument("--scale-factor", "--scale_factor", dest="scale_factor", type=float, default=50.0)
    parser.add_argument("--threshold", type=int, default=15)
    parser.add_argument("--num-attn-candidates", "--num_attn_candidates", dest="num_attn_candidates", type=int, default=5)
    parser.add_argument("--penalty-weights", "--penalty_weights", dest="penalty_weights", type=float, default=1.0)

    # DSCR
    parser.add_argument("--use-dscr", action="store_true", default=False)
    parser.add_argument("--depth-folder", type=str, default="/path/to/data")
    parser.add_argument("--dscr-alpha", type=float, default=0.4)
    parser.add_argument("--dscr-beta", type=float, default=0.2)
    parser.add_argument("--dscr-sigma", type=float, default=0.2)
    parser.add_argument("--dscr-keep-ratio", type=float, default=1.0)
    parser.add_argument("--dscr-lambda", type=float, default=0.01)
    parser.add_argument("--dscr-start-layer", type=int, default=8)
    parser.add_argument("--dscr-end-layer", type=int, default=24)
    parser.add_argument("--dscr-key-only", action="store_true", default=False)
    parser.add_argument("--dscr-value-only", action="store_true", default=False)
    parser.add_argument("--dscr-key-value", action="store_true", default=False)

    args = parser.parse_args()
    set_seed(args.seed)
    disable_torch_init()

    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, _ = load_pretrained_model(
        model_path, args.model_base, model_name
    )
    device = next(model.parameters()).device

    if args.use_vcd:
        evolve_vcd_sampling()

    # Print configuration for debugging
    print("="*80)
    print("[CHAIR mPLUG CONFIG]")
    print(f"  Model: {model_name}")
    print(f"  Method: {'VCD' if args.use_vcd else 'OPERA' if args.use_opera else 'BASELINE'}")
    if args.use_vcd:
        print(f"  VCD: noise_step={args.noise_step} cd_alpha={args.cd_alpha} cd_beta={args.cd_beta}")
    if args.use_opera:
        print(f"  OPERA: beam={args.beam} scale_factor={args.scale_factor} threshold={args.threshold} "
              f"num_attn_candidates={args.num_attn_candidates} penalty_weights={args.penalty_weights}")
    if args.use_dscr:
        dscr_key_value = args.dscr_key_value or (not args.dscr_key_only and not args.dscr_value_only)
        print(f"  DSCR: alpha={args.dscr_alpha} beta={args.dscr_beta} sigma={args.dscr_sigma} "
              f"keep_ratio={args.dscr_keep_ratio} lambda={args.dscr_lambda} "
              f"start_layer={args.dscr_start_layer} end_layer={args.dscr_end_layer} "
              f"key_only={args.dscr_key_only} value_only={args.dscr_value_only} key_value={dscr_key_value}")
    print(f"  Temperature: {args.temperature}, top_p: {args.top_p}, top_k: {args.top_k}")
    print(f"  Max new tokens: {args.max_new_tokens}, Seed: {args.seed}")
    print("="*80)

    images = load_coco_images(args.coco_ann_file, args.seed, args.num_samples)

    out_dir = os.path.dirname(args.answers_file)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)

    prompt = build_prompt(args.conv_mode)

    with open(args.answers_file, "w") as f:
        for info in tqdm(images):
            image_id = info["id"]
            image_file = info["file_name"]
            img_path = os.path.join(args.data_path, image_file)
            image = Image.open(img_path).convert("RGB")
            image_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
            input_ids = tokenizer_image_token(
                prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
            ).unsqueeze(0).to(device)

            pos = torch.where(input_ids[0] == IMAGE_TOKEN_INDEX)[0]
            if pos.numel() == 0:
                raise RuntimeError("IMAGE_TOKEN_INDEX not found in input_ids")
            cut_idx = int(pos[0].item())

            image_bchw = image_tensor.unsqueeze(0).to(device=device, dtype=model.dtype)
            image_bchw_cd = None
            if args.use_vcd:
                image_tensor_cd = add_diffusion_noise(image_tensor, int(args.noise_step))
                image_bchw_cd = image_tensor_cd.unsqueeze(0).to(device=device, dtype=model.dtype)

            cache_clean = None
            if args.use_dscr:
                # Build D matrix (copied from chair_mplug.py - use 8x8=64 tokens)
                import torch.nn.functional as F
                depth_path = os.path.join(args.depth_folder, os.path.splitext(image_file)[0] + ".npy")
                depth_np = np.load(depth_path, allow_pickle=True)
                depth_tensor = torch.tensor(depth_np, dtype=torch.float32)
                
                # Match chair_mplug.py: ensure depth_tensor is (1, 1, H, W)
                if depth_tensor.dim() == 2:
                    depth_tensor = depth_tensor.unsqueeze(0).unsqueeze(0)
                elif depth_tensor.dim() == 3:
                    if depth_tensor.shape[0] == 1:
                        depth_tensor = depth_tensor.unsqueeze(1)
                    else:
                        depth_tensor = depth_tensor.unsqueeze(0)
                elif depth_tensor.dim() == 4:
                    pass
                else:
                    raise ValueError(f"Unexpected depth_tensor dim: {depth_tensor.shape}")
                
                # Match chair_mplug.py: interpolate to 8x8 (64 tokens)
                depth_patch = F.interpolate(depth_tensor, size=(8, 8), mode="bilinear", align_corners=False)
                depth_patch = torch.clamp(1.0e-6 + 1.0 / depth_patch, 0.001, 1000).view(1, -1)
                
                depth_patch_min = depth_patch.min()
                depth_patch_max = depth_patch.max()
                depth_patch = (depth_patch - depth_patch_min) / (depth_patch_max - depth_patch_min + 1e-6)
                
                # Match chair_mplug.py: use addition, not multiplication
                depth_diff = depth_patch - depth_patch.transpose(1, 0)
                gaussian_weight_depth = torch.exp(-(depth_diff ** 2) / (2 * (args.dscr_sigma ** 2) + 1e-12))
                
                pixel_positions = torch.tensor([[i // 8, i % 8] for i in range(64)], dtype=torch.float32) / 7.0
                position_diff = torch.cdist(pixel_positions, pixel_positions, p=2)
                gaussian_weight_position = torch.exp(-(position_diff ** 2) / (2 * (args.dscr_sigma ** 2) + 1e-12))
                
                # Match chair_mplug.py: addition instead of multiplication
                gaussian_weight = (gaussian_weight_depth ** args.dscr_alpha) + (gaussian_weight_position ** args.dscr_beta)
                
                # Match chair_mplug.py: self_keep handling
                self_keep = float(args.dscr_keep_ratio)
                self_keep = max(0.0, min(1.0, self_keep))
                T_img = gaussian_weight.shape[0]
                eye = torch.eye(T_img, device=gaussian_weight.device, dtype=gaussian_weight.dtype)
                if self_keep < 1.0:
                    gaussian_weight = gaussian_weight * (1.0 - eye) + self_keep * eye
                
                row_sums = gaussian_weight.sum(dim=-1, keepdim=True)
                zero_row_mask = (row_sums < 1e-12)
                if zero_row_mask.any():
                    gaussian_weight[zero_row_mask.squeeze(-1)] = 1.0
                    row_sums = gaussian_weight.sum(dim=-1, keepdim=True)
                
                D = gaussian_weight / (row_sums + 1e-6)  # (64, 64)
                
                # Forward to get initial cache
                attention_mask_clean = torch.ones_like(input_ids, device=input_ids.device)
                with torch.inference_mode():
                    kv_out = model(
                        input_ids=input_ids,
                        images=image_bchw,
                        attention_mask=attention_mask_clean,
                        use_cache=True,
                        return_dict=True,
                    )
                
                # Match chair_mplug.py: use 64 tokens and dscr_lambda for mixing
                dtype = kv_out.past_key_values[0][0].dtype
                device = kv_out.past_key_values[0][0].device
                D = D.to(dtype=dtype, device=device)
                
                start_layer_idx = int(args.dscr_start_layer)
                end_layer_idx = int(args.dscr_end_layer)
                num_image_tokens = 64  # 8x8, match chair_mplug.py
                
                total_keys = [kv[0] for kv in kv_out.past_key_values]
                total_values = [kv[1] for kv in kv_out.past_key_values]
                
                device_key = total_keys[0].device
                total_keys = [key.to(device_key) for key in total_keys]
                device_value = total_values[0].device
                total_values = [val.to(device_value) for val in total_values]
                
                dscr_lambda = float(args.dscr_lambda)
                dscr_lambda = max(0.0, min(1.0, dscr_lambda))
                
                dscr_key_value = args.dscr_key_value or (
                    not args.dscr_key_only and not args.dscr_value_only
                )
                
                # Key update (match chair_mplug.py)
                if args.dscr_key_only or dscr_key_value:
                    updated_keys_orig = torch.stack(total_keys, dim=0)  # (L, B, H, T, head_dim)
                    seg_keys = updated_keys_orig[
                        start_layer_idx:end_layer_idx,
                        0,
                        :,
                        cut_idx:cut_idx + num_image_tokens,
                        :
                    ]  # (L', H, 64, head_dim)
                    
                    remaining_part_1 = updated_keys_orig[start_layer_idx:end_layer_idx, 0, :, :cut_idx, :]
                    remaining_part_2 = updated_keys_orig[start_layer_idx:end_layer_idx, 0, :, cut_idx + num_image_tokens:, :]
                    
                    # Match chair_mplug.py: use matmul and dscr_lambda mixing
                    refined_keys = torch.matmul(D, seg_keys)
                    mixed_keys = (1.0 - dscr_lambda) * seg_keys + dscr_lambda * refined_keys
                    
                    remaining_part_1 = remaining_part_1.unsqueeze(1)
                    remaining_part_2 = remaining_part_2.unsqueeze(1)
                    mixed_keys = mixed_keys.unsqueeze(1)
                    
                    updated_total_keys = torch.cat((remaining_part_1, mixed_keys, remaining_part_2), dim=3)
                    total_keys[start_layer_idx:end_layer_idx] = [k for k in updated_total_keys]
                
                # Value update (match chair_mplug.py)
                if args.dscr_value_only or dscr_key_value:
                    updated_values_orig = torch.stack(total_values, dim=0)
                    seg_values = updated_values_orig[
                        start_layer_idx:end_layer_idx,
                        0,
                        :,
                        cut_idx:cut_idx + num_image_tokens,
                        :
                    ]
                    
                    remaining_part_3 = updated_values_orig[start_layer_idx:end_layer_idx, 0, :, :cut_idx, :]
                    remaining_part_4 = updated_values_orig[start_layer_idx:end_layer_idx, 0, :, cut_idx + num_image_tokens:, :]
                    
                    # Match chair_mplug.py: use matmul and dscr_lambda mixing
                    refined_values = torch.matmul(D, seg_values)
                    mixed_values = (1.0 - dscr_lambda) * seg_values + dscr_lambda * refined_values
                    
                    remaining_part_3 = remaining_part_3.unsqueeze(1)
                    remaining_part_4 = remaining_part_4.unsqueeze(1)
                    mixed_values = mixed_values.unsqueeze(1)
                    
                    updated_total_values = torch.cat((remaining_part_3, mixed_values, remaining_part_4), dim=3)
                    total_values[start_layer_idx:end_layer_idx] = [v for v in updated_total_values]
                
                cache_clean = tuple((k, v) for k, v in zip(total_keys, total_values))

            is_baseline = (not args.use_vcd) and (not args.use_opera) and (not args.use_dscr)
            is_dscr_only = args.use_dscr and (not args.use_vcd) and (not args.use_opera)
            
            # DSCR only: match chair_mplug.py behavior (do_sample=False, no temperature)
            if is_dscr_only:
                do_sample = False
                gen_kwargs = dict(
                    input_ids=input_ids,
                    images=image_bchw,
                    do_sample=False,
                    max_new_tokens=1024,  # Match chair_mplug.py
                    use_cache=True,
                    past_key_values=cache_clean,
                )
            else:
                do_sample = True if (is_baseline or args.use_vcd) else False
                gen_kwargs = dict(
                    input_ids=input_ids,
                    images=image_bchw,
                    do_sample=do_sample,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    top_k=args.top_k,
                    max_new_tokens=args.max_new_tokens,
                )
                if cache_clean is not None:
                    gen_kwargs.update(
                        dict(
                            use_cache=True,
                            past_key_values=cache_clean,
                        )
                    )

            if args.use_vcd:
                gen_kwargs.update(
                    dict(
                        images_cd=image_bchw_cd,
                        cd_alpha=args.cd_alpha,
                        cd_beta=args.cd_beta,
                    )
                )

            if args.use_opera:
                key_position = {
                    "image_start": cut_idx,
                    "image_end": cut_idx + 576 - 1,
                    "response_start": input_ids.shape[1] + 576 - 1,
                }
                gen_kwargs.update(
                    dict(
                        key_position=key_position,
                        opera_decoding=True,
                        num_beams=args.beam,
                        output_attentions=True,
                        scale_factor=args.scale_factor,
                        threshold=args.threshold,
                        num_attn_candidates=args.num_attn_candidates,
                        penalty_weights=args.penalty_weights,
                        do_sample=False,
                    )
                )

            with torch.inference_mode():
                output_ids = model.generate(**gen_kwargs)

            decoded = tokenizer.decode(
                output_ids[0, input_ids.shape[1]:], skip_special_tokens=True
            ).strip()

            json.dump({"image_id": image_id, "caption": decoded}, f)
            f.write("\n")


if __name__ == "__main__":
    main()
