# save as llava_causal_tracing.py (replace your original file content)
import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid
import math
from PIL import Image
from torch.utils.data import Dataset, DataLoader

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model, load_pretrained_model_both
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path

# NOTE: you already had these imports in your original file; kept minimal here

def split_list(lst, n):
    chunk_size = math.ceil(len(lst) / n)
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]

def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]

class CustomDataset(Dataset):
    def __init__(self, questions, image_folder, tokenizer, image_processor, model_config, args):
        self.questions = questions
        self.image_folder = image_folder
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.model_config = model_config
        self.args = args
    def __getitem__(self, index):
        line = self.questions[index]
        image_file = line["image"]
        qs = line["text"]
        if self.model_config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

        conv = conv_templates[self.args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
        image_tensor = process_images([image], self.image_processor, self.model_config)

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        # print(f"input_ids shape: {input_ids.shape}, image_tensor shape: {image_tensor.shape}")
        return input_ids, image_tensor, image.size, line

    def __len__(self):
        return len(self.questions)

def collate_fn(batch):
    input_ids, image_tensors, image_sizes, lines = zip(*batch)
    # keep batch dim first
    input_ids = torch.stack(input_ids, dim=0)
    image_tensors = torch.stack(image_tensors, dim=0)
    return input_ids, image_tensors, image_sizes, lines

def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, args, batch_size=1, num_workers=1):
    assert batch_size == 1, "batch_size must be 1"
    dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config, args)
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
    return data_loader

################################################################################
# Hook utilities to capture module inputs/outputs and do restoration
################################################################################
def find_target_modules(model):
    """
    Try to find candidate modules for:
      - visual encoder modules (ve)
      - multimodal interface (mi)
      - llm transformer blocks (llm) and their MLP submodules
    Returns a dict of lists: {'ve': [(name,module),...], 'mi': [...], 'llm_mlp': [(name,module), ...]}
    This uses heuristics; print matches for inspection.
    """
    matches = {'ve': [], 'mi': [], 'llm_mlp': []}
    for name, module in model.named_modules():
        lname = name.lower()
        # print(lname)
        # visual encoder heuristic
        if any(k in lname for k in ['vision', 'visual', 'vit', 'backbone']) and len(list(module.children())) > 0:
            matches['ve'].append((name, module))
        # multimodal interface heuristic
        if any(k in lname for k in ['multimodal', 'mm_adapter', 'multimodal_adapter', 'mm_adapter_layer', 'mm_linear','mm_projector']):
            matches['mi'].append((name, module))
        # llm mlp heuristic
        if any(k in lname for k in ['mlp', 'ffn', 'feed_forward', 'mlp_out', 'mlp_fc', 'mlp_block','fn']):
            # exclude tiny modules that are not transformer FFN by checking parameters shape
            matches['llm_mlp'].append((name, module))
    # import ipdb; ipdb.set_trace()
    # deduplicate and narrow down: for llm_mlp prefer those under top-level llm/transformer blocks
    llm_mlp_filtered = []
    for name, m in matches['llm_mlp']:
        if any(k in name for k in ['transformer', 'blocks', 'h.', 'encoder', 'decoder', 'model.layers', 'layers', 'blocks']):
            llm_mlp_filtered.append((name, m))
    if len(llm_mlp_filtered) >= 1:
        matches['llm_mlp'] = llm_mlp_filtered

    print(f"Found {len(matches['ve'])} visual encoder candidate modules, {len(matches['mi'])} multimodal interface candidates, {len(matches['llm_mlp'])} llm-mlp candidates.")
    # print top few names for user to inspect if needed
    if len(matches['llm_mlp']) > 0:
        print("Example llm_mlp candidates:", matches['llm_mlp'][:6])
    return matches

class ActivationStore:
    def __init__(self):
        # saved[module_name] = {'inputs': tensor, 'outputs': tensor}
        self.saved = {}

    def save_input(self, module_name, tensor):
        # clone to detach from graph / avoid later mutation
        self.saved.setdefault(module_name, {})['input'] = tensor.detach().cpu().clone()

    def save_output(self, module_name, tensor):
        self.saved.setdefault(module_name, {})['output'] = tensor.detach().cpu().clone()

    def get_input(self, module_name):
        return self.saved.get(module_name, {}).get('input', None)

    def get_output(self, module_name):
        return self.saved.get(module_name, {}).get('output', None)

def register_capture_hooks(model, target_module_names):
    """
    Register forward hooks on the given module name list to capture inputs and outputs.
    Returns list of hook handles.
    """
    store = ActivationStore()
    handles = []
    # map name->module for quick lookup
    name_to_module = dict(model.named_modules())
    for mname in target_module_names:
        if mname not in name_to_module:
            print(f"WARNING: module {mname} not found in model.named_modules(); skipping.")
            continue
        module = name_to_module[mname]
        def make_forward_in_hook(name):
            def hook(mod, input):
                # input is tuple
                if len(input) == 0:
                    return
                tensor = input[0]
                # ensure CPU clone
                store.save_input(name, tensor)
            return hook
        def make_forward_out_hook(name):
            def hook(mod, input, output):
                # output can be tensor or tuple
                out = output
                if isinstance(output, tuple):
                    out = output[0]
                store.save_output(name, out)
            return hook
        handles.append(module.register_forward_pre_hook(make_forward_in_hook(mname)))
        handles.append(module.register_forward_hook(make_forward_out_hook(mname)))
    return store, handles

def remove_handles(handles):
    for h in handles:
        try:
            h.remove()
        except Exception:
            pass

################################################################################
# Forward runner helpers (attempt multiple forward call styles for robustness)
################################################################################
def model_forward(model, input_ids, images, image_sizes, return_dict=True):
    """
    Generic forward call wrapper for models that accept:
      - inputs=input_ids
      - images=images
      - image_sizes=image_sizes
    Tries different call signatures and returns model outputs (assumed to include logits in .logits or as return value).
    """
    kwargs = dict(inputs=input_ids, images=images, image_sizes=image_sizes, return_dict=return_dict)
    # try model.forward, then model.__call__
    if hasattr(model, 'forward'):
        try:
            return model.model.forward(**kwargs)
        except Exception as e:
            # fallback
            pass
    try:
        return model(**kwargs)
    except Exception as e:
        # last resort: try names used in some models
        alt = dict(input_ids=input_ids, images=images, image_sizes=image_sizes, return_dict=return_dict)
        # try:
        return model(**alt)
        # except Exception as e2:
        #     raise RuntimeError(f"Model forward failed with common signatures. Error: {e2}")

def tokens_to_first_token_id(tokenizer, text):
    """Return first token id for a given text via tokenizer (batch encoding)."""
    toks = tokenizer.encode(text, add_special_tokens=False)
    if isinstance(toks, (list, tuple)) and len(toks) > 0:
        return toks[0]
    # fallback: use tokenizer.encode_plus
    enc = tokenizer(text, add_special_tokens=False)
    ids = enc['input_ids'] if 'input_ids' in enc else enc['input_ids']
    return ids[0] if len(ids) > 0 else None

def probs_from_logits(logits, target_token_id, tokenizer=None):
    """
    logits: shape [batch, seq_len, vocab] OR [batch, vocab] depending on model.
    We'll pick the final position (last token) to compute probability (common for causal LM).
    Returns scalar probability for target_token_id.
    """
    if logits is None:
        return None
    # convert to tensor and do softmax at final pos
    if logits.dim() == 3:
        # pick last position
        last = logits[:, -1, :]
    elif logits.dim() == 2:
        last = logits
    else:
        raise RuntimeError("Unexpected logits dim: %s" % (logits.dim(),))
    probs = torch.softmax(last, dim=-1)
    if target_token_id is None:
        # return max prob and argmax
        val, idx = torch.max(probs, dim=-1)
        return val.item(), idx.item()
    else:
        return probs[0, target_token_id].item()

################################################################################
# Main eval + causal tracing logic
################################################################################
def do_causal_tracing_for_sample(model, tokenizer, input_ids, image_tensor, image_size, args, save_folder):
    """
    For one sample:
      - run clean
      - run corrupted (add gaussian noise to image)
      - for each candidate llm_mlp module restore module input to clean and compute restored final prob
      - compute IE per module
    Save JSON result in save_folder.
    """
    model = model.eval().to('cuda')
    device = 'cuda'
    input_ids = input_ids.unsqueeze(0).to(device)
    image_tensor = image_tensor.unsqueeze(0).to(device).to(dtype=torch.float16)

    # find candidate modules (do only once per model ideally; but OK per sample)
    matches = find_target_modules(model)
    # import ipdb; ipdb.set_trace()
    # We'll use names of llm_mlp candidates for per-layer restoration
    llm_mlp_names = [n for n, m in matches['llm_mlp']]
    if len(llm_mlp_names) == 0:
        print("No llm_mlp candidate modules found automatically: please inspect model.named_modules() and pass module names manually.")
    # cap to args.max_layers if provided
    if args.max_layers and args.max_layers > 0:
        llm_mlp_names = llm_mlp_names[:args.max_layers]

    # --- CLEAN RUN ---
    # register capture hooks for these modules to save inputs/outputs
    store_clean, handles_clean = register_capture_hooks(model, llm_mlp_names)
    # import ipdb; ipdb.set_trace()
    with torch.inference_mode():
        out_clean = model_forward(model, input_ids, image_tensor, [image_size])
        # out_clean = 
    # try to obtain logits
    logits_clean = getattr(out_clean, 'logits', None)
    # remove hooks
    remove_handles(handles_clean)

    # determine target token id:
    # Option A (default): use the model's own clean generation first token as "correct"
    # generate one token with model.generate (fast):
    # NOTE: use same generation config as original script
    with torch.inference_mode():
        # small generate - produce short answer to get first token
        gen_ids = model.generate(inputs=input_ids, images=image_tensor, image_sizes=[image_size],
                                 do_sample=False, temperature=0.0, top_p=1.0, num_beams=1, max_new_tokens=8, use_cache=True)
    gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
    # get first token id of generated text (serves as 'correct' token)
    target_token_id = tokens_to_first_token_id(tokenizer, gen_text)
    if target_token_id is None:
        raise RuntimeError("Cannot get target token id from generated text: '%s'" % gen_text)
    # compute p_clean (prob model assigns to target token when run clean)
    p_clean = probs_from_logits(logits_clean, target_token_id)

    # --- CORRUPTED RUN ---
    # make corrupted image by adding gaussian noise
    noise = torch.randn_like(image_tensor) * args.noise_sigma
    corrupted_image = (image_tensor + noise).clamp(0, 1)  # keep in valid range

    # capture corrupted activations (we will compute p_corrupted and later reuse corrupted run as baseline)
    store_cor, handles_cor = register_capture_hooks(model, llm_mlp_names)
    with torch.inference_mode():
        out_cor = model_forward(model, input_ids, corrupted_image, [image_size])
    logits_cor = getattr(out_cor, 'logits', None)
    p_cor = probs_from_logits(logits_cor, target_token_id)
    remove_handles(handles_cor)

    # --- RESTORATION PER MODULE (LLM MLP candidates) ---
    per_module_results = []
    # We'll iterate each module name, register a forward_pre_hook that replaces the module input with clean stored input
    name_to_module = dict(model.named_modules())
    for mname in llm_mlp_names:
        clean_input = store_clean.get_input(mname)
        if clean_input is None:
            # if not captured, skip
            print(f"[WARN] No clean input captured for module {mname}, skipping restoration.")
            per_module_results.append({'module': mname, 'p_restored': None, 'IE': None})
            continue
        mod = name_to_module.get(mname, None)
        if mod is None:
            per_module_results.append({'module': mname, 'p_restored': None, 'IE': None})
            continue

        # define pre-hook to inject clean input
        def make_pre_hook(injection_tensor):
            def pre_hook(module, inputs):
                # inputs is a tuple, we replace the first tensor with injection tensor moved to module device and dtype
                inj = injection_tensor.to(device).to(inputs[0].dtype)
                # keep same shape: some modules expect (batch, seq, dim). If shapes mismatch, try to broadcast
                if inj.shape != inputs[0].shape:
                    try:
                        inj = inj.expand(inputs[0].shape)
                    except Exception:
                        # fallback: try to reshape with view if possible
                        inj = inj.reshape(inputs[0].shape)
                return (inj,) + tuple(inputs[1:])
            return pre_hook

        handle = mod.register_forward_pre_hook(make_pre_hook(clean_input))
        try:
            with torch.inference_mode():
                out_rest = model_forward(model, input_ids, corrupted_image, [image_size])
            logits_rest = getattr(out_rest, 'logits', None)
            p_rest = probs_from_logits(logits_rest, target_token_id)
            IE = None
            if p_rest is not None and p_cor is not None:
                IE = p_rest - p_cor
        except Exception as e:
            print(f"[ERROR] restoration forward failed for module {mname}: {e}")
            p_rest = None
            IE = None
        # remove hook
        try:
            handle.remove()
        except Exception:
            pass
        per_module_results.append({'module': mname, 'p_restored': p_rest, 'IE': IE})

    # Save JSON summarizing
    result = {
        'gen_text': gen_text,
        'target_token_id': int(target_token_id),
        'p_clean': float(p_clean) if p_clean is not None else None,
        'p_corrupted': float(p_cor) if p_cor is not None else None,
        'per_module': per_module_results
    }
    # write file
    uid = shortuuid.uuid()[:8]
    out_file = os.path.join(save_folder, f"causal_tracing_{uid}.json")
    with open(out_file, 'w') as f:
        json.dump(result, f, indent=2)
    print(f"[SAVE] saved tracing result to {out_file}")
    
    
    # import ipdb; ipdb.set_trace()
    ie_matrix, p_restored_matrix, image_positions = compute_patch_importance_matrix(
            model=model,
            tokenizer=tokenizer,
            input_ids=input_ids,   # 1D tensor
            corrupted_image=corrupted_image.squeeze(0).cpu(),  # keep as tensor
            image_size=image_size,
            store_clean=store_clean,
            store_cor=store_cor,
            llm_mlp_names=llm_mlp_names,
            mi_module_name=matches['mi'][0][1],   # 选第一个 mi candidate 名称（或显式指定）
            ve_module_name=matches['ve'][0][1],   # 同理
            target_token_id=target_token_id,
            p_corrupted=p_cor,
            args=args,
            image_tensor=image_tensor,
            device='cuda'
        )
    save_path = os.path.join(save_folder, f"patch_importance_{uid}.png")
    print(f"[SAVE] saving patch importance heatmap to {save_path}")
    plot_ie_heatmap_new(ie_matrix, p_restored_matrix=p_restored_matrix, title='Patch x Layer IE heatmap', \
                        figsize=(10,6), cmap='Greens', save_path=save_path, vmin=None, vmax=None, \
                        highlight_threshold=0.4,image_tensor=image_tensor, image_size=image_size, patch_size=model.model.vision_tower.config.patch_size)
    return result
# import torch
# import numpy as np
# import matplotlib.pyplot as plt

# def compute_patch_importance_matrix(
#     model, input_ids, image_tensor, image_sizes,
#     image_token_id, num_layers=None, last_k=4, target_pos="last"
# ):
#     """
#     返回: numpy.ndarray, 形状 [num_image_tokens, num_layers]
#           每列对应一个层的 patch importance 分布
#     """
#     all_importances = []  # 存储每层的结果

#     # Forward hooks 容器
#     handles = []
#     layer_idx = 0

#     def hook_fn(module, input, output):
#         nonlocal layer_idx
#         if (num_layers is not None) and (layer_idx >= num_layers):
#             return
#         import ipdb; ipdb.set_trace()
#         attn = output[1]  # 假设返回 (hidden_states, attn_weights)
#         # attn: [batch, num_heads, seq_len, seq_len]

#         # 取 target_pos 的 token 作为 query
#         if target_pos == "last":
#             q_idx = input_ids.shape[1] - 1
#         elif isinstance(target_pos, int):
#             q_idx = target_pos
#         else:
#             raise ValueError("target_pos must be 'last' or int")

#         # 平均多头
#         attn_mean = attn[0].mean(0)  # [seq_len, seq_len]

#         # 提取图像 token 的权重
#         img_mask = (input_ids[0] == image_token_id)
#         img_indices = torch.nonzero(img_mask, as_tuple=True)[0]

#         patch_attn = attn_mean[q_idx, img_indices]  # [num_img_tokens]

#         # 归一化
#         patch_importance = patch_attn / patch_attn.sum()

#         all_importances.append(patch_importance.detach().cpu().numpy())

#         layer_idx += 1

#     # 注册所有注意力层的 hook
#     for name, module in model.named_modules():
#         if "attn" in name.lower() and hasattr(module, "forward"):
#             h = module.register_forward_hook(hook_fn)
#             handles.append(h)
#     # import ipdb; ipdb.set_trace()
#     # 前向传播
#     with torch.no_grad():
#         _ = model(input_ids=input_ids, images=image_tensor, image_sizes=image_sizes, output_attentions = True,return_dict=True)

#     # 移除 hook
#     for h in handles:
#         h.remove()

#     # 转成二维矩阵 [num_img_tokens, num_layers]
#     importance_matrix = np.stack(all_importances, axis=1)
#     importance_matrix /= importance_matrix.sum(axis=0, keepdims=True)  # 每列归一化

#     return importance_matrix
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib import patches

def detect_image_positions_from_mi(store_clean, store_cor, mi_module_name, ve_module_name):
    """
    通过 mi 层输出的 clean vs cor 差异 + ve 层 patch 数量定位 image token 在序列中的位置区间。
    返回: image_positions (list of indices), num_patches
    """
    ve_out = store_clean.get_output(ve_module_name)  # expected shape [1, num_patches, dim]
    if ve_out is None:
        raise RuntimeError(f"No VE output found in store for {ve_module_name}. Cannot infer num_patches.")
    num_patches = ve_out.shape[1]

    mi_clean = store_clean.get_output(mi_module_name)
    mi_cor = store_cor.get_output(mi_module_name)
    if mi_clean is None or mi_cor is None:
        raise RuntimeError(f"No MI outputs found for {mi_module_name} in clean/cor stores.")

    # per-position difference norm
    diffs = torch.norm(mi_clean - mi_cor, dim=-1).squeeze(0)  # shape [seq_len]
    seq_len = diffs.shape[0]
    K = num_patches
    if seq_len < K:
        # fallback: assume image patches are last K positions
        start = max(0, seq_len - K)
        return list(range(start, seq_len)), K

    # sliding-window sum via cumsum for speed
    cumsum = torch.cat([torch.tensor([0.], device=diffs.device), torch.cumsum(diffs, dim=0)])
    # window_sum[i] = sum diffs[i:i+K]
    window_sum = cumsum[K:] - cumsum[:-K]
    start = int(torch.argmax(window_sum).item())
    image_positions = list(range(start, start + K))
    return image_positions, K

def compute_patch_importance_matrix(
    model,
    tokenizer,
    input_ids,                 # 1D or 2D tensor (batch size 1)
    corrupted_image,           # tensor shape [1, C, H, W], dtype as model expects
    image_size,
    store_clean,               # ActivationStore from clean run (hooks captured)
    store_cor,                 # ActivationStore from corrupted run
    llm_mlp_names,             # list of module names (strings) for LLM layers to try
    mi_module_name,            # string name of multimodal interface module captured
    ve_module_name,            # string name of visual encoder output module captured
    target_token_id,           # int: token id we measure probability for
    p_corrupted,               # float baseline p_corrupted for that sample
    args,        
    image_tensor,# args for forward (if needed, e.g. generation params), can be minimal
    device='cuda'
):
    """
    返回:
    model, input_ids, image_tensor,
      ie_matrix: numpy array shape [num_patches, num_layers] (IE values)
      p_restored_matrix: numpy array same shape containing p_restored values
      image_positions: list of token indices in hidden-state sequence that correspond to patches
    """
    model = model.eval().to(device)
    # detect positions
    # import ipdb; ipdb.set_trace()
    # image_positions, num_patches = detect_image_positions_from_mi(store_clean, store_cor, mi_module_name, ve_module_name)
    # import ipdb; ipdb.set_trace()
    # if hasattr(model, "config") and hasattr(model.config, "vision_config"):
    #     patch_size = model.config.vision_config.patch_size
    # else:
    #     raise RuntimeError("model.config.vision_config.patch_size 不存在，请检查模型配置")
    patch_size = model.model.vision_tower.config.patch_size
    h, w = image_tensor.shape[-2:]
    num_patches = (h // patch_size) * (w // patch_size)
    image_positions = list(range(1, num_patches + 1))  # 0 位置通常是 CLS token
    
    num_layers = len(llm_mlp_names)
    ie_matrix = np.full((num_patches, num_layers), np.nan, dtype=float)
    p_restored_matrix = np.full((num_patches, num_layers), np.nan, dtype=float)

    name_to_module = dict(model.named_modules())

    # Ensure input_ids and image on device/dtype expected
    if input_ids.dim() == 1:
        input_ids_batch = input_ids.unsqueeze(0).to(device)
    else:
        input_ids_batch = input_ids.to(device)
    corrupted_image = corrupted_image.to(device)
    print(f"[INFO] computing patch importance matrix for {num_patches} patches across {num_layers} layers...")
    # For each layer (module) and each patch index, inject only that patch's clean vector
    for layer_idx, mname in enumerate(llm_mlp_names):
        mod = name_to_module.get(mname, None)
        if mod is None:
            print(f"[WARN] module {mname} not found in model.named_modules(); skipping layer {layer_idx}")
            continue

        # get clean inputs (fallback to outputs if input missing)
        clean_in = store_clean.get_input(mname)
        if clean_in is None:
            clean_in = store_clean.get_output(mname)
            if clean_in is None:
                print(f"[WARN] No clean activation (in/out) captured for module {mname}; skipping.")
                continue

        # clean_in expected shape [1, seq_len, dim]
        # convert to device/dtype when used
        print(f"[INFO] processing layer {layer_idx+1}/{num_layers}: {mname} with clean_in shape {clean_in.shape}")
        for p in range(num_patches):
            # Build pre-hook that replaces only the specific patch vector at the corresponding position
            def make_pre_hook(patch_pos, clean_patch_vec):
                def pre_hook(module, inputs):
                    # inputs is a tuple; usually first element is the tensor we want
                    inp = inputs[0]
                    # ensure same device/dtype
                    inj = inp.clone()
                    # broadcast clean_patch_vec to match shape (batch, hidden)
                    cp = clean_patch_vec.to(inp.device).to(inp.dtype)
                    # set that position
                    inj[:, patch_pos, :] = cp
                    # return new tuple of inputs
                    return (inj,) + tuple(inputs[1:])
                return pre_hook

            # get vector for this patch from clean_in
            try:
                clean_patch_vec = clean_in[:, image_positions[p], :].detach().cpu()  # keep CPU to avoid accidental graph
            except Exception as e:
                print(f"[WARN] cannot extract clean patch vector for {mname} patch {p}: {e}")
                continue

            hook_handle = mod.register_forward_pre_hook(make_pre_hook(image_positions[p], clean_patch_vec))

            # run forward on corrupted image with this single-patch injection
            try:
                with torch.inference_mode():
                    out_rest = model_forward(model, input_ids_batch, corrupted_image, [image_size])
                logits_rest = getattr(out_rest, 'logits', None)
                p_rest = probs_from_logits(logits_rest, target_token_id)
                ie = None if p_rest is None or p_corrupted is None else (p_rest - p_corrupted)
                p_restored_matrix[p, layer_idx] = float(p_rest) if p_rest is not None else np.nan
                ie_matrix[p, layer_idx] = float(ie) if ie is not None else np.nan
            except Exception as e:
                print(f"[ERROR] forward-with-single-patch-injection failed for layer {mname} patch {p}: {e}")
                p_restored_matrix[p, layer_idx] = np.nan
                ie_matrix[p, layer_idx] = np.nan
            finally:
                try:
                    hook_handle.remove()
                except Exception:
                    pass

    return ie_matrix, p_restored_matrix, image_positions



def plot_importance_heatmap(importance_matrix, patch_size=14, image_size=336):
    """
    importance_matrix: numpy.ndarray, [num_img_tokens, num_layers]
    """
    num_patches = (image_size // patch_size) ** 2
    assert importance_matrix.shape[0] == num_patches, \
        f"patch数不一致: got {importance_matrix.shape[0]}, expected {num_patches}"

    plt.figure(figsize=(10, 6))
    plt.imshow(importance_matrix, cmap="hot", aspect="auto")
    plt.colorbar(label="Importance")
    plt.xlabel("Layer")
    plt.ylabel("Image Patch Index")
    plt.title("Patch Importance Across Layers")
    plt.show()
def plot_ie_heatmap(ie_matrix, p_restored_matrix=None, title='Patch x Layer IE heatmap',
                    figsize=(10,6), cmap='Greens', save_path=None, vmin=None, vmax=None,
                    highlight_threshold=None):
    """
    ie_matrix: np.array shape [num_patches, num_layers]
    highlight_threshold: float or None, 若指定，则在图上用虚框标注出平均 IE > threshold 的列区间
    """
    num_patches, num_layers = ie_matrix.shape
    plt.figure(figsize=figsize)
    ax = plt.gca()
    # choose vmin/vmax if not given
    if vmin is None:
        vmin = np.nanmin(ie_matrix)
    if vmax is None:
        vmax = np.nanmax(ie_matrix)
    im = ax.imshow(ie_matrix, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)
    plt.colorbar(im, ax=ax, label='IE = p_restored - p_corrupted')

    ax.set_xlabel('LLM Layers')
    ax.set_ylabel('Image Patches (index)')
    ax.set_title(title)
    ax.set_xticks(np.arange(num_layers))
    # optionally set xticks fewer if too many layers
    if num_layers > 30:
        ax.set_xticks(np.linspace(0, num_layers-1, 20, dtype=int))
    ax.set_yticks(np.arange(num_patches))
    # optionally reduce ytick labels for large K
    if num_patches > 40:
        ax.set_yticks(np.linspace(0, num_patches-1, 20, dtype=int))

    # highlight columns where average IE across patches exceeds threshold
    if highlight_threshold is not None:
        col_mean = np.nanmean(ie_matrix, axis=0)
        # find contiguous intervals where mean > threshold
        mask = col_mean > highlight_threshold
        start = None
        for i, val in enumerate(mask):
            if val and start is None:
                start = i
            if (not val or i == len(mask)-1) and start is not None:
                end = i if (not val) else i
                # draw rectangle: x=start-0.5, y=-0.5, width=(end-start+1), height=num_patches
                rect = patches.Rectangle((start-0.5, -0.5), end-start+1, num_patches,
                                         linewidth=1.5, edgecolor='red', facecolor='none', linestyle='--')
                ax.add_patch(rect)
                start = None

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150)
    # plt.show()
    
def expand_patch_map(patch_map: torch.Tensor, image_size, patch_size: int):
    """
    将 patch_map (Hp, Wp) 扩展到 image_size (H, W)，
    每个 patch 的值在对应区域内保持一致（无插值）。
    """
    H, W = image_size
    Hp, Wp = patch_map.shape

    out = torch.zeros((H, W), dtype=patch_map.dtype, device=patch_map.device)

    for i in range(Hp):
        for j in range(Wp):
            h_start, h_end = i * patch_size, min((i + 1) * patch_size, H)
            w_start, w_end = j * patch_size, min((j + 1) * patch_size, W)
            out[h_start:h_end, w_start:w_end] = patch_map[i, j]

    return out

def plot_ie_heatmap_new(
    ie_matrix, p_restored_matrix=None, title='Patch x Layer IE heatmap',
    figsize=(10,6), cmap='Greens', save_path=None, vmin=None, vmax=None,
    highlight_threshold=None,
    image_tensor=None, image_size=None, patch_size=None
):
    """
    ie_matrix, p_restored_matrix=None, title='Patch x Layer IE heatmap',
                    figsize=(10,6), cmap='Greens', save_path=None, vmin=None, vmax=None,
                    highlight_threshold=None):
    ie_matrix: np.array shape [num_patches, num_layers]
    image_tensor: torch.Tensor or np.array, [3,H,W] (可选, 用于显示重合热力图)
    image_size, patch_size: int, 用于重构 patch 网格
    """
    num_patches, num_layers = ie_matrix.shape

    # ----------------------
    # 图1: Patch × Layer 矩阵热力图
    # ----------------------
    plt.figure(figsize=figsize)
    ax = plt.gca()
    if vmin is None:
        vmin = np.nanmin(ie_matrix)
    if vmax is None:
        vmax = np.nanmax(ie_matrix)
    im = ax.imshow(ie_matrix, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)
    plt.colorbar(im, ax=ax, label='IE = p_restored - p_corrupted')

    ax.set_xlabel('LLM Layers')
    ax.set_ylabel('Image Patches (index)')
    ax.set_title(title)
    ax.set_xticks(np.arange(num_layers))
    if num_layers > 30:
        ax.set_xticks(np.linspace(0, num_layers-1, 20, dtype=int))
    ax.set_yticks(np.arange(num_patches))
    if num_patches > 40:
        ax.set_yticks(np.linspace(0, num_patches-1, 20, dtype=int))

    if highlight_threshold is not None:
        col_mean = np.nanmean(ie_matrix, axis=0)
        mask = col_mean > highlight_threshold
        start = None
        for i, val in enumerate(mask):
            if val and start is None:
                start = i
            if (not val or i == len(mask)-1) and start is not None:
                end = i if (not val) else i
                rect = patches.Rectangle((start-0.5, -0.5), end-start+1, num_patches,
                                         linewidth=1.5, edgecolor='red', facecolor='none', linestyle='--')
                ax.add_patch(rect)
                start = None
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path.replace(".png", "_matrix.png"), dpi=150)

    # ----------------------
    # 图2: 与图像对齐的 patch 热力图
    # ----------------------
    if image_tensor is not None and image_size is not None and patch_size is not None:
        # (num_patches, num_layers) → (num_patches,) by averaging layers
        patch_scores = np.nanmean(ie_matrix, axis=1)
        num_patches = patch_scores.shape[0]
        Hp = int(np.sqrt(num_patches))
        Wp = num_patches // Hp
        patch_map = patch_scores.reshape(Hp, Wp)

        # reshape to (H_patch, W_patch)
        # _, _, C, H, W = image_tensor.shape
        # _ , H, W = image_size
        H, W = image_size
        # Hp, Wp = H // patch_size, W // patch_size
        # patch_map = patch_scores.reshape(Hp, Wp)

        # 上采样到原图大小
        
        # patch_map_up = torch.nn.functional.interpolate(
        #     torch.tensor(patch_map)[None, None, :, :].float(),
        #     size=(W,H),
        #     mode="bilinear",
        #     align_corners=False
        # )[0,0].numpy().astype(np.float32)
        patch_map_up = expand_patch_map(torch.tensor(patch_map).float(),image_size, patch_size).numpy().astype(np.float32)
        # 画图
        plt.figure(figsize=(6,6))
        # print(image_tensor.shape)
        # import ipdb;ipdb.set_trace()
        if isinstance(image_tensor, torch.Tensor):
            img = image_tensor.squeeze(0).squeeze(0).permute(1,2,0).cpu().numpy().astype(np.float32)
        else:
            img = image_tensor
        # img = (img - img.min()) / (img.max() - img.min() + 1e-8)

        plt.imshow(img)
        plt.imshow(patch_map_up, cmap=cmap, alpha=0.5)
        plt.axis("off")
        plt.title("Patch Importance Overlay")

        if save_path:
            plt.savefig(save_path.replace(".png", "_overlay.png"), dpi=150)
################################################################################
# Main eval flow modified to run causal tracing optionally
################################################################################
def eval_model(args):
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model_both(model_path, args.model_base, model_name, args.use_prompt_tuning)

    questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
    answers_file = os.path.expanduser(args.answers_file)
    os.makedirs(os.path.dirname(answers_file), exist_ok=True)
    ans_file = open(answers_file, "w")

    data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config, args)

    save_folder = os.path.expanduser(args.save_folder)
    os.makedirs(save_folder, exist_ok=True)

    i = 0
    for (input_ids, image_tensor, image_sizes, lines), line in tqdm(zip(data_loader, questions), total=len(questions)):
        idx = line["question_id"]
        cur_prompt = line["text"]
        input_ids = input_ids.to(device='cuda', non_blocking=True)

        # Normal generation (as in your original code)
        # with torch.inference_mode():
        #     output_ids = model.generate(
        #         inputs=input_ids,
        #         images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
        #         image_sizes=image_sizes,
        #         do_sample=True if args.temperature > 0 else False,
        #         temperature=args.temperature,
        #         top_p=args.top_p,
        #         num_beams=args.num_beams,
        #         max_new_tokens=args.max_new_tokens,
        #         use_cache=True)

        # outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
        # ans_id = shortuuid.uuid()
        # ans_file.write(json.dumps({"question_id": idx,
        #                            "prompt": cur_prompt,
        #                            "text": outputs,
        #                            "answer_id": ans_id,
        #                            "model_id": model_name,
        #                            }) + "\n")
        # ans_file.flush()

        # Optionally run causal tracing for this sample
        if args.do_causal_tracing:
            print(f"Running causal tracing for question_id={idx}")
            # do per-sample tracing
            # try:
            res = do_causal_tracing_for_sample(model, tokenizer, input_ids.squeeze(0).cpu(), image_tensor.squeeze(0).cpu(), image_sizes[0], args, save_folder)
            # also write a small summary into answers file
            summary = {
                'question_id': idx,
                'gen_text': res['gen_text'],
                'p_clean': res['p_clean'],
                'p_corrupted': res['p_corrupted'],
                'top_module_IE': sorted([(p['module'], p['IE']) for p in res['per_module'] if p['IE'] is not None], key=lambda x: (x[1] is None, -(x[1] or 0)))[:5]
            }
            ans_file.write(json.dumps({'causal_tracing_summary': summary}) + "\n")
            ans_file.flush()
            # except Exception as e:
            #     print(f"[ERROR] causal tracing failed for sample {idx}: {e}")
        # image_token_id = IMAGE_TOKEN_INDEX
        # num_layers = len(model.model.layers)


    ans_file.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="")
    parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
    parser.add_argument("--answers-file", type=str, default="answer.jsonl")
    parser.add_argument("--conv-mode", type=str, default="llava_v1")
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--max_new_tokens", type=int, default=128)
    parser.add_argument("--use_lora", action='store_true', help="Enable LoRA weights loading")
    parser.add_argument("--no_use_lora", action='store_false', dest='use_lora', help="Disable LoRA weights loading")
    parser.add_argument("--use_prompt_tuning", action='store_true', help="Enable Prompt Tuning weights loading", default=True)
    parser.add_argument("--no_use_prompt_tuning", action='store_false', dest='use_prompt_tuning', help="Disable Prompt Tuning weights loading")
    parser.add_argument("--num_virtual_tokens", type=int, default=128)
    parser.add_argument("--prompt_tuning_init_text", type=str, default="init prompt text")

    # new args for causal tracing
    parser.add_argument("--do_causal_tracing", action='store_true', help="Run causal tracing per sample (slow)")
    parser.add_argument("--save-folder", type=str, default="./causal_tracing_results")
    parser.add_argument("--noise_sigma", type=float, default=0.07, help="Gaussian noise sigma for corrupted image (in image range 0..1)")
    parser.add_argument("--max_layers", type=int, default=12, help="Max number of llm-mlp modules to try (0 means all found)")
    args = parser.parse_args()

    eval_model(args)
