import ast
import random
from datasets import load_dataset, get_dataset_config_names
import prompt_templates
import argparse
import os
import tools
import json
random.seed(74)


def get_traingen_dataset(prompts, chosens, rejecteds, ids, subsets, cut_num=500):
    prompt_NOANS_jsons = []
    prompt_ANS_jsons = []
    idx = 0
    for prompt, chosen, rejected, id, subset in zip(prompts, chosens, rejecteds, ids, subsets):
        idx += 1
        if idx > cut_num:
            break
        prompt_NOANS_json = {
            "content": prompt_templates.TRAINGEN_NOANS_PROMPT1.replace(
                "[Question_replace]", prompt).replace("[Response1_replace]", chosen).replace("[Response2_replace]", rejected),
            "id": id,
            "subset": subset,
            "answer": "1",  # 1 means Response 1 is better,
            "assembly_prompt": prompt,
            "assembly_response1": chosen,
            "assembly_response2": rejected
        }

        prompt_NOANS_json_reverse = {
            "content": prompt_templates.TRAINGEN_NOANS_PROMPT1.replace(
                "[Question_replace]", prompt).replace("[Response1_replace]", rejected).replace("[Response2_replace]", chosen),
            "id": id,
            "subset": subset,
            "answer": "2",  # 2 means Response 2 is better,
            "assembly_prompt": prompt,
            "assembly_response1": rejected,
            "assembly_response2": chosen

        }
        prompt_NOANS_jsons.append(prompt_NOANS_json)
        prompt_NOANS_jsons.append(prompt_NOANS_json_reverse)

        prompt_ANS_json = {
            "content": prompt_templates.TRAINGEN_ANS_PROMPT1.replace(
                "[Question_replace]", prompt).replace("[Response1_replace]", chosen).replace("[Response2_replace]", rejected).replace("[BetterResponse_replace]", "Response 1"),
            "id": id,
            "subset": subset,
            "answer": "1",
            "assembly_prompt": prompt,
            "assembly_response1": chosen,
            "assembly_response2": rejected

        }

        prompt_ANS_json_reverse = {
            "content": prompt_templates.TRAINGEN_ANS_PROMPT1.replace(
                "[Question_replace]", prompt).replace("[Response1_replace]", rejected).replace("[Response2_replace]", chosen).replace("[BetterResponse_replace]", "Response 2"),
            "id": id,
            "subset": subset,
            "answer": "2",
            "assembly_prompt": prompt,
            "assembly_response1": rejected,
            "assembly_response2": chosen

        }
        prompt_ANS_jsons.append(prompt_ANS_json)
        prompt_ANS_jsons.append(prompt_ANS_json_reverse)
    return prompt_NOANS_jsons, prompt_ANS_jsons


def get_preference_dataset(prompts, chosens, rejecteds, ids, subsets, cut_num=500):
    prompt_NOANS_jsons = []
    prompt_ANS_jsons = []
    idx = 0
    for prompt, chosen, rejected, id, subset in zip(prompts, chosens, rejecteds, ids, subsets):
        idx += 1
        if idx > cut_num:
            break
        prompt_NOANS_json = {
            "content": prompt_templates.PREFERENCE_NOANS_PROMPT1.replace(
                "[Question_replace]", prompt).replace("[Response1_replace]", chosen).replace("[Response2_replace]", rejected),
            "id": id,
            "subset": subset,
            "answer": "1",  # 1 means Response 1 is better,
            "assembly_prompt": prompt,
            "assembly_response1": chosen,
            "assembly_response2": rejected
        }

        prompt_NOANS_json_reverse = {
            "content": prompt_templates.PREFERENCE_NOANS_PROMPT1.replace(
                "[Question_replace]", prompt).replace("[Response1_replace]", rejected).replace("[Response2_replace]", chosen),
            "id": id,
            "subset": subset,
            "answer": "2",  # 2 means Response 2 is better,
            "assembly_prompt": prompt,
            "assembly_response1": rejected,
            "assembly_response2": chosen

        }
        prompt_NOANS_jsons.append(prompt_NOANS_json)
        prompt_NOANS_jsons.append(prompt_NOANS_json_reverse)

        prompt_ANS_json = {
            "content": prompt_templates.PREFERENCE_ANS_PROMPT1.replace(
                "[Question_replace]", prompt).replace("[Response1_replace]", chosen).replace("[Response2_replace]", rejected).replace("[BetterResponse_replace]", "Response 1"),
            "id": id,
            "subset": subset,
            "answer": "1",
            "assembly_prompt": prompt,
            "assembly_response1": chosen,
            "assembly_response2": rejected

        }

        prompt_ANS_json_reverse = {
            "content": prompt_templates.PREFERENCE_ANS_PROMPT1.replace(
                "[Question_replace]", prompt).replace("[Response1_replace]", rejected).replace("[Response2_replace]", chosen).replace("[BetterResponse_replace]", "Response 2"),
            "id": id,
            "subset": subset,
            "answer": "2",
            "assembly_prompt": prompt,
            "assembly_response1": rejected,
            "assembly_response2": chosen

        }
        prompt_ANS_jsons.append(prompt_ANS_json)
        prompt_ANS_jsons.append(prompt_ANS_json_reverse)
    return prompt_NOANS_jsons, prompt_ANS_jsons


def get_factualqa_dataset(prompts, answers, ids, subsets, cut_num=1000):
    prompt_NOANS_jsons = []
    prompt_ANS_jsons = []
    idx = 0
    for prompt, answer, id, subset in zip(prompts, answers, ids, subsets):
        idx += 1
        if idx > cut_num:
            break
        prompt_NOANS_json = {
            "content": prompt_templates.FQ_NOANS_PROMPT1.replace(
                "[Question_replace]", prompt),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_prompt": prompt
        }

        prompt_NOANS_jsons.append(prompt_NOANS_json)

        answer_str = ""
        if type(answer) is list:
            if len(answer) > 1:
                answer_str = "All the following answers are correct: " + \
                    ", ".join(answer)
            else:
                answer_str = answer[0]
        else:
            answer_str = answer
        prompt_ANS_json = {
            "content": prompt_templates.FQ_ANS_PROMPT1.replace(
                "[Question_replace]", prompt).replace("[CorrectAnswer_replace]", answer_str),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_prompt": prompt
        }

        prompt_ANS_jsons.append(prompt_ANS_json)
    return prompt_NOANS_jsons, prompt_ANS_jsons


def get_factualqa_RAG_dataset(prompts, answers, documents, ids, subsets, cut_num=1000):

    prompt_NOANS_jsons = []
    prompt_ANS_jsons = []
    idx = 0
    for prompt, answer, document, id, subset in zip(prompts, answers, documents, ids, subsets):
        idx += 1
        if idx > cut_num:
            break
        documents_str = ""
        for rank, each in enumerate(document[:10]):
            documents_str += 'Document {}:\n {}\n'.format(rank+1, each)

        prompt_NOANS_json = {
            "content": prompt_templates.FQ_RAG_NOANS_PROMPT1.replace(
                "[Question_replace]", prompt).replace("[RetrievedDocuments_replace]", documents_str),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_prompt": prompt,
            "assembly_documents": documents_str
        }

        prompt_NOANS_jsons.append(prompt_NOANS_json)

        answer_str = ""
        if type(answer) is list:
            if len(answer) > 1:
                answer_str = "All the following answers are correct: " + \
                    ", ".join(answer)
            else:
                answer_str = answer[0]
        else:
            answer_str = answer
        prompt_ANS_json = {
            "content": prompt_templates.FQ_RAG_ANS_PROMPT1.replace(
                "[Question_replace]", prompt).replace("[CorrectAnswer_replace]", answer_str).replace("[RetrievedDocuments_replace]", documents_str),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_prompt": prompt,
            "assembly_documents": documents_str
        }

        prompt_ANS_jsons.append(prompt_ANS_json)
    return prompt_NOANS_jsons, prompt_ANS_jsons


def get_math_dataset(prompts, answers, ids, subsets, cut_num=1000):
    prompt_NOANS_jsons = []
    prompt_ANS_jsons = []
    idx = 0
    for prompt, answer, id, subset in zip(prompts, answers, ids, subsets):
        idx += 1
        if idx > cut_num:
            break

        answer = answer.split("####")[-1].strip()

        prompt_NOANS_json = {
            "content": prompt_templates.MATH_NOANS_PROMPT1.replace(
                "[Question_replace]", prompt),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_prompt": prompt
        }

        prompt_NOANS_jsons.append(prompt_NOANS_json)

        prompt_ANS_json = {
            "content": prompt_templates.MATH_ANS_PROMPT1.replace(
                "[Question_replace]", prompt).replace("[CorrectAnswer_replace]", answer),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_prompt": prompt
        }

        prompt_ANS_jsons.append(prompt_ANS_json)
    return prompt_NOANS_jsons, prompt_ANS_jsons


def get_multiplechoice_dataset(prompts, answers, choices, ids, subsets, cut_num=1000):

    prompt_NOANS_jsons = []
    prompt_ANS_jsons = []
    idx = 0
    for prompt, answer, choice, id, subset in zip(prompts, answers, choices, ids, subsets):
        idx += 1
        if idx > cut_num:
            break

        choice_replace = ""
        for rank, each in enumerate(choice):
            choice_replace += "<"+chr(ord("A") + rank)+">"+f" {each}\n"

        prompt_NOANS_json = {
            "content": prompt_templates.MULTIPLECHOICE_NOANS_PROMPT1.replace(
                "[Question_replace]", prompt).replace("[Choices_replace]", choice_replace),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_prompt": prompt,
            "assembly_choices": choice
        }

        prompt_NOANS_jsons.append(prompt_NOANS_json)

        prompt_ANS_json = {
            "content": prompt_templates.MULTIPLECHOICE_ANS_PROMPT1.replace(
                "[Question_replace]", prompt).replace("[CorrectAnswer_replace]", answer).replace("[Choices_replace]", choice_replace),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_prompt": prompt,
            "assembly_choices": choice
        }

        prompt_ANS_jsons.append(prompt_ANS_json)
    return prompt_NOANS_jsons, prompt_ANS_jsons


def get_code_execution_dataset(codes, inputs, answers, ids, subsets, cut_num=1000):
    prompt_NOANS_jsons = []
    prompt_ANS_jsons = []
    idx = 0
    for code, input, answer, id, subset in zip(codes, inputs, answers, ids, subsets):
        idx += 1
        if idx > cut_num:
            break

        prompt_NOANS_json = {
            "content": prompt_templates.CODE_EXECUTION_NOANS_PROMPT1.replace(
                "[Function_replace]", code).replace("[Input_replace]", input),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_code": code,
            "assembly_input": input
        }

        prompt_NOANS_jsons.append(prompt_NOANS_json)

        prompt_ANS_json = {
            "content": prompt_templates.CODE_EXECUTION_ANS_PROMPT1.replace(
                "[Function_replace]", code).replace("[Input_replace]", input).replace("[CorrectAnswer_replace]", answer),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_code": code,
            "assembly_input": input
        }

        prompt_ANS_jsons.append(prompt_ANS_json)
    return prompt_NOANS_jsons, prompt_ANS_jsons


def get_code_execution_dataset_codescope(codes, inputs, answers, ids, subsets, cut_num=1000):
    prompt_NOANS_jsons = []
    prompt_ANS_jsons = []
    idx = 0
    for code, input, answer, id, subset in zip(codes, inputs, answers, ids, subsets):
        idx += 1
        if idx > cut_num:
            break

        prompt_NOANS_json = {
            "content": prompt_templates.CODE_EXECUTION_NOANS_PROMPT2.replace(
                "[Programme_replace]", code).replace("[Input_replace]", input),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_code": code,
            "assembly_input": input
        }

        prompt_NOANS_jsons.append(prompt_NOANS_json)

        prompt_ANS_json = {
            "content": prompt_templates.CODE_EXECUTION_ANS_PROMPT2.replace(
                "[Programme_replace]", code).replace("[Input_replace]", input).replace("[CorrectAnswer_replace]", answer),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_code": code,
            "assembly_input": input
        }

        prompt_ANS_jsons.append(prompt_ANS_json)
    return prompt_NOANS_jsons, prompt_ANS_jsons


def get_readanswer_choice_dataset(questions, contexts, answers, ids, subsets, cut_num=1000):
    prompt_NOANS_jsons = []
    prompt_ANS_jsons = []
    idx = 0
    for question, context, answer, id, subset in zip(questions, contexts, answers, ids, subsets):
        idx += 1
        if idx > cut_num:
            break

        prompt_NOANS_json = {
            "content": prompt_templates.READANSWER_CHOICE_NOANS_PROMPT1.replace(
                "[Question_replace]", question).replace("[Context_replace]", context),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_question": question,
            "assembly_context": context
        }

        prompt_NOANS_jsons.append(prompt_NOANS_json)

        prompt_ANS_json = {
            "content": prompt_templates.READANSWER_CHOICE_ANS_PROMPT1.replace(
                "[Question_replace]", question).replace("[Context_replace]", context).replace("[CorrectAnswer_replace]", answer),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_question": question,
            "assembly_context": context
        }

        prompt_ANS_jsons.append(prompt_ANS_json)
    return prompt_NOANS_jsons, prompt_ANS_jsons


def get_readanswer_dataset(questions, contexts, answers, ids, subsets, cut_num=1000):
    prompt_NOANS_jsons = []
    prompt_ANS_jsons = []
    idx = 0
    for question, context, answer, id, subset in zip(questions, contexts, answers, ids, subsets):
        idx += 1
        if idx > cut_num:
            break

        prompt_NOANS_json = {
            "content": prompt_templates.READANSWER_NOANS_PROMPT1.replace(
                "[Question_replace]", question).replace("[Context_replace]", context),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_question": question,
            "assembly_context": context
        }

        prompt_NOANS_jsons.append(prompt_NOANS_json)

        prompt_ANS_json = {
            "content": prompt_templates.READANSWER_ANS_PROMPT1.replace(
                "[Question_replace]", question).replace("[Context_replace]", context).replace("[CorrectAnswer_replace]", answer),
            "id": id,
            "subset": subset,
            "answer": answer,
            "assembly_question": question,
            "assembly_context": context
        }

        prompt_ANS_jsons.append(prompt_ANS_json)
    return prompt_NOANS_jsons, prompt_ANS_jsons


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate prompts for Reasoning Gen RM dataset")
    parser.add_argument("--dataset_path", required=True, type=str)
    args = parser.parse_args()
    tmp_dir = os.path.basename(args.dataset_path)
    OUTPUT_DIR = os.path.join("../data/", tmp_dir)
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

    if tmp_dir.endswith("datasets--openai--gsm8k"):
        dataset = load_dataset(args.dataset_path, 'main')
    elif tmp_dir.endswith("datasets--Idavidrein--gpqa"):
        dataset = load_dataset(args.dataset_path, 'gpqa_diamond')
    elif tmp_dir.endswith("datasets--RUC-NLPIR--FlashRAG_datasets"):
        raise NotImplementedError
    elif tmp_dir.endswith("datasets--RUC-NLPIR--FlashRAG_datasets@hotpotqa"):
        dataset = load_dataset(args.dataset_path.split("@")[0], 'hotpotqa')
    elif tmp_dir.endswith("datasets--RUC-NLPIR--FlashRAG_datasets@2wikimultihopqa"):
        dataset = load_dataset(
            args.dataset_path.split("@")[0], '2wikimultihopqa')
    elif tmp_dir.endswith("datasets--RUC-NLPIR--FlashRAG_datasets@hotpotqa_RAG"):
        dataset = load_dataset(args.dataset_path.split("@")[0], 'hotpotqa')
    elif tmp_dir.endswith("datasets--RUC-NLPIR--FlashRAG_datasets@2wikimultihopqa_RAG"):
        dataset = load_dataset(
            args.dataset_path.split("@")[0], '2wikimultihopqa')
    elif tmp_dir.endswith("datasets--Fsoft-AIC--CodeMMLU"):
        configs = get_dataset_config_names(args.dataset_path)
        dataset = {}
        for cfg in configs:
            print(f"Loading config: {cfg}")
            dataset[cfg] = load_dataset(args.dataset_path, cfg)
    elif "data/datasets" in args.dataset_path:
        dataset = args.dataset_path
        tmp_dir = dataset
    else:
        dataset = load_dataset(args.dataset_path)
    print("raw")

    prompts = None
    chosens = None
    rejecteds = None
    ids = None
    subsets = None
    prompt_ANS_jsons = None
    prompt_NOANS_jsons = None
    if tmp_dir.endswith("datasets--allenai--reward-bench"):
        print("datasets--allenai--reward-bench")
        subsets = ['alpacaeval-easy',
                   'alpacaeval-length',
                   "alpacaeval-hard",  # 2
                   "mt-bench-easy",  # 3
                   "mt-bench-med",  # 4
                   "mt-bench-hard",  # 5
                   "refusals-dangerous",  # 6
                   "refusals-offensive",  # 7
                   "llmbar-natural",  # 8
                   "llmbar-adver-neighbor",
                   "llmbar-adver-GPTInst",
                   "llmbar-adver-GPTOut",
                   "llmbar-adver-manual",
                   "xstest"]
        dataset = dataset.filter(lambda ex: ex["subset"] in subsets)
        print("filtered")
        print(dataset)
        dataset['filtered'] = dataset['filtered'].shuffle(seed=74)
        prompts = dataset['filtered']['prompt']
        chosens = dataset['filtered']['chosen']
        rejecteds = dataset['filtered']['rejected']
        ids = dataset['filtered']['id']
        subsets = dataset['filtered']['subset']
        prompt_NOANS_jsons, prompt_ANS_jsons = get_preference_dataset(
            prompts, chosens, rejecteds, ids, subsets)

    elif tmp_dir.endswith("datasets--HuggingFaceH4--ultrafeedback_binarized"):
        print("datasets--HuggingFaceH4--ultrafeedback_binarized")
        print(dataset)
        prompts = dataset['train_prefs']['prompt']
        chosens = [each[1]['content']
                   for each in dataset['train_prefs']['chosen']]
        rejecteds = [each[1]['content']
                     for each in dataset['train_prefs']['rejected']]
        ids = dataset['train_prefs']['prompt_id']
        subsets = ['train_prefs'] * len(prompts)
        prompt_NOANS_jsons, prompt_ANS_jsons = get_preference_dataset(
            prompts, chosens, rejecteds, ids, subsets)

    elif tmp_dir.endswith("datasets--RUC-NLPIR--FlashRAG_datasets@2wikimultihopqa"):
        print("datasets--RUC-NLPIR--FlashRAG_datasets@2wikimultihopqa")
        print(dataset)
        # train
        prompts = [each['question'] for each in dataset['train']]
        answers = [each['golden_answers'][0] for each in dataset['train']]
        ids = [each['id'] for each in dataset['train']]
        subsets = [each['metadata']["type"] for each in dataset['train']]
        prompt_NOANS_jsons, prompt_ANS_jsons = get_factualqa_dataset(
            prompts, answers, ids, subsets)

    elif tmp_dir.endswith("datasets--RUC-NLPIR--FlashRAG_datasets@hotpotqa"):
        print("datasets--RUC-NLPIR--FlashRAG_datasets@hotpotqa")
        print(dataset)
        # train
        prompts = [each['question'] for each in dataset['train']]
        answers = [each['golden_answers'][0] for each in dataset['train']]
        ids = [each['id'] for each in dataset['train']]
        subsets = [each['metadata']["type"] for each in dataset['train']]
        prompt_NOANS_jsons, prompt_ANS_jsons = get_factualqa_dataset(
            prompts, answers, ids, subsets)

    elif tmp_dir.endswith("datasets--openai--gsm8k"):
        print("datasets--openai--gsm8k")
        print(dataset)
        prompts = dataset['train']['question']
        answers = dataset['train']['answer']
        ids = [i for i in range(len(prompts))]
        subsets = ['main-train'] * len(prompts)
        prompt_NOANS_jsons, prompt_ANS_jsons = get_math_dataset(
            prompts, answers, ids, subsets)

    elif tmp_dir.endswith("datasets--THU-KEG--RM-Bench"):
        print("datasets--THU-KEG--RM-Bench")
        print(dataset)

        prompts = dataset['train']['prompt']
        chosens = [each[0]
                   for each in dataset['train']['chosen']]
        rejecteds = [each[0]
                     for each in dataset['train']['rejected']]
        ids = dataset['train']['id']
        subsets = ids = dataset['train']['domain']

        prompt_NOANS_jsons, prompt_ANS_jsons = get_preference_dataset(
            prompts, chosens, rejecteds, ids, subsets)

    elif tmp_dir.endswith("datasets--math-ai--aime24"):
        print("datasets--math-ai--aime24")
        print(dataset)

        prompts = dataset['test']['problem']
        answers = [each.split(r"{")[1].split(r"}")[0]
                   for each in dataset['test']['solution']]
        ids = dataset['test']['id']
        subsets = ['main-test30*20=600'] * len(prompts)
        # 30*10=300
        prompts = [x for x in prompts for _ in range(5)]
        answers = [x for x in answers for _ in range(5)]
        ids = [x for x in ids for _ in range(5)]
        subsets = [x for x in subsets for _ in range(5)]
        prompt_NOANS_jsons, prompt_ANS_jsons = get_math_dataset(
            prompts, answers, ids, subsets)

    elif tmp_dir.endswith("datasets--math-ai--amc23"):
        print("datasets--math-ai--amc23")
        print(dataset)

        prompts = dataset['test']['question']
        answers = [each
                   for each in dataset['test']['answer']]
        ids = dataset['test']['id']
        subsets = ['main-test40*10=400'] * len(prompts)
        # 40*10=400
        prompts = [x for x in prompts for _ in range(5)]
        answers = [x for x in answers for _ in range(5)]
        ids = [x for x in ids for _ in range(5)]
        subsets = [x for x in subsets for _ in range(5)]
        prompt_NOANS_jsons, prompt_ANS_jsons = get_math_dataset(
            prompts, answers, ids, subsets)

    elif tmp_dir.endswith("datasets--math-ai--math500"):
        print("datasets--math-ai--math500")
        print(dataset)

        prompts = dataset['test']['problem']
        answers = [each
                   for each in dataset['test']['answer']]
        ids = dataset['test']['unique_id']
        subsets = dataset['test']['subject']
        prompt_NOANS_jsons, prompt_ANS_jsons = get_math_dataset(
            prompts, answers, ids, subsets)

    elif tmp_dir.endswith("datasets--Idavidrein--gpqa"):
        print("datasets--Idavidrein--gpqa")
        print(dataset)

        prompts = dataset['train']["Question"]
        choices = []
        answers = []
        wrong_answers = [[a1, a2, a3]
                         for a1, a2, a3 in zip(dataset['train']["Incorrect Answer 1"], dataset['train']["Incorrect Answer 2"], dataset['train']["Incorrect Answer 3"])]
        for rank, each in enumerate(dataset['train']['Correct Answer']):
            answer_rank = random.randint(0, 3)
            tmp = wrong_answers[rank].copy()
            tmp.insert(answer_rank, each)
            choices.append(tmp)
            answers.append(chr(ord("A") + answer_rank))
        ids = dataset['train']['Record ID']
        subsets = dataset['train']['High-level domain']
        # 198
        prompts = [x for x in prompts for _ in range(1)]
        answers = [x for x in answers for _ in range(1)]
        choices = [x for x in choices for _ in range(1)]
        ids = [x for x in ids for _ in range(1)]
        subsets = [x for x in subsets for _ in range(1)]
        prompt_NOANS_jsons, prompt_ANS_jsons = get_multiplechoice_dataset(
            prompts, answers, choices, ids, subsets)

    elif tmp_dir.endswith("datasets--TIGER-Lab--MMLU-Pro"):
        print("datasets--TIGER-Lab--MMLU-Pro")
        print(dataset)

        prompts = dataset['test']["question"]
        choices = dataset['test']["options"]
        answers = dataset['test']["answer"]

        ids = dataset['test']['question_id']
        subsets = dataset['test']['category']
        prompt_NOANS_jsons, prompt_ANS_jsons = get_multiplechoice_dataset(
            prompts, answers, choices, ids, subsets)

    elif tmp_dir.endswith("datasets--Fsoft-AIC--CodeMMLU"):
        print("datasets--Fsoft-AIC--CodeMMLU")
        print(dataset)

        prompts = []
        choices = []
        answers = []
        ids = []
        subsets = []
        for subset_name, subset in dataset.items():
            prompts.extend(subset['test']['question'][:100])
            choices.extend(subset['test']['choices'][:100])
            answers.extend(subset['test']['answer'][:100])
            ids.extend(subset['test']['task_id'][:100])
            subsets.extend([subset_name] * 100)
        prompt_NOANS_jsons, prompt_ANS_jsons = get_multiplechoice_dataset(
            prompts, answers, choices, ids, subsets)

    elif tmp_dir.endswith("datasets--livecodebench--execution-v2"):
        print("datasets--livecodebench--execution-v2")
        print(dataset)

        codes = dataset["test"]["code"]
        answers = dataset["test"]["output"]
        inputs = dataset["test"]["input"]
        ids = dataset["test"]["id"]
        subsets = dataset["test"]["difficulty"]

        prompt_NOANS_jsons, prompt_ANS_jsons = get_code_execution_dataset(
            codes, inputs, answers, ids, subsets)

    elif tmp_dir.endswith("datasets--cruxeval-org--cruxeval"):
        print("datasets--cruxeval-org--cruxeval")
        print(dataset)

        codes = dataset["test"]["code"]
        answers = dataset["test"]["output"]
        inputs = ["f("+each+")" for each in dataset["test"]["input"]]
        ids = dataset["test"]["id"]
        subsets = ["test"]*len(codes)

        prompt_NOANS_jsons, prompt_ANS_jsons = get_code_execution_dataset(
            codes, inputs, answers, ids, subsets)

    elif tmp_dir.endswith("datasets--WeixiangYan--CodeScope"):
        print("datasets--WeixiangYan--CodeScope")
        print(dataset)
        codes = []
        answers = []
        inputs = []
        ids = []
        subsets = []
        for each in dataset["test"]:
            for (each_input, each_output) in zip(json.loads(each["sample_inputs"]), json.loads(each["sample_outputs"])):
                inputs.append(each_input)
                answers.append(each_output)
                codes.append(each["source_code"])
                ids.append(each["notes"])
                subsets.append(each["lang_cluster"])

        prompt_NOANS_jsons, prompt_ANS_jsons = get_code_execution_dataset_codescope(
            codes, inputs, answers, ids, subsets)

    elif tmp_dir.endswith("datasets--allenai--qasc"):
        print("datasets--allenai--qasc")
        print(dataset)

        questions = dataset["train"]["formatted_question"]
        contexts = []
        for f1, f2 in zip(dataset["train"]["fact1"], dataset["train"]["fact2"]):
            contexts.append(f1 + "\n" + f2)

        answers = dataset["train"]["answerKey"]
        ids = dataset["train"]["id"]
        subsets = dataset["train"]["combinedfact"]

        prompt_NOANS_jsons, prompt_ANS_jsons = get_readanswer_choice_dataset(
            questions, contexts, answers, ids, subsets)

    elif tmp_dir.endswith("datasets--TAUR-Lab--MuSR"):
        print("datasets--TAUR-Lab--MuSR")
        print(dataset)
        subsets = []
        questions = []
        contexts = []
        answers = []
        ids = []
        for subset in ["team_allocation", "object_placements", "murder_mysteries"]:
            for question_str, choices_list in zip(dataset[subset]["question"], dataset[subset]["choices"]):
                choices_str = ""
                for idx, each in enumerate(ast.literal_eval(choices_list)):
                    choices_str += "<"+chr(ord("A") + idx)+">"+f" {each}\n"
                questions.append(question_str+choices_str)
            contexts += dataset[subset]["narrative"]
            ids += dataset[subset]["choices"]

            answers += [chr(ord("A") + int(each))
                        for each in dataset[subset]["answer_index"]]
            subsets += [subset for _ in range(len(dataset[subset]["choices"]))]

        prompt_NOANS_jsons, prompt_ANS_jsons = get_readanswer_choice_dataset(
            questions, contexts, answers, ids, subsets)

    elif tmp_dir.endswith("datasets--ucinlp--drop"):
        print("datasets--ucinlp--drop")
        print(dataset)

        questions = dataset["train"]["question"]
        contexts = dataset["train"]['passage']

        answers = [each["spans"][0]
                   for each in dataset["train"]["answers_spans"]]
        ids = dataset["train"]["query_id"]
        subsets = ["train"] * len(questions)

        prompt_NOANS_jsons, prompt_ANS_jsons = get_readanswer_dataset(
            questions, contexts, answers, ids, subsets)

    elif tmp_dir.endswith("datasets--allenai--reward-bench-2"):
        print("datasets--allenai--reward-bench-2")
        subsets = ['Factuality', "Precise IF", 'Focus', 'Safety', "Math"]
        dataset = dataset.filter(lambda ex: ex["subset"] in subsets)
        print("filtered")
        print(dataset)
        dataset["test"] = dataset["test"].shuffle(seed=74)
        prompts = []
        chosens = []
        rejecteds = []
        ids = []
        subsets = []
        for prompt, chosen, rejected_list, id, subset in zip(dataset['test']['prompt'], dataset['test']['chosen'], dataset['test']['rejected'], dataset['test']['id'], dataset['test']['subset']):
            for rejected in rejected_list:
                prompts.append(prompt)
                chosens.append(chosen[0])
                rejecteds.append(rejected)
                ids.append(id)
                subsets.append(subset)
        print(len(chosens), len(rejecteds))

        prompt_NOANS_jsons, prompt_ANS_jsons = get_preference_dataset(
            prompts, chosens, rejecteds, ids, subsets, cut_num=100000)

        train_prompt_NOANS_jsons = prompt_NOANS_jsons[:len(
            prompt_NOANS_jsons) * 2//3]

        open(os.path.join(OUTPUT_DIR, "train_NOANS.jsonl"),
             'w').close()  # clear file
        for each in train_prompt_NOANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "train_NOANS.jsonl"))

        test_prompt_NOANS_jsons = prompt_NOANS_jsons[len(
            prompt_NOANS_jsons) * 2//3:]
        open(os.path.join(OUTPUT_DIR, "test_NOANS.jsonl"),
             'w').close()  # clear file
        for each in test_prompt_NOANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "test_NOANS.jsonl"))
        open(os.path.join(OUTPUT_DIR, "validation_NOANS.jsonl"),
             'w').close()  # clear file
        for each in test_prompt_NOANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "validation_NOANS.jsonl"))

        train_prompt_ANS_jsons = prompt_ANS_jsons[:len(
            prompt_ANS_jsons) * 2//3]
        open(os.path.join(OUTPUT_DIR, "train_ANS.jsonl"),
             'w').close()  # clear file
        for each in train_prompt_ANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "train_ANS.jsonl"))

        test_prompt_ANS_jsons = prompt_ANS_jsons[len(prompt_ANS_jsons) * 2//3:]
        open(os.path.join(OUTPUT_DIR, "test_ANS.jsonl"), 'w').close()  # clear file
        for each in test_prompt_ANS_jsons:
            tools.write_jsonl(each, os.path.join(OUTPUT_DIR, "test_ANS.jsonl"))
        open(os.path.join(OUTPUT_DIR, "validation_ANS.jsonl"),
             'w').close()  # clear file
        for each in test_prompt_ANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "validation_ANS.jsonl"))

    elif tmp_dir.endswith("datasets--RUC-NLPIR--FlashRAG_datasets@2wikimultihopqa_RAG"):
        print("datasets--RUC-NLPIR--FlashRAG_datasets@2wikimultihopqa_RAG")
        print(dataset)
        # train
        prompts = [each['question'] for each in dataset['train']]
        answers = [each['golden_answers'][0] for each in dataset['train']]
        ids = [each['id'] for each in dataset['train']]
        documents = []
        for each in dataset['train']:
            tmp = []
            for doc in each['metadata']["context"]['content']:
                tmp.append(' '.join(doc))
            documents.append(tmp)
        subsets = [each['metadata']["type"] for each in dataset['train']]

        prompt_NOANS_jsons, prompt_ANS_jsons = get_factualqa_RAG_dataset(
            prompts, answers, documents, ids, subsets, cut_num=15000)

        train_prompt_NOANS_jsons = prompt_NOANS_jsons[:len(
            prompt_NOANS_jsons) * 2//3]
        open(os.path.join(OUTPUT_DIR, "train_NOANS.jsonl"),
             'w').close()  # clear file
        for each in train_prompt_NOANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "train_NOANS.jsonl"))

        validation_prompt_NOANS_jsons = prompt_NOANS_jsons[len(
            prompt_NOANS_jsons) * 2//3:]
        open(os.path.join(OUTPUT_DIR, "validation_NOANS.jsonl"),
             'w').close()  # clear file
        for each in validation_prompt_NOANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "validation_NOANS.jsonl"))

        train_prompt_ANS_jsons = prompt_ANS_jsons[:len(
            prompt_ANS_jsons) * 2//3]
        open(os.path.join(OUTPUT_DIR, "train_ANS.jsonl"),
             'w').close()  # clear file
        for each in train_prompt_ANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "train_ANS.jsonl"))

        validation_prompt_ANS_jsons = prompt_ANS_jsons[len(
            prompt_ANS_jsons) * 2//3:]
        open(os.path.join(OUTPUT_DIR, "validation_ANS.jsonl"),
             'w').close()  # clear file
        for each in validation_prompt_ANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "validation_ANS.jsonl"))

        # test
        prompts = [each['question'] for each in dataset['dev']]
        answers = [each['golden_answers'][0] for each in dataset['dev']]
        ids = [each['id'] for each in dataset['dev']]
        documents = []
        for each in dataset['dev']:
            tmp = []
            for doc in each['metadata']["context"]['content']:
                tmp.append(' '.join(doc))
            documents.append(tmp)
        subsets = [each['metadata']["type"] for each in dataset['dev']]

        prompt_NOANS_jsons, prompt_ANS_jsons = get_factualqa_RAG_dataset(
            prompts, answers, documents, ids, subsets, cut_num=1000)

        test_prompt_NOANS_jsons = prompt_NOANS_jsons
        open(os.path.join(OUTPUT_DIR, "test_NOANS.jsonl"),
             'w').close()  # clear file
        for each in test_prompt_NOANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "test_NOANS.jsonl"))

        test_prompt_ANS_jsons = prompt_ANS_jsons
        open(os.path.join(OUTPUT_DIR, "test_ANS.jsonl"),
             'w').close()  # clear file
        for each in test_prompt_ANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "test_ANS.jsonl"))

    elif tmp_dir.endswith("datasets--RUC-NLPIR--FlashRAG_datasets@hotpotqa_RAG"):
        print("datasets--RUC-NLPIR--FlashRAG_datasets@hotpotqa_RAG")
        print(dataset)
        # train
        levels = {}
        prompts = []
        answers = []
        ids = []
        subsets = []
        documents = []
        for each in dataset['train']:
            if each['metadata']["level"] == "hard":
                prompts.append(each['question'])
                answers.append(each['golden_answers'][0])
                ids.append(each['id'])
                subsets.append(each['metadata']["type"])

                tmp = []
                for doc in each['metadata']["context"]['sentences']:
                    tmp.append(' '.join(doc))
                documents.append(tmp)

        prompt_NOANS_jsons, prompt_ANS_jsons = get_factualqa_RAG_dataset(
            prompts, answers, documents, ids, subsets, cut_num=200000)

        train_prompt_NOANS_jsons = prompt_NOANS_jsons[:len(
            prompt_NOANS_jsons) * 2//3]
        open(os.path.join(OUTPUT_DIR, "train_NOANS.jsonl"),
             'w').close()  # clear file
        for each in train_prompt_NOANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "train_NOANS.jsonl"))

        validation_prompt_NOANS_jsons = prompt_NOANS_jsons[len(
            prompt_NOANS_jsons) * 2//3:]
        open(os.path.join(OUTPUT_DIR, "validation_NOANS.jsonl"),
             'w').close()  # clear file
        for each in validation_prompt_NOANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "validation_NOANS.jsonl"))

        train_prompt_ANS_jsons = prompt_ANS_jsons[:len(
            prompt_ANS_jsons) * 2//3]
        open(os.path.join(OUTPUT_DIR, "train_ANS.jsonl"),
             'w').close()  # clear file
        for each in train_prompt_ANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "train_ANS.jsonl"))

        validation_prompt_ANS_jsons = prompt_ANS_jsons[len(
            prompt_ANS_jsons) * 2//3:]
        open(os.path.join(OUTPUT_DIR, "validation_ANS.jsonl"),
             'w').close()  # clear file
        for each in validation_prompt_ANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "validation_ANS.jsonl"))

        # test
        prompts = [each['question'] for each in dataset['dev']]
        answers = [each['golden_answers'][0] for each in dataset['dev']]
        ids = [each['id'] for each in dataset['dev']]
        documents = []
        for each in dataset['dev']:
            tmp = []
            for doc in each['metadata']["context"]['sentences']:
                tmp.append(' '.join(doc))
            documents.append(tmp)
        subsets = [each['metadata']["type"] for each in dataset['dev']]

        prompt_NOANS_jsons, prompt_ANS_jsons = get_factualqa_RAG_dataset(
            prompts, answers, documents, ids, subsets, cut_num=1000)

        test_prompt_NOANS_jsons = prompt_NOANS_jsons
        open(os.path.join(OUTPUT_DIR, "test_NOANS.jsonl"),
             'w').close()  # clear file
        for each in test_prompt_NOANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "test_NOANS.jsonl"))

        test_prompt_ANS_jsons = prompt_ANS_jsons
        open(os.path.join(OUTPUT_DIR, "test_ANS.jsonl"),
             'w').close()  # clear file
        for each in test_prompt_ANS_jsons:
            tools.write_jsonl(each, os.path.join(
                OUTPUT_DIR, "test_ANS.jsonl"))

    if prompt_ANS_jsons is None or prompt_NOANS_jsons is None:
        raise ValueError("No prompts found in the dataset.")
    print(f"Number of prompts without answer: {len(prompt_NOANS_jsons)}")
    print(f"Number of prompts with answer: {len(prompt_ANS_jsons)}")
    test_prompt_NOANS_jsons = prompt_NOANS_jsons
    open(os.path.join(OUTPUT_DIR, "test_NOANS.jsonl"),
         'w').close()  # clear file
    for each in test_prompt_NOANS_jsons:
        tools.write_jsonl(each, os.path.join(
            OUTPUT_DIR, "test_NOANS.jsonl"))

    test_prompt_ANS_jsons = prompt_ANS_jsons
    open(os.path.join(OUTPUT_DIR, "test_ANS.jsonl"), 'w').close()  # clear file
    for each in test_prompt_ANS_jsons:
        tools.write_jsonl(each, os.path.join(OUTPUT_DIR, "test_ANS.jsonl"))

    print(f"Finished writing {tmp_dir}")
