import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model, load_pretrained_model_both
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from torch.utils.data import Dataset, DataLoader
from peft import PromptTuningConfig, PromptTuningInit, get_peft_model

from PIL import Image
import math
# 加载LoRA适配器
from peft import PeftModel
import pdb
import sys
from .utils_attn import handle_attentions_i2t


def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
        self.questions = questions
        self.image_folder = image_folder
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.model_config = model_config
    def __getitem__(self, index):
        line = self.questions[index]
        image_file = line["image"]
        qs = line["text"]
        if self.model_config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
        image_tensor = process_images([image], self.image_processor, self.model_config)[0]

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        image_path = os.path.join(self.image_folder, image_file)
        return input_ids, image_tensor, image.size, image_path


    def __len__(self):
        return len(self.questions)


def collate_fn(batch):
    input_ids, image_tensors, image_sizes,image_path = zip(*batch)
    input_ids = torch.stack(input_ids, dim=0)
    image_tensors = torch.stack(image_tensors, dim=0)
    return input_ids, image_tensors, image_sizes, image_path


# DataLoader
def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=1):
    assert batch_size == 1, "batch_size must be 1"
    dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
    return data_loader


# --- put these imports at top of your file ---
import os
import json
import logging
import math
import shortuuid
from types import SimpleNamespace

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
# ------------------------------------------------

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def _recover_image_from_tensor(image_tensor):
    """
    image_tensor: torch.Tensor, shape [B, C, H, W] or [C, H, W]
    Return: PIL.Image (RGB)
    Assumes tensor in 0..1 (or -1..1) - handles common ranges heuristically.
    """
    if isinstance(image_tensor, torch.Tensor):
        t = image_tensor.detach().cpu()
        if t.dim() == 4:
            t = t[0]
        # C,H,W -> H,W,C
        arr = t.permute(1, 2, 0).numpy()
    else:
        arr = np.array(image_tensor)

    # Heuristics for range
    if arr.dtype != np.uint8:
        mn, mx = arr.min(), arr.max()
        if mn < -0.5 and mx <= 1.5:
            # probably in [-1,1]
            arr = (arr + 1.0) / 2.0
        if arr.max() <= 1.5:
            arr = np.clip(arr * 255.0, 0, 255).astype(np.uint8)
        else:
            arr = np.clip(arr, 0, 255).astype(np.uint8)

    if arr.shape[2] == 1:
        pil = Image.fromarray(arr[:, :, 0], mode='L').convert('RGB')
    else:
        pil = Image.fromarray(arr)
    return pil


# def draw_heatmap_on_image(attn_2d, recovered_image_pil, colormap='jet', alpha=0.55):
#     """
#     attn_2d: numpy array (H_patch, W_patch) with values >=0 (not necessarily normalized).
#     recovered_image_pil: PIL.Image (RGB)
#     returns: PIL.Image (RGBA) blended overlay
#     """
#     # normalize
#     a = np.array(attn_2d, dtype=np.float32)
#     if np.isnan(a).any():
#         a = np.nan_to_num(a, 0.0)
#     if a.max() > 0:
#         a = a / float(a.max())
#     else:
#         a = a * 0.0
    
#     # Resize attention to image size
#     img_w, img_h = recovered_image_pil.size
#     attn_img = Image.fromarray((a * 255).astype(np.uint8)).resize((img_w, img_h), resample=Image.BILINEAR)
#     attn_arr = np.array(attn_img) / 255.0  # 0..1

#     # apply colormap
#     cmap = plt.get_cmap(colormap)
#     colored = cmap(attn_arr)  # H,W,4 rgba 0..1
#     colored = (colored * 255).astype(np.uint8)
#     overlay = Image.fromarray(colored).convert("RGBA")

#     base = recovered_image_pil.convert("RGBA")
#     # Blend: alpha controls overlay strength
#     blended = Image.blend(base, overlay, alpha=alpha)
#     return blended

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import warnings
import os

def draw_heatmap_on_image(attn_2d, recovered_image_pil, colormap='jet', alpha=0.55, save_path="/datanfs2/medllava/llava/Externalization_llava/922image"):
    """
    Draw a heatmap overlay on an image and optionally save both the original and blended images.

    Args:
        attn_2d: numpy array (H_patch, W_patch) with values >=0 (not necessarily normalized).
        recovered_image_pil: PIL.Image (RGB)
        colormap: str, name of matplotlib colormap (default: 'jet')
        alpha: float, blending strength for overlay (0 to 1, default: 0.55)
        save_path: str, directory to save images (optional). Saves 'original.png' and 'heatmap.png'

    Returns:
        PIL.Image (RGBA): blended overlay
    """
    # Input validation
    if not isinstance(recovered_image_pil, Image.Image):
        raise ValueError("recovered_image_pil must be a PIL.Image")
    if not isinstance(attn_2d, np.ndarray) or attn_2d.ndim != 2:
        raise ValueError("attn_2d must be a 2D NumPy array")
    if np.any(attn_2d < 0):
        warnings.warn("Negative values in attn_2d were clipped to 0")
        attn_2d = np.clip(attn_2d, 0, None)
    alpha = np.clip(alpha, 0, 1)  # Ensure alpha is in [0, 1]

    # Normalize attention map
    a = np.array(attn_2d, dtype=np.float32)
    a = np.nan_to_num(a, nan=0.0)
    if a.max() > 0:
        a = a / a.max()
    else:
        warnings.warn("Attention map is all zeros; heatmap will be empty")

    # Resize attention to image size
    img_w, img_h = recovered_image_pil.size
    attn_img = Image.fromarray((a * 255).astype(np.uint8)).resize(
        (img_w, img_h), resample=Image.Resampling.BILINEAR
    )
    attn_arr = np.array(attn_img) / 255.0  # 0..1

    # Apply colormap
    try:
        cmap = plt.get_cmap(colormap)
    except ValueError as e:
        raise ValueError(f"Invalid colormap name: {colormap}. Use a valid matplotlib colormap.") from e
    colored = cmap(attn_arr)  # H,W,4 rgba 0..1
    colored = (colored * 255).astype(np.uint8)
    overlay = Image.fromarray(colored).convert("RGBA")

    # Blend images
    base = recovered_image_pil.convert("RGBA")
    blended = Image.blend(base, overlay, alpha=alpha)

    # Save images if save_path is provided
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        original_path = os.path.join(save_path, "original.png")
        heatmap_path = os.path.join(save_path, "heatmap.png")
        
        # Save synchronously
        recovered_image_pil.save(original_path, format="PNG")
        blended.save(heatmap_path, format="PNG")
        # print(f"Saved original image to {original_path}")
        # print(f"Saved heatmap image to {heatmap_path}")

    return blended
def handle_attentions_i2t(state,
                          token_selector=None,
                          layer_indices="all",
                          head_indices="all",
                          patch_grid=None,
                          img_token_start=None,
                          num_image_tokens=None,
                          save_dir=None,
                          overlay_alpha=0.55,
                          image_recover=None):
    """
    state: object or dict that must contain:
        - attention_key (prefix path) => function will load attention from attention_key + '_attn.pt'
        - recovered_image (PIL.Image) OR image_tensor (torch.Tensor)
        - image_idx or img_token_start (index where image tokens start in attention key space)
        - output_ids_decoded: list[str] (decoded tokens in output)
    token_selector:
        - None -> default [0] (first generated token)
        - list of ints -> indices of generated tokens
        - list of strings -> token texts to match in output_ids_decoded (exact match preferred, substring fallback)
    layer_indices: "all" or list of layer indices (0-based)
    head_indices: "all" or list of head indices (0-based)
    patch_grid: (H_patches, W_patches) default (24,24)
    num_image_tokens: defaults to product(patch_grid)
    save_dir: directory to save images. if None, use cwd/'attn_vis'
    returns:
        list of saved file paths (images), and summary fig path (if created)
    """
    # normalize state access
    if isinstance(state, dict):
        st = state
    else:
        # SimpleNamespace or object
        st = state.__dict__ if hasattr(state, "__dict__") else vars(state)

    if 'attention_key' not in st:
        raise ValueError("state must have 'attention_key' (path prefix for saved attention file)")

    fn_attention = st['attention_key'] + '_attn.pt'
    if not os.path.exists(fn_attention):
        raise FileNotFoundError(f"Attention file not found: {fn_attention}")

    logger.info(f'Loading attentions from {fn_attention}')
    attentions = torch.load(fn_attention, map_location='cpu')  # keep on cpu

    # Recover image PIL if needed
    # recovered_image = st.get('recovered_image', None)
    # if recovered_image is None and 'image_tensor' in st:
    #     recovered_image = _recover_image_from_tensor(st['image_tensor'])
    # if recovered_image is None:
    #     raise ValueError("state must contain 'recovered_image' (PIL) or 'image_tensor' (torch.Tensor)")
    recovered_image = image_recover
    # set patch grid default
    if patch_grid is None:
        patch_grid = tuple(st.get('patch_grid', (24, 24)))
    if num_image_tokens is None:
        num_image_tokens = st.get('num_image_tokens', patch_grid[0] * patch_grid[1])
    # image token start index
    img_start = st.get('img_token_start', st.get('image_idx', img_token_start))
    if img_start is None:
        raise ValueError("image token start index must be provided in state as 'img_token_start' or 'image_idx'")

    # determine output tokens
    output_tokens = st.get('output_ids_decoded', None)
    if output_tokens is None:
        raise ValueError("state must contain 'output_ids_decoded' (list of decoded tokens)")

    # compute token_idx_list from token_selector
    token_idx_list = []
    if token_selector is None:
        token_idx_list = [0]
    else:
        # selector may be list of ints or strings
        if all(isinstance(x, int) for x in token_selector):
            token_idx_list = [int(x) for x in token_selector]
        else:
            # treat as strings: try exact match first, then substring match
            for tok in token_selector:
                if tok in output_tokens:
                    token_idx_list.append(output_tokens.index(tok))
                else:
                    # fallback substring search (first match)
                    found = False
                    for i, outtok in enumerate(output_tokens):
                        if tok in outtok:
                            token_idx_list.append(i)
                            found = True
                            break
                    if not found:
                        logger.warning(f"token '{tok}' not found in generated tokens; skipping")
    


    # # compute token_idx_list from token_selector
    # token_idx_list = []
    # soft_prompt_len = st.get('soft_prompt_length', 0)

    # if token_selector is None:
    #     token_idx_list = [0]  # 默认第一个生成的 token
    # elif token_selector == "soft_prompt":
    #     token_idx_list = list(range(soft_prompt_len))  # 前 N 个 soft prompt tokens
    # elif all(isinstance(x, int) for x in token_selector):
    #     token_idx_list = [int(x) for x in token_selector]
    # else:
    #     # treat as strings: match output tokens
    #     for tok in token_selector:
    #         if tok in output_tokens:
    #             token_idx_list.append(output_tokens.index(tok))
    #         else:
    #             for i, outtok in enumerate(output_tokens):
    #                 if tok in outtok:
    #                     token_idx_list.append(i)
    #                     break
    # token_idx_list = list(range(soft_prompt_len)) 

    if not token_idx_list:
        raise ValueError("No token indices selected/found.")

    # interpret layer / head indices
    # determine if attentions is per-token list (attentions[token_idx][layer_idx]) or per-layer list (attentions[layer_idx])
    per_token_format = False
    if isinstance(attentions, (list, tuple)):
        first = attentions[0] if len(attentions) > 0 else None
        if isinstance(first, (list, tuple)):
            per_token_format = True
            num_layers = len(first)
        elif torch.is_tensor(first):
            per_token_format = False
            num_layers = len(attentions)
        else:
            # try converting to list
            per_token_format = True
            num_layers = len(first) if first is not None else 0
    elif isinstance(attentions, dict) and 'attentions' in attentions:
        attentions = attentions['attentions']
        return handle_attentions_i2t(state=SimpleNamespace(**st),
                                    token_selector="soft_prompt",
                                    layer_indices=layer_indices,
                                    head_indices=head_indices,
                                    patch_grid=patch_grid,
                                    img_token_start=img_start,
                                    num_image_tokens=num_image_tokens,
                                    save_dir=save_dir,
                                    overlay_alpha=overlay_alpha)

    # normalize layer_indices
    if layer_indices == "all":
        layer_idx_list = list(range(num_layers))
    else:
        layer_idx_list = list(layer_indices)

    # We'll collect saved image paths
    if save_dir is None:
        save_dir = os.path.join(os.path.dirname(st.get('attention_key')), "attn_vis_prompt")
    os.makedirs(save_dir, exist_ok=True)

    saved_paths = []
    mean_attention_per_head = {}  # key layer -> array of mean attention per head (for heatmap)

    # get number of heads by inspecting a tensor
    # find a sample mh tensor
    sample_mh = None
    if per_token_format:
        sample_mh = attentions[token_idx_list[0]][layer_idx_list[0]]
    else:
        sample_mh = attentions[layer_idx_list[0]]
    # adapt shape: sample_mh could be [B,H,Q,K] or [B,H,K] or [B,H,Q,K] inside list
    if isinstance(sample_mh, torch.Tensor):
        if sample_mh.dim() == 4:
            _, num_heads, _, key_len = sample_mh.shape
        elif sample_mh.dim() == 3:
            _, num_heads, key_len = sample_mh.shape
        else:
            raise ValueError(f"Unexpected attention tensor shape: {sample_mh.shape}")
    else:
        raise ValueError("Couldn't determine attention tensor shape.")

    # iterate layers and heads
    for layer_idx in layer_idx_list:
        per_head_means = []
        for head_idx in range(num_heads if head_indices == "all" else len(head_indices)):
            real_head = head_idx if head_indices == "all" else head_indices[head_idx]
            # accumulate attention over selected tokens
            img_attn_acc = None
            valid_token_count = 0
            for token_idx in token_idx_list:
                # get mh_attention for this token and layer
                try:
                    if per_token_format:
                        mh = attentions[token_idx][layer_idx]  # tensor
                    else:
                        mh = attentions[layer_idx]  # tensor -> we will select query dim
                except Exception as e:
                    logger.warning(f"Failed to index attentions with token={token_idx}, layer={layer_idx}: {e}")
                    continue

                if not isinstance(mh, torch.Tensor):
                    logger.warning(f"Attention entry is not a tensor: {type(mh)}; skipping")
                    continue

                # mh could be [B,H,Q,K] or [B,H,K]. If Q>1, pick the last query index (typical for step-wise cross-attn).
                if mh.dim() == 4:
                    # [B, H, Q, K]
                    B, H, Q, K = mh.shape
                    # choose query index -> token's output pos may map to an index in Q dimension.
                    # if per_token_format we are already indexing per output token so choose last query
                    if per_token_format:
                        # mh already corresponds to a single output step (often Q==1), take last query
                        if Q > 1:
                            mh2 = mh[:, real_head, -1, :]  # [B, K]
                        else:
                            mh2 = mh[:, real_head, 0, :]
                    else:
                        # attentions[layer] is per layer across queries; we use token_idx as query index
                        q_index = token_idx
                        if q_index >= Q:
                            # fallback to last
                            q_index = Q - 1
                        mh2 = mh[:, real_head, q_index, :]  # [B, K]
                elif mh.dim() == 3:
                    # [B, H, K] -> directly
                    mh2 = mh[:, real_head, :]
                else:
                    raise ValueError(f"Unhandled attention tensor dim: {mh.dim()}")

                # mh2 is [B, K] -> take batch 0
                arr = mh2[0].detach().cpu().numpy()  # key_len
                # select image token slice
                s = img_start + 128
                e = img_start + 128 + num_image_tokens
                # s = img_start
                # e = img_start + num_image_tokens
                if e > arr.shape[-1]:
                    # fallback: if image tokens not at expected location, try to detect where they are
                    # if arr length equals num_image_tokens, use 0..num_image_tokens
                    if arr.shape[-1] == num_image_tokens:
                        s, e = 0, num_image_tokens
                    else:
                        # clamp end
                        e = arr.shape[-1]
                        if s >= e:
                            s = max(0, e - num_image_tokens)
                # import ipdb; ipdb.set_trace()
                img_slice = arr[s:e]
                if layer_idx == 16:
                    img_slice[:-128] = img_slice[:-128] * 30
                if img_slice.size != (patch_grid[0] * patch_grid[1]):
                    # try to reshape based on available size
                    # attempt to compute grid from available length
                    L = img_slice.size
                    g_h = patch_grid[0]
                    g_w = patch_grid[1]
                    if g_h * g_w != L:
                        # try to infer
                        possible = int(math.sqrt(L))
                        if possible * possible == L:
                            g_h = g_w = possible
                            logger.info(f"Inferred patch grid {g_h}x{g_w} from attention length {L}")
                        else:
                            # fallback: pad or crop to expected size
                            if L < patch_grid[0] * patch_grid[1]:
                                # pad with zeros
                                pad_len = patch_grid[0] * patch_grid[1] - L
                                img_slice = np.concatenate([img_slice, np.zeros(pad_len, dtype=img_slice.dtype)])
                            else:
                                img_slice = img_slice[:patch_grid[0] * patch_grid[1]]
                            g_h, g_w = patch_grid
                    else:
                        g_h, g_w = g_h, g_w
                else:
                    g_h, g_w = patch_grid

                img_attn_token = img_slice.reshape(g_h, g_w).astype(np.float32)

                if img_attn_acc is None:
                    img_attn_acc = img_attn_token
                else:
                    img_attn_acc = img_attn_acc + img_attn_token
                valid_token_count += 1

            if valid_token_count == 0:
                logger.info(f"No valid tokens found for layer {layer_idx} head {real_head}")
                continue
            img_attn_acc = img_attn_acc / float(valid_token_count)

            # Save overlay to file
            overlay = draw_heatmap_on_image(img_attn_acc, recovered_image, alpha=overlay_alpha)
            # file name
            base_name = os.path.basename(st['attention_key'])
            ans_id = st.get('answer_id', base_name)
            fname = f"{ans_id}_layer{layer_idx+1:02d}_head{real_head+1:02d}.png"
            out_path = os.path.join(save_dir, fname)
            overlay.save(out_path)
            saved_paths.append(out_path)

            # mean per head for summary
            per_head_means.append(np.mean(img_attn_acc))

        # store mean attention array for this layer (used for per-layer head ranking)
        mean_attention_per_head[layer_idx] = np.array(per_head_means)

    # Create a summary heatmap (layer x head mean) if possible
    try:
        # Build matrix rows = layers, cols = heads (using mean_attention_per_head)
        rows = []
        layers_sorted = sorted(mean_attention_per_head.keys())
        max_heads = max((len(mean_attention_per_head[l]) for l in layers_sorted), default=0)
        for l in layers_sorted:
            row = list(mean_attention_per_head[l])
            # pad to max_heads
            row += [0.0] * (max_heads - len(row))
            rows.append(row)
        if rows:
            fig, ax = plt.subplots(figsize=(max(6, max_heads * 0.4), max(2, len(rows) * 0.35)))
            sns.heatmap(np.array(rows), ax=ax, cbar=True, xticklabels=[f"H{i+1}" for i in range(max_heads)],
                        yticklabels=[f"L{l+1}" for l in layers_sorted], cmap="viridis")
            ax.imshow(recovered_image)
            ax.set_xlabel("Head")
            ax.set_ylabel("Layer")
            ax.set_title("Mean image attention per head (selected tokens)")
            summary_path = os.path.join(save_dir, f"{st.get('answer_id', 'ans')}_summary_heads_layers.png")
            fig.tight_layout()
            fig.savefig(summary_path)
            plt.close(fig)
            saved_paths.append(summary_path)
    except Exception as e:
        logger.warning(f"Failed to create summary heatmap: {e}")

    logger.info(f"Saved {len(saved_paths)} attention visualization files to {save_dir}")
    return saved_paths


# ----------------- 修改 eval_model 中的部分（替换或按需整合） -----------------
def eval_model(args):
    # Model
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model_both(model_path, args.model_base, model_name, args.use_prompt_tuning)
    questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
    answers_file = os.path.expanduser(args.answers_file)
    os.makedirs(os.path.dirname(answers_file), exist_ok=True)
    ans_file = open(answers_file, "w")

    if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
        args.conv_mode = args.conv_mode + '_mmtag'
        print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')

    data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
    i = 0

    # prepare folder to save attentions/visualizations
    base_out_dir = os.path.dirname(answers_file)
    attn_save_dir = os.path.join(base_out_dir, "attentions")
    os.makedirs(attn_save_dir, exist_ok=True)
    vis_save_dir = os.path.join(base_out_dir, "attn_visuals_withMem-test-prompt926")
    os.makedirs(vis_save_dir, exist_ok=True)

    for (input_ids, image_tensor, image_sizes,image_path), line in tqdm(zip(data_loader, questions), total=len(questions)):
        idx = line["question_id"]
        cur_prompt = line["text"]
        input_ids = input_ids.to(device='cuda', non_blocking=True)
        img_idx = torch.where(input_ids == IMAGE_TOKEN_INDEX)[1][0].item()
        model.enc_attn_weights = []
        model.enc_attn_weights_vit = []
        eos_token_id = tokenizer.eos_token_id
        with torch.inference_mode():
            outputs = model.generate(
                inputs=input_ids,
                images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
                image_sizes=image_sizes,
                do_sample=True if args.temperature > 0 else False,
                temperature=args.temperature,
                top_p=args.top_p,
                num_beams=args.num_beams,
                max_new_tokens=args.max_new_tokens,
                use_cache=True,
                output_attentions=True,
                return_dict_in_generate=True,
                output_scores=True,
                eos_token_id=eos_token_id)
        # import ipdb; ipdb.set_trace()
        image_recover = Image.open(image_path[0]).convert('RGB')
        # prepare input text for debug / token replacement
        input_ids_list = input_ids.reshape(-1).tolist()
        input_ids_list[img_idx] = 0
        input_text = tokenizer.decode(input_ids_list)
        if input_text.startswith("<s>"):
            input_text = '<s>' + input_text[4:]  # Remove the first space after <s> to maintain correct length
        input_text_tokenized = tokenizer.tokenize(input_text)
        input_text_tokenized[img_idx] = "average_image"

        # decode outputs
        output_ids = outputs.sequences.reshape(-1).tolist()
        generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
        output_ids_decoded = [tokenizer.decode(oid).strip() for oid in output_ids]

        # write answer entry
        ans_id = shortuuid.uuid()
        ans_file.write(json.dumps({"question_id": idx,
                                   "prompt": cur_prompt,
                                   "text": generated_text,
                                   "answer_id": ans_id,
                                   "model_id": model_name,
                                   }) + "\n")
        ans_file.flush()

        # ---------------- save attentions to disk in expected format ----------------
        # Prefer model.enc_attn_weights if populated (your model seems to append to that),
        # else fall back to outputs.attentions
        attn_to_save = None
        if hasattr(model, 'enc_attn_weights') and getattr(model, 'enc_attn_weights'):
            attn_to_save = model.enc_attn_weights
            logger.info("Using model.enc_attn_weights for visualization.")
        elif hasattr(outputs, 'attentions') and outputs.attentions is not None:
            attn_to_save = outputs.attentions
            logger.info("Using outputs.attentions for visualization.")
        else:
            logger.warning("No attentions found on model or outputs; skipping attention visualization for this sample.")

        if attn_to_save is not None:
            attn_key_prefix = os.path.join(attn_save_dir, f"{ans_id}")
            torch.save(attn_to_save, attn_key_prefix + "_attn.pt")
            # import ipdb; ipdb.set_trace()
            # Build state for visualization
            state = {
                'attention_key': attn_key_prefix,
                'recovered_image': _recover_image_from_tensor(image_tensor),  # PIL
                'image_idx': img_idx,
                'img_token_start': img_idx,
                'num_image_tokens': 24 * 24,   # default; change if your model uses different patch count
                'patch_grid': (24, 24),
                'output_ids_decoded': output_ids_decoded,
                'answer_id': ans_id,
                # optionally keep raw tensor if needed
                'image_tensor': image_tensor,
                'soft_prompt_length': 128
            }

            # You can customize which tokens/layers/heads to visualize:
            # Examples:
            #  - token_selector = None -> visualize token 0 (default)
            #  - token_selector = [0, 1, 2] -> first 3 output tokens
            #  - token_selector = ["dog", "cat"] -> try to match tokens by text
            token_selector = None  # default single first token; change if you want to visualize others
            # layer_indices = "all" or [0,1,2]
            layer_indices = "all"
            # head_indices = "all" or [0,1,2]
            head_indices = "all"

            try:
                saved_imgs = handle_attentions_i2t(SimpleNamespace(**state),
                                                   token_selector=token_selector,
                                                   layer_indices=layer_indices,
                                                   head_indices=head_indices,
                                                   save_dir=vis_save_dir,
                                                   overlay_alpha=0.55,
                                                   image_recover=image_recover)
                # Optionally: append visualization file paths to answer json (not required)
                # Note: re-open answers file or write separately if structured output desired.
                # Here we'll append a small record file mapping answer_id to visualizations
                viz_map_file = os.path.join(base_out_dir, "attn_visualizations.jsonl")
                with open(viz_map_file, "a") as vmf:
                    vmf.write(json.dumps({"answer_id": ans_id, "visualizations": saved_imgs}) + "\n")
            except Exception as e:
                logger.exception(f"Failed to generate attention visualizations for {ans_id}: {e}")

    ans_file.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="")
    parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
    parser.add_argument("--answers-file", type=str, default="answer.jsonl")
    parser.add_argument("--conv-mode", type=str, default="llava_v1")
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--max_new_tokens", type=int, default=128)
    parser.add_argument("--use_lora", action='store_true', help="Enable LoRA weights loading")
    parser.add_argument("--no_use_lora", action='store_false', dest='use_lora', help="Disable LoRA weights loading")
    parser.add_argument("--use_prompt_tuning", action='store_true', help="Enable Prompt Tuning weights loading",default=True)
    parser.add_argument("--no_use_prompt_tuning", action='store_false', dest='use_prompt_tuning', help="Disable Prompt Tuning weights loading")
    parser.add_argument("--num_virtual_tokens", type=int, default=128)
    parser.add_argument("--prompt_tuning_init_text", type=str, default="init prompt text")
    
    args = parser.parse_args()

    eval_model(args)
