
import torch
from transformers import AutoModel
from rebuttal_model import GDMask
from rebuttal_gen_X import load_embeddings

def eval_metric(f,f_m):
    metric = torch.norm(f - f_m,p='fro') / torch.norm(f,p='fro')
    return metric
    

if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # load model
    llama_model = AutoModel.from_pretrained("llama3.2")
    ground_truth_model = GDMask(llama_model.config).to(device)
    ground_truth_model.q_proj.weight.data = llama_model.layers[15].self_attn.q_proj.weight.clone().to(device)
    ground_truth_model.k_proj.weight.data = llama_model.layers[15].self_attn.k_proj.weight.clone().to(device)

    my_model_gd = torch.load('rebuttal_model/my_model_gd.pth')
    my_model_gd.to(device)
    my_model_wanda = torch.load('rebuttal_model/my_model_wanda.pth')
    my_model_wanda.to(device)
    my_model_sparse = torch.load('rebuttal_model/my_model_sparse.pth')
    my_model_sparse.to(device)

    # load embeddings
    embeddings = load_embeddings()
    embeddings = [torch.tensor(embedding, dtype=torch.float32).to(device) for embedding in embeddings]

    # loop over the embeddings and calculate the f and f_m
    gd_metrics = []
    wanda_metrics = []
    sparse_metrics = []
    for embedding in embeddings:
        f, _= ground_truth_model(embedding)
        
        _, f_m_gd = my_model_gd(embedding)
        gd_metrics.append(eval_metric(f,f_m_gd))
        _, f_m_wanda = my_model_wanda(embedding)
        wanda_metrics.append(eval_metric(f,f_m_wanda))
        _, f_m_sparse = my_model_sparse(embedding)
        sparse_metrics.append(eval_metric(f,f_m_sparse))

    print(f"gd_metrics: {gd_metrics}")
    print(f"wanda_metrics: {wanda_metrics}")
    print(f"sparse_metrics: {sparse_metrics}")

