import json
import os
from typing import Dict, List, Any
from dataclasses import dataclass
from datasets import load_dataset
from PIL import Image
from io import BytesIO
import time
import json
import os
from pathlib import Path
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, AutoModelForZeroShotObjectDetection
from qwen_vl_utils import process_vision_info, smart_resize
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
from PIL import Image
import os  
import re  
import random
import cv2
import warnings
warnings.filterwarnings("ignore")
random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)


min_pixels = 4 * 28 * 28
max_pixels = 2048 * 28 * 28
device = "cuda:5"
# model_path = "/data/jcy/ckpt/Qwen2.5-VL-7B-Instruct"
# model_path = "./ckpt/Qwen2.5-VL-rl-0511"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    # attn_implementation="flash_attention_2",
    device_map=device,
)
processor = AutoProcessor.from_pretrained(
    "/data/jcy/ckpt/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)


question = "What is the blue thing behind the man sitting at the table with a blue plastic tablecloth?"
question += "\nAnswer the question using a single word or phrase"
res1 = """<think>
Step 1: To identify the blue thing behind the man, I need to focus on the area behind him. The blue plastic tablecloth is visible on the table in front of him, but I need to look further back to find the blue object.

Step 2: I will crop and zoom in on the area behind the man to get a clearer view of what is blue.
{"bbox_2d": [301, 1, 390, 206]}
"""
res2 = """Step 3: After cropping, I can see that the blue object is an umbrella. It is closed and hanging above the man's head. The umbrella has some white text on it, but the text is not clearly legible in this view.
</think>
<answer>umbrella</answer>
"""
image_path = "/data/jcy/project/VLM-R1/assets/t2.png"
crop1_path = "/data/jcy/project/VLM-R1/assets/cuda_sample_1_crop_1.png"
text_messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": image_path},
            {"type": "text", "text": question},
        ],
    },
    {
        "role": "assistant",
        "content": [
            {"type": "text", "text": res1},
            {"type": "image", "image": crop1_path},
            {"type": "text", "text": res2},
        ],
    }
]

img_messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": image_path},
           {"type": "image", "image": crop1_path},
        ],
    }
]

# Prepare initial prompt
text = processor.apply_chat_template(
    text_messages, tokenize=False, add_generation_prompt=False
)

image_inputs, video_inputs = process_vision_info(img_messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
).to(device)

gen_len = len(inputs['input_ids'][0])
print(f"##DEBUG: Generated length: {gen_len}")

input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
pixel_values = inputs["pixel_values"].to(device)
image_grid_thw = inputs["image_grid_thw"].to(device)

output = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw, output_attentions=True)

print(output["attentions"][0].shape)


import matplotlib.pyplot as plt
import seaborn as sns
import torch
import numpy as np
from PIL import Image

def visualize_token_attention(attentions, st, ed, start_range, end_range, layer_id=0, note_list=[], head_id=None):
    batch_size, num_heads, seq_len, seq_len = attentions[0].shape
    layer_attn = attentions[layer_id]  # shape: [batch_size, num_heads, seq_len, seq_len]
    layer_attn_np = layer_attn.detach().cpu().float().numpy()
    
    if head_id is None:
        attn_weights = np.mean(layer_attn_np, axis=1)
    else:
        attn_weights = layer_attn_np[:, head_id, :, :]

    token_attention = attn_weights[0, st:ed, :end_range]
    token_attention[:, :start_range] = 0  
    
    max_val = token_attention.max()
    min_val = token_attention.min()
    print(f"Max attention weight: {max_val}")
    print(f"Min attention weight: {min_val}")

    x_token_indices = np.arange(start_range, end_range)
    y_tick_interval = max(1, (ed - st) // 10)  # 最多显示10个y轴标签
    y_ticks = np.arange(0, ed - st, y_tick_interval)
    y_tick_labels = np.arange(st, ed, y_tick_interval)
    
    plt.figure(figsize=(36, 8))

    plt.subplot(1, 1, 1)
    ax = sns.heatmap(token_attention, 
                cmap="viridis", 
                xticklabels=50, 
                # yticklabels=y_token_indices, 
                cbar_kws={'label': 'Attention Weight'}) 
    ax.set_yticks(y_ticks)
    ax.set_yticklabels(y_tick_labels)

    for note_idx in note_list:
        if start_range <= note_idx < end_range:
            plt.axvline(x=note_idx - 0.5, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
    
    plt.title(f"Attention Heatmap (Layer {layer_id})")
    plt.xlabel("Target Token Index")
    plt.ylabel("Source Token Index")
    
    plt.tight_layout()
    return plt.gcf()


st = 470
ed = 533 
# st = 1384
# ed = 1524 
start_range = 11 
end_range = 533 

# st = 241
# ed = 377 
# # st = 1384
# # ed = 1524 
# start_range = 11 
# end_range = 377 

note_list = [14,240,378,469]
# end_range = 1384
for layer_id in range(28):
    attentions = output["attentions"] 
    fig = visualize_token_attention(attentions, st, ed, start_range, end_range, layer_id, note_list)
    plt.savefig(f"./vis_ours/token_{st}_{ed}_attention_layer_{layer_id}.png")
    plt.close()

