from knowledge_neurons import (
    KnowledgeNeurons,
    initialize_model_and_tokenizer,
    model_type,
)
import random
from tqdm import tqdm
import torch
import json
import random

data_list = {key: [] for key in ["a", "b", "c", "d", "e", "f"]}


batch_size = 100
saved_dataset = "meta_math"

with open(f"/share/projset/knowledge-neurons/ckpt/{saved_dataset}/fn_data.json", "r", encoding="utf-8") as f:
    new_data = json.load(f)
f.close

random.shuffle(new_data)

print(len(new_data))

# new_data = [{"question": TEXT, "answer": ANS}]*10
MODEL_NAME = "/share/projset/Model_edit/model_saves/models--meta-llama--Llama-2-7b-hf"
ml_model, tokenizer = initialize_model_and_tokenizer(MODEL_NAME, dtype="fp16")

kn_ml = KnowledgeNeurons(ml_model, tokenizer, model_type=model_type(MODEL_NAME))
layers = kn_ml.n_layers()


for i in range(int(len(new_data)/batch_size)):
    test = None
    for dat in tqdm(new_data[i*batch_size: (i+1)*batch_size], desc=f"now/total -> {i}/{int(len(new_data)/batch_size)-1}"):
        scores = []
        TEXT = dat["question"]
        ANS = dat["answer"]
        for layer_idx in tqdm(range(layers), disable=True):
            kn_ml = KnowledgeNeurons(ml_model, tokenizer, model_type=model_type(MODEL_NAME))
            layer_scores = kn_ml.get_scores_for_layer(
                TEXT,
                ANS,
                layer_idx=layer_idx)
            # torch.cuda.empty_cache()
            scores.append(layer_scores.cpu().detach().clone())
            del kn_ml
            torch.cuda.empty_cache()
        scores = torch.stack(scores)
        if test is None:
            test = scores
        else:
            with torch.no_grad():
                test = test + scores
        
    test = test / batch_size
    torch.save(test, f"/share/projset/knowledge-neurons/ckpt/{saved_dataset}/te{i}.pt")
