from transformers import AutoModel
from rebuttal_model import GDMask
import torch
from rebuttal_pruning import run_gd_mask, run_sparse_gpt, run_wanda
from rebuttal_gen_X import load_embeddings
from rebuttal_eval import eval_metric
import numpy as np

torch.random.manual_seed(0)


if __name__ == '__main__':
    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = 'cuda:0'
    llama_model = AutoModel.from_pretrained("llama3.2")
    ground_truth_model = GDMask(llama_model.config).to(device)
    ground_truth_model.q_proj.weight.data = llama_model.layers[15].self_attn.q_proj.weight.clone().to(device)
    ground_truth_model.k_proj.weight.data = llama_model.layers[15].self_attn.k_proj.weight.clone().to(device)

    my_model_gd = GDMask(llama_model.config)
    my_model_gd.q_proj.weight.data = llama_model.layers[15].self_attn.q_proj.weight.clone().to(device)
    my_model_gd.k_proj.weight.data = llama_model.layers[15].self_attn.k_proj.weight.clone().to(device)
    my_model_wanda = GDMask(llama_model.config)
    my_model_wanda.q_proj.weight.data = llama_model.layers[15].self_attn.q_proj.weight.clone().to(device)
    my_model_wanda.k_proj.weight.data = llama_model.layers[15].self_attn.k_proj.weight.clone().to(device)
    my_model_sparse = GDMask(llama_model.config)
    my_model_sparse.q_proj.weight.data = llama_model.layers[15].self_attn.q_proj.weight.clone().to(device)
    my_model_sparse.k_proj.weight.data = llama_model.layers[15].self_attn.k_proj.weight.clone().to(device)
    

    my_model_gd = run_gd_mask(my_model_gd,0.7,epochs = 400, lr=0.005, device=device, lam=0.05)
    my_model_wanda = run_wanda(my_model_wanda,0.7,device=device)
    my_model_sparse = run_sparse_gpt(my_model_sparse,0.7, 1, 1, device=device)


    # load embeddings
    embeddings = load_embeddings()
    concatenated_array = np.concatenate(embeddings, axis=1)
    concat_tensor = torch.tensor(concatenated_array, dtype=torch.float32).to(device)

    n_list = [64,128,256,512,1024,2048,4096]

    for n in n_list:
        cut_tensor = concat_tensor[:, :n, :]
        f, _= ground_truth_model(cut_tensor)
        _, f_m_gd = my_model_gd(cut_tensor)
        gd_metric = eval_metric(f,f_m_gd)
        _, f_m_wanda = my_model_wanda(cut_tensor)
        wanda_metric = eval_metric(f,f_m_wanda)
        f_sparse, f_m_sparse = my_model_sparse(cut_tensor)
        sparse_metric = eval_metric(f,f_sparse)
        print(f"n: {n}, gd_metric: {gd_metric}, wanda_metric: {wanda_metric}, sparse_metric: {sparse_metric}")

        

        
