################
"This code file uses the normalized Frobenius norm to plot the degree of orthogonality of the left and right singular vectors."
################
import matplotlib.pyplot as plt
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

# Stores NFs expressing the degree of orthogonality.
q_proj = []
k_proj = []
v_proj = []
o_proj = []
up_proj = []
down_proj = []
gate_proj = []

# caculate the NF.
def draw_compare(model, reasoning_model, x_labels, y_labels, layer_idx:int, type_layer:str='q'):
    reasoning_tgt_layer = None
    tgt_layer = None
    if type_layer == 'q':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].self_attn.q_proj
        tgt_layer = model.model.layers[layer_idx].self_attn.q_proj
    elif type_layer == 'k':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].self_attn.k_proj
        tgt_layer = model.model.layers[layer_idx].self_attn.k_proj
    elif type_layer == 'v':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].self_attn.v_proj
        tgt_layer = model.model.layers[layer_idx].self_attn.v_proj
    elif type_layer == 'o':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].self_attn.o_proj
        tgt_layer = model.model.layers[layer_idx].self_attn.o_proj
    elif type_layer == 'gate':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].mlp.gate_proj
        tgt_layer = model.model.layers[layer_idx].mlp.gate_proj
    elif type_layer == 'up':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].mlp.up_proj
        tgt_layer = model.model.layers[layer_idx].mlp.up_proj
    elif type_layer == 'down':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].mlp.down_proj
        tgt_layer = model.model.layers[layer_idx].mlp.down_proj


    weight_matrix = tgt_layer.weight.to(torch.float32).detach()
    reasoning_weight_matrix = reasoning_tgt_layer.weight.to(torch.float32).detach()

    U_norm, S_non, V_norm = torch.linalg.svd(weight_matrix, full_matrices=False)
    U_reasoning, S_reasoning, V_reasoning = torch.linalg.svd(reasoning_weight_matrix, full_matrices=False)
    U_reasoning = U_reasoning.to(U_norm)
    U_norm = U_norm
    V_reasoning = V_reasoning.T.to(V_norm)
    V_norm = V_norm.T
    UU_t = np.abs(torch.matmul(U_reasoning.T, U_norm).cpu().numpy())
    VV_t = np.abs(torch.matmul(V_reasoning.T, V_norm).cpu().numpy())

    tem_q = torch.matmul(U_reasoning.T, U_norm)
    tem_v = torch.matmul(V_reasoning.T, V_norm)
    Q = torch.abs(torch.matmul(tem_q.T, tem_v))
    L2_norm = torch.norm(torch.abs(Q - torch.eye(Q.size(0)).to(Q)), p='fro')
    return L2_norm.item()  / np.sqrt(Q.size(0) * Q.size(1))


x_labels = "Reasoning model(Qwen2.5-Math-7B)"
y_labels = "base model(Qwen2.5-Math-7B)"
# Qwen/Qwen2.5-Math-1.5B
# Qwen/Qwen2.5-Math-7B
# Qwen/Qwen2.5-Math-14B
# meta-llama/Llama-3.1-8B
model_name = "Qwen/Qwen2.5-Math-7B"
model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto")
model.eval()

# deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
# deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
# deepseek-ai/DeepSeek-R1-Distill-Qwen-14B
# deepseek-ai/DeepSeek-R1-Distill-Llama-8B
reasoning_model_name = "Qwen/Qwen2.5-7B"
reasoning_model = AutoModelForCausalLM.from_pretrained(
        reasoning_model_name,
        torch_dtype="auto",
        device_map="auto")
reasoning_model.eval()

model_length = len(model.model.layers)
# idx = 0
for idx in range(model_length):
    q_proj.append(draw_compare(model, reasoning_model, x_labels, y_labels, idx, 'q'))
    k_proj.append(draw_compare(model, reasoning_model, x_labels, y_labels, idx, 'k'))
    v_proj.append(draw_compare(model, reasoning_model, x_labels, y_labels, idx, 'v'))
    o_proj.append(draw_compare(model, reasoning_model, x_labels, y_labels, idx, 'o'))
    gate_proj.append(draw_compare(model, reasoning_model, x_labels, y_labels, idx, 'gate'))
    up_proj.append(draw_compare(model, reasoning_model, x_labels, y_labels, idx, 'up'))
    down_proj.append(draw_compare(model, reasoning_model, x_labels, y_labels, idx, 'down'))
    print(F"-------{idx}-------")


layer_ids = list(range(1, len(q_proj) + 1))
plt.figure(figsize=(8, 8))
plt.plot(layer_ids, q_proj, label="q_proj", marker='o', linewidth=5, markersize=10)
plt.plot(layer_ids, k_proj, label="k_proj", marker='o', linewidth=5, markersize=10)
plt.plot(layer_ids, v_proj, label="v_proj", marker='o', linewidth=5, markersize=10)
plt.plot(layer_ids, o_proj, label="o_proj", marker='o', linewidth=5, markersize=10)
plt.plot(layer_ids, up_proj, label="up_proj", marker='o', linewidth=5, markersize=10)
plt.plot(layer_ids, down_proj, label="down_proj", marker='o', linewidth=5, markersize=10)
plt.plot(layer_ids, gate_proj, label="gate_proj", marker='o', linewidth=5, markersize=10)


plt.title("Base model VS Base model", fontsize=25)
plt.xlabel("Layer ID", fontsize=25)
plt.ylabel("Frobenius Norm / n", fontsize=25)
plt.ylim(ymin=0, ymax=0.1)
plt.legend(loc=2,fontsize=25, bbox_to_anchor=(1.05, 1))
plt.tick_params(axis='both', labelsize=25)
plt.grid(True, linestyle='--', alpha=0.6)
save_path = os.path.join("./new_orth_map", f"orth_map_BaseVSBase_7B.png")
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
print("finished!")
plt.close()