import sys
import os
import json
import numpy as np
import random
import torch
import pandas as pd
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
# os.environ["CUDA_VISIBLE_DEVICES"]="0"
current_path = os.path.abspath(os.path.dirname(os.getcwd()))
sys.path.append(os.path.join(current_path))
sys.path.append(os.path.join(current_path, "src")) 
from decoding_algorithm import ContrastiveDecoding
from utils.format_data_mmlu import subcategories, categories, gen_prompt, format_problem
model_name = "/mnt/llms/model/meta-llama/Llama-2-7b-hf"
# model_name="/mnt/llms/model/google-gemma/gemma-2b"
# model_name= "/mnt/llms/model/meta-llama/Llama-2-13b-hf"
llm = ContrastiveDecoding(model_name)
stop_word_list = ["Q:"]
llm.set_stop_words(stop_word_list)
PROMPT_NUM = 3 # 1, 2, 3, 4, 5
N_CHOICE = 4
BIAS_ANS = "A"
data_path = "../../data/MMLU"
llm.set_label_id("mmlu")
subjects = sorted(
    [
        f.split("_test.csv")[0]
        for f in os.listdir(os.path.join(data_path, "test"))
        if "_test.csv" in f
    ]
)
all_cors = []
subcat_cors = {
    subcat: [] for subcat_lists in subcategories.values() for subcat in subcat_lists
}
cat_cors = {cat: [] for cat in categories}
all_data = []
for subject in subjects:
    """
    if subcategories[subject][0] not in categories["other (business, health, misc.)"]: # humanities, social sciences, other (business, health, misc.)
        continue
    """
    dev_df = pd.read_csv(
        os.path.join(data_path, "dev", subject + "_dev.csv"), header=None
    )[: 5]
    test_df = pd.read_csv(
        os.path.join(data_path, "test", subject + "_test.csv"), header=None
    )
    for i in range(test_df.shape[0]):
        k = PROMPT_NUM
        # 改变不同的bias ans对于最后找到的layer是不变的
        prompt = gen_prompt(dev_df, subject, k, bias=False, bias_ans=BIAS_ANS, n=N_CHOICE)
        bias_prompt = gen_prompt(dev_df, subject, k, bias=True, bias_ans=BIAS_ANS, n=N_CHOICE)
        problem, y_true = format_problem(test_df, i, n=N_CHOICE)
        bias_content = bias_prompt + problem
        prompt = prompt + problem
        bias_prompt = bias_prompt + problem
        all_data.append({"prompt": prompt, "bias_prompt": bias_prompt, "problem": problem, "y_true": y_true})


# check干预哪一个layer对于最后的精度提升最多
SAMPLE_NUM = 8
data_train = []
data_test = []
data_bias = []
data_no_bias = []
indexes = list(range(len(all_data)))
random.shuffle(indexes)
for i in indexes[:SAMPLE_NUM]:
    data_train.append(all_data[i])
for i in indexes[SAMPLE_NUM:1000]:
    data_test.append(all_data[i])

print(llm.eval_model(data_train))

print("begin to train")

for name, param in llm.model.named_parameters():
    if "layers.17." not in name.lower():
        param.requires_grad = False
    else:
        print(name)
        param.requires_grad = True
optimizer = optim.Adam(filter(lambda p: p.requires_grad, llm.model.parameters()), lr=0.00001)
for idx, data in tqdm(enumerate(data_train)):
    optimizer.zero_grad()
    y_true = data["y_true"]
    # content = data["bias_prompt"] if bias else data["prompt"]
    content = data["bias_prompt"]
    input_ids = llm.tokenizer(content, return_tensors='pt').input_ids.to(llm.device)
    outputs = llm.model(input_ids)["logits"][-1]
    logits = outputs[-1, :]
    probs = F.softmax(logits, dim=-1)
    target = torch.eye(len(probs), requires_grad=True)[llm.label_id_list[y_true]].cuda()
    loss = F.cross_entropy(probs, target)
    loss.backward()
    optimizer.step()
print("====finish===")
print(llm.eval_model(data_train))
