###############################
"This code file is used to plot CKA heatmaps."
###############################

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
import numpy as np
from sklearn.metrics.pairwise import linear_kernel
import matplotlib.pyplot as plt
from tqdm import *
import os
import copy
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

def cka(X, Y):
    X_centered = X - X.mean(dim=0)
    Y_centered = Y - Y.mean(dim=0)
    
    XTX = torch.matmul(X_centered, X_centered.t())
    YTY = torch.matmul(Y_centered, Y_centered.t())
    hsic = torch.trace(torch.matmul(XTX, YTY))
    
    cka_score = hsic / (torch.sqrt(torch.trace(torch.matmul(XTX, XTX))) * 
                        torch.sqrt(torch.trace(torch.matmul(YTY, YTY))))
    return cka_score

def get_hidden_states(model, tokenizer, texts, layer_idx=None):
    inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device)
    with torch.no_grad():
        outputs = model(**inputs)
    hidden_states = outputs.hidden_states
    if layer_idx is not None:
        return hidden_states[layer_idx]
    return hidden_states

def get_new_model_sv(model, reasoning_model, if_re:bool=True, using_Q:bool=True):
    tgt_model = copy.deepcopy(reasoning_model)
    depth_model = len(tgt_model.model.layers)
    if not if_re:
        print("--No using--")
        return tgt_model.eval(), reasoning_model.eval()

    for idx in tqdm(range(depth_model), desc="parameter manipulation"):
        # ##################### replacing the V ################
        type_weight = tgt_model.model.layers[idx].self_attn.o_proj.weight.dtype
        layers_device = tgt_model.model.layers[idx].self_attn.o_proj.weight.device # WO
        U_a, S_a, Vh_a = torch.linalg.svd(model.model.layers[idx].self_attn.o_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        U_b, S_b,  _  = torch.linalg.svd(reasoning_model.model.layers[idx].self_attn.o_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        if using_Q:tgt_model.model.layers[idx].self_attn.o_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b) @ U_b.T @ U_a @ Vh_a).to(type_weight))
        else:tgt_model.model.layers[idx].self_attn.o_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b) @ Vh_a).to(type_weight))

        layers_device = tgt_model.model.layers[idx].self_attn.q_proj.weight.device # WQ
        U_a, S_a, Vh_a = torch.linalg.svd(model.model.layers[idx].self_attn.q_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        U_b, S_b,  _  = torch.linalg.svd(reasoning_model.model.layers[idx].self_attn.q_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        if using_Q:tgt_model.model.layers[idx].self_attn.q_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b)@ U_b.T @ U_a @ Vh_a).to(type_weight))
        else:tgt_model.model.layers[idx].self_attn.q_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b)@ Vh_a).to(type_weight))

        layers_device = tgt_model.model.layers[idx].self_attn.k_proj.weight.device # WK
        U_a, S_a, Vh_a = torch.linalg.svd(model.model.layers[idx].self_attn.k_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        U_b, S_b,  _  = torch.linalg.svd(reasoning_model.model.layers[idx].self_attn.k_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        if using_Q:tgt_model.model.layers[idx].self_attn.k_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b)@ U_b.T @ U_a @ Vh_a).to(type_weight))
        else:tgt_model.model.layers[idx].self_attn.k_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b) @ Vh_a).to(type_weight))

        layers_device = tgt_model.model.layers[idx].self_attn.v_proj.weight.device # WV
        U_a, S_a, Vh_a = torch.linalg.svd(model.model.layers[idx].self_attn.v_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        U_b, S_b,  _  = torch.linalg.svd(reasoning_model.model.layers[idx].self_attn.v_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        if using_Q:tgt_model.model.layers[idx].self_attn.v_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b)@ U_b.T @ U_a @ Vh_a).to(type_weight))
        else:tgt_model.model.layers[idx].self_attn.v_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b) @ Vh_a).to(type_weight))

        # #####################################
        # layers_device = tgt_model.model.layers[idx].mlp.gate_proj.weight.device # Wgate；
        # U_a, S_a, Vh_a = torch.linalg.svd(model.model.layers[idx].mlp.gate_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        # U_b, S_b,  _  = torch.linalg.svd(reasoning_model.model.layers[idx].mlp.gate_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        # tgt_model.model.layers[idx].mlp.gate_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b)@ U_b.T @ U_a @ Vh_a).to(type_weight))


        # layers_device = tgt_model.model.layers[idx].mlp.up_proj.weight.device # Wup；
        # U_a, S_a, Vh_a = torch.linalg.svd(model.model.layers[idx].mlp.up_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        # U_b, S_b,  _  = torch.linalg.svd(reasoning_model.model.layers[idx].mlp.up_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        # tgt_model.model.layers[idx].mlp.up_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b)@ U_b.T @ U_a @ Vh_a).to(type_weight))


        # layers_device = tgt_model.model.layers[idx].mlp.down_proj.weight.device # Wdown；
        # U_a, S_a, Vh_a = torch.linalg.svd(model.model.layers[idx].mlp.down_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        # U_b, S_b,  _  = torch.linalg.svd(reasoning_model.model.layers[idx].mlp.down_proj.weight.to(torch.float32).detach().to(layers_device), full_matrices=False)
        # tgt_model.model.layers[idx].mlp.down_proj.weight = torch.nn.Parameter((U_b @ torch.diag(S_b)@ U_b.T @ U_a @ Vh_a).to(type_weight))

    tgt_model.eval()
    reasoning_model.eval()
    return tgt_model, reasoning_model

################################# loading model
# Qwen/Qwen2.5-Math-1.5B-Instruct
# Qwen/Qwen2.5-Math-7B-Instruct
# Qwen/Qwen2.5-Math-14B-Instruct
# meta-llama/Llama-3.1-8B
# deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
# deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
# deepseek-ai/DeepSeek-R1-Distill-Qwen-14B
# deepseek-ai/DeepSeek-R1-Distill-Llama-8B
dir_name="./CKA_8B"
if not os.path.exists(dir_name):
    os.mkdir(dir_name)
    print("ok")
model2_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model1_name = "meta-llama/Llama-3.1-8B"

tokenizer = AutoTokenizer.from_pretrained(model2_name, padding_side='left')
model1 = AutoModelForCausalLM.from_pretrained(model1_name,
                                              torch_dtype="auto",
                                              device_map="auto",
                                               output_hidden_states=True)
model2 = AutoModelForCausalLM.from_pretrained(model2_name,
                                              torch_dtype="auto",
                                              device_map="auto", 
                                              output_hidden_states=True)
if "Llama" in model1_name:
    tokenizer.pad_token = tokenizer.eos_token
model1, model2 = get_new_model_sv(model=model1, reasoning_model=model2, if_re=False, using_Q=False) # TODO

################################ loading data.
dataset = load_dataset("gsm8k", "main")
questions = dataset["train"]["question"][:100] # first 100 examples.
texts=[]
for question in questions:
    texts.append(tokenizer.apply_chat_template(
        [{"role": "system", "content": r"Please put your final answer within \boxed{}."},
            {"role": "user", "content": question}],
        tokenize=False,
        add_generation_prompt=True))

hidden_states1 = get_hidden_states(model1, tokenizer, texts)
hidden_states2 = get_hidden_states(model2, tokenizer, texts)

num_layers = len(hidden_states1)
cka_matrix = np.zeros((num_layers-1, num_layers-1))
hidden_dim = hidden_states1[0].size()[-1]
print(type(hidden_dim))
for i in tqdm(range(1, num_layers), desc="Model1 Layers"):
    X = hidden_states1[i][:,-1,:]
    
    for j in range(1, num_layers):
        Y = hidden_states2[j][:,-1,:].to(X.device)
        cka_matrix[i-1, j-1] = cka(X, Y)

plt.imshow(cka_matrix, cmap="viridis", vmin=0, vmax=1, origin='lower')
plt.colorbar()
plt.xlabel("Reasoning model(Llama-3.1-8B)") # without Q transformation
plt.ylabel("Reasoning model(Llama-3.1-8B)") # Reasoning model(with Q transformation)
plt.title("CKA Similarity Between All Layers")

save_path = os.path.join(dir_name, f"cka_heatmap_Reasoning_8B_Q_non.png")
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
plt.close()
