from transformers import AutoModel
from rebuttal_model import GDMask
import torch
from rebuttal_pruning import run_gd_mask, run_sparse_gpt, run_wanda
from rebuttal_gen_X import load_embeddings
from rebuttal_eval import eval_metric
import numpy as np

torch.random.manual_seed(0)


if __name__ == '__main__':
    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = 'cuda:0'
    llama_model = AutoModel.from_pretrained("llama3.2")
    ground_truth_model = GDMask(llama_model.config).to(device)
    ground_truth_model.q_proj.weight.data = llama_model.layers[15].self_attn.q_proj.weight.clone().to(device)
    ground_truth_model.k_proj.weight.data = llama_model.layers[15].self_attn.k_proj.weight.clone().to(device)

    # load embeddings
    embeddings = load_embeddings()
    concatenated_array = np.concatenate(embeddings, axis=1)
    concat_tensor = torch.tensor(concatenated_array, dtype=torch.float32).to(device)
    n = 512
    cut_tensor = concat_tensor[:, :n, :]

    rho_list = [0.4,0.5, 0.6, 0.7, 0.8]
    for rho in rho_list:
        my_model_gd = GDMask(llama_model.config)
        my_model_gd.q_proj.weight.data = llama_model.layers[15].self_attn.q_proj.weight.clone().to(device)
        my_model_gd.k_proj.weight.data = llama_model.layers[15].self_attn.k_proj.weight.clone().to(device)
        my_model_wanda = GDMask(llama_model.config)
        my_model_wanda.q_proj.weight.data = llama_model.layers[15].self_attn.q_proj.weight.clone().to(device)
        my_model_wanda.k_proj.weight.data = llama_model.layers[15].self_attn.k_proj.weight.clone().to(device)
        my_model_sparse = GDMask(llama_model.config)
        my_model_sparse.q_proj.weight.data = llama_model.layers[15].self_attn.q_proj.weight.clone().to(device)
        my_model_sparse.k_proj.weight.data = llama_model.layers[15].self_attn.k_proj.weight.clone().to(device)
    

        my_model_gd = run_gd_mask(my_model_gd,rho,epochs = 400, lr=0.005, device=device, lam=0.05)
        my_model_wanda = run_wanda(my_model_wanda,rho,device=device)
        my_model_sparse = run_sparse_gpt(my_model_sparse,rho, 1, 1, device=device)


        f, _= ground_truth_model(cut_tensor)
        _, f_m_gd = my_model_gd(cut_tensor)
        gd_metric = eval_metric(f,f_m_gd)
        _, f_m_wanda = my_model_wanda(cut_tensor)
        wanda_metric = eval_metric(f,f_m_wanda)
        f_sparse, f_m_sparse = my_model_sparse(cut_tensor)
        sparse_metric = eval_metric(f,f_sparse)
        print(f"rho: {rho}, gd_metric: {gd_metric}, wanda_metric: {wanda_metric}, sparse_metric: {sparse_metric}")
        del my_model_gd
        del my_model_wanda
        del my_model_sparse

        

        
