import math
import torch
import argparse
import os
import re
import matplotlib.pyplot as plt
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from qwen_vl_utils import process_vision_info
import numpy as np
import seaborn as sns

MODEL_ID = '/data/model/MM-Eureka-Qwen-7B/'

def load_model(model_id):
    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, padding_side='left', use_fast=True)
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_id,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        attn_implementation="eager",
        device_map="cuda:0"
    ).to("cuda:0").eval()
    return model, processor

def clean_think_output(output_text):

    think_match = re.search(r"<think>(.*?)</think>", output_text, re.DOTALL)
    if think_match:
        think_content = think_match.group(1).strip()
    else:
        think_content = output_text.strip()
    think_content = re.split(r"\*\*Answer:\*\*", think_content, maxsplit=1)[0].strip()
    return think_content

def calculate_plt_size(attention_layer_num):

    num_layers = attention_layer_num
    cols = math.ceil(math.sqrt(num_layers))
    rows = math.ceil(num_layers / cols)
    return rows, cols

def get_response(model, processor, image_path, question,
                 max_tokens_thinking=40000, enforce_num=8, ignore_str="Wait,",
                 final_answer_tokens=1024):

    messages = [
        {"role": "user", "content": [
            {"type": "image", "image": image_path},
            {"type": "text", "text": f"{question}"}
        ]}
    ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to(model.device)

    input_ids = inputs['input_ids'][0].tolist()
    vision_start_token_id = processor.tokenizer.convert_tokens_to_ids('<|vision_start|>')
    vision_end_token_id = processor.tokenizer.convert_tokens_to_ids('<|vision_end|>')
    pos = input_ids.index(vision_start_token_id) + 1
    pos_end = input_ids.index(vision_end_token_id)
    image_indices = list(range(pos, pos_end))


    image_inputs_aux = processor.image_processor(images=image_inputs)
    output_shape = image_inputs_aux["image_grid_thw"].numpy().squeeze(0)[1:] / 2
    output_shape = output_shape.astype(int)

    think_content = "<think>\n"
    remaining_tokens = max_tokens_thinking
    all_attentions = []

    with torch.no_grad():
        gen_output = model.generate(
            **inputs,
            max_new_tokens=remaining_tokens,
            do_sample=True,
            output_attentions=True,
            return_dict_in_generate=True
        )
        gen_ids = gen_output.sequences
        trimmed_ids = [out[len(inp):] for inp, out in zip(inputs.input_ids, gen_ids)]
        output_full = processor.batch_decode(trimmed_ids, skip_special_tokens=True)[0]
        all_attentions.extend(gen_output.attentions)

    think_content += clean_think_output(output_full)
    think_content += "\n"
    remaining_tokens -= len(trimmed_ids[0])


    for _ in range(enforce_num):
        if remaining_tokens <= 0:
            break
        think_content += f"\n{ignore_str}\n"

        think_messages = [
            {"role": "user", "content": [
                {"type": "image", "image": image_path},
                {"type": "text", "text": f"{question}\n\n{think_content}"}
            ]}
        ]

        think_text = processor.apply_chat_template(think_messages, tokenize=False, add_generation_prompt=True)
        think_inputs = processor(
            text=[think_text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        ).to(model.device)

        with torch.no_grad():
            think_gen_output = model.generate(
                **think_inputs,
                max_new_tokens=remaining_tokens,
                do_sample=True,
                output_attentions=True,
                return_dict_in_generate=True
            )
            think_gen_ids = think_gen_output.sequences
            think_trimmed_ids = [out[len(inp):] for inp, out in zip(think_inputs.input_ids, think_gen_ids)]
            additional_think = processor.batch_decode(think_trimmed_ids, skip_special_tokens=True)[0].strip()
            all_attentions.extend(think_gen_output.attentions)

        additional_think = clean_think_output(additional_think)
        think_content += "\n" + additional_think
        remaining_tokens -= len(think_trimmed_ids[0])

    think_content += "\n</think>\n\n"
    final_prompt = f"{question}\n{think_content}"


    final_messages = [
        {"role": "user", "content": [
            {"type": "image", "image": image_path},
            {"type": "text", "text": f"{final_prompt} Only provide the final answer."}
        ]}
    ]

    final_text = processor.apply_chat_template(final_messages, tokenize=False, add_generation_prompt=True)
    final_inputs = processor(
        text=[final_text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to(model.device)

    with torch.no_grad():
        final_gen_output = model.generate(
            **final_inputs,
            max_new_tokens=final_answer_tokens,
            do_sample=True,
            output_attentions=True,
            return_dict_in_generate=True
        )
        final_gen_ids = final_gen_output.sequences
        final_trimmed_ids = [out[len(inp):] for inp, out in zip(final_inputs.input_ids, final_gen_ids)]
        final_answer = processor.batch_decode(final_trimmed_ids, skip_special_tokens=True)[0].strip()
        all_attentions.extend(final_gen_output.attentions)

    final_output = f"{final_answer}"
    torch.cuda.empty_cache()

    num_layers = len(all_attentions[0])
    num_generated_tokens = len(all_attentions)
    input_len = inputs['input_ids'].shape[1]

    rows, cols = calculate_plt_size(num_layers)
    fig, axes = plt.subplots(rows, cols, figsize=(10.8, 16))

    for i, ax in enumerate(axes.flatten()):
        if i < num_layers:
            layer_att = []
            for t in range(num_generated_tokens):
                att_t = all_attentions[t][i][0, :, -1, pos:pos_end].mean(dim=0)
                layer_att.append(att_t)
            att = torch.stack(layer_att).mean(dim=0)
            att = att.to(torch.float32).detach().cpu().numpy()

            att_reshaped = att.reshape(output_shape)
            ax.imshow(att_reshaped, cmap="viridis", interpolation="nearest")
            ax.set_title(f"The {i+1} layer", fontsize=10, pad=5)
            ax.axis("off")
        else:
            ax.axis("off")


    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig('heatmap_4.png', dpi=300, bbox_inches='tight')
    plt.close()

    print(final_prompt)
    print(final_output)

    return final_output

def process_single_sample(model, processor, image_path, question):

    answer = get_response(model, processor, image_path, question)
    return answer

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_path', type=str, required=True)
    parser.add_argument('--question', type=str, required=True)
    parser.add_argument('--model_id', type=str, default=MODEL_ID)
    args = parser.parse_args()

    model, processor = load_model(args.model_id)
    process_single_sample(model, processor, args.image_path, args.question)