import openai
import json
import os
import argparse
from tqdm import tqdm
from zeno_build.models import lm_config
from zeno_build.prompts import chat_prompt
from incontext_maths_prm import *
from incontext_maths_gsm import *
from incontext_qa import *
from incontext_icode import *
from incontext_icode_sql import *
from incontext_web_agent import *
from fast_prompt import *

def maths_convertion(args):
    system_info = "You are a helpful assist to convert natural language plans or structural action list" \
                    "into subgoal-based plans and their corresponding structured actions."
    model_config = lm_config.LMConfig(provider="openai_chat", model="gpt-4")
    data_num = args.ed_idx - args.st_idx
    converted_num = 0

    if "prm" in args.data_fn:
        with open(args.data_fn, "r") as f:
            math_solu = [json.loads(line) for line in f]
        math_solu = math_solu[:data_num]

        while converted_num < data_num:
            full_contexts = list()
            for i in range(len(math_solu)):
                final_prompt = instruction_maths_prm + f"Task: {math_solu[i]['question']['problem']}\n\n"
                final_prompt += f"Natural language plan: {math_solu[i]['question']['ground_truth_solution']}\n\n"
                full_contexts.append(
                    chat_prompt.ChatMessages([{"role": "system", "content": system_info}, {"role": "user", "content": final_prompt}])
                )

            predictions = asyncio.run(
                generate_from_openai_chat_completion(
                    full_contexts=full_contexts,
                    model_config=model_config,
                    temperature=0,
                    max_tokens=500
                )
            )

            new_math_solu = list()
            with open(f"maths/prm800k_converted_time_073011_size_{data_num}.json", "a+") as f:
                for i, (solu, pred) in enumerate(zip(math_solu, predictions)):
                    inst = dict()
                    inst["task"] = solu['question']['problem']
                    inst["natural_language_plan"] = solu['question']['ground_truth_solution']
                    inst["subgoal_plan"] = pred
                    if not inst["subgoal_plan"]:
                        new_math_solu.append(solu)
                    else:
                        converted_num += 1
                        f.write(json.dumps(inst)+'\n')
                        if converted_num >= data_num:
                            break
            math_solu = new_math_solu
    elif "gsm" in args.data_fn:
        with open(args.data_fn, "r") as f:
            math_solu = [json.loads(line) for line in f]
        if not args.convert_all:
            math_solu = math_solu[args.st_idx: args.ed_idx]
            output_fn = f"maths/gsm8k_converted_time_082101_size_{args.st_idx}_{args.ed_idx}.json"
        else:
            output_fn = "maths/gsm8k_converted_time_082101_size_all.json"
            data_num = len(math_solu)

        while converted_num < data_num:
            full_contexts = list()
            for i in range(len(math_solu)):
                final_prompt = instruction_maths_gsm + f"Task: {math_solu[i]['question']}\n\n"
                human_solu = math_solu[i]['answer'].split("####")[0].split(" ** ")
                final_prompt += f"Natural language plan: {' '.join(human_solu[1:])}\n\n"
                full_contexts.append(
                    chat_prompt.ChatMessages([{"role": "system", "content": system_info}, {"role": "user", "content": final_prompt}])
                )

            predictions = asyncio.run(
                generate_from_openai_chat_completion(
                    full_contexts=full_contexts,
                    model_config=model_config,
                    temperature=0,
                    max_tokens=500
                )
            )

            new_math_solu = list()
            with open(output_fn, "a+") as f:
                for i, (solu, pred) in enumerate(zip(math_solu, predictions)):
                    inst = dict()
                    inst["task"] = solu['question']
                    human_solu = math_solu[i]['answer'].split("####")[0].split(" ** ")
                    inst["ans"] = math_solu[i]['answer'].split("####")[1].strip()
                    inst["natural_language_plan"] = ' '.join(human_solu[1:])
                    inst["subgoal_plan"] = pred
                    if not inst["subgoal_plan"]:
                        new_math_solu.append(solu)
                    else:
                        converted_num += 1
                        f.write(json.dumps(inst)+'\n')
                        if converted_num >= data_num:
                            break
            math_solu = new_math_solu
    elif "aqua" in args.data_fn:
        with open(args.data_fn, "r") as f:
            math_solu = [json.loads(line) for line in f]
        if not args.convert_all:
            math_solu = math_solu[args.st_idx: args.ed_idx]
            output_fn = f"maths/aqua_converted_time_080823_size_{args.st_idx}_{args.ed_idx}.json"
        else:
            output_fn = "maths/aqua_converted_time_080823_size_all.json"
            data_num = len(math_solu)

        available_math_solu = []
        for i in range(len(math_solu)):
            if len(math_solu[i]["rationale"].split('\n')) > 2:
                available_math_solu.append(math_solu[i])
        
        math_solu = available_math_solu.copy()
        available_maths_cases = len(math_solu)

        while converted_num < data_num and converted_num < available_maths_cases:
            full_contexts = list()
            for solu in math_solu:
                final_prompt = instruction_maths_gsm + f"Task: {solu['question']}\n\n"
                human_solu = solu['rationale'].split('\n')[:-1]
                final_prompt += f"Natural language plan: {' '.join(human_solu)}\n\n"
                full_contexts.append(
                    chat_prompt.ChatMessages([{"role": "system", "content": system_info}, {"role": "user", "content": final_prompt}])
                )

            predictions = asyncio.run(
                generate_from_openai_chat_completion(
                    full_contexts=full_contexts,
                    model_config=model_config,
                    temperature=0,
                    max_tokens=200
                )
            )

            new_math_solu = list()
            with open(output_fn, "a+") as f:
                for i, (solu, pred) in enumerate(zip(math_solu, predictions)):
                    inst = dict()
                    inst["task"] = solu['question']
                    inst["natural_language_plan"] = ' '.join(solu['rationale'].split('\n')[:-1])
                    inst["subgoal_plan"] = pred
                    if not inst["subgoal_plan"]:
                        new_math_solu.append(solu)
                    else:
                        converted_num += 1
                        f.write(json.dumps(inst)+'\n')
                        if converted_num >= data_num or converted_num >= available_maths_cases:
                            break
            math_solu = new_math_solu

        while converted_num < data_num and converted_num < available_maths_cases:
            full_contexts = list()
            for solu in math_solu:
                final_prompt = instruction_maths_gsm + f"Task: {solu['question']}\n\n"
                human_solu = solu['rationale'].split('\n')[:-1]
                final_prompt += f"Natural language plan: {' '.join(human_solu)}\n\n"
                full_contexts.append(
                    chat_prompt.ChatMessages([{"role": "system", "content": system_info}, {"role": "user", "content": final_prompt}])
                )

            predictions = asyncio.run(
                generate_from_openai_chat_completion(
                    full_contexts=full_contexts,
                    model_config=model_config,
                    temperature=0,
                    max_tokens=200
                )
            )

            new_math_solu = list()
            with open(f"maths/aqua_converted_time_073121_size_{args.st_idx}_{args.ed_idx}.json", "a+") as f:
                for i, (solu, pred) in enumerate(zip(math_solu, predictions)):
                    inst = dict()
                    inst["task"] = solu['question']
                    inst["natural_language_plan"] = ' '.join(solu['rationale'].split('\n')[:-1])
                    inst["subgoal_plan"] = pred
                    if not inst["subgoal_plan"]:
                        new_math_solu.append(solu)
                    else:
                        converted_num += 1
                        f.write(json.dumps(inst)+'\n')
                        if converted_num >= data_num or converted_num >= available_maths_cases:
                            break
            math_solu = new_math_solu
    elif "asdiv" in args.data_fn:
        with open(args.data_fn, "r") as f:
            math_solu = [json.loads(line) for line in f]
        if not args.convert_all:
            math_solu = math_solu[args.st_idx: args.ed_idx]
            output_fn = f"maths/asdiv_converted_time_082101_size_{args.st_idx}_{args.ed_idx}.json"
        else:
            output_fn = "maths/asdiv_converted_time_082101_size_all.json"
            data_num = len(math_solu)

        while converted_num < data_num:
            full_contexts = list()
            for i in range(len(math_solu)):
                final_prompt = instruction_maths_gsm + f"Task: {math_solu[i]['question']}\n\n"
                final_prompt += f"Natural language plan: {math_solu[i]['solution']}\n\n"
                full_contexts.append(
                    chat_prompt.ChatMessages([{"role": "system", "content": system_info}, {"role": "user", "content": final_prompt}])
                )

            predictions = asyncio.run(
                generate_from_openai_chat_completion(
                    full_contexts=full_contexts,
                    model_config=model_config,
                    temperature=0,
                    max_tokens=500
                )
            )

            new_math_solu = list()
            with open(output_fn, "a+") as f:
                for i, (solu, pred) in enumerate(zip(math_solu, predictions)):
                    inst = dict()
                    inst["task"] = solu['question']
                    inst["natural_language_plan"] = math_solu[i]['solution']
                    inst["subgoal_plan"] = pred
                    if not inst["subgoal_plan"]:
                        new_math_solu.append(solu)
                    else:
                        converted_num += 1
                        f.write(json.dumps(inst)+'\n')
                        if converted_num >= data_num:
                            break
            math_solu = new_math_solu


def convert_qa_to_human_solution(x, data_fn):
    if "strategyqa" in data_fn:
        prompt = "We find relevant facts: "
        prompt += ' '.join(x['facts'])
        prompt += " We need to answer these questions: "

        for evidence in x["evidence"]:
            useful = True
            all_ref = []
            for ref in evidence:
                if "no_evidence" in ref:
                    useful = False
                    break
                else:
                    islist = False
                    for r in ref:
                        if isinstance(r, list):
                            islist = True
                            all_ref.append(r)
                    if not islist:
                        all_ref.append([])
            if useful:
                break

        if not useful:
            return None
        
        for i, q in enumerate(x['decomposition']):
            ref = all_ref[i]
            ref = [f"'{r}'" for r in ref]
            prompt += str(i+1) + ". " + q + ' '
            if ref:
                prompt += f"(Can be answered based on paragraph {', '.join(ref)}) "
        prompt += f"Based on these evidences and decomposed questions, the answer is {str(x['answer'])}."
    elif "musique" in data_fn:
        prompt = "We find relevant facts: "
        paras = dict()
        for para in x["paragraphs"]:
            paras[para["idx"]] = para
        for q in x["question_decomposition"]:
            prompt += ' ' + paras[q["paragraph_support_idx"]]["paragraph_text"]
        
        prompt += " We need to answer these questions:"
        for i, q in enumerate(x["question_decomposition"]):
            prompt += ' ' + str(i+1) + ". " + q["question"] + f" (Can be answered based on paragraph '{paras[q['paragraph_support_idx']]['title']}')"
        prompt += f" Based on these evidences and decomposed questions, the answer is {x['answer']}."

    return prompt


def convert_sql_to_human_solution(x):
    tables = list(x["db_tables"].keys())
    prompt = f"We have {len(tables)} SQL tables: "
    prompt += ', '.join(tables) + '. '
    for table in tables:
        cap_table = table[0].upper() + table[1:]
        prompt += f"{cap_table} has columns: "
        for column in x["db_tables_type"][table]:
            prompt += column[0] + f" (type: {column[1]}), "
        prompt = prompt[:-2] + '. '
    prompt += f"The gold SQL query for this task is {x['gold']}."

    return prompt


def complex_qa_convertion(args):
    system_info = "You are a helpful assist to convert natural language plans or structural action list" \
                    "into subgoal-based plans and their corresponding structured actions."
    model_config = lm_config.LMConfig(provider="openai_chat", model="gpt-4")
    data_num = args.ed_idx - args.st_idx
    converted_num = 0

    if "strategyqa" in args.data_fn:
        with open(args.data_fn, "r") as f:
            qa_data = json.load(f)
        print(len(qa_data))
        if not args.convert_all:
            qa_data = qa_data[args.st_idx: args.ed_idx]
            output_fn = f"complex_qa/strategyqa_converted_time_080815_size_{args.st_idx}_{args.ed_idx}.json"
        else:
            output_fn = "complex_qa/strategyqa_converted_time_080815_size_all.json"
            data_num = len(qa_data)

        available_qa_data = []
        for i in range(len(qa_data)):
            converted_test_case = convert_qa_to_human_solution(qa_data[i], args.data_fn)
            if converted_test_case:
                available_qa_data.append(qa_data[i])
        
        qa_data = available_qa_data.copy()
        available_qa_cases = len(qa_data)
        
        while converted_num < data_num and converted_num < available_qa_cases:
            full_contexts = list()
            for i in range(len(qa_data)):
                final_prompt = instruction_qa + f"Task: {qa_data[i]['question']}\n\n"
                if convert_qa_to_human_solution(qa_data[i], args.data_fn):
                    final_prompt += f"Natural language plan: {convert_qa_to_human_solution(qa_data[i], args.data_fn)}\n\n"
                full_contexts.append(
                    chat_prompt.ChatMessages([{"role": "system", "content": system_info}, {"role": "user", "content": final_prompt}])
                )

            predictions = asyncio.run(
                generate_from_openai_chat_completion(
                    full_contexts=full_contexts,
                    model_config=model_config, 
                    temperature=0,
                    max_tokens=500
                )
            )

            new_qa_data = list()
            with open(output_fn, "a+") as f:
                for i, (qa, pred) in enumerate(zip(qa_data, predictions)):
                    inst = dict()
                    inst["task"] = qa['question']
                    inst["natural_language_plan"] = convert_qa_to_human_solution(qa_data[i], args.data_fn)
                    inst["subgoal_plan"] = pred
                    if not inst["subgoal_plan"]:
                        new_qa_data.append(qa)
                    else:
                        if inst["natural_language_plan"]:
                            converted_num += 1
                            f.write(json.dumps(inst)+'\n')
                        if converted_num >= data_num or converted_num >= available_qa_cases:
                            break
            qa_data = new_qa_data
    elif "musique" in args.data_fn:
        with open(args.data_fn, "r") as f:
            qa_data = [json.loads(line) for line in f]
        if not args.convert_all:
            qa_data = qa_data[args.st_idx: args.ed_idx]
            output_fn = f"complex_qa/musique_converted_time_080823_size_{args.st_idx}_{args.ed_idx}.json"
        else:
            output_fn = "complex_qa/musique_converted_time_080823_size_all.json"
            data_num = len(qa_data)

        available_qa_data = []
        for i in range(len(qa_data)):
            is_support = True
            for q in qa_data[i]['question_decomposition']:
                if not q['paragraph_support_idx']:
                    is_support = False
                    break
            if is_support:
                available_qa_data.append(qa_data[i])
        
        qa_data = available_qa_data.copy()
        available_qa_cases = len(qa_data)

        while converted_num < data_num and converted_num < available_qa_cases:
            full_contexts = list()
            for i in range(len(qa_data)):
                final_prompt = instruction_qa + f"Task: {qa_data[i]['question']}\n\n"
                if convert_qa_to_human_solution(qa_data[i], args.data_fn):
                    final_prompt += f"Natural language plan: {convert_qa_to_human_solution(qa_data[i], args.data_fn)}\n\n"
                full_contexts.append(
                    chat_prompt.ChatMessages([{"role": "system", "content": system_info}, {"role": "user", "content": final_prompt}])
                )

            predictions = asyncio.run(
                generate_from_openai_chat_completion(
                    full_contexts=full_contexts,
                    model_config=model_config, 
                    temperature=0,
                    max_tokens=500
                )
            )

            new_qa_data = list()
            with open(output_fn, "a+") as f:
                for i, (qa, pred) in enumerate(zip(qa_data, predictions)):
                    inst = dict()
                    inst["task"] = qa['question']
                    inst["natural_language_plan"] = convert_qa_to_human_solution(qa_data[i], args.data_fn)
                    inst["subgoal_plan"] = pred
                    if not inst["subgoal_plan"]:
                        new_qa_data.append(qa)
                    else:
                        if inst["natural_language_plan"]:
                            converted_num += 1
                            f.write(json.dumps(inst)+'\n')
                        if converted_num >= data_num or converted_num >= available_qa_cases:
                            break
            qa_data = new_qa_data


def web_agent_convertion(args):
    system_info = "You are a helpful assist to convert natural language plans or structural action list" \
                    "into subgoal-based plans and their corresponding structured actions."
    model_config = lm_config.LMConfig(provider="openai_chat", model="gpt-4")
    data_num = args.ed_idx - args.st_idx
    converted_num = 0

    data_fns = os.listdir("web_agent/data")
    data = list()
    for fn in data_fns:
        if "train_" in fn:
            with open(os.path.join("web_agent/data", fn), "r") as f:
                data += json.load(f)

    if not args.convert_all:
        data = data[args.st_idx: args.ed_idx]
        output_fn = f"web_agent/data/mind2web_converted_time_091300_size_{args.st_idx}_{args.ed_idx}.json"
    else:
        output_fn = "web_agent/data/mind2web_converted_time_091300_size_all.json"
        data_num = len(data)
    
    while converted_num < data_num:
        full_contexts = list()
        for i in range(len(data)):
            final_prompt = instruction_web_agent + f"Task: {data[i]['confirmed_task']}\n\n"
            final_prompt += f"Natural language plan: {'; '.join(data[i]['action_reprs'])}\n\n"
            full_contexts.append(
                chat_prompt.ChatMessages([{"role": "system", "content": system_info}, {"role": "user", "content": final_prompt}])
            )

        predictions = asyncio.run(
            generate_from_openai_chat_completion(
                full_contexts=full_contexts,
                model_config=model_config, 
                temperature=0,
                max_tokens=800
            )
        )

        new_data = list()
        with open(output_fn, "a+") as f:
            for i, (d, pred) in enumerate(zip(data, predictions)):
                inst = dict()
                inst["task"] = d['confirmed_task']
                inst["natural_language_plan"] = d['action_reprs']
                inst["subgoal_plan"] = pred
                if not inst["subgoal_plan"]:
                    new_data.append(d)
                else:
                    if inst["natural_language_plan"]:
                        converted_num += 1
                        f.write(json.dumps(inst)+'\n')
                    if converted_num >= data_num:
                        break
        data = new_data


def icode_convertion(args):
    system_info = "You are a helpful assist to convert programming language plans or structural action list" \
                    "into subgoal-based plans and their corresponding structured actions."
    model_config = lm_config.LMConfig(provider="openai_chat", model="gpt-4")
    data_num = args.ed_idx - args.st_idx
    converted_num = 0

    if "nl2bash" in args.data_fn:
        with open(args.data_fn, "r") as f:
            data = [json.loads(d) for d in f]
        print(len(data))

        if not args.convert_all:
            data = data[args.st_idx: args.ed_idx]
            output_fn = f"icode/data/nl2bash_converted_time_090919_size_{args.st_idx}_{args.ed_idx}.json"
        else:
            output_fn = "icode/data/nl2bash_converted_time_090919_size_all.json"
            data_num = len(data)
        
        while converted_num < data_num:
            full_contexts = list()
            for i in range(len(data)):
                final_prompt = instruction_icode + f"Task: {data[i]['nl']}\n\n"
                final_prompt += f"Programming language plan: {data[i]['bash']}\n\n"
                full_contexts.append(
                    chat_prompt.ChatMessages([{"role": "system", "content": system_info}, {"role": "user", "content": final_prompt}])
                )

            predictions = asyncio.run(
                generate_from_openai_chat_completion(
                    full_contexts=full_contexts,
                    model_config=model_config, 
                    temperature=0,
                    max_tokens=500
                )
            )

            new_data = list()
            with open(output_fn, "a+") as f:
                for i, (d, pred) in enumerate(zip(data, predictions)):
                    inst = dict()
                    inst["task"] = d['nl']
                    inst["natural_language_plan"] = d['bash']
                    inst["subgoal_plan"] = pred
                    if not inst["subgoal_plan"]:
                        new_data.append(d)
                    else:
                        if inst["natural_language_plan"]:
                            converted_num += 1
                            f.write(json.dumps(inst)+'\n')
                        if converted_num >= data_num:
                            break
            data = new_data
    elif "spider" in args.data_fn:
        with open(args.data_fn, "r") as f:
            data = json.load(f)
        print(len(data))

        if not args.convert_all:
            data = data[args.st_idx: args.ed_idx]
            output_fn = f"icode/data/spider_converted_time_090919_size_{args.st_idx}_{args.ed_idx}.json"
        else:
            output_fn = "icode/data/spider_converted_time_090919_size_all.json"
            data_num = len(data)
        
        while converted_num < data_num:
            full_contexts = list()
            for i in range(len(data)):
                final_prompt = instruction_icode_sql + f"Task: {data[i]['query']}\n\n"
                final_prompt += f"Natural language plan: {convert_sql_to_human_solution(data[i])}\n\n"
                full_contexts.append(
                    chat_prompt.ChatMessages([{"role": "system", "content": system_info}, {"role": "user", "content": final_prompt}])
                )

            predictions = asyncio.run(
                generate_from_openai_chat_completion(
                    full_contexts=full_contexts,
                    model_config=model_config, 
                    temperature=0,
                    max_tokens=1000
                )
            )

            new_data = list()
            with open(output_fn, "a+") as f:
                for i, (d, pred) in enumerate(zip(data, predictions)):
                    inst = dict()
                    inst["task"] = d['query']
                    inst["natural_language_plan"] = convert_sql_to_human_solution(d)
                    inst["subgoal_plan"] = pred
                    if not inst["subgoal_plan"]:
                        new_data.append(d)
                    else:
                        if inst["natural_language_plan"]:
                            converted_num += 1
                            f.write(json.dumps(inst)+'\n')
                        if converted_num >= data_num:
                            break
            data = new_data


def main(args):
    if args.domain == "maths":
        maths_convertion(args)
    elif args.domain == "complex_qa":
        complex_qa_convertion(args)
    elif args.domain == "icode":
        icode_convertion(args)
    elif args.domain == "web_agent":
        web_agent_convertion(args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--domain',
        dest='domain',
        type=str,
    )

    parser.add_argument(
        '--data_fn',
        dest='data_fn',
        type=str,
    )

    parser.add_argument(
        '--convert_all',
        dest='convert_all',
        action="store_true"
    )

    parser.add_argument(
        '--st_idx',
        dest='st_idx',
        type=int,
        default=0
    )

    parser.add_argument(
        '--ed_idx',
        dest='ed_idx',
        type=int,
        default=0
    )

    args = parser.parse_args()
    
    main(args)
