from knowledge_neurons import (
    KnowledgeNeurons,
    initialize_model_and_tokenizer,
    model_type,
)
import random
from loaddata import BaseDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import json
import random
import torch.nn.functional as F
from torch.optim import Adam, SGD
from draw import get_data, find_neu
import numpy as np
from rouge import Rouge


def seed_everything(seed):
    if seed >= 10000:
        raise ValueError("seed number should be less than 10000")
    if torch.distributed.is_initialized():
        rank = torch.distributed.get_rank()
    else:
        rank = 0
    seed = (rank * 100000) + seed

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def set_requires_grad(requires_grad, *models):
    """
    Sets requires_grad true or false for all parameters within the
    models passed.
    """
    for model in models:
        if isinstance(model, torch.nn.Module):
            for param in model.parameters():
                param.requires_grad = requires_grad
        elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
            model.requires_grad = requires_grad
        else:
            assert False, "unknown type %r" % type(model)

def batch_decode(logits, labels, tok):
    pre = np.array(torch.argmax(logits.detach().clone(), dim=-1).cpu())
    # batch_size * vocab_size
    labels = np.array(labels.cpu())
    ind = np.where(labels == -100)

    pre[ind] = tok.pad_token_id
    labels[ind] = tok.pad_token_id
    decoded_preds = tokenizer.batch_decode(pre, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # token_acc = np.mean()
    return decoded_labels, decoded_preds

class compute_acc:
    def __init__(self, func):
        self.func = func
    def __call__(self, s_label, s_pre):
        return self.func(s_label, s_pre)

def comput_rouge(s_label, s_pre):
    ro = Rouge()
    s_label_ = []
    s_pre_ = []
    for la in s_label:
        if len(la) <= 0:
            la += "*"
        s_label_.append(la)
    for pre in s_pre:
        if len(pre) <= 0:
            pre += "*"
        s_pre_.append(pre)
    te = {"rouge-1": [], "rouge-2": [], "rouge-l": []}
    acc_li = ro.get_scores(hyps=s_pre_, refs=s_label_, avg=False)
    for i in acc_li:
        for key in i:
            te[key].append(i[key]["f"])

    return te
    
def exa_match(s_label, s_pre):
    return {"acc": np.array(s_label) == np.array(s_pre)}

def math_acc(s_label, s_pre):
    s_label_ = [s.split("###  ")[-1] for s in s_label]
    s_pre_ = [s.split("###  ")[-1] for s in s_pre]
    return {"acc": np.array(s_label_) == np.array(s_pre_)}

def evaluate(eval_loader, kn, eval_caler):
    with torch.no_grad():
        loop = tqdm(eval_loader, desc=f"evaling")
        correct_dic = {}
        num = 0
        for i in loop:
            output = kn.model(**i["tok_data"])
            logits = output["logits"]
            labels = i["tok_data"]["labels"]
            str_lables, str_pre = batch_decode(logits=logits, labels=labels, tok=tokenizer)
            tem_list = eval_caler(s_label=str_lables, s_pre=str_pre)
            for k in tem_list:
                if k not in correct_dic:
                    correct_dic[k] = 0
                correct_dic[k] += np.sum(tem_list[k])
            num += len(str_lables)
    return {ke: va/num for ke, va in correct_dic.items()}

def gen_random_neur(neur, neur_layer):
    layer = [ne[0] for ne in neur]
    layer_dic = {str(la): [] for la in sorted(list(set(layer)))}
    for i in neur:
        layer_dic[str(i[0])].append(i[1])
    random_neur = []
    for key in layer_dic:
        layer = set([i for i in range(neur_layer)])
        now_set = list(layer.difference(set(layer_dic[key])))
        ind = sorted(random.sample(now_set, k=len(layer_dic[key])))
        random_neur += [[int(key), i] for i in ind]
    return random_neur



seed_everything(42)
# the random seed to life, the universe, and everything

# caler = compute_acc(exa_match)
# caler = compute_acc(math_acc)
caler = compute_acc(comput_rouge)
batch_size = 32
lr = 1e-5
epoch = 5
"===================================="
data_name = "code2k"
num_data = 22
data_rato = 50
tho = 12
te = "random"
test_data_names = ["emotion", "code2k", "gsm8k_cho", "imdb", "meta_math", "gsm8k_cho_new", "meta_math_new"]
dtype = "bf16"
test_size = None
"===================================="
file_name = "fn_data.json"
test_file_name = "fn_data.json"
# emotion 2400
num_neur = 11008

MODEL_NAME = "/share/projset/Model_edit/model_saves/models--meta-llama--Llama-2-7b-hf"
ml_model, tokenizer = initialize_model_and_tokenizer(MODEL_NAME, dtype=dtype)
kn_ml = KnowledgeNeurons(ml_model, tokenizer, model_type=model_type(MODEL_NAME))


# te_n = [[0, 912], [0, 3920], [0, 5899], [0, 6100], [0, 7027], [0, 7374], [0, 7642], [0, 9243], [0, 10559], [1, 277], [1, 3864], [1, 4108], [1, 5616], [1, 7890], [3, 4347], [3, 5119], [4, 1542], [4, 2339], [4, 2982], [4, 5299], [4, 6604], [4, 9117], [4, 10165], [5, 10881], [6, 4301], [7, 7775], [7, 7926], [8, 3917], [8, 6084], [8, 6932], [11, 7005], [12, 1474], [12, 1928], [12, 2669], [13, 455], [13, 1817], [13, 3363], [14, 2808], [14, 6019], [14, 7301], [14, 8834], [14, 10277], [14, 10346], [15, 7774], [15, 8758], [15, 10283], [16, 5], [16, 4486], [16, 6874], [16, 8655], [16, 10103], [17, 3014], [17, 4373], [17, 7098], [17, 8337], [17, 8903], [17, 9264], [17, 9580], [17, 10057], [18, 3344], [18, 4270], [18, 4724], [18, 4979], [18, 6858], [18, 7556], [18, 7625], [18, 8223], [18, 9745], [19, 628], [19, 4898], [19, 8144], [20, 4235], [20, 4277], [20, 4524], [20, 6626], [20, 8388], [20, 8744], [20, 9202], [20, 9268], [20, 9606], [20, 9818], [20, 10150], [20, 10638], [21, 1288], [21, 2711], [21, 3422], [21, 5258], [21, 8519], [21, 9568], [22, 1420], [22, 3043], [22, 3797], [22, 5917], [22, 6952], [22, 7639], [22, 9907], [23, 859], [23, 2832], [23, 3413], [23, 6270], [24, 895], [24, 8175], [24, 10316], [24, 10582], [25, 600], [25, 756], [25, 1410], [25, 2874], [25, 3977], [25, 4189], [25, 6381], [25, 7166], [25, 8164], [25, 9297], [25, 9475], [25, 9988], [25, 10041], [26, 2445], [26, 4600], [26, 6678], [26, 7484]]
neu_data = get_data(ind_list=[i for i in range(num_data)], data_name=data_name, rato=data_rato)
te_n = find_neu(neu_data, threshold=tho)
if te == "random":
    ran_n = gen_random_neur(te_n, neur_layer=num_neur)
    te_n = ran_n
set_requires_grad(False, kn_ml.model)
# patch_list = kn_ml.load_trainable_neurons(te_n)

base_acc_list = {k: [] for k in test_data_names}
erase_acc_list = {k: [] for k in test_data_names}

for test_name in test_data_names:
    data = BaseDataset(tokenizer=tokenizer, device=kn_ml.model.device,path=f"/share/projset/knowledge-neurons/ckpt/{test_name}/{file_name}", num=test_size)
    loader = DataLoader(data, batch_size=batch_size, shuffle=False, collate_fn=data.collate_fn)
    test_caler = compute_acc(exa_match) if "code" not in test_name else compute_acc(comput_rouge)
    base_acc = evaluate(eval_caler=test_caler, eval_loader=loader, kn=kn_ml)
    base_acc_list[test_name].append(base_acc)
    # print(f"{test_name}_base_acc: {}")
for k in base_acc_list.keys():
    print(k)
    print(f"base_acc: {base_acc_list[k]}")
    # print(f"erase_acc: {erase_acc_list[k]}")
    print("==================")

kn_ml.unloda_trainable_neurons(te_n, mode="erase")

for test_name in test_data_names:
    data = BaseDataset(tokenizer=tokenizer, device=kn_ml.model.device,path=f"/share/projset/knowledge-neurons/ckpt/{test_name}/{test_file_name}", num=test_size)
    loader = DataLoader(data, batch_size=batch_size, shuffle=False, collate_fn=data.collate_fn)
    test_caler = compute_acc(exa_match) if "code" not in test_name else compute_acc(comput_rouge)
    erase_acc = evaluate(eval_caler=test_caler, eval_loader=loader, kn=kn_ml)
    erase_acc_list[test_name].append(erase_acc)

for k in erase_acc_list.keys():
    print(k)
    print(f"base_acc: {base_acc_list[k]}")
    print(f"erase_acc: {erase_acc_list[k]}")
    print("==================")

with open(f"/share/projset/knowledge-neurons/ckpt/{data_name}/erase_result_{te}_tho-{tho}.json", "w", encoding="utf-8") as f:
    json.dump({"base_acc": base_acc_list, "erase_acc": erase_acc_list}, f, ensure_ascii=False)
f.close()
