import math
import torch
import matplotlib.pyplot as plt
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import numpy as np
import seaborn as sns


device = 'cuda'
model_path = '/data/R1-Onevision-7B-RL/'
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    attn_implementation="eager",
    device_map='auto',
).eval()
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, padding_side='left', use_fast=True)


image_path = '/data/example.png'
question = 'Describe this image. What is the name of the food shown in the image?'

messages_query = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": image_path},
            {"type": "text", "text": f"{question}"},
        ],
    }
]

image_inputs, _ = process_vision_info(messages_query)
text_query = processor.apply_chat_template(messages_query, tokenize=False, add_generation_prompt=True)
inputs = processor(
    text=[text_query],
    images=image_inputs,
    padding=True,
    return_tensors="pt",
).to(device)

input_ids = inputs['input_ids'][0].tolist()
input_tokens = processor.tokenizer.convert_ids_to_tokens(input_ids)


vision_start_token_id = processor.tokenizer.convert_tokens_to_ids('<|vision_start|>')
vision_end_token_id = processor.tokenizer.convert_tokens_to_ids('<|vision_end|>')
pos = input_ids.index(vision_start_token_id) + 1
pos_end = input_ids.index(vision_end_token_id)
question_start = pos_end + 1
question_end = inputs['input_ids'].shape[1]

image_indices = list(range(pos, pos_end))

question_indices_template = list(range(pos_end + 1, inputs['input_ids'].shape[1]))

question_tokens = processor.tokenizer.tokenize(question)

def find_subsequence_indices(full_list, sub_list):
    for i in range(len(full_list) - len(sub_list) + 1):
        if full_list[i:i+len(sub_list)] == sub_list:
            return list(range(i, i+len(sub_list)))
    return []

question_indices = find_subsequence_indices(input_tokens, question_tokens)


selected_indices = image_indices +  question_indices_template
selected_tokens = [input_tokens[i] for i in selected_indices]

print("==================")
print("All Input token:", input_tokens)
print("Image token:", [input_tokens[i] for i in image_indices])
print("Question token:", [input_tokens[i] for i in question_indices])
print("MCQ token:", selected_tokens)

with torch.no_grad():
    generated_output = model.generate(
        **inputs,
        max_new_tokens=4096,
        output_attentions=True,
        return_dict_in_generate=True
    )
    generated_ids = generated_output.sequences
    attentions = generated_output.attentions
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print("======================")
    print(f"模型回答: {generated_text}")


    num_layers = len(attentions[0])
    num_generated_tokens = len(attentions)
    input_len = inputs['input_ids'].shape[1]

    attention_matrix = np.zeros((num_layers, len(selected_indices)))

    for layer_idx in range(num_layers):
        layer_att = []
        for t in range(num_generated_tokens):
            att_t = attentions[t][layer_idx][0, :, -1, :input_len].mean(dim=0)  
            layer_att.append(att_t)

        att = torch.stack(layer_att).mean(dim=0)
        att = att.to(torch.float32).detach().cpu().numpy()

        att_selected = att[selected_indices]
        att_selected = (att_selected - att_selected.min()) / (att_selected.max() - att_selected.min())
        attention_matrix[layer_idx] = att_selected

    plt.figure(figsize=(20, 10))

    if len(image_indices) > 10:
        trimmed_image_tokens = selected_tokens[:50] + ['...'] + selected_tokens[-50:]
        trimmed_indices = image_indices[:50] + image_indices[-50:]
        trimmed_question_tokens = [input_tokens[i] for i in question_indices]
        trimmed_tokens = trimmed_image_tokens + trimmed_question_tokens
        trimmed_selected_indices = trimmed_indices + question_indices
        attention_matrix_trimmed = attention_matrix[:, [selected_indices.index(i) for i in trimmed_selected_indices]]
    else:
        trimmed_tokens = selected_tokens
        attention_matrix_trimmed = attention_matrix

    sns.heatmap(
        attention_matrix_trimmed,
        cmap='viridis',
        xticklabels=trimmed_tokens,
        yticklabels=[f"Layer {i + 1}" for i in range(num_layers)],
        cbar_kws={'label': 'Attention Score'}
    )
    plt.title("Attention Scores for Generated Tokens on Image and Question Tokens")
    plt.xlabel("Input Tokens (Image + Question)")
    plt.ylabel("Layers")
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig('attention_heatmap_trimmed.png', dpi=300, bbox_inches='tight')
    plt.close()
