# import argparse
# import torch
# import os
# import json
# from tqdm import tqdm
# import shortuuid
# import math

# from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
# from llava.conversation import conv_templates, SeparatorStyle
# from llava.model.builder import load_pretrained_model, load_pretrained_model_both
# from llava.utils import disable_torch_init
# from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
# from torch.utils.data import Dataset, DataLoader
# from peft import PeftModel

# from PIL import Image
# import numpy as np
# import torch.nn.functional as F
# import matplotlib.pyplot as plt

# # --------------------
# # Helpers for batching
# # --------------------

# def split_list(lst, n):
#     chunk_size = math.ceil(len(lst) / n)
#     return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]


# def get_chunk(lst, n, k):
#     chunks = split_list(lst, n)
#     return chunks[k]


# # --------------------
# # Dataset
# # --------------------
# class CustomDataset(Dataset):
#     def __init__(self, questions, image_folder, tokenizer, image_processor, model_config, conv_mode):
#         self.questions = questions
#         self.image_folder = image_folder
#         self.tokenizer = tokenizer
#         self.image_processor = image_processor
#         self.model_config = model_config
#         self.conv_mode = conv_mode

#     def __getitem__(self, index):
#         line = self.questions[index]
#         image_file = line["image"]
#         qs = line["text"]

#         if self.model_config.mm_use_im_start_end:
#             qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
#         else:
#             qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

#         conv = conv_templates[self.conv_mode].copy()
#         conv.append_message(conv.roles[0], qs)
#         conv.append_message(conv.roles[1], None)
#         prompt = conv.get_prompt()

#         image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
#         image_tensor = process_images([image], self.image_processor, self.model_config)[0]

#         input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')

#         return input_ids, image_tensor, image.size, line, image_file

#     def __len__(self):
#         return len(self.questions)


# def collate_fn(batch):
#     input_ids, image_tensors, image_sizes, lines, image_files = zip(*batch)
#     input_ids = torch.stack(input_ids, dim=0)
#     image_tensors = torch.stack(image_tensors, dim=0)
#     return input_ids, image_tensors, image_sizes, lines, image_files


# def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, conv_mode, batch_size=1, num_workers=1):
#     assert batch_size == 1, "batch_size must be 1 for attention alignment"
#     dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config, conv_mode)
#     data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
#     return data_loader


# # --------------------
# # Attention utilities
# # --------------------

# def _get_device(model):
#     return next(model.parameters()).device


# def aggregate_llm_attentions(attentions, layer_agg="mean", head_agg="mean"):
#     attn = torch.stack(attentions, dim=0)
#     if head_agg == "mean":
#         attn = attn.mean(dim=2)
#     elif head_agg == "max":
#         attn = attn.max(dim=2).values
#     else:
#         raise ValueError("head_agg must be 'mean' or 'max'")

#     if layer_agg == "mean":
#         attn = attn.mean(dim=0)
#     elif layer_agg == "last":
#         attn = attn[-1]
#     else:
#         raise ValueError("layer_agg must be 'mean' or 'last'")
#     return attn


# def get_vision_backbone(model):
#     vt = model.get_vision_tower()
#     if hasattr(vt, 'vision_tower') and hasattr(vt.vision_tower, 'vision_model'):
#         return vt.vision_tower.vision_model
#     if hasattr(vt, 'vision_model'):
#         return vt.vision_model
#     if hasattr(vt, 'model') and hasattr(vt.model, 'vision_model'):
#         return vt.model.vision_model
#     raise RuntimeError("Cannot locate vision backbone that exposes attentions.")


# def get_vit_patch_attention_maps(model, image_tensor, layer_index=-1):
#     device = _get_device(model)
#     vision_model = get_vision_backbone(model)
#     with torch.inference_mode():
#         out = vision_model(pixel_values=image_tensor.to(device), output_attentions=True, return_dict=True)
#     atn = out.attentions[layer_index][0]
#     atn = atn.mean(dim=0)

#     atn = atn[1:, 1:]
#     num_patches = atn.shape[0]

#     g = int(num_patches ** 0.5)
#     if g * g != num_patches:
#         Gh, Gw = g, num_patches // g
#     else:
#         Gh = Gw = g

#     vis_maps = atn.reshape(num_patches, Gh, Gw)
#     return vis_maps, (Gh, Gw)


# def find_image_placeholder_positions(input_ids_row, image_token_index: int):
#     pos = (input_ids_row == image_token_index).nonzero(as_tuple=True)[0]
#     return pos.tolist()


# def compute_lang_to_image_maps(model, input_ids_full, image_tensor, image_sizes, vis_maps, grid_size, layer_agg="mean", head_agg="mean", normalize=True):
#     device = _get_device(model)
#     with torch.inference_mode():
#         out = model(
#             input_ids=input_ids_full.to(device),
#             images=image_tensor.to(device, dtype=torch.float16),
#             image_sizes=image_sizes,
#             use_cache=False,
#             output_attentions=True,
#             return_dict=True,
#         )

#     llm_attn = aggregate_llm_attentions(out.attentions, layer_agg=layer_agg, head_agg=head_agg)[0]

#     placeholder_positions = find_image_placeholder_positions(input_ids_full[0], IMAGE_TOKEN_INDEX)
#     if len(placeholder_positions) == 0:
#         raise RuntimeError("No <image> placeholder found in input_ids.")
#     im_placeholder = placeholder_positions[0]

#     num_vis_tokens = vis_maps.shape[0]
#     vision_token_start = im_placeholder
#     vision_token_end = im_placeholder + num_vis_tokens

#     if not hasattr(input_ids_full, "_prompt_len"):
#         raise RuntimeError("input_ids_full must have attribute _prompt_len (original prompt length).")
#     prompt_len = input_ids_full._prompt_len

#     token_maps = []
#     Gh, Gw = grid_size

#     for tgt_idx in range(prompt_len, input_ids_full.shape[1]):
#         w = llm_attn[tgt_idx, vision_token_start:vision_token_end]
#         if normalize:
#             w = w.clamp_min(0)
#             denom = w.sum().clamp_min(1e-6)
#             w = w / denom
#         lang_map = (w.view(-1, 1, 1) * vis_maps).sum(dim=0)
#         if normalize:
#             s = lang_map.sum().clamp_min(1e-6)
#             lang_map = lang_map / s
#         token_maps.append(lang_map.detach().cpu())

#     return token_maps


# def upsample_map_to_image(attn_map, image_hw):
#     Gh, Gw = attn_map.shape
#     H, W = image_hw
#     grid = attn_map.unsqueeze(0).unsqueeze(0)
#     grid = F.interpolate(grid, size=(H, W), mode='bilinear', align_corners=False)[0, 0]
#     grid = (grid - grid.min()) / (grid.max() - grid.min() + 1e-6)
#     return grid


# def overlay_heatmap_on_image(image_path, attn_map, out_path):
#     img = Image.open(image_path).convert("RGB")
#     img = np.array(img)
#     H, W = img.shape[:2]

#     attn_map_resized = upsample_map_to_image(attn_map, (H, W)).numpy()
#     cmap = plt.get_cmap("jet")
#     heatmap = cmap(attn_map_resized)[:, :, :3]
#     overlay = (0.6 * img / 255.0 + 0.4 * heatmap)
#     overlay = (overlay / overlay.max() * 255).astype(np.uint8)

#     Image.fromarray(overlay).save(out_path)


# # --------------------
# # Main eval with attention mapping
# # --------------------

# def eval_model(args):
#     disable_torch_init()
#     model_path = os.path.expanduser(args.model_path)
#     model_name = get_model_name_from_path(model_path)

#     tokenizer, model, image_processor, context_len = load_pretrained_model_both(
#         model_path, args.model_base, model_name, args.use_prompt_tuning
#     )

#     device = _get_device(model)

#     questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
#     answers_file = os.path.expanduser(args.answers_file)
#     os.makedirs(os.path.dirname(answers_file), exist_ok=True)

#     if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
#         args.conv_mode = args.conv_mode + '_mmtag'
#         print(f'Plain model detected. Auto-switching conv mode to {args.conv_mode}.')

#     data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config, args.conv_mode)

#     os.makedirs(args.save_attn_dir, exist_ok=True)

#     with open(answers_file, "w") as ans_file:
#         for (input_ids, image_tensor, image_sizes, lines, image_files) in tqdm(data_loader, total=len(questions)):
#             line = lines[0]
#             image_file = image_files[0]
#             idx = line.get("question_id", line.get("id", shortuuid.uuid()))
#             cur_prompt = line["text"]

#             input_ids = input_ids.to(device=device, non_blocking=True)
#             with torch.inference_mode():
#                 gen = model.generate(
#                     inputs=input_ids,
#                     images=image_tensor.to(device=device, dtype=torch.float16, non_blocking=True),
#                     image_sizes=image_sizes,
#                     do_sample=True if args.temperature > 0 else False,
#                     temperature=args.temperature,
#                     top_p=args.top_p,
#                     num_beams=args.num_beams,
#                     max_new_tokens=args.max_new_tokens,
#                     use_cache=True,
#                     return_dict_in_generate=True,
#                     output_attentions=False,
#                 )

#             sequences = gen.sequences
#             outputs = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0].strip()

#             full_ids = sequences
#             full_ids._prompt_len = input_ids.shape[1]

#             vis_maps, grid_size = get_vit_patch_attention_maps(model, image_tensor, layer_index=args.vis_layer_index)

#             token_maps = compute_lang_to_image_maps(
#                 model,
#                 full_ids,
#                 image_tensor,
#                 image_sizes,
#                 vis_maps,
#                 grid_size,
#                 layer_agg=args.layer_agg,
#                 head_agg=args.head_agg,
#                 normalize=True,
#             )

#             saved_maps = []
#             if args.save_heatmap:
#                 H, W = image_sizes[0][1], image_sizes[0][0]
#                 tokens = tokenizer.convert_ids_to_tokens(sequences[0])
#                 for t, m in enumerate(token_maps):
#                     token_str = tokens[full_ids._prompt_len + t]
#                     overlay_path = os.path.join(args.save_attn_dir, f"{idx}_tok{t:03d}_{token_str}.png")
#                     overlay_heatmap_on_image(os.path.join(args.image_folder, image_file), m, overlay_path)
#                     saved_maps.append(overlay_path)

#             ans_id = shortuuid.uuid()
#             record = {
#                 "question_id": idx,
#                 "prompt": cur_prompt,
#                 "text": outputs,
#                 "answer_id": ans_id,
#                 "model_id": model_name,
#                 "grid_size": {"h": grid_size[0], "w": grid_size[1]},
#                 "num_token_maps": len(token_maps),
#                 "heatmaps": saved_maps,
#             }
#             ans_file.write(json.dumps(record) + "\n")


# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
#     parser.add_argument("--model-base", type=str, default=None)
#     parser.add_argument("--image-folder", type=str, default="")
#     parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
#     parser.add_argument("--answers-file", type=str, default="answer.jsonl")
#     parser.add_argument("--conv-mode", type=str, default="llava_v1")

#     parser.add_argument("--temperature", type=float, default=0.2)
#     parser.add_argument("--top_p", type=float, default=None)
#     parser.add_argument("--num_beams", type=int, default=1)
#     parser.add_argument("--max_new_tokens", type=int, default=128)

#     parser.add_argument("--use_prompt_tuning", action='store_true', default=True)
#     parser.add_argument("--no_use_prompt_tuning", dest='use_prompt_tuning', action='store_false')

#     parser.add_argument("--layer_agg", type=str, default="mean", choices=["mean", "last"])
#     parser.add_argument("--head_agg", type=str, default="mean", choices=["mean", "max"])
#     parser.add_argument("--vis_layer_index", type=int, default=-1)

#     parser.add_argument("--save_attn_dir", type=str, default="attn_maps")
#     parser.add_argument("--save_heatmap", action='store_true')

#     args = parser.parse_args()

#     eval_model(args)

import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid
import math

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model, load_pretrained_model_both
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from torch.utils.data import Dataset, DataLoader
from peft import PeftModel

from PIL import Image
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
import re

# --------------------
# Helpers for batching
# --------------------

def split_list(lst, n):
    chunk_size = math.ceil(len(lst) / n)
    return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


# --------------------
# Dataset
# --------------------
class CustomDataset(Dataset):
    def __init__(self, questions, image_folder, tokenizer, image_processor, model_config, conv_mode):
        self.questions = questions
        self.image_folder = image_folder
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.model_config = model_config
        self.conv_mode = conv_mode

    def __getitem__(self, index):
        line = self.questions[index]
        image_file = line["image"]
        qs = line["text"]

        if self.model_config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
        image_tensor = process_images([image], self.image_processor, self.model_config)[0]

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')

        return input_ids, image_tensor, image.size, line, image_file

    def __len__(self):
        return len(self.questions)


def collate_fn(batch):
    input_ids, image_tensors, image_sizes, lines, image_files = zip(*batch)
    input_ids = torch.stack(input_ids, dim=0)
    image_tensors = torch.stack(image_tensors, dim=0)
    return input_ids, image_tensors, image_sizes, lines, image_files


def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, conv_mode, batch_size=1, num_workers=1):
    assert batch_size == 1, "batch_size must be 1 for attention alignment"
    dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config, conv_mode)
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
    return data_loader


# --------------------
# Attention utilities
# --------------------

def _get_device(model):
    return next(model.parameters()).device


def aggregate_llm_attentions(attentions, layer_agg="mean", head_agg="mean"):
    attn = torch.stack(attentions, dim=0)
    if head_agg == "mean":
        attn = attn.mean(dim=2)
    elif head_agg == "max":
        attn = attn.max(dim=2).values
    else:
        raise ValueError("head_agg must be 'mean' or 'max'")

    if layer_agg == "mean":
        attn = attn.mean(dim=0)
    elif layer_agg == "last":
        attn = attn[-1]
    else:
        raise ValueError("layer_agg must be 'mean' or 'last'")
    return attn


def get_vision_backbone(model):
    vt = model.get_vision_tower()
    if hasattr(vt, 'vision_tower') and hasattr(vt.vision_tower, 'vision_model'):
        return vt.vision_tower.vision_model
    if hasattr(vt, 'vision_model'):
        return vt.vision_model
    if hasattr(vt, 'model') and hasattr(vt.model, 'vision_model'):
        return vt.model.vision_model
    raise RuntimeError("Cannot locate vision backbone that exposes attentions.")


def get_vit_patch_attention_maps(model, image_tensor, layer_index=-1):
    device = _get_device(model)
    vision_model = get_vision_backbone(model)
    with torch.inference_mode():
        out = vision_model(pixel_values=image_tensor.to(device), output_attentions=True, return_dict=True)
    atn = out.attentions[layer_index][0]
    atn = atn.mean(dim=0)

    atn = atn[1:, 1:]
    num_patches = atn.shape[0]

    g = int(num_patches ** 0.5)
    if g * g != num_patches:
        Gh, Gw = g, num_patches // g
    else:
        Gh = Gw = g

    vis_maps = atn.reshape(num_patches, Gh, Gw)
    return vis_maps, (Gh, Gw)


def get_token_id_safe(tokenizer, token_str):
    try:
        _id = tokenizer.convert_tokens_to_ids(token_str)
        if isinstance(_id, list):
            _id = _id[0]
        if _id is None:
            return None
        return int(_id)
    except Exception:
        return None


def find_vision_start_index(input_ids_row, tokenizer):
    ids = input_ids_row.tolist()
    try:
        if IMAGE_TOKEN_INDEX in ids:
            return ids.index(IMAGE_TOKEN_INDEX)
    except Exception:
        pass
    image_id = get_token_id_safe(tokenizer, DEFAULT_IMAGE_TOKEN)
    if image_id is not None and image_id in ids:
        return ids.index(image_id)
    im_start_id = get_token_id_safe(tokenizer, DEFAULT_IM_START_TOKEN)
    if im_start_id is not None and im_start_id in ids:
        return ids.index(im_start_id)
    raise RuntimeError("Cannot locate image placeholder span in input_ids.")


def compute_lang_to_image_maps(model, input_ids_full, prompt_input_ids, tokenizer, image_tensor, image_sizes, vis_maps, grid_size, layer_agg="mean", head_agg="mean", normalize=True):
    device = _get_device(model)
    with torch.inference_mode():
        out = model(
            input_ids=input_ids_full.to(device),
            images=image_tensor.to(device, dtype=torch.float16),
            image_sizes=image_sizes,
            use_cache=False,
            output_attentions=True,
            return_dict=True,
            # trace=True,
        )
    # import ipdb; ipdb.set_trace()
    llm_attn = aggregate_llm_attentions(out.attentions, layer_agg=layer_agg, head_agg=head_agg)[0]

    vision_token_start = find_vision_start_index(prompt_input_ids[0], tokenizer)

    # Align length mismatch: llm_attn slice length (num_llm_vis_tokens) vs vis_maps (num_vit_patches)
    num_llm_vis_tokens = llm_attn.shape[1] - prompt_input_ids.shape[1] + 1 if vision_token_start is not None else vis_maps.shape[0]
    num_vis_tokens = vis_maps.shape[0]

    vision_token_end = vision_token_start + min(num_llm_vis_tokens, num_vis_tokens)

    if not hasattr(input_ids_full, "_prompt_len"):
        raise RuntimeError("input_ids_full must have attribute _prompt_len (original prompt length).")
    prompt_len = input_ids_full._prompt_len

    token_maps = []
    Gh, Gw = grid_size
    import ipdb; ipdb.set_trace()
    for tgt_idx in range(prompt_len, input_ids_full.shape[1]):
        w = llm_attn[tgt_idx, vision_token_start:vision_token_end]
        if normalize:
            w = w.clamp_min(0)
            denom = w.sum().clamp_min(1e-6)
            w = w / denom
        # If mismatch, interpolate weights to vis_maps count
        if w.shape[0] != num_vis_tokens:
            w = F.interpolate(w.view(1,1,-1), size=num_vis_tokens, mode="linear", align_corners=False).view(-1)
        lang_map = (w.view(-1, 1, 1) * vis_maps).sum(dim=0)
        if normalize:
            s = lang_map.sum().clamp_min(1e-6)
            lang_map = lang_map / s
        token_maps.append(lang_map.detach().cpu())

    return token_maps


def upsample_map_to_image(attn_map, image_hw):
    Gh, Gw = attn_map.shape
    H, W = image_hw
    grid = attn_map.unsqueeze(0).unsqueeze(0)
    grid = F.interpolate(grid, size=(H, W), mode='bilinear', align_corners=False)[0, 0]
    grid = (grid - grid.min()) / (grid.max() - grid.min() + 1e-6)
    return grid


def overlay_heatmap_on_image(image_path, attn_map, out_path):
    img = Image.open(image_path).convert("RGB")
    img = np.array(img)
    H, W = img.shape[:2]

    attn_map_resized = upsample_map_to_image(attn_map, (H, W)).numpy()
    cmap = plt.get_cmap("jet")
    heatmap = cmap(attn_map_resized)[:, :, :3]
    overlay = (0.6 * img / 255.0 + 0.4 * heatmap)
    overlay = np.clip(overlay, 0.0, 1.0)
    overlay = (overlay * 255).astype(np.uint8)

    Image.fromarray(overlay).save(out_path)


def sanitize_token_for_filename(tok: str) -> str:
    if not isinstance(tok, str):
        tok = str(tok)
    tok = tok.replace("▁", "_")
    tok = re.sub(r"[^a-zA-Z0-9._-]", "", tok)
    if tok == "":
        tok = "tok"
    return tok[:40]


# --------------------
# Main eval with attention mapping
# --------------------

def eval_model(args):
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)

    tokenizer, model, image_processor, context_len = load_pretrained_model_both(
        model_path, args.model_base, model_name, args.use_prompt_tuning
    )

    device = _get_device(model)

    questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
    answers_file = os.path.expanduser(args.answers_file)
    os.makedirs(os.path.dirname(answers_file), exist_ok=True)

    if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
        args.conv_mode = args.conv_mode + '_mmtag'
        print(f'Plain model detected. Auto-switching conv mode to {args.conv_mode}.')

    data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config, args.conv_mode)

    os.makedirs(args.save_attn_dir, exist_ok=True)

    with open(answers_file, "w") as ans_file:
        for (input_ids, image_tensor, image_sizes, lines, image_files) in tqdm(data_loader, total=len(questions)):
            line = lines[0]
            image_file = image_files[0]
            idx = line.get("question_id", line.get("id", shortuuid.uuid()))
            cur_prompt = line["text"]

            input_ids = input_ids.to(device=device, non_blocking=True)
            with torch.inference_mode():
                gen = model.generate(
                    inputs=input_ids,
                    images=image_tensor.to(device=device, dtype=torch.float16, non_blocking=True),
                    image_sizes=image_sizes,
                    do_sample=True if args.temperature > 0 else False,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    num_beams=args.num_beams,
                    max_new_tokens=args.max_new_tokens,
                    use_cache=True,
                    return_dict_in_generate=True,
                    output_attentions=False,
                )

            sequences = gen.sequences
            outputs = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0].strip()

            full_ids = sequences
            full_ids._prompt_len = input_ids.shape[1]

            vis_maps, grid_size = get_vit_patch_attention_maps(model, image_tensor, layer_index=args.vis_layer_index)

            token_maps = compute_lang_to_image_maps(
                model,
                full_ids,
                input_ids,  # use prompt ids to locate the image span robustly
                tokenizer,
                image_tensor,
                image_sizes,
                vis_maps,
                grid_size,
                layer_agg=args.layer_agg,
                head_agg=args.head_agg,
                normalize=True,
            )

            saved_maps = []
            if args.save_heatmap:
                H, W = image_sizes[0][1], image_sizes[0][0]
                tokens = tokenizer.convert_ids_to_tokens(sequences[0])
                for t, m in enumerate(token_maps):
                    token_str_raw = tokens[full_ids._prompt_len + t]
                    token_str = sanitize_token_for_filename(token_str_raw)
                    overlay_path = os.path.join(args.save_attn_dir, f"{idx}_tok{t:03d}_{token_str}.png")
                    overlay_heatmap_on_image(os.path.join(args.image_folder, image_file), m, overlay_path)
                    saved_maps.append(overlay_path)

            ans_id = shortuuid.uuid()
            record = {
                "question_id": idx,
                "prompt": cur_prompt,
                "text": outputs,
                "answer_id": ans_id,
                "model_id": model_name,
                "grid_size": {"h": grid_size[0], "w": grid_size[1]},
                "num_token_maps": len(token_maps),
                "heatmaps": saved_maps,
            }
            ans_file.write(json.dumps(record) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="")
    parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
    parser.add_argument("--answers-file", type=str, default="answer.jsonl")
    parser.add_argument("--conv-mode", type=str, default="llava_v1")

    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--max_new_tokens", type=int, default=128)

    parser.add_argument("--use_prompt_tuning", action='store_true', default=True)
    parser.add_argument("--no_use_prompt_tuning", dest='use_prompt_tuning', action='store_false')

    parser.add_argument("--layer_agg", type=str, default="mean", choices=["mean", "last"])
    parser.add_argument("--head_agg", type=str, default="mean", choices=["mean", "max"])
    parser.add_argument("--vis_layer_index", type=int, default=-1)

    parser.add_argument("--save_attn_dir", type=str, default="attn_maps")
    parser.add_argument("--save_heatmap", action='store_true')

    args = parser.parse_args()

    eval_model(args)
