import argparse
import json
import os
import pickle
import random
import sys
from pathlib import Path
import numpy as np
import torch

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
print(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from B_train_Topic_model.Topic_XICL.evaluate_mlqa import compute_f1

# os.environ['CUDA_VISIBLE_DEVICES'] = "0"
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name())
print(torch.cuda.current_device())
from tqdm import tqdm
from transformers import GenerationConfig, AutoModelForCausalLM, AutoTokenizer, Qwen2Tokenizer

from A_data_preprocess.util.data import load_data, task2lang

choices = ["A", "B", "C", "D"]


def set_seed(seed, set_gpu=True):
    seed = int(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if set_gpu and torch.cuda.is_available():
        # Necessary for reproducibility; lower performance
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.cuda.manual_seed_all(seed)


def load_data(dataset_name, file_path, set_up, src, tgt, seed, k, n_clusters, tokenizer, mode='test', k_num=0, per_max_len=512):
    if src == None:
        src = 'all'
    # cluster_560m_100_hard_src=en_tgt=en_k=4_seed=100_class=20.pkl
    file_path = os.path.join(file_path, 'process_data/{}/{}'.format(dataset_name, mode))
    if not os.path.exists(file_path):
        print(file_path)
        raise ValueError('We dont have data for specified target language')
    try:
        file_name = os.path.join(file_path,
                                 '{}_src={}_tgt={}_k={}_seed={}_class={}.pkl'.format(set_up, src, tgt, k, seed,
                                                                                     n_clusters))

        with open(file_name, 'rb') as handle:
            data = pickle.load(handle)
    except:
        file_name = os.path.join(file_path,
                                 '{}_src={}_s={}_k={}_class={}.pkl'.format(set_up, src, tgt, k,
                                                                           n_clusters))

        with open(file_name, 'rb') as handle:
            data = pickle.load(handle)

    for dp_index, dp in enumerate(tqdm(data, desc='Tokenising Data')):
        demonstrations = "\n"
        if k_num > 0:
            demonstrations = "\nFor example:"
            try:
                for it, demo in enumerate(dp['demos']):
                    if it < k_num:
                        text_ori = demo["input"].replace("Question:", "\nQuestion:").replace("Answer:","\nAnswer:\n") + \
                                          demo["output"].replace('[\'', '').replace('\']', '').split('\', \'')[0]
                        text_id = tokenizer(text_ori, return_tensors="pt").input_ids
                        if text_id.size()[0] > per_max_len:
                            text_id = text_id[text_id.size()[0]-per_max_len:]
                            text_ori = tokenizer.decode(text_id)
                        demonstrations += "\n\n" + text_ori
            except:
                for it, demo in enumerate(dp['demons']):
                    if it < k_num:
                        text_ori = demo["input"].replace("Question:", "\nQuestion:").replace("Answer:", "\nAnswer:\n") + \
                                   demo["output"].replace('[\'', '').replace('\']', '').split('\', \'')[0]
                        text_id = tokenizer(text_ori, return_tensors="pt").input_ids
                        if text_id.size()[0] > per_max_len:
                            text_id = text_id[text_id.size()[0] - per_max_len:]
                            text_ori = tokenizer.decode(text_id)
                        demonstrations += "\n\n" + text_ori

        dp['input'] = demonstrations + ("\n\n" if demonstrations != "\n" else "") + dp['input'].replace("Question:",
                                                                                                          "\nQuestion:").replace(
            "Answer:", "\nAnswer:")

    return data


def main(
        args,
        is_fp16: bool = True,
        save_dir: str = None,
):
    batch_size = args.batch_size
    print(f"main start, is_fp16:{is_fp16}, batch_size:{batch_size}")

    model_path = args.model_path
    model, tokenizer = get_model(model_path, is_fp16=is_fp16)
    print("model loaded")

    batch_llama = get_batch_llama(model, tokenizer)

    if args.lang_only is None:
        langs = task2lang[args.dataset_name]
    else:
        langs = args.lang_only

    results = {}
    results_f1 = {}
    if "sim_in" in args.set_up:
        seeds = [32]
    else:
        seeds = args.seeds.split(',')

    set_up_init = args.set_up
    for s in seeds:
        save_dir = os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)),
                                model_path.rstrip("/").split("/")[-1],
                                args.dataset_name, args.set_up+"_"+str(s))
        Path(save_dir).mkdir(parents=True, exist_ok=True)
        set_seed(s)

        if "cluster_1b7" in set_up_init or "cluster_560m" in set_up_init:
            args.set_up = set_up_init + u"_{m}".format(m=s)
        # args.set_up = set_up_init + u"_{m}".format(m=s)
        if int(s) == 32 and "random" in args.set_up:
            k_list = [0, 2, 3, 4]
        else:
            k_list = [2, 3, 4]

        for k_num in k_list:
            for lang in langs:
                print(f'===========we are testing in {lang}====================')

                tydiqa_datas = load_data(dataset_name=args.dataset_name,
                                         file_path=args.file_path,
                                         tgt=lang, src="en",
                                         set_up=args.set_up,
                                         seed=s,
                                         k=args.k, mode='test', n_clusters=args.n_clusters, k_num=k_num,tokenizer=tokenizer,
                                         per_max_len=args.per_max_len)

                print(tydiqa_datas[0])

                gen_datas_jsonl = Path(
                    save_dir) / f"gen_{lang}_datas_{k_num}_use_logic_{s}.jsonl" if args.use_logic else Path(
                    save_dir) / f"gen_{lang}_datas_{k_num}_{s}.jsonl"
                if args.overwrite_results:
                    start_index = 0
                else:
                    start_index = (
                        len(open(gen_datas_jsonl).readlines()) if gen_datas_jsonl.exists() else 0
                    )
                print(f"start_index: {start_index}")

                for i in tqdm(range(start_index, len(tydiqa_datas), batch_size)):
                    cur_tydiqa_batch = tydiqa_datas[i: i + batch_size]
                    # print(cur_tydiqa_batch)
                    if tokenizer.chat_template is not None:
                        input_str_list, output_str_list = tydiqa_batch_gen(args.use_logic,
                                                                           cur_tydiqa_batch, batch_llama, tokenizer,
                                                                           lang_=lang,model_path=model_path
                                                                           )
                    else:
                        input_str_list, output_str_list = tydiqa_batch_gen_no_chat(args.use_logic,
                                                                                   cur_tydiqa_batch, batch_llama,
                                                                                   lang_=lang, model_path=model_path
                                                                                   )
                    for j, (tydiqa_data, input_str, output_str) in enumerate(
                            zip(cur_tydiqa_batch, input_str_list, output_str_list)
                    ):
                        with open(gen_datas_jsonl, "a") as f:
                            json.dump(
                                dict(
                                    index=i + j,
                                    correct_answer_num=tydiqa_data["output"].replace('[\'', '').replace('\']',
                                                                                                        '').split(
                                        '\', \''),
                                    output_str=output_str,
                                    input_str=input_str,
                                ),
                                f,
                                ensure_ascii=False
                            )
                            f.write("\n")

                # calculate acc
                with open(gen_datas_jsonl) as f:
                    gen_datas = [json.loads(line) for line in f]

                correct_results = []
                wrong_results = []
                f1_scores = 0.0
                all_num = 0
                for gen in gen_datas:
                    result = dict(
                        **gen,
                        is_correct=False,
                    )
                    result["extract_pred_num"] = result["output_str"].strip()
                    if "\nPassage:" in result['extract_pred_num']:
                        result['extract_pred_num'] = result['extract_pred_num'].split("\nPassage:")[0].strip()
                    if "system\nYou" in result['extract_pred_num']:
                        result['extract_pred_num'] = result['extract_pred_num'].split("system\nYou")[0].strip()
                    if not isinstance(result["correct_answer_num"], list):
                        result["correct_answer_num"] = [result["correct_answer_num"]]
                    f1_scores += max(
                        compute_f1(a.lower().replace("\"", ""), result["extract_pred_num"].lower(), lang) for a in
                        result["correct_answer_num"])
                    all_num += 1
                    for truth in result["correct_answer_num"]:
                        if lang == "ko":
                            truth = truth.replace(" ", "")
                            result["extract_pred_num"] = result["extract_pred_num"].replace(" ", "")
                        truth = truth.replace("–", "-")
                        result["extract_pred_num"] = result["extract_pred_num"].replace("–", "-")
                        if truth.lower().replace("\"", "") in result["extract_pred_num"].lower():
                            result["is_correct"] = True
                            correct_results.append(result)
                            break
                    if result["is_correct"] == False:
                        wrong_results.append(result)

                print(f'=======done {lang}============')
                result = f"Accuracy={len(correct_results)}/({len(correct_results)}+{len(wrong_results)})={len(correct_results) / (len(correct_results) + len(wrong_results))}"
                print(result)
                num_result = float(result.split('=')[-1])
                results[lang] = num_result
                results_f1[lang] = f1_scores / all_num
                print("f1 = ", results_f1[lang])

            result_jsonl = Path(save_dir) / f"gen_datas_{k_num}_use_logic.txt"
            with open(result_jsonl, "w+") as f:
                for key, value in results.items():
                    f.write(f"EM {key}={value}\n")
                average = sum(results.values()) / len(results)
                f.write(f"EM AVG={average}\n")
                for key, value in results_f1.items():
                    f.write(f"F1 {key}={value}\n")
                average = sum(results_f1.values()) / len(results_f1)
                f.write(f"F1 AVG={average}\n")
            print(average)

def tydiqa_batch_gen(
        use_logic, tydiqa_questions, batch_llm, tokenizer, lang_="en", model_path=""
):
    if use_logic:
        if lang_ == "en":
            prompt_template = """Answer the question from the given passage. Your answer should be directly extracted from the passage, and it should be a single entity, name, or number, not a sentence. \n{data}  Based on the passage, the answer to the question is\""""
        if lang_ == "zh":
            prompt_template = """回答给定段落中的问题。 你的答案应该直接从文章中直接提取，并且应该是单个实体、名称或数字，而不是句子。 \n{data} 根据该段落，问题的答案是\""""
        if lang_ == "bg":
            prompt_template = """প্রদত্ত অনুচ্ছেদ থেকে প্রশ্নের উত্তর দিন। আপনার উত্তরটি অনুচ্ছেদ থেকে সরাসরি তুলে নেওয়া উচিত এবং এটি একটি একক সত্তা, নাম বা সংখ্যা হওয়া উচিত, বাক্য নয়। \n{data} অনুচ্ছেদ অনুযায়ী, প্রশ্নের উত্তর হল \""""
        if lang_ == "fi":
            prompt_template = """Vastaa kysymykseen annetusta tekstistä. Vastauksesi tulisi olla suoraan tekstistä poimittu, ja sen tulisi olla yksittäinen entiteetti, nimi tai numero, ei lause. \n{data} Tekstin perusteella kysymyksen vastaus on \""""
        if lang_ == "id":
            prompt_template = """Jawablah pertanyaan dari teks yang diberikan. Jawaban Anda harus diambil langsung dari teks, dan harus berupa satu entitas, nama, atau angka, bukan kalimat. \n{data} Berdasarkan teks tersebut, jawaban dari pertanyaan adalah \""""
        if lang_ == "ru":
            prompt_template = """Ответьте на вопрос из данного отрывка. Ваш ответ должен быть напрямую взят из отрывка и должен быть одним объектом, именем или числом, а не предложением. \n{data} Основываясь на отрывке, ответ на вопрос:  \""""
        if lang_ == "te":
            prompt_template = """ఇచ్చిన ప్యారాగ్రాఫ్ నుండి ప్రశ్నకు సమాధానం ఇవ్వండి. మీ సమాధానం ప్యారాగ్రాఫ్ నుండి నేరుగా తీసుకోవాలి మరియు అది ఒకే ఒక ఎంటిటీ, పేరు లేదా సంఖ్యగా ఉండాలి, వాక్యంగా ఉండకూడదు. \n{data} ప్యారాగ్రాఫ్ ఆధారంగా, ప్రశ్నకు సమాధానం  \""""
        if lang_ == "sw":
            prompt_template = """Jibu swali kutoka kwenye kifungu kilichotolewa. Jibu lako linapaswa kuchukuliwa moja kwa moja kutoka kwenye kifungu, na linapaswa kuwa kitu kimoja, jina, au nambari, si sentensi. \n{data} Kulingana na kifungu, jibu la swali ni \""""
        if lang_ == "ko":
            prompt_template = """주어진 구절의 질문에 답하세요. 귀하의 답변은 해당 구절에서 직접 추출되어야 하며, 문장이 아닌 단일 개체, 이름 또는 번호여야 합니다. \n{data} 지문을 바탕으로 질문에 대한 답은\""""
        if lang_ == "ar":
            prompt_template = """أجب عن السؤال من المقطع المحدد. يجب أن تستخرج إجابتك مباشرة من المقطع، ويجب أن تكون كيانًا واحدًا أو اسمًا أو رقمًا، وليس جملة. \n{data} بناءً على المقطع، إجابة السؤال هي\""""

    else:
        prompt_template = """{data}"""
    if "qwen" in model_path.lower():
        input_str_list = [[{"role": "system", "content": "You are a helpful assistant."+prompt_template.split("\n")[0]},
                           {"role": "user", "content": d["input"]}]
                          for d in tydiqa_questions]
    else:
        input_str_list = [[{"role": "user", "content": prompt_template.format(data=d["input"])}]
                          for d in tydiqa_questions]
    input_str_list_ = [tokenizer.apply_chat_template(input_dict, tokenize=False, add_generation_prompt=True) for
                       input_dict in input_str_list]
    output_str_list = batch_llm(input_str_list_)
    return input_str_list, output_str_list


def tydiqa_batch_gen_no_chat(
        use_logic, tydiqa_questions, batch_llm, lang_="en", model_path=""
):
    if use_logic:
        if "llama" in model_path.lower():
            prompt_template = """Answer the question from the given passage. Your answer should be directly extracted from the passage, and it should be a single entity, name, or number, not a sentence. \n{data}  Based on the passage, the answer to the question is\""""
        # elif "qwen" in model_path.lower():
        #     if lang_ == "en":
        #         prompt_template = """Answer the question from the given passage. Your answer should be directly extracted from the passage, and it should be a single entity, name, or number, not a sentence. \n{data}  Based on the passage, the answer to the question is\""""
        #     if lang_ == "zh":
        #         prompt_template = """回答给定段落中的问题。 你的答案应该直接从文章中直接提取，并且应该是单个实体、名称或数字，而不是句子。 \n{data} 根据该段落，问题的答案是\""""
        #     if lang_ == "bg":
        #         prompt_template = """প্রদত্ত অনুচ্ছেদ থেকে প্রশ্নের উত্তর দিন। আপনার উত্তরটি অনুচ্ছেদ থেকে সরাসরি তুলে নেওয়া উচিত এবং এটি একটি একক সত্তা, নাম বা সংখ্যা হওয়া উচিত, বাক্য নয়। \n{data} অনুচ্ছেদ অনুযায়ী, প্রশ্নের উত্তর হল \""""
        #     if lang_ == "fi":
        #         prompt_template = """Vastaa kysymykseen annetusta tekstistä. Vastauksesi tulisi olla suoraan tekstistä poimittu, ja sen tulisi olla yksittäinen entiteetti, nimi tai numero, ei lause. \n{data} Tekstin perusteella kysymyksen vastaus on \""""
        #     if lang_ == "id":
        #         prompt_template = """Jawablah pertanyaan dari teks yang diberikan. Jawaban Anda harus diambil langsung dari teks, dan harus berupa satu entitas, nama, atau angka, bukan kalimat. \n{data} Berdasarkan teks tersebut, jawaban dari pertanyaan adalah \""""
        #     if lang_ == "ru":
        #         prompt_template = """Ответьте на вопрос из данного отрывка. Ваш ответ должен быть напрямую взят из отрывка и должен быть одним объектом, именем или числом, а не предложением. \n{data} Основываясь на отрывке, ответ на вопрос:  \""""
        #     if lang_ == "te":
        #         prompt_template = """ఇచ్చిన ప్యారాగ్రాఫ్ నుండి ప్రశ్నకు సమాధానం ఇవ్వండి. మీ సమాధానం ప్యారాగ్రాఫ్ నుండి నేరుగా తీసుకోవాలి మరియు అది ఒకే ఒక ఎంటిటీ, పేరు లేదా సంఖ్యగా ఉండాలి, వాక్యంగా ఉండకూడదు. \n{data} ప్యారాగ్రాఫ్ ఆధారంగా, ప్రశ్నకు సమాధానం  \""""
        #     if lang_ == "sw":
        #         prompt_template = """Jibu swali kutoka kwenye kifungu kilichotolewa. Jibu lako linapaswa kuchukuliwa moja kwa moja kutoka kwenye kifungu, na linapaswa kuwa kitu kimoja, jina, au nambari, si sentensi. \n{data} Kulingana na kifungu, jibu la swali ni \""""
        #     if lang_ == "ko":
        #         prompt_template = """주어진 구절의 질문에 답하세요. 귀하의 답변은 해당 구절에서 직접 추출되어야 하며, 문장이 아닌 단일 개체, 이름 또는 번호여야 합니다. \n{data} 지문을 바탕으로 질문에 대한 답은\""""
        #     if lang_ == "ar":
        #         prompt_template = """أجب عن السؤال من المقطع المحدد. يجب أن تستخرج إجابتك مباشرة من المقطع، ويجب أن تكون كيانًا واحدًا أو اسمًا أو رقمًا، وليس جملة. \n{data} بناءً على المقطع، إجابة السؤال هي\""""
        else:
            prompt_template = """Answer the question from the given passage. Your answer should be directly extracted from the passage, and it should be a single entity, name, or number, not a sentence. \n{data}\n"""
    else:
        prompt_template = """{data}"""

    input_str_list_ = [prompt_template.format(data=d["input"]) for d in tydiqa_questions]
    # input_str_list_ = [tokenizer.apply_chat_template(input_dict, tokenize=False, add_generation_prompt=True) for input_dict in input_str_list]
    output_str_list = batch_llm(input_str_list_)
    return input_str_list_, output_str_list


def get_batch_llama(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, max_length: int = 4096):
    @torch.inference_mode()
    def batch_llama(input_strs):
        input_ids_w_attnmask = tokenizer(input_strs,padding=True,return_tensors="pt").to(model.device)
        with torch.no_grad():
            output_ids = model.generate(
                input_ids=input_ids_w_attnmask.input_ids if input_ids_w_attnmask.input_ids.shape[1]<=max_length else input_ids_w_attnmask.input_ids[:, input_ids_w_attnmask.input_ids.shape[1]-max_length:],
                attention_mask=input_ids_w_attnmask.attention_mask if input_ids_w_attnmask.attention_mask.shape[1]<=max_length else input_ids_w_attnmask.attention_mask[:, input_ids_w_attnmask.attention_mask.shape[1]-max_length:],
                generation_config=GenerationConfig(
                    max_new_tokens=32,
                    do_sample=False,
                    temperature=0.0,
                    pad_token_id=tokenizer.eos_token_id
                ),
            ).tolist()
            # pad_token_id=tokenizer.eos_token_id
            real_output_ids = [
                output_id[len(input_ids_w_attnmask.input_ids[i]):] for i, output_id in enumerate(output_ids)
            ]
            output_strs = tokenizer.batch_decode(real_output_ids, skip_special_tokens=True)
        return output_strs

    return batch_llama


def get_model(model_path: str, is_fp16: bool = False):
    print(model_path)
    if "epoch" in model_path:
        tokenizer = AutoTokenizer.from_pretrained(model_path.split("epoch")[0], padding_side="left")
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
    print(tokenizer.pad_token)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print('new pad ', tokenizer.pad_token)
    print(tokenizer.bos_token)
    print(tokenizer.unk_token)
    print(tokenizer.eos_token)
    print(tokenizer.truncation_side)
    print(tokenizer.padding_side)

    if is_fp16:
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
        ).cuda()
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
        ).cuda()
    model.eval()
    print(model.dtype)

    return model, tokenizer


if __name__ == "__main__":
    import fire

    parser = argparse.ArgumentParser(description="Eval the finetued SFT model")
    parser.add_argument(
        "--model_path",
        type=str,
        help="Path to baseline model",
        default="",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        help="batchsize",
        default="1"
    )
    parser.add_argument(
        "--lang_only",
        type=str,
        nargs='+',
        help="specific language to test",
        default=None
    )
    parser.add_argument('--dataset_name', type=str, default='tydiqa', help='The dataset_name')
    parser.add_argument('--k', type=int, default=4)
    parser.add_argument('--seeds', type=str, default='32,44,100')
    parser.add_argument('--set_up', type=str, default='cluster_in_cross')  # ,choices=['sim_in_cross', 'cluster_1b7', 'cluster_in_cross', 'random', 'cluster_1b7_true','cluster_1b7_only_hard', 'cluster_1b7_all']
    parser.add_argument("--n_clusters", type=int, default=20)
    parser.add_argument("--file_path", type=str, default='')
    parser.add_argument('--use_logic', default=True, action='store_true')
    parser.add_argument('--overwrite_results', default=False, action='store_true')
    parser.add_argument('--per_max_len', default=512)

    args = parser.parse_args()

    fire.Fire(main(args=args))
