import argparse
import json
import os
import pickle
import random
import sys
from pathlib import Path
import numpy as np
import torch

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))

# os.environ['CUDA_VISIBLE_DEVICES'] = "0"
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name())
print(torch.cuda.current_device())
from tqdm import tqdm
from transformers import GenerationConfig, AutoModelForCausalLM, AutoTokenizer, Qwen2Tokenizer

from A_data_preprocess.util.data import load_data, task2lang

choices = ["A", "B", "C", "D"]


def set_seed(seed, set_gpu=True):
    seed = int(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if set_gpu and torch.cuda.is_available():
        # Necessary for reproducibility; lower performance
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.cuda.manual_seed_all(seed)


def load_data(dataset_name, file_path, set_up, src, tgt, seed, k, n_clusters, tokenizer, mode='test', k_num=0,
              per_max_len=512):
    if src == None:
        src = 'all'
    # cluster_560m_100_hard_src=en_tgt=en_k=4_seed=100_class=20.pkl
    file_path = os.path.join(file_path, 'process_data/{}/{}'.format(dataset_name, mode))
    if not os.path.exists(file_path):
        print(file_path)
        raise ValueError('We dont have data for specified target language')
    try:
        file_name = os.path.join(file_path,
                                 '{}_src={}_tgt={}_k={}_seed={}_class={}.pkl'.format(set_up, src, tgt, k, seed,
                                                                                     n_clusters))

        with open(file_name, 'rb') as handle:
            data = pickle.load(handle)
    except:
        file_name = os.path.join(file_path,
                                 '{}_src={}_s={}_k={}_class={}.pkl'.format(set_up, src, tgt, k,
                                                                           n_clusters))

        with open(file_name, 'rb') as handle:
            data = pickle.load(handle)

    for dp_index, dp in enumerate(tqdm(data, desc='Tokenising Data')):
        demonstrations = ""
        if k_num > 0:
            try:
                for it, demo in enumerate(dp['demos']):
                    if it < k_num:
                        text_ori = demo["input"].replace("The correct option is ", "Output: ") + str(demo["output"])
                        text_id = tokenizer(text_ori, return_tensors="pt").input_ids
                        if text_id.size()[0] > per_max_len:
                            text_id = text_id[text_id.size()[0] - per_max_len:]
                            text_ori = tokenizer.decode(text_id)
                        demonstrations += ("\n\n" if demonstrations != "" else "") + text_ori
            except:
                for it, demo in enumerate(dp['demons']):
                    if it < k_num:
                        text_ori = demo["input"].replace("The correct option is ", "Output: ") + str(demo["output"])
                        text_id = tokenizer(text_ori, return_tensors="pt").input_ids
                        if text_id.size()[0] > per_max_len:
                            text_id = text_id[text_id.size()[0] - per_max_len:]
                            text_ori = tokenizer.decode(text_id)
                        demonstrations += ("\n\n" if demonstrations != "" else "") + text_ori

        dp['input'] = demonstrations + ("\n\n" if demonstrations != "" else "") + dp['input'].replace(
            "The correct option is ", "Output: ")

    return data


def main(
        args,
        is_fp16: bool = True,
        save_dir: str = None,
):
    batch_size = args.batch_size
    print(f"main start, is_fp16:{is_fp16}, batch_size:{batch_size}")

    model_path = args.model_path
    model, tokenizer = get_model(model_path, is_fp16=is_fp16)
    print("model loaded")

    batch_llama = get_batch_llama(model, tokenizer)

    if args.lang_only is None:
        langs = task2lang[args.dataset_name]
    else:
        langs = args.lang_only

    results = {}
    results_f1 = {}
    if "sim_in" in args.set_up:
        seeds = [32]
    else:
        seeds = args.seeds.split(',')

    set_up_init = args.set_up
    for s in seeds:
        if args.src == 'en':
            save_dir = os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)),
                                    model_path.rstrip("/").split("/")[-1],
                                    args.dataset_name, args.set_up + "_" + str(s))
        else:
            save_dir = os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)),
                                    model_path.rstrip("/").split("/")[-1],
                                    args.dataset_name, args.set_up + "_" + str(s) + "_" + args.src)

        Path(save_dir).mkdir(parents=True, exist_ok=True)
        set_seed(s)

        if "cluster_1b7" in set_up_init or "cluster_560m" in set_up_init:
            args.set_up = set_up_init + u"_{m}".format(m=s)
        # args.set_up = set_up_init + u"_{m}".format(m=s)
        if args.k_list == "":
            if int(s) == 32 and "random" in args.set_up:
                k_list = [0, 2, 3, 4]
            else:
                k_list = [2, 3, 4]
        else:
            k_list = args.k_list.split(',')

        for k_num in k_list:
            k_num = int(k_num)
            for lang in langs:
                print(f'===========we are testing in {lang}====================')

                xcopa_datas = load_data(dataset_name=args.dataset_name,
                                        file_path=args.file_path,
                                        tgt=lang, src=args.src,
                                        set_up=args.set_up,
                                        seed=s,
                                        k=args.k, mode='test', n_clusters=args.n_clusters, k_num=k_num,
                                        tokenizer=tokenizer,
                                        per_max_len=args.per_max_len)

                #print(xcopa_datas[0])

                gen_datas_jsonl = Path(
                    save_dir) / f"gen_{lang}_datas_{k_num}_use_logic_{s}.jsonl" if args.use_logic else Path(
                    save_dir) / f"gen_{lang}_datas_{k_num}_{s}.jsonl"
                if args.overwrite_results:
                    start_index = 0
                else:
                    start_index = (
                        len(open(gen_datas_jsonl).readlines()) if gen_datas_jsonl.exists() else 0
                    )
                print(f"start_index: {start_index}")

                for i in tqdm(range(start_index, len(xcopa_datas), batch_size)):
                    cur_xcopa_batch = xcopa_datas[i: i + batch_size]
                    # print(cur_xcopa_batch)
                    if tokenizer.chat_template is not None:
                        input_str_list, output_str_list = xcopa_batch_gen(args.use_logic,
                                                                          cur_xcopa_batch, batch_llama, tokenizer,
                                                                          lang_=lang, model_path=model_path
                                                                          )
                    else:
                        input_str_list, output_str_list = xcopa_batch_gen_no_chat(args.use_logic,
                                                                                  cur_xcopa_batch, batch_llama,
                                                                                  lang_=lang, model_path=model_path
                                                                                  )
                    for j, (xcopa_data, input_str, output_str) in enumerate(
                            zip(cur_xcopa_batch, input_str_list, output_str_list)
                    ):
                        with open(gen_datas_jsonl, "a") as f:
                            json.dump(
                                dict(
                                    index=i + j,
                                    correct_answer_num=str(xcopa_data["output"]),
                                    output_str=output_str,
                                    input_str=input_str,
                                ),
                                f,
                                ensure_ascii=False
                            )
                            f.write("\n")

                # calculate acc
                with open(gen_datas_jsonl) as f:
                    gen_datas = [json.loads(line) for line in f]

                correct_results = []
                wrong_results = []
                f1_scores = 0.0
                all_num = 0
                for gen in gen_datas:
                    result = dict(
                        **gen,
                        is_correct=False,
                    )
                    result["extract_pred_num"] = result["output_str"].strip().replace("-","")
                    if result["correct_answer_num"] == result["extract_pred_num"]:
                        result["is_correct"] = True
                        correct_results.append(result)

                    if result["is_correct"] == False:
                        wrong_results.append(result)

                print(f'=======done {lang}============')
                result = f"Accuracy={len(correct_results)}/({len(correct_results)}+{len(wrong_results)})={len(correct_results) / (len(correct_results) + len(wrong_results))}"
                print(result)
                num_result = float(result.split('=')[-1])
                results[lang] = num_result

            result_jsonl = Path(save_dir) / f"gen_datas_{k_num}_use_logic.txt"
            with open(result_jsonl, "w+") as f:
                for key, value in results.items():
                    f.write(f"EM {key}={value}\n")
                average = sum(results.values()) / len(results)
                f.write(f"EM AVG={average}\n")
            print(average)


def xcopa_batch_gen(
        use_logic, xcopa_questions, batch_llm, tokenizer, lang_="en", model_path=""
):
    if "qwen" in model_path.lower():
        input_str_list = [[{"role": "system", "content": "You are a helpful assistant."},
                           {"role": "user", "content": d["input"]}]
                          for d in xcopa_questions]
    else:
        input_str_list = [[{"role": "user", "content": d["input"]}]
                          for d in xcopa_questions]
    input_str_list_ = [tokenizer.apply_chat_template(input_dict, tokenize=False, add_generation_prompt=True) for
                       input_dict in input_str_list]
    output_str_list = batch_llm(input_str_list_)
    return input_str_list, output_str_list


def xcopa_batch_gen_no_chat(
        use_logic, xcopa_questions, batch_llm, lang_="en", model_path=""
):
    input_str_list_ = [d["input"] for d in xcopa_questions]
    # input_str_list_ = [tokenizer.apply_chat_template(input_dict, tokenize=False, add_generation_prompt=True) for input_dict in input_str_list]
    output_str_list = batch_llm(input_str_list_)
    return input_str_list_, output_str_list


def get_batch_llama(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, max_length: int = 4096):
    @torch.inference_mode()
    def batch_llama(input_strs):
        input_ids_w_attnmask = tokenizer(input_strs, padding=True, return_tensors="pt").to(model.device)
        with torch.no_grad():
            output_ids = model.generate(
                input_ids=input_ids_w_attnmask.input_ids,
                attention_mask=input_ids_w_attnmask.attention_mask,
                generation_config=GenerationConfig(
                    max_new_tokens=2,
                    do_sample=False,
                    temperature=0.0,
                    pad_token_id=tokenizer.eos_token_id
                ),
            ).tolist()
            # pad_token_id=tokenizer.eos_token_id
            real_output_ids = [
                output_id[len(input_ids_w_attnmask.input_ids[i]):] for i, output_id in enumerate(output_ids)
            ]
            output_strs = tokenizer.batch_decode(real_output_ids, skip_special_tokens=True)
        return output_strs

    return batch_llama


def get_model(model_path: str, is_fp16: bool = False):
    print(model_path)
    if "epoch" in model_path:
        tokenizer = AutoTokenizer.from_pretrained(model_path.split("epoch")[0], padding_side="left")
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
    print(tokenizer.pad_token)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print('new pad ', tokenizer.pad_token)
    print(tokenizer.bos_token)
    print(tokenizer.unk_token)
    print(tokenizer.eos_token)
    print(tokenizer.truncation_side)
    print(tokenizer.padding_side)

    if is_fp16:
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
        ).cuda()
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
        ).cuda()
    model.eval()
    print(model.dtype)

    return model, tokenizer


if __name__ == "__main__":
    import fire

    parser = argparse.ArgumentParser(description="Eval the finetued SFT model")
    parser.add_argument(
        "--model_path",
        type=str,
        help="Path to baseline model",
        default="/data/home10/shenyl/cached_models/meta-llama/Meta-Llama-3.1-8B/",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        help="batchsize",
        default=4
    )
    parser.add_argument(
        "--lang_only",
        type=str,
        nargs='+',
        help="specific language to test",
        default=None
    )
    parser.add_argument('--dataset_name', type=str, default='xcopa', help='The dataset_name')
    parser.add_argument('--k', type=int, default=4)
    parser.add_argument('--seeds', type=str, default='32,44,100')
    parser.add_argument('--src', type=str, default='en')
    parser.add_argument('--set_up', type=str,
                        default='sim_in_cross')  # ,choices=['sim_in_cross', 'cluster_1b7', 'cluster_in_cross', 'random', 'cluster_1b7_true','cluster_1b7_only_hard', 'cluster_1b7_all']
    parser.add_argument("--n_clusters", type=int, default=5)
    parser.add_argument("--file_path", type=str, default='')
    parser.add_argument('--use_logic', default=False, action='store_true')
    parser.add_argument('--overwrite_results', default=False, action='store_true')
    parser.add_argument('--per_max_len', default=256)
    parser.add_argument("--k_list", type=str, default="")

    args = parser.parse_args()

    fire.Fire(main(args=args))
