from knowledge_neurons import (
    KnowledgeNeurons,
    initialize_model_and_tokenizer,
    model_type,
)
import random
from loaddata import BaseDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import json
import random
import torch.nn.functional as F
from torch.optim import Adam, SGD
from draw import get_data, find_neu
import numpy as np
from rouge import Rouge
import time


def seed_everything(seed):
    if seed >= 10000:
        raise ValueError("seed number should be less than 10000")
    if torch.distributed.is_initialized():
        rank = torch.distributed.get_rank()
    else:
        rank = 0
    seed = (rank * 100000) + seed

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def set_requires_grad(requires_grad, *models):
    """
    Sets requires_grad true or false for all parameters within the
    models passed.
    """
    for model in models:
        if isinstance(model, torch.nn.Module):
            for param in model.parameters():
                param.requires_grad = requires_grad
        elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
            model.requires_grad = requires_grad
        else:
            assert False, "unknown type %r" % type(model)

def batch_decode(logits, labels, tok):
    pre = np.array(torch.argmax(logits.detach().clone(), dim=-1).cpu())
    # batch_size * vocab_size
    labels = np.array(labels.cpu())
    ind = np.where(labels == -100)

    pre[ind] = tok.pad_token_id
    labels[ind] = tok.pad_token_id
    decoded_preds = tokenizer.batch_decode(pre, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # token_acc = np.mean()
    return decoded_labels, decoded_preds

class compute_acc:
    def __init__(self, func):
        self.func = func
    def __call__(self, s_label, s_pre):
        return self.func(s_label, s_pre)

def comput_rouge(s_label, s_pre):
    ro = Rouge()
    s_label_ = []
    s_pre_ = []
    for la in s_label:
        if len(la) <= 0:
            la += "*"
        s_label_.append(la)
    for pre in s_pre:
        if len(pre) <= 0:
            pre += "*"
        s_pre_.append(pre)
    te = {"rouge-1": [], "rouge-2": [], "rouge-l": []}
    acc_li = ro.get_scores(hyps=s_pre_, refs=s_label_, avg=False)
    for i in acc_li:
        for key in i:
            te[key].append(i[key]["f"])

    return te
    
def exa_match(s_label, s_pre):
    return {"acc": np.array(s_label) == np.array(s_pre)}

def math_acc(s_label, s_pre):
    s_label_ = [s.split("###  ")[-1] for s in s_label]
    s_pre_ = [s.split("###  ")[-1] for s in s_pre]
    return {"acc": np.array(s_label_) == np.array(s_pre_)}

def evaluate(eval_loader, kn, eval_caler):
    with torch.no_grad():
        loop = tqdm(eval_loader, desc=f"evaling")
        correct_dic = {}
        num = 0
        for i in loop:
            output = kn.model(**i["tok_data"])
            logits = output["logits"]
            labels = i["tok_data"]["labels"]
            str_lables, str_pre = batch_decode(logits=logits, labels=labels, tok=tokenizer)
            tem_list = eval_caler(s_label=str_lables, s_pre=str_pre)
            for k in tem_list:
                if k not in correct_dic:
                    correct_dic[k] = 0
                correct_dic[k] += np.sum(tem_list[k])
            num += len(str_lables)
    return {ke: va/num for ke, va in correct_dic.items()}

def gen_random_neur(neur, neur_layer):
    layer = [ne[0] for ne in neur]
    layer_dic = {str(la): [] for la in sorted(list(set(layer)))}
    for i in neur:
        layer_dic[str(i[0])].append(i[1])
    random_neur = []
    for key in layer_dic:
        layer = set([i for i in range(neur_layer)])
        now_set = list(layer.difference(set(layer_dic[key])))
        ind = sorted(random.sample(now_set, k=len(layer_dic[key])))
        random_neur += [[int(key), i] for i in ind]
    return random_neur



seed_everything(42)
# the random seed to life, the universe, and everything


batch_size = 32
lr = 1e-5
# emotion 1e-5
te = "random"
"===================================="
data_name = "code2k"
caler = compute_acc(exa_match) if data_name!="code2k" else compute_acc(comput_rouge)
# caler = compute_acc(comput_rouge)
num_data = 22
data_rato = 50
file_name = "fn_data.json"
test_file_name = "test4FT_data.json"
test_data_names = ["emotion", "code2k", "gsm8k_cho", "imdb", "meta_math", "gsm8k_cho_new", "meta_math_new"] if te != "full" else [data_name]
dtype = "bf16"
test_size = 2400
epoch = 10
train_size = None
# emotion 2400
# code 2400
"===================================="
num_neur = 11008

MODEL_NAME = "/share/projset/Model_edit/model_saves/models--meta-llama--Llama-2-7b-hf"
ml_model, tokenizer = initialize_model_and_tokenizer(MODEL_NAME, dtype=dtype)
kn_ml = KnowledgeNeurons(ml_model, tokenizer, model_type=model_type(MODEL_NAME))


# te_n = [[0, 912], [0, 3920], [0, 5899], [0, 6100], [0, 7027], [0, 7374], [0, 7642], [0, 9243], [0, 10559], [1, 277], [1, 3864], [1, 4108], [1, 5616], [1, 7890], [3, 4347], [3, 5119], [4, 1542], [4, 2339], [4, 2982], [4, 5299], [4, 6604], [4, 9117], [4, 10165], [5, 10881], [6, 4301], [7, 7775], [7, 7926], [8, 3917], [8, 6084], [8, 6932], [11, 7005], [12, 1474], [12, 1928], [12, 2669], [13, 455], [13, 1817], [13, 3363], [14, 2808], [14, 6019], [14, 7301], [14, 8834], [14, 10277], [14, 10346], [15, 7774], [15, 8758], [15, 10283], [16, 5], [16, 4486], [16, 6874], [16, 8655], [16, 10103], [17, 3014], [17, 4373], [17, 7098], [17, 8337], [17, 8903], [17, 9264], [17, 9580], [17, 10057], [18, 3344], [18, 4270], [18, 4724], [18, 4979], [18, 6858], [18, 7556], [18, 7625], [18, 8223], [18, 9745], [19, 628], [19, 4898], [19, 8144], [20, 4235], [20, 4277], [20, 4524], [20, 6626], [20, 8388], [20, 8744], [20, 9202], [20, 9268], [20, 9606], [20, 9818], [20, 10150], [20, 10638], [21, 1288], [21, 2711], [21, 3422], [21, 5258], [21, 8519], [21, 9568], [22, 1420], [22, 3043], [22, 3797], [22, 5917], [22, 6952], [22, 7639], [22, 9907], [23, 859], [23, 2832], [23, 3413], [23, 6270], [24, 895], [24, 8175], [24, 10316], [24, 10582], [25, 600], [25, 756], [25, 1410], [25, 2874], [25, 3977], [25, 4189], [25, 6381], [25, 7166], [25, 8164], [25, 9297], [25, 9475], [25, 9988], [25, 10041], [26, 2445], [26, 4600], [26, 6678], [26, 7484]]
neu_data = get_data(ind_list=[i for i in range(num_data)], data_name=data_name, rato=data_rato)
te_n = find_neu(neu_data)
if te == "random":
    ran_n = gen_random_neur(te_n, neur_layer=num_neur)
    te_n = ran_n
set_requires_grad(False, kn_ml.model)


if te == "full":
    set_requires_grad(True, kn_ml.model)
    # te_n = []
    weights = []
    for key, w in kn_ml.model.named_parameters():
        if kn_ml.input_ff_attr in key:
            weights.append(w)
    opt = Adam(params=weights, lr=lr, eps=1e-3)
else: 
    patch_list = kn_ml.load_trainable_neurons(te_n)

    params = []
    for patch in patch_list:
        params += [li.weight for li in patch.delta_neurons]
    opt = Adam(params=params, lr=lr, eps=1e-3)


loss_li = []
acc_li = {}
# eps_emotion = 1e-2
data_fn = []
epoch_train_acc_li = []
epoch_val_acc_li = {k: [] for k in test_data_names}

start_time = time.time()
for ep in range(epoch):
    data = BaseDataset(tokenizer=tokenizer, device=kn_ml.model.device,path=f"/share/projset/knowledge-neurons/ckpt/{data_name}/{file_name}", num=train_size)
    loader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=data.collate_fn)
    loop = tqdm(loader, desc=f"now/total -> {ep}/{epoch-1}")
    for i in loop:
        if te == "full":
            # kn_ml.load_trainable_neurons(te_n)
            kn_ml.unloda_trainable_neurons(te_n, mode="erase")
        opt.zero_grad()
        kn_ml.model.train()
        output = kn_ml.model(**i["tok_data"])
        logits = output["logits"]
        # batch_size * seqlen * vocab_size
        ans_indice = torch.where(i["tok_data"]["labels"] != -100)
        te_logits = logits[ans_indice]
        labels = i["tok_data"]["labels"]
        te_labels = labels[ans_indice]
        num_nan = torch.sum(np.isnan(te_logits.detach().clone().to(torch.float).cpu()))
        "===================================================="
        if num_nan > 0:
            # if use fp16 to train there might be nan in logits
            loop.set_postfix(loss=loss.item(), nan="True")
            continue
        "===================================================="
        # data_fn += i["raw_data"]
        loss = F.cross_entropy(te_logits.to(te_labels.device), te_labels)
        loss.backward()
        with torch.no_grad():
            kn_ml.model.eval()
            # te = tokenizer.batch_decode(i["tok_data"]["labels"], skip_special_tokens=True)
            str_lables, str_pre = batch_decode(logits=logits, labels=labels, tok=tokenizer)
            acc_dic = caler(s_label=str_lables, s_pre=str_pre)
            tem_dic = {k: np.mean(v) for k, v in acc_dic.items()}
            for key in tem_dic:
                if te not in acc_li:
                    acc_li[key] = []
                acc_li[key].append(tem_dic[key])
            # acc_li.append(acc)
        loop.set_postfix(loss=loss.item(), **tem_dic)
        loss_li.append(loss.item())
        opt.step()
        if te == "full":
            # kn_ml.load_trainable_neurons(te_n)
            kn_ml.unloda_trainable_neurons(te_n, mode="erase")
    data = BaseDataset(tokenizer=tokenizer, device=kn_ml.model.device,path=f"/share/projset/knowledge-neurons/ckpt/{data_name}/{file_name}", num=train_size)
    loader = DataLoader(data, batch_size=batch_size, shuffle=False, collate_fn=data.collate_fn)
    train_acc = evaluate(eval_caler=caler, eval_loader=loader, kn=kn_ml)
    epoch_train_acc_li.append(train_acc)
    print(f"train_acc: {train_acc}")

    for test_name in test_data_names:
        data = BaseDataset(tokenizer=tokenizer, device=kn_ml.model.device,path=f"/share/projset/knowledge-neurons/ckpt/{test_name}/{test_file_name}", num=test_size)
        if (test_name == "imdb") or (test_name == "meta_math") or ("new" in test_name):
            loader = DataLoader(data, batch_size=10, shuffle=False, collate_fn=data.collate_fn)
        else:
            loader = DataLoader(data, batch_size=batch_size, shuffle=False, collate_fn=data.collate_fn)
        test_caler = compute_acc(exa_match) if "code" not in test_name else compute_acc(comput_rouge)
        val_acc = evaluate(eval_caler=test_caler, eval_loader=loader, kn=kn_ml)
        epoch_val_acc_li[test_name].append(val_acc)
        print(f"{test_name}_val_acc: {val_acc}")
        # del data, loader
        # torch.cuda.empty_cache()

end_time = time.time()

with open(f"/share/projset/knowledge-neurons/ckpt/{data_name}/result_{te}.json", "w", encoding="utf-8") as f:
    json.dump({"fn_train_acc": epoch_train_acc_li, "fn_val_acc": epoch_val_acc_li, "time": end_time-start_time,"loss": loss_li, "acc": acc_li, "neurons": te_n}, f, ensure_ascii=False)
f.close()
