###############################
"This code file is used to plot the similarity matrix, difference matrix and orthogonality matrix of singular vectors."
###############################

import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

def draw_compare(model, reasoning_model, x_labels, y_labels, layer_idx:int, type_layer:str='q'):
    reasoning_tgt_layer = None
    tgt_layer = None
    if type_layer == 'q':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].self_attn.q_proj
        tgt_layer = model.model.layers[layer_idx].self_attn.q_proj
    elif type_layer == 'k':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].self_attn.k_proj
        tgt_layer = model.model.layers[layer_idx].self_attn.k_proj
    elif type_layer == 'v':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].self_attn.v_proj
        tgt_layer = model.model.layers[layer_idx].self_attn.v_proj
    elif type_layer == 'o':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].self_attn.o_proj
        tgt_layer = model.model.layers[layer_idx].self_attn.o_proj
    elif type_layer == 'gate':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].mlp.gate_proj
        tgt_layer = model.model.layers[layer_idx].mlp.gate_proj
    elif type_layer == 'up':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].mlp.up_proj
        tgt_layer = model.model.layers[layer_idx].mlp.up_proj
    elif type_layer == 'down':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].mlp.down_proj
        tgt_layer = model.model.layers[layer_idx].mlp.down_proj


    # SVD
    weight_matrix = tgt_layer.weight.to(torch.float32).detach()
    reasoning_weight_matrix = reasoning_tgt_layer.weight.to(torch.float32).detach()

    U_norm, S_non, V_norm = torch.linalg.svd(weight_matrix, full_matrices=False)
    U_reasoning, S_reasoning, V_reasoning = torch.linalg.svd(reasoning_weight_matrix, full_matrices=False)

    fig, axes = plt.subplots(1, 1, figsize=(8, 8))

    U_reasoning = U_reasoning.to(U_norm)
    U_norm = U_norm
    V_reasoning = V_reasoning.T.to(V_norm)
    V_norm = V_norm.T

    #sim mat
    UU_t = np.abs(torch.matmul(U_reasoning.T, U_norm).cpu().numpy())
    VV_t = np.abs(torch.matmul(V_reasoning.T, V_norm).cpu().numpy())

    tem_q = torch.matmul(U_reasoning.T, U_norm)
    tem_v = torch.matmul(V_reasoning.T, V_norm)

    # orth mat
    Q = torch.abs(torch.matmul(tem_q.T, tem_v)).cpu().numpy()
   
    # sim mat of U
    fig, axes_2 = plt.subplots(1, 1, figsize=(16, 16))
    im = axes_2.imshow(UU_t[:25,:25], cmap='viridis', origin='lower',vmin=0.0, vmax=1.0)
    axes_2.set_xlabel(x_labels, fontsize=36, labelpad=10)
    axes_2.set_ylabel(y_labels, fontsize=36, labelpad=10)
    axes_2.tick_params(axis='both', labelsize=40)
    cbar = plt.colorbar(im, ax=axes_2, shrink=0.8)
    cbar.ax.tick_params(labelsize=40)
    save_dir = "./Draw_picture"
    now_dir = f"sim_{layer_idx}"
    save_dir = os.path.join(save_dir, now_dir)
    if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"svd_layer{layer_idx}_{type_layer}_U.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()

    # sim mat of V
    fig, axes_2 = plt.subplots(1, 1, figsize=(16, 16))
    im = axes_2.imshow(VV_t[:25,:25], cmap='viridis', origin='lower',vmin=0.0, vmax=1.0)
    axes_2.set_xlabel(x_labels, fontsize=36, labelpad=10)
    axes_2.set_ylabel(y_labels, fontsize=36, labelpad=10)
    axes_2.tick_params(axis='both', labelsize=40)
    cbar = plt.colorbar(im, ax=axes_2, shrink=0.8)
    cbar.ax.tick_params(labelsize=40)
    save_dir = "./Draw_picture"
    now_dir = f"sim_{layer_idx}"
    save_dir = os.path.join(save_dir, now_dir)
    if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"svd_layer{layer_idx}_{type_layer}_V.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()

    # diff mat of UV
    fig, axes_2 = plt.subplots(1, 1, figsize=(16, 16))
    im = axes_2.imshow(np.abs(UU_t - VV_t)[:25,:25], cmap='viridis', origin='lower',vmin=0.0, vmax=1.0)
    axes_2.set_xlabel(x_labels, fontsize=36, labelpad=10)
    axes_2.set_ylabel(y_labels, fontsize=36, labelpad=10)
    axes_2.tick_params(axis='both', labelsize=40)
    cbar = plt.colorbar(im, ax=axes_2, shrink=0.8)
    cbar.ax.tick_params(labelsize=40)
    save_dir = "./Draw_picture"
    now_dir = f"sim_{layer_idx}"
    save_dir = os.path.join(save_dir, now_dir)
    if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"svd_layer{layer_idx}_{type_layer}_UV_Diff.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()

    # orth mat
    fig, axes_2 = plt.subplots(1, 1, figsize=(16, 16))
    im = axes_2.imshow(Q[:25,:25], cmap='viridis', origin='lower',vmin=0.0, vmax=1.0)
    axes_2.set_xlabel(x_labels, fontsize=36, labelpad=10)
    axes_2.set_ylabel(y_labels, fontsize=36, labelpad=10)
    axes_2.tick_params(axis='both', labelsize=40)
    cbar = plt.colorbar(im, ax=axes_2, shrink=0.8)
    cbar.ax.tick_params(labelsize=40)
    save_dir = "./Draw_picture"
    now_dir = f"sim_{layer_idx}"
    save_dir = os.path.join(save_dir, now_dir)
    if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"svd_layer{layer_idx}_{type_layer}_Orthogonality.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()

    print(f"layer{layer_idx}.{type_layer}_proj finished!")
    print("------------")


x_labels = "Reasoning model(Qwen2.5-Math-1.5B)"
y_labels = "Base model(Qwen2.5-Math-1.5B)"
# Qwen/Qwen2.5-Math-1.5B
# Qwen/Qwen2.5-Math-7B
# Qwen/Qwen2.5-Math-14B
# meta-llama/Llama-3.1-8B
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"  # 请确认实际模型名称
model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto")
model.eval()

# deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
# deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
# deepseek-ai/DeepSeek-R1-Distill-Qwen-14B
# deepseek-ai/DeepSeek-R1-Distill-Llama-8B
reasoning_model_name = "Qwen/Qwen2.5-Math-1.5B"  # 请确认实际模型名称
reasoning_model = AutoModelForCausalLM.from_pretrained(
        reasoning_model_name,
        torch_dtype="auto",
        device_map="auto")
reasoning_model.eval()

model_length = len(model.model.layers)
# idx = 0
for idx in range(model_length):
    draw_compare(model, reasoning_model, x_labels, y_labels, idx, 'q')
    draw_compare(model, reasoning_model, x_labels, y_labels, idx, 'k')
    draw_compare(model, reasoning_model, x_labels, y_labels, idx, 'v')
    draw_compare(model, reasoning_model, x_labels, y_labels, idx, 'o')
    draw_compare(model, reasoning_model, x_labels, y_labels, idx, 'gate')
    draw_compare(model, reasoning_model, x_labels, y_labels, idx, 'up')
    draw_compare(model, reasoning_model, x_labels, y_labels, idx, 'down')
