################
"This code file is used to plot the attention entropy of the model under a given input."
################
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import copy
from tqdm import *
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

def draw_attention_map(layers:list[int], prompt:list[str], model, tokenizer, save_dir:str = "./attn_map", model_type:str="before"):

    ########################################### generate answer:
    answer = ""
    text = []
    for question in prompt:
        text.append(tokenizer.apply_chat_template(
            [{"role": "system", "content": r"Please put your final answer within \boxed{}."}, 
             {"role": "user", "content": question}],
            tokenize=False,
            add_generation_prompt=True))
    model_inputs = tokenizer(text, return_tensors="pt").to(model.device)
    
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=1,
        do_sample=False,
        num_beams=1
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print(f"answer:{answer[:100]}")
    text = []
    for question in prompt:
        text.append(tokenizer.apply_chat_template(
            [{"role": "system", "content": r"Please put your final answer within \boxed{}."},
             {"role": "user", "content": question},
             {"role": "assistant", "content": answer}],
            tokenize=False,
            add_generation_prompt=True))
    new_model_inputs = tokenizer(text, return_tensors="pt").to(model.device)

    print("model.device: ", model.device)
    outputs = None

    input_token_length = len(model_inputs["input_ids"][0])
    output_token_length = len(new_model_inputs["input_ids"][0]) - len(model_inputs["input_ids"][0])
    total_length = len(new_model_inputs["input_ids"][0])
    print(f"input_token_length: {input_token_length}")
    print(f"output_token_length: {output_token_length}")

    with torch.no_grad():
        outputs = model(**new_model_inputs)

    wait_for_drawing =[]
    subtitles = []
    for i, layer in enumerate(layers):
        attention_map = outputs.attentions[layer]
        attention_avg = attention_map.mean(dim=1)  # mean attn heads [batch, seq, seq]
        attention_matrix = attention_avg[0].to(dtype=torch.float32).cpu().numpy()
        wait_for_drawing.append(attention_matrix)
        subtitles.append(f"attn_map of block {layer}")
    return wait_for_drawing, subtitles

def draw_attention(layers:list[int], wait_for_drawing, subtitles, save_dir:str = "./attn_map", model_type:str="before"):
    n = wait_for_drawing[0].shape[0]
    length = len(layers)

    fig, axes = plt.subplots(1, length, figsize=(length*10, 8))
    if length == 1:
        axes = [axes]

    im = None
    for i, (matrix, ax) in enumerate(zip(wait_for_drawing, axes)):
        processed = matrix[:20,:20]
        if model_type != "minus": im = ax.imshow(processed, cmap='viridis',vmin=0,vmax=1)
        else: im = ax.imshow(processed, cmap='viridis',vmin=-0.4,vmax=0.4)
        ax.set_title(subtitles[i], fontsize=25, pad=5)
        ax.tick_params(axis='both', labelsize=25)

    save_path = os.path.join(save_dir, f"attention_heatmap_{model_type}.png")
    plt.tight_layout()
    plt.colorbar(im, ax=axes)
    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close()
    print("finished!")

def get_new_model(using_Q:bool=True):
    # Qwen/Qwen2.5-Math-1.5B-Instruct
    # Qwen/Qwen2.5-Math-7B-Instruct
    # Qwen/Qwen2.5-14B-Instruct
    # meta-llama/Llama-3.1-8B
    # deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
    # deepseek-ai/DeepSeek-R1-Distill-Llama-8B
    model_name = "meta-llama/Llama-3.1-8B-Instruct"
    reasoning_model_name = "meta-llama/Llama-3.1-8B"

    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
    model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype="auto",
            device_map="auto",
            output_attentions=True)

    
    reasoning_model = AutoModelForCausalLM.from_pretrained(
            reasoning_model_name,
            torch_dtype="auto",
            device_map="auto",
            output_attentions=True)

    tgt_model = copy.deepcopy(reasoning_model)
    depth_model = len(tgt_model.model.layers)
    for idx in tqdm(range(depth_model)):
        # ##################### replacing the singular values of SA; ################
        type_weight = tgt_model.model.layers[idx].self_attn.o_proj.weight.dtype
        layers_device = tgt_model.model.layers[idx].self_attn.o_proj.weight.device # WO
        U_a, S_a, Vh_a = torch.linalg.svd(model.model.layers[idx].self_attn.o_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        U_b, S_b,  _  = torch.linalg.svd(reasoning_model.model.layers[idx].self_attn.o_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        if using_Q:tgt_model.model.layers[idx].self_attn.o_proj.weight = torch.nn.Parameter((U_a @ torch.diag(S_b) @ Vh_a).to(type_weight))
        else:tgt_model.model.layers[idx].self_attn.o_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b) @ Vh_a).to(type_weight))

        layers_device = tgt_model.model.layers[idx].self_attn.q_proj.weight.device # WQ
        U_a, S_a, Vh_a = torch.linalg.svd(model.model.layers[idx].self_attn.q_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        U_b, S_b,  _  = torch.linalg.svd(reasoning_model.model.layers[idx].self_attn.q_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        if using_Q:tgt_model.model.layers[idx].self_attn.q_proj.weight = torch.nn.Parameter((U_a @ torch.diag(S_b) @ Vh_a).to(type_weight))
        else:tgt_model.model.layers[idx].self_attn.q_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b) @ Vh_a).to(type_weight))

        layers_device = tgt_model.model.layers[idx].self_attn.k_proj.weight.device # WK
        U_a, S_a, Vh_a = torch.linalg.svd(model.model.layers[idx].self_attn.k_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        U_b, S_b,  _  = torch.linalg.svd(reasoning_model.model.layers[idx].self_attn.k_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        if using_Q:tgt_model.model.layers[idx].self_attn.k_proj.weight = torch.nn.Parameter((U_a @ torch.diag(S_b) @ Vh_a).to(type_weight))
        else:tgt_model.model.layers[idx].self_attn.k_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b) @ Vh_a).to(type_weight))

        layers_device = tgt_model.model.layers[idx].self_attn.v_proj.weight.device # WV
        U_a, S_a, Vh_a = torch.linalg.svd(model.model.layers[idx].self_attn.v_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        U_b, S_b,  _  = torch.linalg.svd(reasoning_model.model.layers[idx].self_attn.v_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        if using_Q:tgt_model.model.layers[idx].self_attn.v_proj.weight = torch.nn.Parameter((U_a @ torch.diag(S_b) @ Vh_a).to(type_weight))
        else:tgt_model.model.layers[idx].self_attn.v_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b) @ Vh_a).to(type_weight))

        # ##################### replacing the singular values of FFN; ################
        layers_device = tgt_model.model.layers[idx].mlp.gate_proj.weight.device # Wgate;
        U_a, S_a, Vh_a = torch.linalg.svd(model.model.layers[idx].mlp.gate_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        U_b, S_b,  _  = torch.linalg.svd(reasoning_model.model.layers[idx].mlp.gate_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        tgt_model.model.layers[idx].mlp.gate_proj.weight = torch.nn.Parameter((U_a @ torch.diag(S_b) @ Vh_a).to(type_weight))


        layers_device = tgt_model.model.layers[idx].mlp.up_proj.weight.device # Wup;
        U_a, S_a, Vh_a = torch.linalg.svd(model.model.layers[idx].mlp.up_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        U_b, S_b,  _  = torch.linalg.svd(reasoning_model.model.layers[idx].mlp.up_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        tgt_model.model.layers[idx].mlp.up_proj.weight = torch.nn.Parameter((U_a @ torch.diag(S_b) @ Vh_a).to(type_weight))


        layers_device = tgt_model.model.layers[idx].mlp.down_proj.weight.device # Wdown;
        U_a, S_a, Vh_a = torch.linalg.svd(model.model.layers[idx].mlp.down_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        U_b, S_b,  _  = torch.linalg.svd(reasoning_model.model.layers[idx].mlp.down_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        tgt_model.model.layers[idx].mlp.down_proj.weight = torch.nn.Parameter((U_a @ torch.diag(S_b) @ Vh_a).to(type_weight))

    model.eval()
    tgt_model.eval()
    reasoning_model.eval()

    return model, tgt_model, reasoning_model, tokenizer


def attention_entropy(attn_weights, eps=1e-12):
    # attn_weights: shape (num_heads, seq_len, seq_len)
    # or (batch, num_heads, seq_len, seq_len)
    attn_weights = np.clip(attn_weights, eps, 1.0)
    entropy = -np.sum(attn_weights * np.log(attn_weights))
    return entropy

model, tgt_model, reasoning_model, tokenizer = get_new_model()

# 01 "Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?"
# 02 "What size of cannula would you use in a patient who needed a rapid blood transfusion (as of 2020 medical knowledge)?"
# 03 "The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?"
# 04 "Two quantum states with energies E1 and E2 have a lifetime of 10^-9 sec and 10^-8 sec, respectively. We want to clearly distinguish these two energy levels. Which one of the following options could be their energy difference so that they be clearly resolved?"
prompt = ["Two quantum states with energies E1 and E2 have a lifetime of 10^-9 sec and 10^-8 sec, respectively. We want to clearly distinguish these two energy levels. Which one of the following options could be their energy difference so that they be clearly resolved?"]
# draw_picture = [0, 15, 31, 47, 63]
draw_picture = [0,3,5,8,10,13,15,18,20,23,25]
a, s = draw_attention_map(draw_picture, prompt=prompt, model=model, tokenizer=tokenizer, model_type="before")
b, _ = draw_attention_map(draw_picture, prompt=prompt, model=tgt_model, tokenizer=tokenizer, model_type="after")
c = []

a_attn_en = [attention_entropy(i) for i in a]
b_attn_en = [attention_entropy(i) for i in b]
layer_ids = draw_picture

bar_width = 0.35
x = np.arange(len(layer_ids))
plt.figure(figsize=(20, 6))

bars1 = plt.bar(x - bar_width/2, a_attn_en, width=bar_width, label='Before', alpha=0.7)
bars2 = plt.bar(x + bar_width/2, b_attn_en, width=bar_width, label='After', alpha=0.7)

for i in range(len(a_attn_en)):
    diff = b_attn_en[i] - a_attn_en[i]
    diff_text = f"{diff:+.2f}"
    plt.text(x[i], max(a_attn_en[i], b_attn_en[i]) + 0.1, diff_text, 
             ha='center', va='bottom', fontsize=24)

plt.xticks(x, layer_ids)
plt.tick_params(axis='both', labelsize=24)

plt.xlabel('Layer ID', fontsize=24)
plt.ylabel('Attention Entropy', fontsize=24)

plt.ylim(40,320)
plt.legend(fontsize=24)
plt.tight_layout()
save_path = os.path.join("./attn_entropy_GPQA", f"attention_entropy_Instruct_8B.png")
plt.savefig(save_path, bbox_inches='tight', dpi=300)
plt.close()
