from transformers import AutoModel
from rebuttal_model import GDMask
import torch
from rebuttal_pruning import run_gd_mask, run_sparse_gpt, run_wanda




if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    llama_model = AutoModel.from_pretrained("llama3.2")
    my_model_gd = GDMask(llama_model.config)
    my_model_gd.q_proj.weight.data = llama_model.layers[15].self_attn.q_proj.weight.clone().to(device)
    my_model_gd.k_proj.weight.data = llama_model.layers[15].self_attn.k_proj.weight.clone().to(device)
    my_model_wanda = GDMask(llama_model.config)
    my_model_wanda.q_proj.weight.data = llama_model.layers[15].self_attn.q_proj.weight.clone().to(device)
    my_model_wanda.k_proj.weight.data = llama_model.layers[15].self_attn.k_proj.weight.clone().to(device)
    my_model_sparse = GDMask(llama_model.config)
    my_model_sparse.q_proj.weight.data = llama_model.layers[15].self_attn.q_proj.weight.clone().to(device)
    my_model_sparse.k_proj.weight.data = llama_model.layers[15].self_attn.k_proj.weight.clone().to(device)

    my_model_gd = run_gd_mask(my_model_gd,0.7,epochs = 80, lr=0.05,device=device)
    my_model_wanda = run_wanda(my_model_wanda,0.7,device=device)
    my_model_sparse = run_sparse_gpt(my_model_sparse,0.7, 64, 32, device=device)

    # save 3 models
    torch.save(my_model_gd, 'rebuttal_model/my_model_gd.pth')
    torch.save(my_model_wanda, 'rebuttal_model/my_model_wanda.pth')
    torch.save(my_model_sparse, 'rebuttal_model/my_model_sparse.pth')


    