import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model, load_pretrained_model_both
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from torch.utils.data import Dataset, DataLoader
from peft import PromptTuningConfig, PromptTuningInit, get_peft_model

from PIL import Image
import math
# 加载LoRA适配器
from peft import PeftModel
import pdb
import sys
# import llava-interp
# from llava-interp.src import create_interactive_logit_lens
import torch
import json
from pathlib import Path
import base64
from io import BytesIO
from PIL import Image
import torch
import numpy as np
from PIL import Image, ImageDraw

def find_image_token_span(input_ids, image_token_id):
    seq = input_ids[0].tolist()
    idxs = [i for i, t in enumerate(seq) if t == image_token_id]
    if not idxs:
        return None, 0
    # 取第一段连续的 <image> token（LLaVA 通常是一段）
    start = idxs[0]
    count = 1
    for i in range(1, len(idxs)):
        if idxs[i] == idxs[i-1] + 1:
            count += 1
        else:
            break
    return start, count

@torch.no_grad()
# def compute_patch_importance_attention(
#     model, input_ids, image_tensor, image_sizes,
#     image_token_id, last_k=4, target_pos="last"
# ):
#     """
#     返回: numpy.ndarray, 形状 [num_image_tokens]，已在图像 token 上归一化（和为1）
#     """
#     # 确保能返回 attentions
#     out = model(
#         input_ids=input_ids,
#         images=image_tensor,
#         image_sizes=image_sizes,
#         output_attentions=True,
#         use_cache=False,
#         return_dict=True
#     )
#     attentions = out.attentions  # tuple(len=L), 每项 [B, H, S, S]
#     assert attentions is not None, "model未返回attentions，确认config/output_attentions=True"

#     img_start, img_count = find_image_token_span(input_ids, image_token_id)
#     assert img_count > 0, "未在input_ids中找到图像token区间"

#     S = input_ids.size(1)
#     if target_pos == "last":
#         q_pos = S - 1
#     else:
#         q_pos = int(target_pos)

#     # 堆叠 -> [L, H, S, S]
#     attn = torch.stack([a[0] for a in attentions], dim=0)
#     # head 平均 -> [L, S, S]
#     attn = attn.mean(dim=1)

#     # 聚合最后K层
#     K = min(last_k, attn.size(0))
#     vecs = []
#     for l in range(attn.size(0) - K, attn.size(0)):
#         # 从目标文本位置 q_pos 看向图像 token 段的注意力
#         vecs.append(attn[l, q_pos, img_start:img_start + img_count])
#     v = torch.stack(vecs, dim=0).mean(dim=0).clamp(min=0)

#     # 在图像 token 上归一化
#     v = v / (v.sum() + 1e-8)
#     return v.detach().cpu().numpy(), (img_start, img_count)
def compute_patch_importance_attention(model,
                                       input_ids,
                                       image_tensor,
                                       image_sizes,
                                       image_token_id,
                                       last_k=4,
                                       target_pos="last"):
    """
    计算 MLLM 基于 cross-attention 权重的图像 patch 重要性
    
    Args:
        model: MLLM 模型
        input_ids: 文本输入 token ids
        image_tensor: 视觉特征 (预处理后)
        image_sizes: 图像尺寸 (通常由preprocess返回)
        image_token_id: 图像token的id（用于确认图像位置）
        last_k: 使用最后K层的cross-attn进行聚合
        target_pos: 选择哪个文本token作为query, "last"表示最后一个token

    Returns:
        importance: numpy array, shape [num_patches]，每个图像patch的重要性分数
    """

    cross_attn_maps = []

    def hook_cross_attn(module, input, output):
        # output: (attn_output, attn_weights)
        attn_weights = output[1]  # [B, num_heads, Q_len, K_len]
        cross_attn_maps.append(attn_weights.detach().cpu())

    # 注册 hook 到 cross-attn 层
    handles = []
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.MultiheadAttention):
            if "cross" in name or "vision" in name:  # 保证只hook cross-attn
                handles.append(module.register_forward_hook(hook_cross_attn))

    # forward 触发 hook
    _ = model(
        input_ids=input_ids,
        images=image_tensor,
        image_sizes=image_sizes,
        output_attentions=False,
        use_cache=False,
        return_dict=True
    )

    # 移除hook
    for h in handles:
        h.remove()

    if len(cross_attn_maps) == 0:
        raise RuntimeError("没有捕捉到 cross-attention 权重，请确认模型结构和 hook 条件。")

    # 堆叠 → [L, num_heads, T, P]
    attn = torch.stack([m[0] for m in cross_attn_maps], dim=0)
    attn = attn.mean(dim=1)  # 平均多头 → [L, T, P]

    T = attn.size(1)
    if target_pos == "last":
        q_pos = T - 1
    else:
        q_pos = int(target_pos)

    # 聚合最后K层
    K = min(last_k, attn.size(0))
    vecs = [attn[l, q_pos] for l in range(attn.size(0)-K, attn.size(0))]
    v = torch.stack(vecs, dim=0).mean(dim=0).clamp(min=0)  # [P]

    # 归一化
    v = v / (v.sum() + 1e-8)
    return v.numpy()

def overlay_patch_heatmap(
    image: Image.Image,
    patch_scores: np.ndarray,
    image_size=336, patch_size=14,
    alpha=0.55
) -> Image.Image:
    """
    将长度 grid^2 的 patch_scores(0~1) 画在图片上，返回RGBA图。
    """
    # 方形中央裁剪 + resize（和你HTML一致）
    img_w, img_h = image.size
    min_dim = min(img_w, img_h)
    left = (img_w - min_dim) / 2
    top = (img_h - min_dim) / 2
    right = (img_w + min_dim) / 2
    bottom = (img_h + min_dim) / 2
    image_cropped = image.crop((left, top, right, bottom))
    import ipdb;ipdb.set_trace()
    image_resized = image_cropped.resize(image_size, Image.LANCZOS).convert("RGBA")

    grid = image_size // patch_size
    
    assert patch_scores.size == grid * grid, f"patch_scores长度({patch_scores.size})与grid^2({grid*grid})不符"

    # 0-1 归一，避免全零
    s = patch_scores.astype(np.float32)
    s = (s - s.min()) / (s.max() - s.min() + 1e-8)

    overlay = Image.new("RGBA", (image_size, image_size), (0, 0, 0, 0))
    draw = ImageDraw.Draw(overlay, "RGBA")

    for idx, val in enumerate(s):
        r = idx // grid
        c = idx % grid
        x0, y0 = c * patch_size, r * patch_size
        x1, y1 = x0 + patch_size, y0 + patch_size
        # 以透明度表达权重；你也可以把 (255,0,0) 换成更复杂的色图
        a = int(255 * float(val) * alpha)
        draw.rectangle([x0, y0, x1, y1], fill=(255, 0, 0, a))

    return Image.alpha_composite(image_resized, overlay)

def create_interactive_logit_lens(hidden_states, norm, lm_head, tokenizer, image, model_name, image_filename, prompt, save_folder = ".", image_size=336, patch_size=14, misc_text="", patch_scores=None):
    # Tokenize the prompt
    input_ids = tokenizer.encode(prompt)
    
    # Find the image token and replace it with image tokens
    img_token_id = 32000  # The token ID for <img>
    img_token_count = (image_size // patch_size) ** 2  # 576 for 336x336 image with 14x14 patches
    
    if patch_scores is None:
        patch_scores = [0.0] * ((image_size // patch_size) ** 2)
    
    
    token_labels = []
    for token_id in input_ids:
        if token_id == img_token_id:
            # One indexed because the HTML logic wants it that way
            token_labels.extend([f"<IMG{(i+1):03d}>" for i in range(img_token_count)])
        else:
            token_labels.append(tokenizer.decode([token_id]))
    
    # Exclude the input embedding layer if it's included
    num_layers = len(hidden_states)
    sequence_length = hidden_states[0].size(1)
    
    all_top_tokens = []
    
    for layer in range(num_layers):
        layer_hidden_states = hidden_states[layer]
        
        # Apply norm and lm_head
        normalized = norm(layer_hidden_states)
        logits = lm_head(normalized)
        
        # Get probabilities
        probs = torch.softmax(logits, dim=-1)
        
        # Get top 5 tokens and their probabilities for each position
        top_5_values, top_5_indices = torch.topk(probs, k=5, dim=-1)
        
        layer_top_tokens = []
        for pos in range(sequence_length):
            top_5_tokens = [tokenizer.decode(idx.item()) for idx in top_5_indices[0, pos]]
            top_5_probs = [f"{prob.item():.4f}" for prob in top_5_values[0, pos]]
            layer_top_tokens.append(list(zip(top_5_tokens, top_5_probs)))
        
        all_top_tokens.append(layer_top_tokens)
    
    # Process the image: central crop and resize
    img_w, img_h = image.size
    min_dim = min(img_w, img_h)
    left = (img_w - min_dim) / 2
    top = (img_h - min_dim) / 2
    right = (img_w + min_dim) / 2
    bottom = (img_h + min_dim) / 2
    image_cropped = image.crop((left, top, right, bottom))
    image_resized = image_cropped.resize((image_size, image_size), Image.LANCZOS)
    
    # Convert image to base64
    buffered = BytesIO()
    image_resized.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    
    # Generate HTML
    html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Interactive Logit Lens</title>
    <style>
        <div class="image-container">
            <img ... style="width: 336px; height: 336px;">
            <canvas id="heatmapCanvas" width="336" height="336" style="position:absolute;left:0;top:0;"></canvas>
            <div class="highlight-box"></div>
            ...
        </div>

        body { margin: 0; padding: 0; font-family: Arial, sans-serif; }
        .container { display: flex; }
        .image-container { 
            flex: 0 0 auto; 
            margin: 20px; 
            position: relative;
            width: 336px; /* Set to match image width */
        }
        .highlight-box {
            position: absolute;
            border: 2px solid red;
            pointer-events: none;
            display: none;
        }
        .table-container { 
            flex: 1 1 auto;
            position: relative;
            max-height: 90vh;
            overflow: auto;
            margin: 20px;
        }
        table { 
            border-collapse: separate;
            border-spacing: 0;
        }
        th, td { 
            border: 1px solid #ddd; 
            padding: 8px; 
            text-align: center;
            min-width: 80px;
        }
        th { 
            background-color: #f2f2f2; 
            font-weight: bold;
        }
        .corner-header {
            position: sticky;
            top: 0;
            left: 0;
            z-index: 3;
            background-color: #f2f2f2;
        }
        .row-header {
            position: sticky;
            left: 0;
            z-index: 2;
            background-color: #f2f2f2;
        }
        .col-header {
            position: sticky;
            top: 0;
            z-index: 1;
            background-color: #f2f2f2;
        }
        #tooltip {
            display: none;
            position: fixed;
            background: white;
            border: 1px solid black;
            padding: 5px;
            z-index: 1000;
            pointer-events: none;
            max-width: 300px;
            font-size: 14px;
        }
        .highlighted-row {
            background-color: #ffff99;
        }
        .image-info {
            margin-top: 10px;
            font-size: 14px;
            width: 100%;
            word-wrap: break-word;
        }
        .prompt {
            font-weight: bold;
            margin-bottom: 5px;
        }
        .instructions {
            font-style: italic;
        }
    </style>
</head>
<body>
    <div class="container">
        <div class="image-container">
            <img src="data:image/png;base64,IMAGEPLACEHOLDER" alt="Input Image" style="width: 336px; height: 336px;">
            <div class="highlight-box"></div>
            <div class="image-info">
                <p class="prompt">Prompt: "PROMPTPLACEHOLDER"</p>
                <p class="instructions">Instructions: Click on image to lock the patch, click on image/table to unlock</p>
                <p>Info: MISCPLACEHOLDER</p>
            </div>
        </div>
        <div class="table-container">
            <table id="logitLens"></table>
        </div>
    </div>
    <div id="tooltip"></div>
<script>
    const patchScores = PATCHSCORESPLACEMENT;
    const heatmapCanvas = document.getElementById('heatmapCanvas');
    const ctx = heatmapCanvas.getContext('2d');

    function drawHeatmap() {
    const grid = gridSize;
    const cell = imageSize / grid;
    // 归一化
    const minV = Math.min(...patchScores);
    const maxV = Math.max(...patchScores);
    const denom = (maxV - minV) || 1;
    ctx.clearRect(0, 0, heatmapCanvas.width, heatmapCanvas.height);
    for (let i = 0; i < patchScores.length; i++) {
        const v = (patchScores[i] - minV) / denom;
        const r = Math.floor(i / grid), c = i % grid;
        const x = c * cell, y = r * cell;
        // 红色，透明度表达权重
        ctx.fillStyle = `rgba(255,0,0,${0.55 * v})`;
        ctx.fillRect(x, y, cell, cell);
    }
    }
    drawHeatmap();

    // 鼠标提示里加上该 patch 的权重值（若是图像token）
    function showTooltip(e, layer, pos, shouldScroll=false) {
    tooltip.innerHTML = data[layer][pos].map(([token, prob]) => `${token}: ${prob}`).join('<br>');
    // 若为图像token，额外显示权重
    if (tokenLabels[pos].startsWith('<IMG')) {
        const patchIndex = parseInt(tokenLabels[pos].slice(4, 7)) - 1; // 0-based
        const v = patchScores[patchIndex];
        tooltip.innerHTML = `<b>patch weight:</b> ${v.toFixed(4)}<br>` + tooltip.innerHTML;
    }
    ...
    }

    const data = DATAPLACEMENT;
    const tokenLabels = TOKENLABELSPLACEMENT;
    const tooltip = document.getElementById('tooltip');
    const highlightBox = document.querySelector('.highlight-box');
    const image = document.querySelector('.image-container img');
    const table = document.getElementById('logitLens');
    
    const imageSize = IMAGESIZEPLACEHOLDER;
    const patchSize = PATCHSIZEPLACEHOLDER;
    const gridSize = imageSize / patchSize;
    
    let isLocked = false;
    let highlightedRow = null;
    let lockedPatchIndex = null;
    
    // Create corner header
    const cornerHeader = table.createTHead().insertRow();
    cornerHeader.insertCell().textContent = 'Token/Layer';
    cornerHeader.cells[0].classList.add('corner-header');
    
    // Create layer headers
    for (let i = 0; i < data.length; i++) {
        const th = document.createElement('th');
        th.textContent = `Layer ${i + 1}`;
        th.classList.add('col-header');
        cornerHeader.appendChild(th);
    }
    
    // Create rows with token labels
    for (let pos = 0; pos < tokenLabels.length; pos++) {
        const row = table.insertRow();
        const rowHeader = row.insertCell();
        rowHeader.textContent = tokenLabels[pos];
        rowHeader.classList.add('row-header');
        
        for (let layer = 0; layer < data.length; layer++) {
            const cell = row.insertCell();
            const topToken = data[layer][pos][0][0];
            cell.textContent = topToken;
            
            cell.addEventListener('mouseover', (e) => {
                if (!isLocked) {
                    showTooltip(e, layer, pos, false);
                }
            });
            cell.addEventListener('mousemove', updateTooltipPosition);
            cell.addEventListener('mouseout', () => {
                if (!isLocked) {
                    hideTooltip();
                }
            });
        }
    }

    function showTooltip(e, layer, pos, shouldScroll = false) {
        tooltip.innerHTML = data[layer][pos].map(([token, prob]) => `${token}: ${prob}`).join('<br>');
        tooltip.style.display = 'block';
        updateTooltipPosition(e);
        
        if (tokenLabels[pos].startsWith('<IMG')) {
            const patchIndex = parseInt(tokenLabels[pos].slice(4, 7));
            highlightImagePatch(patchIndex);
            highlightTableRow(pos, shouldScroll);
        } else {
            highlightBox.style.display = 'none';
            unhighlightTableRow();
        }
    }

    function hideTooltip() {
        tooltip.style.display = 'none';
        if (!isLocked) {
            highlightBox.style.display = 'none';
            unhighlightTableRow();
        }
    }

    function updateTooltipPosition(e) {
        const tooltipRect = tooltip.getBoundingClientRect();
        const viewportWidth = window.innerWidth;
        const viewportHeight = window.innerHeight;

        let x = e.clientX + 10;
        let y = e.clientY + 10;

        if (x + tooltipRect.width > viewportWidth) {
            x = e.clientX - tooltipRect.width - 10;
        }

        if (y + tooltipRect.height > viewportHeight) {
            y = e.clientY - tooltipRect.height - 10;
        }

        x = Math.max(0, x);
        y = Math.max(0, y);

        tooltip.style.left = x + 'px';
        tooltip.style.top = y + 'px';
    }
    
    function highlightImagePatch(patchIndex) {
        const scaleFactor = image.width / imageSize;
        const row = Math.floor((patchIndex - 1) / gridSize);
        const col = (patchIndex - 1) % gridSize;
        
        const left = col * patchSize * scaleFactor;
        const top = row * patchSize * scaleFactor;
        const size = patchSize * scaleFactor;
        
        highlightBox.style.left = `${left}px`;
        highlightBox.style.top = `${top}px`;
        highlightBox.style.width = `${size}px`;
        highlightBox.style.height = `${size}px`;
        highlightBox.style.display = 'block';
    }

    function highlightTableRow(rowIndex, shouldScroll = false) {
        if (highlightedRow) {
            highlightedRow.classList.remove('highlighted-row');
        }
        highlightedRow = table.rows[rowIndex + 1];  // +1 to account for header row
        highlightedRow.classList.add('highlighted-row');
        if (shouldScroll) {
            highlightedRow.scrollIntoView({ behavior: 'smooth', block: 'center' });
        }
    }

    function unhighlightTableRow() {
        if (highlightedRow) {
            highlightedRow.classList.remove('highlighted-row');
            highlightedRow = null;
        }
    }

    image.addEventListener('mousemove', (e) => {
        if (!isLocked) {
            const patchIndex = getPatchIndexFromMouseEvent(e);
            highlightImagePatch(patchIndex);
            const tokenIndex = getTokenIndexFromPatchIndex(patchIndex);
            if (tokenIndex !== -1) {
                showTooltip(e, 0, tokenIndex, true);
            }
        }
    });

    image.addEventListener('mouseout', () => {
        if (!isLocked) {
            hideTooltip();
        }
    });

    image.addEventListener('click', (e) => {
        isLocked = !isLocked;
        if (isLocked) {
            lockedPatchIndex = getPatchIndexFromMouseEvent(e);
            highlightImagePatch(lockedPatchIndex);
            const tokenIndex = getTokenIndexFromPatchIndex(lockedPatchIndex);
            if (tokenIndex !== -1) {
                highlightTableRow(tokenIndex, true);
            }
        } else {
            lockedPatchIndex = null;
            hideTooltip();
        }
    });

    table.addEventListener('click', () => {
        isLocked = false;
        lockedPatchIndex = null;
        hideTooltip();
    });

    function getPatchIndexFromMouseEvent(e) {
        const rect = image.getBoundingClientRect();
        const x = e.clientX - rect.left;
        const y = e.clientY - rect.top;
        const patchX = Math.floor(x / (image.width / gridSize));
        const patchY = Math.floor(y / (image.height / gridSize));
        return patchY * gridSize + patchX + 1;
    }

    function getTokenIndexFromPatchIndex(patchIndex) {
        return tokenLabels.findIndex(label => label === `<IMG${patchIndex.toString().padStart(3, '0')}>`);
    }
</script>
</body>
</html>
    """
    
    # Replace placeholders
    html_content = html_content.replace('IMAGEPLACEHOLDER', img_str)
    html_content = html_content.replace('DATAPLACEMENT', json.dumps(all_top_tokens))
    html_content = html_content.replace('TOKENLABELSPLACEMENT', json.dumps(token_labels))
    html_content = html_content.replace('IMAGESIZEPLACEHOLDER', str(image_size))
    html_content = html_content.replace('PATCHSIZEPLACEHOLDER', str(patch_size))
    html_content = html_content.replace('PROMPTPLACEHOLDER', prompt)  # Add this line
    html_content = html_content.replace('MISCPLACEHOLDER', misc_text)  # Add this line
    html_content = html_content.replace('PATCHSCORESPLACEMENT', json.dumps(list(map(float, patch_scores))))
    
    # Create filename using model name and image filename
    output_filename = f"{model_name}_{Path(image_filename).stem}_logit_lens.html"
    
    # Join save folder and filename
    output_path = Path(save_folder) / output_filename

    # Write to file
    with open(output_path, 'w') as f:
        f.write(html_content)
    
    print(f"Interactive logit lens HTML has been saved to: {output_path}")



def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
        self.questions = questions
        self.image_folder = image_folder
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.model_config = model_config

    # def __getitem__(self, index):
    #     line = self.questions[index]
    #     image_file = line["image"]
    #     # qs = line["text"]
    #     # 获取 prompt 部分（去除 <image>）
    #     qs = line["conversations"][0]["value"].replace("<image>", "").strip()

    #     # 插入图像 token
    #     if self.model_config.mm_use_im_start_end:
    #         qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
    #     else:
    #         qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

    #     # 确保tokenizer不为None
    #     if self.tokenizer is None:
    #         raise ValueError("Tokenizer is not initialized")

    #     conv = conv_templates[args.conv_mode].copy()
    #     conv.append_message(conv.roles[0], qs)
    #     conv.append_message(conv.roles[1], None)
    #     prompt = conv.get_prompt()

    #     image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
    #     image_tensor = process_images([image], self.image_processor, self.model_config)[0]

    #     input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')

    #     return input_ids, image_tensor, image.size
    def __getitem__(self, index):
        line = self.questions[index]
        image_file = line["image"]
        qs = line["text"]
        if self.model_config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
        image_tensor = process_images([image], self.image_processor, self.model_config)[0]

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')

        return input_ids, image_tensor, image.size


    def __len__(self):
        return len(self.questions)


def collate_fn(batch):
    input_ids, image_tensors, image_sizes = zip(*batch)
    input_ids = torch.stack(input_ids, dim=0)
    image_tensors = torch.stack(image_tensors, dim=0)
    return input_ids, image_tensors, image_sizes


# DataLoader
def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=1):
    assert batch_size == 1, "batch_size must be 1"
    dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
    return data_loader

def load_model_with_adapters(args, model_path):
    """加载基础模型和所有适配器权重"""
    # 加载基础模型
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path, args.model_base, get_model_name_from_path(model_path)#, use_lora = True
    )
    
    # 检查并加载LoRA权重(去掉，直接model_path==model_base就是不开启lora)
    # if getattr(args, 'use_lora', False):
    #     lora_weights_path = os.path.join(model_path, "lora_weights.bin")
    #     if os.path.exists(lora_weights_path):
    #         # print("Loading LoRA weights...")
    #         model = PeftModel.from_pretrained(model, model_path)
    #         # 合并LoRA权重到基础模型
    #         model = model.merge_and_unload()
    
    # 检查并加载Prompt Tuning权重
    if getattr(args, 'use_prompt_tuning', False):
        prompt_dir = os.path.join(model_path, "llava-prompt_tuning")
        # prompt_weights_path = os.path.join(prompt_dir,"adapter_model.safetensors")
        if os.path.exists(prompt_dir):
            print("Loading Prompt Tuning weights...")
            # prompt_config = PromptTuningConfig(
            #     task_type="CAUSAL_LM",
            #     prompt_tuning_init=PromptTuningInit.TEXT if hasattr(args, 'prompt_tuning_init_text') else PromptTuningInit.RANDOM,
            #     num_virtual_tokens=getattr(args, 'num_virtual_tokens', 20),
            #     tokenizer_name_or_path=args.model_base,  # 修改这里，使用基础模型的路径而不是当前路径
            #     prompt_tuning_init_text=getattr(args, 'prompt_tuning_init_text', None)
            # )
            # import ipdb; ipdb.set_trace()
            model = PeftModel.from_pretrained(
                model,
                prompt_dir,
                # device_map='auto',
                # adapter_name="prompt",
                is_trainable=False,
            )

            # model = get_peft_model(model, prompt_config)
            # prompt_state_dict = torch.load(prompt_weights_path, map_location="cpu")
            # model.load_state_dict(prompt_state_dict, strict=False)
    
    # 检查并加载非LoRA训练参数
    # non_lora_path = os.path.join(model_path, "non_lora_trainables.bin")
    # if os.path.exists(non_lora_path):
    #     # print("Loading non-LoRA trainable weights...")
    #     non_lora_state_dict = torch.load(non_lora_path, map_location="cpu")
    #     model.load_state_dict(non_lora_state_dict, strict=False)
    
    return tokenizer, model, image_processor, context_len

def eval_model(args):
    # Model
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model_both(model_path, args.model_base, model_name,args.use_prompt_tuning)
    # 使用新的加载函数
    # tokenizer, model, image_processor, context_len = load_model_with_adapters(args, model_path)

    questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
    # questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
    answers_file = os.path.expanduser(args.answers_file)
    os.makedirs(os.path.dirname(answers_file), exist_ok=True)
    ans_file = open(answers_file, "w")

    if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
        args.conv_mode = args.conv_mode + '_mmtag'
        print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')

    data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
    i = 0
    for (input_ids, image_tensor, image_sizes), line in tqdm(zip(data_loader, questions), total=len(questions)):
        idx = line["question_id"]
        cur_prompt = line["text"]

        input_ids = input_ids.to(device='cuda', non_blocking=True)

        with torch.inference_mode():
            # 改成 forward，直接拿 hidden_states
            outputs = model(
                input_ids,
                images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
                image_sizes=image_sizes,
                output_hidden_states=True,   # ✅ 拿 hidden states
                use_cache=True
            )

        # 最终生成答案 (你可以依然用 generate, 这里只展示 forward + decode)
        output_ids = model.generate(
            inputs=input_ids,
            images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
            image_sizes=image_sizes,
            do_sample=True if args.temperature > 0 else False,
            temperature=args.temperature,
            top_p=args.top_p,
            num_beams=args.num_beams,
            max_new_tokens=args.max_new_tokens,
            use_cache=True
        )
        outputs_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
        from llava.constants import IMAGE_TOKEN_INDEX

        # 生成答案前（或后）做一次forward拿 attentions
        with torch.no_grad():
            _ = model.generate(  # 若你要重用logit lens的hidden_states，也可以顺便拿 output_hidden_states=True
                inputs=input_ids,
                images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
                image_sizes=image_sizes,
                output_attentions=True,
                use_cache=False,
                return_dict=True
            )

        # 直接再跑一次（上面也可直接接收返回值）拿到 patch 分数
        patch_scores = compute_patch_importance_attention(
            model=model,
            input_ids=input_ids,
            image_tensor=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
            image_sizes=image_sizes,
            image_token_id=IMAGE_TOKEN_INDEX,
            last_k=4,
            target_pos="last"
        )

        # ✅ 调用 logit lens 可视化
        hidden_states = outputs.hidden_states  # list[Tensor], 每层 hidden state
        create_interactive_logit_lens(
            hidden_states=hidden_states,
            norm=model.model.norm,   # 归一化层
            lm_head=model.lm_head,   # 输出投影层
            tokenizer=tokenizer,
            image=Image.open(os.path.join(args.image_folder, line["image"])).convert("RGB"),
            model_name=model_name,
            image_filename=line["image"],
            prompt=cur_prompt,
            save_folder=os.path.dirname(args.answers_file),
            misc_text=f"Answer: {outputs_text}",
            patch_scores=patch_scores
            
        )

        # 画热力图
        orig_img = Image.open(os.path.join(args.image_folder, line["image"])).convert("RGB")
        heatmap_img = overlay_patch_heatmap(
            image=orig_img,
            patch_scores=patch_scores,
            image_size=image_sizes[0], patch_size=14, alpha=0.55
        )
        heatmap_path = os.path.join(os.path.dirname(args.answers_file), "823heatmap",f"{model_name}_{Path(line['image']).stem}_heatmap.png")
        heatmap_img.save(heatmap_path)

        # ✅ 保留原有 JSON 输出
        ans_id = shortuuid.uuid()
        ans_file.write(json.dumps({
            "question_id": idx,
            "prompt": cur_prompt,
            "text": outputs_text,
            "answer_id": ans_id,
            "model_id": model_name,
        }) + "\n")
        ans_file.flush()

    ans_file.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="")
    parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
    parser.add_argument("--answers-file", type=str, default="answer.jsonl")
    parser.add_argument("--conv-mode", type=str, default="llava_v1")
    # parser.add_argument("--num-chunks", type=int, default=1)
    # parser.add_argument("--chunk-idx", type=int, default=0)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--max_new_tokens", type=int, default=128)
    parser.add_argument("--use_lora", action='store_true', help="Enable LoRA weights loading")
    parser.add_argument("--no_use_lora", action='store_false', dest='use_lora', help="Disable LoRA weights loading")
    parser.add_argument("--use_prompt_tuning", action='store_true', help="Enable Prompt Tuning weights loading",default=True)
    parser.add_argument("--no_use_prompt_tuning", action='store_false', dest='use_prompt_tuning', help="Disable Prompt Tuning weights loading")
    parser.add_argument("--num_virtual_tokens", type=int, default=128)
    parser.add_argument("--prompt_tuning_init_text", type=str, default="init prompt text")
    
    args = parser.parse_args()

    eval_model(args)
