import json
import re
import argparse
from call_math_tool import *
from prompt_convertion import *


def collect_cot_data(args):
    if args.domain == "maths":
        with open("maths/prm800k_converted_time_073011_size_10000.json", 'r') as f:
            prm_data = [json.loads(d) for d in f]
        with open("maths/gsm8k_converted_time_080300_size_all.json", 'r') as f:
            gsm_data = [json.loads(d) for d in f]
        data = prm_data + gsm_data
    elif args.domain == "complex_qa":
        with open("complex_qa/musique_converted_time_080823_size_all.json", 'r') as f:
            musique_data = [json.loads(d) for d in f]
        with open("complex_qa/strategyqa_converted_time_080815_size_all.json", 'r') as f:
            strategyqa_data = [json.loads(d) for d in f]
        data = musique_data + strategyqa_data

    if args.domain == "maths":
        with open("train/data/processed/maths/maths_cot.jsonl", 'w') as f:
            for i, d in enumerate(data):
                messages = [{
                    "role": "user",
                    "content": d["question"]
                }, 
                {
                    "role": "assistant",
                    "content": d["answer"].split("#### ")[0].replace("\n", " ").strip() + f" Answer is {d['answer'].split('####')[1].strip()}."
                }]

                f.write(json.dumps({
                    "dataset": "maths",
                    "id": f"maths_{i}",
                    "messages": messages
                }) + "\n")
    elif args.domain == "complex_qa":
        with open("train/data/processed/complex_qa/complex_qa_cot.jsonl", 'w') as f:
            for i, d in enumerate(data):
                pos_answer_part = d["natural_language_plan"].find("Based on these evidences and decomposed questions")
                answer_part = d["natural_language_plan"][pos_answer_part:]
                messages = [{
                    "role": "user",
                    "content": d["task"]
                }, 
                {
                    "role": "assistant",
                    "content": d["natural_language_plan"].split("We need to answer these questions:")[0].strip() + ' ' + answer_part
                }]

                f.write(json.dumps({
                    "dataset": "complex_qa",
                    "id": f"complex_qa_{i}",
                    "messages": messages
                }) + "\n")


def collect_direct_data(args):
    if args.domain == "maths":
        with open("maths/data/prm800k_converted_time_073011_size_10000.json", 'r') as f:
            prm_data = [json.loads(d) for d in f]
        with open("maths/data/aqua_converted_time_080823_size_0_10000.json", 'r') as f:
            aqua_data = [json.loads(d) for d in f]
        with open("maths/data/gsm8k_converted_time_080300_size_all.json", 'r') as f:
            gsm_data = [json.loads(d) for d in f]
        data = prm_data + gsm_data
    elif args.domain == "complex_qa":
        with open("complex_qa/data/musique/musique_ans_v1.0_train.jsonl", 'r') as f:
            musique_data = [json.loads(d) for d in f]
        with open("complex_qa/data/strategyqa/strategyqa_train.json", 'r') as f:
            strategyqa_data = json.load(f)
    elif args.domain == "icode":
        with open("icode/data/ic_spider_train.json", 'r') as f:
            spider_data = json.load(f)
        data = spider_data

    if args.domain == "maths":
        with open(f"train/data/processed/maths/maths_direct.jsonl", 'w') as f:
            for i, d in enumerate(data):
                messages = [{
                    "role": "user",
                    "content": d["question"]
                }, 
                {
                    "role": "assistant",
                    "content": f"Answer is {d['answer'].split('####')[1].strip()}."
                }]

                f.write(json.dumps({
                    "dataset": args.domain,
                    "id": f"{args.domain}_{i}",
                    "messages": messages
                }) + "\n")
    elif args.domain == "complex_qa":
        with open(f"train/data/processed/complex_qa/complex_qa_direct.jsonl", 'w') as f:
            for i, d in enumerate(musique_data):
                messages = [{
                    "role": "user",
                    "content": d["question"]
                }, 
                {
                    "role": "assistant",
                    "content": d["answer"]
                }]

                f.write(json.dumps({
                    "dataset": args.domain,
                    "id": f"{args.domain}_{i}",
                    "messages": messages
                }) + "\n")
            for i, d in enumerate(strategyqa_data):
                messages = [{
                    "role": "user",
                    "content": d["question"]
                }, 
                {
                    "role": "assistant",
                    "content": str(d["answer"])
                }]

                f.write(json.dumps({
                    "dataset": args.domain,
                    "id": f"{args.domain}_{i}",
                    "messages": messages
                }) + "\n")
    elif args.domain == "icode":
        with open(f"train/data/processed/icode/icode_direct.jsonl", 'w') as f:
            for i, d in enumerate(data):
                messages = [{
                    "role": "user",
                    "content": d["query"] + " " + convert_sql_to_human_solution(d).split("The gold SQL query")[0].strip()
                }, 
                {
                    "role": "assistant",
                    "content": d["gold"]
                }]

                f.write(json.dumps({
                    "dataset": args.domain,
                    "id": f"{args.domain}_{i}",
                    "messages": messages
                }) + "\n")


def collect_plan_data(args):
    if args.domain == "maths":
        with open("maths/prm800k_converted_time_073011_size_10000.json", 'r') as f:
            prm_data = [json.loads(d) for d in f]
        with open("maths/gsm8k_converted_time_080300_size_all.json", 'r') as f:
            gsm_data = [json.loads(d) for d in f]
        with open("maths/asdiv_converted_time_082022_size_all.json", 'r') as f:
            asdiv_data = [json.loads(d) for d in f]
        data = prm_data + gsm_data + asdiv_data
    elif args.domain == "complex_qa":
        with open("complex_qa/musique_converted_time_080823_size_all.json", 'r') as f:
            musique_data = [json.loads(d) for d in f]
            for i in range(len(musique_data)):
                musique_data[i]["source"] = "musique"
        with open("complex_qa/strategyqa_converted_time_080815_size_all.json", 'r') as f:
            strategyqa_data = [json.loads(d) for d in f]
            for i in range(len(strategyqa_data)):
                strategyqa_data[i]["source"] = "strategyqa"
        data = musique_data + strategyqa_data
    elif args.domain == "icode":
        with open("icode/data/spider_converted_time_090919_size_all.json", 'r') as f:
            spider_data = [json.loads(d) for d in f]
            for i in range(len(spider_data)):
                spider_data[i]["source"] = "spider"
            data = spider_data
    elif args.domain == "web_agent":
        with open("web_agent/data/mind2web_converted_time_091300_size_all.json", 'r') as f:
            mind2web_data = [json.loads(d) for d in f]
            for i in range(len(mind2web_data)):
                mind2web_data[i]["source"] = "mind2web"
            data = mind2web_data

    if "iterative" in args.formulation:
        if not args.decompose_training:
            output_fn = f"train/data/processed/{args.domain}/{args.domain}_plan_091300.jsonl"
        else:
            output_fn = f"train/data/processed/{args.domain}/{args.domain}_plan_decompose_091300.jsonl"
    else:
        output_fn = f"train/data/processed/{args.domain}/{args.domain}_plan_onetime_091300.jsonl"

    all_messages = list()
    for i, d in enumerate(data):
        subgoal_plan = d["subgoal_plan"]
        subgoals, actions = parse_subgoals(subgoal_plan)
        user = f"Please provide a reasonable subgoal-based plan to solve the given task.\nTask: {d['task']}; Initial Environment Description: None."

        unavailable = False
        messages = []
        messages.append({
            "role": "user",
            "content": user
        })

        if "iterative" not in args.formulation:
            messages.append({
                "role": "assistant",
                "content": "; ".join(subgoals)
            })
        else:
            for j, (subgoal, action) in enumerate(zip(subgoals, actions)):
                if action == []:
                    unavailable = True
                    break
                if args.domain != "icode":
                    if args.domain != "web_agent":
                        user = f"The executed result for Subgoal {str(j+1)} is "
                        last_act = action[-1]
                    else:
                        last_subgoal = subgoal[subgoal.find(": ")+2:]
                        last_subgoal = last_subgoal[0].lower() + last_subgoal[1:]
                        user = f"We have already " + last_subgoal       # we won't put the html results here (too long..)
                else:
                    user = f"The executed result for Subgoal {str(j+1)} is "
                    if j == 0:      # the output of first 2 subgoals doesn't contain '=' 
                        for act in action:
                            user += act[act.rfind('=')+2:].strip()    # just include the final result
                    elif j == 1:
                        subgoal = subgoal[:-1]
                        for act in action:
                            pattern = re.compile("DESCRIBE (.*?) =")
                            if pattern.findall(act):
                                table = pattern.findall(act)[0]
                                subgoal += f", {table}"
                                user += f"Table {table} has columns {act[act.rfind('=')+2:].strip()}; "
                        subgoal += '.'
                        user = user[:-2]
                    else:
                        for act in action:
                            user += act[act.find('=')+2:].strip() + ' '   # just include the final result
                user = user.strip() + " Should we stop planning?"
                
                try:
                    current_variable = action[-1].split(" = ")[0].split(':')[1].strip().split(", ")[-1]
                    if current_variable[0] == "R":
                        current_variable_idx = int(current_variable[1:])
                    else:
                        user = user.replace("Output", "R" + str(current_variable_idx + 1))

                    if j != 0:
                        messages.append({
                            "role": "assistant",
                            "content": "No, I will keep planning. " + subgoal
                        })
                    else:
                        messages.append({
                            "role": "assistant",
                            "content": subgoal
                        })
                    
                    # if j != len(subgoals)-1:
                    messages.append({
                        "role": "user",
                        "content": user
                    })
                except:
                    unavailable = True
                    break
            
            messages.append({
                "role": "assistant",
                "content": "Yes, I will stop planning."
            })

        if unavailable:
            continue
        else:
            all_messages.append(messages)
    
    if not args.decompose_training:
        with open(output_fn, 'w') as f:
            for i, messages in enumerate(all_messages):
                f.write(json.dumps({
                    "dataset": f"{args.domain}",
                    "id": f"{args.domain}_{i}",
                    "messages": messages
                }) + "\n")
    else:
        decomposed_messages = list()
        with open(output_fn, 'w') as f:
            for message in all_messages:
                context = ""
                for i in range(0, len(message), 2):
                    user = context + ' ' + message[i]["content"]
                    assistant = message[i+1]["content"]
                    new_messages = [{
                        "role": "user",
                        "content": user
                    }, {
                        "role": "assistant",
                        "content": assistant
                    }]
                    decomposed_messages.append(new_messages)
                    f.write(json.dumps({
                        "dataset": f"{args.domain}",
                        "id": f"{args.domain}_{str(len(decomposed_messages)-1)}",
                        "messages": new_messages
                    }) + "\n")  
                    context = user + ' ' + assistant


def collect_ground_data(args):
    if args.domain == "maths":
        with open("maths/prm800k_converted_time_073011_size_10000.json", 'r') as f:
            prm_data = [json.loads(d) for d in f]
        with open("maths/gsm8k_converted_time_080300_size_all.json", 'r') as f:
            gsm_data = [json.loads(d) for d in f]
        with open("maths/asdiv_converted_time_082022_size_all.json", 'r') as f:
            asdiv_data = [json.loads(d) for d in f]
        data = prm_data + gsm_data + asdiv_data
    elif args.domain == "complex_qa":
        with open("complex_qa/musique_converted_time_080823_size_all.json", 'r') as f:
            musique_data = [json.loads(d) for d in f]
        with open("complex_qa/strategyqa_converted_time_080815_size_all.json", 'r') as f:
            strategyqa_data = [json.loads(d) for d in f]
        data = musique_data + strategyqa_data
    elif args.domain == "icode":
        with open("icode/data/spider_converted_time_090919_size_all.json", 'r') as f:
            spider_data = [json.loads(d) for d in f]
            for i in range(len(spider_data)):
                spider_data[i]["source"] = "spider"
            data = spider_data
    elif args.domain == "web_agent":
        with open("web_agent/data/mind2web_converted_time_091300_size_all.json", 'r') as f:
            mind2web_data = [json.loads(d) for d in f]
            for i in range(len(mind2web_data)):
                mind2web_data[i]["source"] = "mind2web"
            data = mind2web_data

    if "iterative" in args.formulation:
        if not args.decompose_training:
            output_fn = f"train/data/processed/{args.domain}/{args.domain}_ground_091300.jsonl"
        else:
            output_fn = f"train/data/processed/{args.domain}/{args.domain}_ground_decompose_091300.jsonl"
    else:
        output_fn = f"train/data/processed/{args.domain}/{args.domain}_ground_onetime_091300.jsonl"

    all_messages = list()
    for i, d in enumerate(data):
        subgoal_plan = d["subgoal_plan"]
        subgoals, actions = parse_subgoals(subgoal_plan)
        # print(actions)
        if args.domain == "maths":
            user_inst = "Please ground the given subgoal to corresponding executable actions for solving the given task. The grounded actions must be the one in available action list.\n\n" \
                        "The available action list is 'Calculator', 'SetEquation', 'SolveEquation', 'Count', 'SolveInequality', 'Code', and 'Define'.\n" \
                        "Calculator(formula): Calculate the input formula; SetEquation(equation): Set up an equation to be solved; SolveEquation(equation): Solve the previous set equation; Count(list): Count the number of elements in the given list; SolveInequality(inequality): Solve the previous set inequality; Code(pseudo_code): Generate a Python function that corresponds to the pseudo code; Define(variable/number): Define a variable or a number for latter usage.\n\n" \
                        f"Task: {d['task']} \n"
        elif args.domain == "complex_qa":
            user_inst = "Please ground the given subgoal to corresponding executable actions for solving the given task. The grounded actions must be the one in available action list.\n\n" \
                        "The available action list is 'KnowledgeQuery', 'ParagraphRetrieve', 'QA', 'Calculator', and 'Code'.\n" \
                        "Calculator(formula): Calculate the input formula; KnowledgeQuery(query): Capture the relevant webpages according to the query; ParagraphRetrieve(context, query): Given a query, retrieve the most relevant paragraphs from the given context; QA(context, query): Given context, answer the given query; Calculator(formula): Calculate the input formula; Code(pseudo_code): Generate a Python function that corresponds to the pseudo code.\n\n" \
                        f"Task: {d['task']} \n"
        elif args.domain == "icode":
            user_inst = "Please ground the given subgoal to corresponding executable actions for solving the given task. The grounded actions must be the one in available action list.\n\n" \
                        "The available action list is one of the valid operations used in SQL.\n\n" \
                        f"Task: {d['task']} \n"
        elif args.domain == "web_agent":
            user_inst = "Please ground the given subgoal to corresponding executable actions for solving the given task. The grounded actions must be the one in available action list.\n\n" \
                        "The available action list is 'CLICK', 'TYPE', 'SELECT'.\n" \
                        "CLICK(Env, Query): Click the relevant html region in Env according to Query; TYPE(Env, Query, Text): Type Text into the relevant html region in Env according to Query; SELECT(Env, Query, Text): Select the value Text of the relevant selection box in Env according to Query.\n\n" \
                        f"Task: {d['task']} \n"

        unavailable = False
        messages = []

        if "iterative" not in args.formulation:
            user = user_inst + f"Subgoals to be grounded: {'; '.join(subgoals)}\n"
                
            assistant = ""
            if args.domain != "icode":
                for action in actions:
                    for act in action:
                        action_variable, action_name, action_args, action_results = parse_actions(act, args.domain)
                        assistant += f"{', '.join(action_variable)} = {action_name}({action_args})" + '; '
            else:
                for action in actions:
                    for act in action:
                        assistant += act[act.find(": ")+2: -1] + '; '
            assistant = assistant[:-2]

            messages.append({
                "role": "user",
                "content": user
            })
            messages.append({
                "role": "assistant",
                "content": assistant
            })
        else:
            subgoal_result = dict()
            for j, (subgoal, action) in enumerate(zip(subgoals, actions)):
                if action == []:
                    unavailable = True
                    break
                if j == 0:
                    user = user_inst + f"Subgoal to be grounded: {subgoal}\n"
                else:
                    if args.domain == "complex_qa":
                        user = f"Subgoal to be grounded: {subgoal} The grounding could be based on the following results:"
                        for subgoal_idx in sorted(list(subgoal_result.keys())):
                            user += f" The execution result for {subgoal_idx} is {subgoal_result[subgoal_idx]}"
                        user += "\n"
                    elif args.domain == "icode":
                        user = f"Subgoal to be grounded: {subgoal} The grounding could be based on the following results:"
                        for subgoal_idx in sorted(list(subgoal_result.keys())):
                            if j >= 3 and subgoal_idx == "Subgoal 2":
                                continue
                            else:
                                user += f" The execution result for {subgoal_idx} is {subgoal_result[subgoal_idx]}"
                        user += "\n"
                    else:
                        user = f"Subgoal to be grounded: {subgoal}\n"
                
                assistant = ""
                if args.domain != "icode":
                    cur_subgoal_idx = subgoal[: subgoal.find(':')]
                    for act in action:
                        action_variable, action_name, action_args, action_results = parse_actions(act, args.domain)
                        print(action_variable, action_name, action_args, action_results)
                        assistant += f"{', '.join(action_variable)} = {action_name}({action_args})" + '; '
                    subgoal_result[cur_subgoal_idx] = action_results
                else:
                    cur_subgoal_idx = subgoal[: subgoal.find(':')]
                    for act in action:
                        action_variable, _, _, _ = parse_actions(act)
                        if j == 0:      # the output of first 2 subgoals doesn't contain '=' 
                            assistant += act[act.find(": ")+2: act.rfind('=')].strip() + '; '
                            if cur_subgoal_idx not in subgoal_result:
                                subgoal_result[cur_subgoal_idx] = act[act.rfind('=')+2:].strip()
                            else:
                                subgoal_result[cur_subgoal_idx] += "; " + act[act.rfind('=')+2:].strip()
                        elif j == 1:
                            assistant += act[act.find(": ")+2: act.rfind('=')].strip() + '; '
                            pattern = re.compile("DESCRIBE (.*?) =")
                            if pattern.findall(act):
                                table = pattern.findall(act)[0]
                                prefix = f"Table {table} has columns "
                            else:
                                prefix = ""
                            if cur_subgoal_idx not in subgoal_result:
                                subgoal_result[cur_subgoal_idx] = prefix + act[act.rfind('=')+2:].strip()
                            else:
                                subgoal_result[cur_subgoal_idx] += "; " + prefix + act[act.rfind('=')+2:].strip()
                        else:
                            assistant += act[act.find(": ")+2:].strip() + '; '   # just include the final result
                            if cur_subgoal_idx not in subgoal_result:
                                subgoal_result[cur_subgoal_idx] = act[act.find('=')+2:].strip()
                            else:
                                subgoal_result[cur_subgoal_idx] += "; " + act[act.find('=')+2:].strip()
                # print(subgoal_idx, act, action_results)
                assistant = assistant[:-2]

                if args.domain != "icode":
                    try:
                        current_variable = action_variable[-1]
                        if current_variable[0] == "R":
                            current_variable_idx = int(current_variable[1:])
                        else:
                            assistant = assistant.replace("Output", "R" + str(current_variable_idx + 1))
                    except:
                        unavailable = True
                        break

                messages.append({
                    "role": "user",
                    "content": user
                })
                messages.append({
                    "role": "assistant",
                    "content": assistant
                })

        if unavailable:
            continue
        else:
            all_messages.append(messages)
    
    if not args.decompose_training:
        with open(output_fn, 'w') as f:
            for i, messages in enumerate(all_messages):
                f.write(json.dumps({
                    "dataset": f"{args.domain}",
                    "id": f"{args.domain}_{i}",
                    "messages": messages
                }) + "\n")
    else:
        decomposed_messages = list()
        with open(output_fn, 'w') as f:
            for message in all_messages:
                context = ""
                for i in range(0, len(message), 2):
                    user = context + ' ' + message[i]["content"]
                    assistant = message[i+1]["content"]
                    new_messages = [{
                        "role": "user",
                        "content": user
                    }, {
                        "role": "assistant",
                        "content": assistant
                    }]
                    decomposed_messages.append(new_messages)
                    f.write(json.dumps({
                        "dataset": f"{args.domain}",
                        "id": f"{args.domain}_{str(len(decomposed_messages)-1)}",
                        "messages": new_messages
                    }) + "\n")  
                    context = user + ' ' + assistant
        print(len(decomposed_messages))


def collect_plan_ground_data(args):
    if args.domain == "maths":
        with open("maths/prm800k_converted_time_073011_size_10000.json", 'r') as f:
            prm_data = [json.loads(d) for d in f]
        with open("maths/gsm8k_converted_time_080300_size_all.json", 'r') as f:
            gsm_data = [json.loads(d) for d in f]
        data = gsm_data + prm_data
    elif args.domain == "complex_qa":
        with open("complex_qa/musique_converted_time_080823_size_all.json", 'r') as f:
            musique_data = [json.loads(d) for d in f]
        with open("complex_qa/strategyqa_converted_time_080815_size_all.json", 'r') as f:
            strategyqa_data = [json.loads(d) for d in f]
        data = musique_data + strategyqa_data

    if "iterative" not in args.formulation:
        if "plan_ground" in args.formulation:
            output_fn = f"train/data/processed/{args.domain}/{args.domain}_plan_ground_wo_module_onetime_081722.jsonl"
        else:
            output_fn = f"train/data/processed/{args.domain}/{args.domain}_ground_wo_module_onetime_081722.jsonl"
    else:
        if "plan_ground" in args.formulation:
            output_fn = f"train/data/processed/{args.domain}/{args.domain}_plan_ground_wo_module_081722.jsonl"
        else:
            output_fn = f"train/data/processed/{args.domain}/{args.domain}_ground_wo_module_081722.jsonl"

    all_messages = list()
    for i, d in enumerate(data):
        subgoal_plan = d["subgoal_plan"]
        subgoals, actions = parse_subgoals(subgoal_plan)
        user_inst_plan = f"Please provide a reasonable subgoal-based plan to solve the given task.\nTask: {d['task']}; Initial Environment Description: None."
        if args.domain == "maths":
            user_inst_ground = "Please ground the generated subgoals to corresponding executable actions for solving the given task. The grounded actions must be the one in available action list.\n\n" \
                        "The available action list is 'Calculator', 'SetEquation', 'SolveEquation', 'Count', 'SolveInequality', 'Code', and 'Define'.\n" \
                        "Calculator(formula): Calculate the input formula; SetEquation(equation): Set up an equation to be solved; SolveEquation(equation): Solve the previous set equation; Count(list): Count the number of elements in the given list; SolveInequality(inequality): Solve the previous set inequality; Code(pseudo_code): Generate a Python function that corresponds to the pseudo code; Define(variable/number): Define a variable or a number for latter usage.\n\n"
        elif args.domain == "complex_qa":
            user_inst_ground = "Please ground the generated subgoals to corresponding executable actions for solving the given task. The grounded actions must be the one in available action list.\n\n" \
                        "The available action list is 'KnowledgeQuery', 'ParagraphRetrieve', 'QA', 'Calculator', and 'Code'.\n" \
                        "Calculator(formula): Calculate the input formula; KnowledgeQuery(query): Capture the relevant webpages according to the query; ParagraphRetrieve(context, query): Given a query, retrieve the most relevant paragraphs from the given context; QA(context, query): Given context, answer the given query; Calculator(formula): Calculate the input formula; Code(pseudo_code): Generate a Python function that corresponds to the pseudo code.\n\n"

        unavailable = False
        messages = []

        if "plan_ground" in args.formulation:
            messages.append({
                "role": "user",
                "content": user_inst_plan
            })

            if "iterative" not in args.formulation:
                messages.append({
                    "role": "assistant",
                    "content": "; ".join(subgoals)
                })

                messages.append({
                    "role": "user",
                    "content": user_inst_ground
                })
                    
                assistant = ""
                for action in actions:
                    for act in action:
                        action_variable, action_name, action_args, action_results = parse_actions(act)
                        assistant += f"{', '.join(action_variable)} = {action_name}({action_args})" + '; '
                assistant = assistant[:-2]

                messages.append({
                    "role": "assistant",
                    "content": assistant
                })
        else:
            messages.append({
                "role": "user",
                "content": user_inst_ground + f"Task: {d['task']}"
            })

            if "iterative" not in args.formulation:
                assistant = ""
                for action in actions:
                    for act in action:
                        action_variable, action_name, action_args, action_results = parse_actions(act)
                        assistant += f"{', '.join(action_variable)} = {action_name}({action_args})" + '; '
                assistant = assistant[:-2]

                messages.append({
                    "role": "assistant",
                    "content": assistant
                })

        if unavailable:
            continue
        else:
            all_messages.append(messages)
    
    with open(output_fn, 'w') as f:
        for i, messages in enumerate(all_messages):
            f.write(json.dumps({
                "dataset": f"{args.domain}",
                "id": f"{args.domain}_{i}",
                "messages": messages
            }) + "\n")


def collect_unified_plan_data(domains):
    unified_plan_data = list()
    if "maths" in domains:
        with open(f"train/data/processed/maths/maths_plan_090919.jsonl", "r") as f:
            unified_plan_data += [json.loads(d) for d in f]
    if "complex_qa" in domains:
        with open(f"train/data/processed/complex_qa/complex_qa_plan_081722.jsonl", "r") as f:
            unified_plan_data += [json.loads(d) for d in f]
    if "web_agent" in domains:
        with open(f"train/data/processed/web_agent/web_agent_plan_091300.jsonl", "r") as f:
            unified_plan_data += [json.loads(d) for d in f]

    with open(f"train/data/processed/unified/unified_{'_'.join(domains)}_plan_091300.jsonl", "w") as f:
        for d in unified_plan_data:
            f.write(json.dumps(d)+'\n')


def collect_unified_ground_data(domains):
    unified_ground_data = list()
    if "maths" in domains:
        with open(f"train/data/processed/maths/maths_ground_090919.jsonl", "r") as f:
            unified_ground_data += [json.loads(d) for d in f]
    if "complex_qa" in domains:
        with open(f"train/data/processed/complex_qa/complex_qa_ground_081722.jsonl", "r") as f:
            unified_ground_data += [json.loads(d) for d in f]
    if "web_agent" in domains:
        with open(f"train/data/processed/web_agent/web_agent_ground_091300.jsonl", "r") as f:
            unified_ground_data += [json.loads(d) for d in f]

    with open(f"train/data/processed/unified/unified_{'_'.join(domains)}_ground_091300.jsonl", "w") as f:
        for d in unified_ground_data:
            f.write(json.dumps(d)+'\n')


def main(args):
    if args.unified:
        domains = args.unified.split(",")
        collect_unified_plan_data(domains)
        collect_unified_ground_data(domains)
        return
    if args.formulation == "cot":
        collect_cot_data(args)
    elif args.formulation == "direct":
        collect_direct_data(args)
    elif "uniagent" in args.formulation:
        collect_plan_data(args)
        collect_ground_data(args)
    elif "wo_module" in args.formulation:
        collect_plan_ground_data(args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--formulation',
        dest='formulation',
        type=str,
        default="uniagent_iterative"
    )

    parser.add_argument(
        '--domain',
        dest='domain',
        type=str,
        default="maths"
    )

    parser.add_argument(
        '--decompose_training',
        dest='decompose_training',
        action='store_true'
    )

    parser.add_argument(
        '--unified',
        dest='unified',
        type=str,
    )

    args = parser.parse_args()
    
    main(args)