import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid
import math

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model, load_pretrained_model_both
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from torch.utils.data import Dataset, DataLoader
from peft import PeftModel

from PIL import Image
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
import re

# --------------------
# Helpers for batching
# --------------------

def split_list(lst, n):
    chunk_size = math.ceil(len(lst) / n)
    return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


# --------------------
# Dataset
# --------------------
class CustomDataset(Dataset):
    def __init__(self, questions, image_folder, tokenizer, image_processor, model_config, conv_mode):
        self.questions = questions
        self.image_folder = image_folder
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.model_config = model_config
        self.conv_mode = conv_mode

    def __getitem__(self, index):
        line = self.questions[index]
        image_file = line["image"]
        qs = line["text"]

        if self.model_config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
        image_tensor = process_images([image], self.image_processor, self.model_config)[0]

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')

        return input_ids, image_tensor, image.size, line, image_file

    def __len__(self):
        return len(self.questions)


def collate_fn(batch):
    input_ids, image_tensors, image_sizes, lines, image_files = zip(*batch)
    input_ids = torch.stack(input_ids, dim=0)
    image_tensors = torch.stack(image_tensors, dim=0)
    return input_ids, image_tensors, image_sizes, lines, image_files


def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, conv_mode, batch_size=1, num_workers=1):
    assert batch_size == 1, "batch_size must be 1 for attention alignment"
    dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config, conv_mode)
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
    return data_loader


# --------------------
# Attention utilities
# --------------------

def _get_device(model):
    return next(model.parameters()).device


def aggregate_llm_attentions(attentions, layer_agg="mean", head_agg="mean"):
    attn = torch.stack(attentions, dim=0)
    if head_agg == "mean":
        attn = attn.mean(dim=2)
    elif head_agg == "max":
        attn = attn.max(dim=2).values
    else:
        raise ValueError("head_agg must be 'mean' or 'max'")

    if layer_agg == "mean":
        attn = attn.mean(dim=0)
    elif layer_agg == "last":
        attn = attn[-1]
    else:
        raise ValueError("layer_agg must be 'mean' or 'last'")
    return attn


def get_token_id_safe(tokenizer, token_str):
    try:
        _id = tokenizer.convert_tokens_to_ids(token_str)
        if isinstance(_id, list):
            _id = _id[0]
        if _id is None:
            return None
        return int(_id)
    except Exception:
        return None


def find_vision_start_index(input_ids_row, tokenizer):
    ids = input_ids_row.tolist()
    try:
        if IMAGE_TOKEN_INDEX in ids:
            return ids.index(IMAGE_TOKEN_INDEX)
    except Exception:
        pass
    image_id = get_token_id_safe(tokenizer, DEFAULT_IMAGE_TOKEN)
    if image_id is not None and image_id in ids:
        return ids.index(image_id)
    im_start_id = get_token_id_safe(tokenizer, DEFAULT_IM_START_TOKEN)
    if im_start_id is not None and im_start_id in ids:
        return ids.index(im_start_id)
    raise RuntimeError("Cannot locate image placeholder span in input_ids.")


def compute_lang_to_image_maps(model, sequences, prompt_input_ids, tokenizer, image_tensor, image_sizes, num_vis_tokens, grid_size, layer_agg="mean", head_agg="mean", normalize=True, virtual_prompt_len=0):
    device = _get_device(model)
    with torch.inference_mode():
        out = model(
            input_ids=sequences.to(device),
            images=image_tensor.to(device, dtype=torch.float16),
            image_sizes=image_sizes,
            use_cache=False,
            output_attentions=True,
            return_dict=True,
            trace=True,  # Enable tracing for debugging
        )

    if not hasattr(out, "attentions") or out.attentions is None:
        raise RuntimeError("Model output does not include attentions. Ensure the model supports output_attentions=True.")

    # shape: [seq_len, seq_len]
    llm_attn = aggregate_llm_attentions(out.attentions, layer_agg=layer_agg, head_agg=head_agg)[0]

    vision_token_start_rel = find_vision_start_index(prompt_input_ids[0], tokenizer)
    # vision_token_start = virtual_prompt_len + vision_token_start_rel
    vision_token_start = vision_token_start_rel
    vision_token_end = vision_token_start + num_vis_tokens + 128

    gen_len = sequences.shape[1] - prompt_input_ids.shape[1]
    if gen_len <= 0:
        raise RuntimeError(f"Invalid generation length computed: gen_len={gen_len}. sequences_len={sequences.shape[1]}, prompt_len={prompt_input_ids.shape[1]}")

    q_start = llm_attn.shape[0] - gen_len

    token_maps = []
    Gh, Gw = grid_size

    for t in range(gen_len):
        q_idx = q_start + t
        w = llm_attn[q_idx, vision_token_start:vision_token_end]
        if normalize:
            w = w.clamp_min(0)
            denom = w.sum().clamp_min(1e-6)
            w = w / denom
        if w.shape[0] != num_vis_tokens:
            w = F.interpolate(w.view(1,1,-1), size=num_vis_tokens, mode="linear", align_corners=False).view(-1)
        lang_map = w.view(Gh, Gw)
        if normalize:
            s = lang_map.sum().clamp_min(1e-6)
            lang_map = lang_map / s
        token_maps.append(lang_map.detach().cpu())

    if len(token_maps) == 0:
        raise RuntimeError("No token maps computed. Likely mismatch in sequence vs attention alignment.")

    # debug print
    print(f"[DEBUG] cross-attn maps: seq_len={llm_attn.shape[0]}, vis_start={vision_token_start}, vis_end={vision_token_end}, gen_len={gen_len}, num_vis_tokens={num_vis_tokens}")

    return token_maps


def upsample_map_to_image(attn_map, image_hw):
    Gh, Gw = attn_map.shape
    H, W = image_hw
    grid = attn_map.unsqueeze(0).unsqueeze(0)
    grid = F.interpolate(grid, size=(H, W), mode='bilinear', align_corners=False)[0, 0]
    grid = (grid - grid.min()) / (grid.max() - grid.min() + 1e-6)
    return grid


def overlay_heatmap_on_image(image_path, attn_map, out_path):
    img = Image.open(image_path).convert("RGB")
    img = np.array(img)
    H, W = img.shape[:2]

    attn_map_resized = upsample_map_to_image(attn_map, (H, W)).numpy()
    cmap = plt.get_cmap("jet")
    heatmap = cmap(attn_map_resized)[:, :, :3]
    overlay = (0.6 * img / 255.0 + 0.4 * heatmap)
    overlay = np.clip(overlay, 0.0, 1.0)
    overlay = (overlay * 255).astype(np.uint8)

    Image.fromarray(overlay).save(out_path)


def sanitize_token_for_filename(tok: str) -> str:
    if not isinstance(tok, str):
        tok = str(tok)
    tok = tok.replace("▁", "_")
    tok = re.sub(r"[^a-zA-Z0-9._-]", "", tok)
    if tok == "":
        tok = "tok"
    return tok[:40]


# --------------------
# Main eval with attention mapping
# --------------------

def eval_model(args):
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)

    tokenizer, model, image_processor, context_len = load_pretrained_model_both(
        model_path, args.model_base, model_name, args.use_prompt_tuning
    )

    device = _get_device(model)

    questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
    answers_file = os.path.expanduser(args.answers_file)
    os.makedirs(os.path.dirname(answers_file), exist_ok=True)

    if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
        args.conv_mode = args.conv_mode + '_mmtag'
        print(f'Plain model detected. Auto-switching conv mode to {args.conv_mode}.')

    data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config, args.conv_mode)

    os.makedirs(args.save_attn_dir, exist_ok=True)

    with open(answers_file, "w") as ans_file:
        for (input_ids, image_tensor, image_sizes, lines, image_files) in tqdm(data_loader, total=len(questions)):
            line = lines[0]
            image_file = image_files[0]
            idx = line.get("question_id", line.get("id", shortuuid.uuid()))
            cur_prompt = line["text"]

            input_ids = input_ids.to(device=device, non_blocking=True)
            with torch.inference_mode():
                gen = model.generate(
                    inputs=input_ids,
                    images=image_tensor.to(device=device, dtype=torch.float16, non_blocking=True),
                    image_sizes=image_sizes,
                    do_sample=True if args.temperature > 0 else False,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    num_beams=args.num_beams,
                    max_new_tokens=args.max_new_tokens,
                    use_cache=True,
                    return_dict_in_generate=True,
                    output_attentions=False,
                )

            sequences = gen.sequences
            outputs = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0].strip()

            # Directly use cross-attention without visual backbone self-attention
            Gh = Gw = int((image_tensor.shape[2]//16))
            num_vis_tokens = Gh * Gw

            try:
                token_maps = compute_lang_to_image_maps(
                    model,
                    sequences,
                    input_ids,
                    tokenizer,
                    image_tensor,
                    image_sizes,
                    num_vis_tokens,
                    (Gh, Gw),
                    layer_agg=args.layer_agg,
                    head_agg=args.head_agg,
                    normalize=True,
                    virtual_prompt_len=args.virtual_prompt_len,
                )
            except RuntimeError as e:
                print(f"[Warning] Skipping sample {idx} due to attention mapping error: {e}")
                token_maps = []

            saved_maps = []
            if args.save_heatmap and len(token_maps) > 0:
                H, W = image_sizes[0][1], image_sizes[0][0]
                tokens = tokenizer.convert_ids_to_tokens(sequences[0])
                prompt_visible_len = input_ids.shape[1]
                for t, m in enumerate(token_maps):
                    idx_in_seq = prompt_visible_len + t
                    token_str_raw = tokens[min(idx_in_seq, len(tokens)-1)]
                    token_str = sanitize_token_for_filename(token_str_raw)
                    overlay_path = os.path.join(args.save_attn_dir, f"{idx}_tok{t:03d}_{token_str}.png")
                    overlay_heatmap_on_image(os.path.join(args.image_folder, image_file), m, overlay_path)
                    saved_maps.append(overlay_path)

            ans_id = shortuuid.uuid()
            record = {
                "question_id": idx,
                "prompt": cur_prompt,
                "text": outputs,
                "answer_id": ans_id,
                "model_id": model_name,
                "grid_size": {"h": Gh, "w": Gw},
                "num_token_maps": len(token_maps),
                "heatmaps": saved_maps,
            }
            ans_file.write(json.dumps(record) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="")
    parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
    parser.add_argument("--answers-file", type=str, default="answer.jsonl")
    parser.add_argument("--conv-mode", type=str, default="llava_v1")

    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--max_new_tokens", type=int, default=128)

    parser.add_argument("--use_prompt_tuning", action='store_true', default=True)
    parser.add_argument("--no_use_prompt_tuning", dest='use_prompt_tuning', action='store_false')

    parser.add_argument("--layer_agg", type=str, default="mean", choices=["mean", "last"])
    parser.add_argument("--head_agg", type=str, default="mean", choices=["mean", "max"])

    parser.add_argument("--save_attn_dir", type=str, default="attn_maps")
    parser.add_argument("--save_heatmap", action='store_true')
    parser.add_argument("--virtual_prompt_len", type=int, default=128, help="Number of virtual prompt tokens prepended during prompt tuning")

    args = parser.parse_args()

    eval_model(args)
