###############################
"This code file is used to draw SVSM."
###############################

import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from matplotlib import cm
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

def return_div_S(model, reasoning_model, layer_idx:int, type_layer:str='q'):
    reasoning_tgt_layer = None
    tgt_layer = None
    if type_layer == 'q':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].self_attn.q_proj
        tgt_layer = model.model.layers[layer_idx].self_attn.q_proj
    elif type_layer == 'k':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].self_attn.k_proj
        tgt_layer = model.model.layers[layer_idx].self_attn.k_proj
    elif type_layer == 'v':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].self_attn.v_proj
        tgt_layer = model.model.layers[layer_idx].self_attn.v_proj
    elif type_layer == 'o':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].self_attn.o_proj
        tgt_layer = model.model.layers[layer_idx].self_attn.o_proj
    elif type_layer == 'gate':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].mlp.gate_proj
        tgt_layer = model.model.layers[layer_idx].mlp.gate_proj
    elif type_layer == 'up':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].mlp.up_proj
        tgt_layer = model.model.layers[layer_idx].mlp.up_proj
    elif type_layer == 'down':
        reasoning_tgt_layer = reasoning_model.model.layers[layer_idx].mlp.down_proj
        tgt_layer = model.model.layers[layer_idx].mlp.down_proj


    # 获得分解后的权重；
    weight_matrix = tgt_layer.weight.to(torch.float32).detach()
    reasoning_weight_matrix = reasoning_tgt_layer.weight.to(torch.float32).detach()

    U_norm, S_non, V_norm = torch.linalg.svd(weight_matrix, full_matrices=True)
    U_reasoning, S_reasoning, V_reasoning = torch.linalg.svd(reasoning_weight_matrix, full_matrices=True)

    S_non = S_non.cpu().numpy()
    S_reasoning = S_reasoning.cpu().numpy()
    div_S = S_reasoning / S_non
    return div_S


def draw_3d_heatmap(data_list, type_layer: str, clip_max:float=2.2, clip_min:float=0.2):
    data_matrix = np.vstack(data_list)
    data_matrix = np.clip(data_matrix, a_min=clip_min, a_max=clip_max)
    num_layers, num_features = data_matrix.shape
    layer_ids = np.arange(num_layers)
    sv_multiples = np.arange(num_features)
    X, Y = np.meshgrid(sv_multiples, layer_ids)
    Z = data_matrix
    fig = plt.figure(figsize=(15, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    surf = ax.plot_surface(
        X, Y, Z, 
        cmap=cm.jet,
        edgecolor='none',
        rstride=1,
        cstride=1,
        vmin=clip_min,
        vmax=clip_max
    )
    
    cbar = fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10)
    ax.set_xlabel('Singular Value Multiples', fontsize=16, labelpad=10)
    ax.set_ylabel('Layer ID', fontsize=16, labelpad=10)
    ax.set_zlim(clip_min, clip_max)
    ax.tick_params(axis='both', labelsize=14)
    ax.set_title(f"3D Heatmap of layers.{type_layer}_proj", fontsize=16)
    
    ax.view_init(elev=30, azim=45 + 180)
    
    save_dir = "./Draw_picture"
    now_dir = f"3d_heatmap_{type_layer}"
    save_dir = os.path.join(save_dir, now_dir)
    if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"svd__{type_layer}_3d_heatmap.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()
    print(f"layers.{type_layer}_proj 3D heatmap finished!")

# deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
# deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
# deepseek-ai/DeepSeek-R1-Distill-Qwen-14B
# deepseek-ai/DeepSeek-R1-Distill-Llama-8B
reasoning_model_name = "meta-llama/Llama-3.1-8B-Instruct"
reasoning_model = AutoModelForCausalLM.from_pretrained(
        reasoning_model_name,
        torch_dtype="auto",
        device_map="auto")
reasoning_model.eval()

# Qwen/Qwen2.5-Math-1.5B-Instruct
# Qwen/Qwen2.5-Math-7B
# Qwen/Qwen2.5-Math-14B
# meta-llama/Llama-3.1-8B
model_name = "meta-llama/Llama-3.1-8B"
model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto")
model.eval()

model_length = len(model.model.layers)

# idx = 0
list_o = []
list_q = []
list_k = []
list_v = []
list_down = []
list_up = []
list_gate = []

for idx in range(model_length):
    list_gate.append(return_div_S(model, reasoning_model, idx, 'gate'))
    list_up.append(return_div_S(model, reasoning_model, idx, 'up'))
    list_down.append(return_div_S(model, reasoning_model, idx, 'down'))
    list_q.append(return_div_S(model, reasoning_model, idx, 'q'))
    list_k.append(return_div_S(model, reasoning_model, idx, 'k'))
    list_v.append(return_div_S(model, reasoning_model, idx, 'v'))
    list_o.append(return_div_S(model, reasoning_model, idx, 'o'))
    print(f"{idx} finished.")

draw_3d_heatmap(list_o,"o")
draw_3d_heatmap(list_q,"q")
draw_3d_heatmap(list_k,"k")
draw_3d_heatmap(list_v,"v")
draw_3d_heatmap(list_up,"up")
draw_3d_heatmap(list_gate,"gate")
draw_3d_heatmap(list_down,"down")
