import pickle
import os
import torch



def load_attn_value(idx,second_dir,setting):
    path='attn_values/meta-llama/Llama-2-7b-chat-hf/{a}/{b}/'.format(a=second_dir,b=setting)
    data_path=path+str(idx)+'.pkl'
    if not os.path.exists(data_path):
        return None
    with open(data_path,'rb') as f:
        return pickle.load(f)

def eval_sparsity(second_dir,setting):
    path='attn_values/meta-llama/Llama-2-7b-chat-hf/{a}/{b}/attn_norm_fro.pkl'.format(a=second_dir,b=setting)
    with open(path,'rb') as f:
        attn_norms = pickle.load(f)
    print(torch.norm(attn_norms))
    print(attn_norms[998])
settings=['gold_only']
second_dirs=['original','check_point']
for setting in settings:
    for second_dir in second_dirs:
        print(second_dir, setting)
        eval_sparsity(second_dir,setting)
        attn_values=load_attn_value(0,second_dir,setting)
        print(attn_values[900].shape)

    
    