import os
from datasets import Dataset, load_dataset, concatenate_datasets
from tqdm import tqdm
from verl.utils.hdfs_io import copy, makedirs
import argparse
import json


#INSTRUCT_PROMPT_SYSTEM1 = """You are a helpful assistant. Now the user asks you to solve a problem. Please directly output the answer based on intuition. When providing the answer, do not include the option letters (such as A, B, C, etc.). Instead, directly provide the specific content of the options. For example: Question: What is the capital of France? Choices: A: Berlin B: Paris Answer: Paris (instead of: "B: Paris")"""
INSTRUCT_PROMPT_SYSTEM2 = """You are a helpful assistant. Now the user asks you to solve a problem. Please first carefully think about this problem and then provides the user with the answer. Please output the thinking process in <think> </think> and final answer in <answer> </answer> tags. When providing the answer, do not include the option letters (such as A, B, C, etc.). Instead, directly provide the specific content of the options. For example: Question: What is the capital of France? Choices: A: Berlin B: Paris Answer: Paris (instead of: "B: Paris")"""
#INSTRUCT_PROMPT_ADAPTIVE = """You are a helpful assistant. Now the user asks you to solve a problem. Adaptively choose one of the following three cognitive modes to solve the problem: (1) Based on intuition, directly output the answer in <answer></answer> tags. (2) Perform understanding and analysis of the social context, then provide the user with the answer. Please output the social context understanding and analysis in <social context understanding> </social context understanding> and final answer in <answer> </answer> tags. (3) Carefully think about the problem, then provides the user with the answer. Please output the thinking process in <think> </think> and final answer in <answer> </answer> tags. When providing the answer, do not include the option letters (such as A, B, C, etc.). Instead, directly provide the specific content of the options. For example: Question: What is the capital of France? Choices: A: Berlin B: Paris Answer: Paris (instead of: "B: Paris")"""



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_dir', type=str, default=".../")

    
    args = parser.parse_args()
    

    # Load custom JSONL dataset
    def gen_from_jsonl(path):
        with open(path) as f:
            for line in f:
                yield json.loads(line)


    dataset_tomi = load_dataset("json", data_files=".../test_balanced.json", split="train[:800]")

    dataset_hitom_1 = load_dataset(".../Hi-ToM", split="train[:60]")
    dataset_hitom_2 = load_dataset(".../Hi-ToM", split="train[100:160]")
    dataset_hitom_3 = load_dataset(".../Hi-ToM", split="train[200:260]")

    dataset_hitom_4 = load_dataset(".../Hi-ToM", split="train[600:660]")
    dataset_hitom_5 = load_dataset(".../Hi-ToM", split="train[700:760]")
    dataset_hitom_6 = load_dataset(".../Hi-ToM", split="train[800:860]")

    dataset_hitom = concatenate_datasets([dataset_hitom_1, dataset_hitom_2, dataset_hitom_3, dataset_hitom_4, dataset_hitom_5, dataset_hitom_6])

    dataset_exploretom = load_dataset("csv", data_files=".../ExploreToM-data-sample.csv", split="train[:2000]")

    dataset_ToMbench = load_dataset("json", data_files=".../train_combined.json", split="train[:2420]")

    dataset_socialqa = load_dataset("json", data_files="/.../socialIWa_v1.4_trn_wDims.json", split="train[:2000]")


    dataset_test_tomi = load_dataset("json", data_files=".../test_balanced.json", split="train[800:1000]")


    dataset_test_hitom_1 = load_dataset(".../Hi-ToM", split="train[60:80]")
    dataset_test_hitom_2 = load_dataset(".../Hi-ToM", split="train[160:180]")
    dataset_test_hitom_3 = load_dataset(".../Hi-ToM", split="train[260:280]")

    dataset_test_hitom_4 = load_dataset(".../Hi-ToM", split="train[660:680]")
    dataset_test_hitom_5 = load_dataset(".../Hi-ToM", split="train[760:780]")
    dataset_test_hitom_6 = load_dataset(".../Hi-ToM", split="train[860:880]")

    dataset_test_hitom_third = concatenate_datasets([dataset_test_hitom_1, dataset_test_hitom_2, dataset_test_hitom_3, dataset_test_hitom_4, dataset_test_hitom_5, dataset_test_hitom_6])

    dataset_test_hitom_7 = load_dataset(".../Hi-ToM", split="train[80:100]")
    dataset_test_hitom_8 = load_dataset(".../Hi-ToM", split="train[180:200]")
    dataset_test_hitom_9 = load_dataset(".../Hi-ToM", split="train[280:300]")

    dataset_test_hitom_10 = load_dataset(".../Hi-ToM", split="train[680:700]")
    dataset_test_hitom_11 = load_dataset(".../Hi-ToM", split="train[780:800]")
    dataset_test_hitom_12 = load_dataset(".../Hi-ToM", split="train[880:900]")

    dataset_test_hitom_fourth = concatenate_datasets([dataset_test_hitom_7, dataset_test_hitom_8, dataset_test_hitom_9, dataset_test_hitom_10, dataset_test_hitom_11, dataset_test_hitom_12])


    dataset_test_exploretom = load_dataset("csv", data_files=".../ExploreToM-data-sample.csv", split="train[2000:2300]")

    
    dataset_test_ToMbench = load_dataset("json", data_files=".../test_combined.json", split="train[:431]")


    dataset_test_socialqa = load_dataset("json", data_files=".../socialIWa_v1.4_tst_wDims.json", split="train[:120]")


    dataset_test_simpletom_mental = load_dataset("json", data_files=".../mental-state-qa/test.json", split="train[:120]")

    dataset_test_simpletom_behavior = load_dataset("json", data_files=".../behavior-qa/test.json", split="train[:120]")

    dataset_test_simpletom_judgment = load_dataset("json", data_files=".../judgment-qa/test.json", split="train[:120]")


    dataset_test_tomato_first = load_dataset("json", data_files=".../tomato_first.json", split="train[:25]")

    dataset_test_tomato_second = load_dataset("json", data_files=".../tomato_second.json", split="train[:25]")


    dataset_test_opentom_attitude = load_dataset("json", data_files=".../OpenToM/merged_attitude_data.json")
    dataset_test_opentom_attitude = dataset_test_opentom_attitude['train']
    print(dataset_test_opentom_attitude)

    dataset_test_opentom_location_cg_fo = load_dataset("json", data_files=".../OpenToM/merged_location_cg_fo_data.json")
    dataset_test_opentom_location_cg_fo = dataset_test_opentom_location_cg_fo['train']
    print(dataset_test_opentom_location_cg_fo)

    dataset_test_opentom_location_cg_so = load_dataset("json", data_files=".../OpenToM/merged_location_cg_so_data.json")
    dataset_test_opentom_location_cg_so = dataset_test_opentom_location_cg_so['train']
    print(dataset_test_opentom_location_cg_so)


    apply_chat_template = True
        
    def make_map_fn(name, split, apply_chat_template):
        def process_fn(example, idx):
            if name == 'HiToM' or name == 'HiToM_Third' or name == 'HiToM_Fourth':
                choices_text = example["choices"]
                if isinstance(choices_text, list):
                    choices_text = ", ".join(choices_text)
                entire_instruction = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
                answer = example['answer']
            if name == 'ToMI':
                choices_text = example["containers"]
                if isinstance(choices_text, list):
                    choices_text = ", ".join(choices_text)
                entire_instruction = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
                answer = example['answer']
            if name == 'ExploreToM':
                entire_instruction = f"Story: {example['story_structure']} Question: {example['question']}"
                answer = example['expected_answer']
            if name == 'ToMbench':
                option_A = example["OPTION-A"]
                option_B = example["OPTION-B"]
                option_C = example["OPTION-C"]
                option_D = example["OPTION-D"]

                formatted_string = ""
                formatted_string += "A: " + option_A + " "
                formatted_string += "B: " + option_B
                if option_C != None:
                    formatted_string += " " + "C: " + option_C
                if option_D != None:
                    formatted_string += " " + "D: " + option_D
                entire_instruction = f"Story: {example['STORY']} Question: {example['QUESTION']} Choices: {formatted_string}"
                if example["答案\nANSWER"] == 'A':
                    answer = option_A
                elif example["答案\nANSWER"] == 'B':
                    answer = option_B
                elif example["答案\nANSWER"] == 'C':
                    answer = option_C
                else:
                    answer = option_D
            if name == 'Socialqa':
                answer_A = example["answerA"]
                answer_B = example["answerB"]
                answer_C = example["answerC"]
    
                formatted_string = ""
                formatted_string += "A: " + answer_A + " " + "B: " + answer_B + " " + "C: " + answer_C
                entire_instruction = f"Context: {example['context']} Question: {example['question']} Choices: {formatted_string}"
                if example["label_letter"] == "A":
                    answer = answer_A
                elif example["label_letter"] == "B":
                    answer = answer_B
                else:
                    answer = answer_C
            if name == 'SimpleToM_mental' or name == 'SimpleToM_behavior' or name == 'SimpleToM_judgment':
                data = example["choices"]
                result = ""
                answer_A = data["text"][0]
                answer_B = data["text"][1]
                result += "A: " + answer_A + " " + "B: " + answer_B
                print(result)
                entire_instruction = f"story: {example['story']} question: {example['question']} Choices: {result}"
                if example["answerKey"] == 'A':
                    answer = data["text"][0]
                else:
                    answer = data["text"][1]
            if name == 'ToMATO_first' or name == 'ToMATO_second':
                a1 = example["a0"]
                a2 = example["a1"]
                a3 = example["a2"]
                a4 = example["a3"]
                
                formatted_string = ""
                formatted_string += "A: " + a1 + " " + "B: " + a2 + " " + "C: " + a3 + " " + "D: " + a4

                entire_instruction = f"Conversation: {example['conversation']} Question: {example['q']} Choices: {formatted_string}"
                answer = example["a_str"]

            if name == 'OpenToM_attitude':
                entire_instruction = f"story: {example['narrative']} Question: {example['question']} Choices: positive, negative, neutral"
                answer = example["answer"]

            if name == 'OpenToM_location_cg_fo' or name == 'OpenToM_location_cg_so':
                entire_instruction = f"story: {example['narrative']} Question: {example['question']} Choices: Yes, No"
                answer = example["answer"]

            prompt = [
                {
                    "role": "system",
                    "content": INSTRUCT_PROMPT_SYSTEM2
                },
                {
                    "role": "user",
                    "content": entire_instruction,
                }
            ]
            
            data = {
                "data_source": name,
                "prompt": prompt,
                "ability": "logic",
                "reward_model": {
                    "style": "rule",
                    "ground_truth": answer
                },
                "extra_info": {
                    'split': split,
                    'index': idx,
                    'apply_chat_template': apply_chat_template
                }
            }
            return data
        return process_fn



    train_dataset_tomi = dataset_tomi.map(function=make_map_fn('ToMI', 'train', apply_chat_template=apply_chat_template), with_indices=True)
    #print(train_dataset_tomi)
    train_dataset_hitom = dataset_hitom.map(function=make_map_fn('HiToM', 'train', apply_chat_template=apply_chat_template), with_indices=True)
    #print(train_dataset_hitom)
    train_dataset_exploretom = dataset_exploretom.map(function=make_map_fn('ExploreToM', 'train', apply_chat_template=apply_chat_template), with_indices=True)
    #print(train_dataset_exploretom)
    train_dataset_tombench = dataset_ToMbench.map(function=make_map_fn('ToMbench', 'train', apply_chat_template=apply_chat_template), with_indices=True)
    #print(train_dataset_tombench)
    train_dataset_socialqa = dataset_socialqa.map(function=make_map_fn('Socialqa', 'train', apply_chat_template=apply_chat_template), with_indices=True)
    #print(train_dataset_socialqa)

    
    
    test_dataset_tomi = dataset_test_tomi.map(function=make_map_fn('ToMI','test', apply_chat_template=apply_chat_template), with_indices=True)
    test_dataset_hitom_third = dataset_test_hitom_third.map(function=make_map_fn('HiToM_Third', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    test_dataset_hitom_fourth = dataset_test_hitom_fourth.map(function=make_map_fn('HiToM_Fourth', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    test_dataset_exploretom = dataset_test_exploretom.map(function=make_map_fn('ExploreToM', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    test_dataset_tombench = dataset_test_ToMbench.map(function=make_map_fn('ToMbench', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    test_dataset_socialqa = dataset_test_socialqa.map(function=make_map_fn('Socialqa', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    
    ##OOD
    test_dataset_simpletom_mental = dataset_test_simpletom_mental.map(function=make_map_fn('SimpleToM_mental', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    #print(test_dataset_simpletom_mental)
    test_dataset_simpletom_mental = test_dataset_simpletom_mental.remove_columns('choices')
    print(test_dataset_simpletom_mental)

    test_dataset_simpletom_behavior = dataset_test_simpletom_behavior.map(function=make_map_fn('SimpleToM_behavior', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    test_dataset_simpletom_behavior = test_dataset_simpletom_behavior.remove_columns('choices')

    test_dataset_simpletom_judgment = dataset_test_simpletom_judgment.map(function=make_map_fn('SimpleToM_judgment', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    test_dataset_simpletom_judgment = test_dataset_simpletom_judgment.remove_columns('choices')

    test_dataset_tomato_first = dataset_test_tomato_first.map(function=make_map_fn('ToMATO_first', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    #print(test_dataset_tomato_first)
    test_dataset_tomato_second = dataset_test_tomato_second.map(function=make_map_fn('ToMATO_second', 'test', apply_chat_template=apply_chat_template), with_indices=True)

    test_dataset_opentom_attitude = dataset_test_opentom_attitude.map(function=make_map_fn('OpenToM_attitude', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    #print(test_dataset_opentom_attitude)
    test_dataset_opentom_location_cg_fo = dataset_test_opentom_location_cg_fo.map(function=make_map_fn('OpenToM_location_cg_fo', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    test_dataset_opentom_location_cg_so = dataset_test_opentom_location_cg_so.map(function=make_map_fn('OpenToM_location_cg_so', 'test', apply_chat_template=apply_chat_template), with_indices=True)
    


    train_dataset = concatenate_datasets([train_dataset_tomi, train_dataset_hitom, train_dataset_exploretom, train_dataset_tombench, train_dataset_socialqa])
    test_dataset = concatenate_datasets([test_dataset_tomi, test_dataset_hitom_third, test_dataset_hitom_fourth, test_dataset_exploretom, test_dataset_tombench, test_dataset_socialqa, test_dataset_simpletom_mental, test_dataset_simpletom_behavior, test_dataset_simpletom_judgment, test_dataset_tomato_first, test_dataset_tomato_second, test_dataset_opentom_attitude, test_dataset_opentom_location_cg_fo, test_dataset_opentom_location_cg_so])


    print(train_dataset)
    print(len(train_dataset))
    print(len(test_dataset))
    print(train_dataset[0])

    local_dir = args.local_dir


    # Create local directory if not exists
    os.makedirs(os.path.expanduser(local_dir), exist_ok=True)

    train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'), batch_size=1)
    test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'), batch_size=1)



