import json
import os
import pprint
import asyncio
from datetime import datetime
from time import sleep
from tqdm import tqdm
import argparse
from collections import Counter
from distutils.util import strtobool
import multiprocessing as mp

from src.evol.data_utils import load_data
from src.evol.openai_backend import call_chatgpt, LLM
from src.utils.data_utils import extract_answer_math, extract_answer_number
from src.evol.openai_utils import num_tokens_from_messages
from src.utils.code_utils import execute_tora, execute_code_interactive
from src.utils.math_utils import compare_ans, vote
from src.utils.file_utils import load_jsonl_ml


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--prompt_path", default=None, type=str)
    parser.add_argument("--strategy_path", default=None, type=str)
    parser.add_argument("--dataset", default="gsm", type=str)
    parser.add_argument("--data_path", default=None, type=str)
    parser.add_argument("--model", default="gpt-3.5-turbo", type=str)
    parser.add_argument("--temperature", default=0.0, type=float)
    parser.add_argument("--top_p", default=1.0, type=float)
    parser.add_argument("--max_tokens", default=1024, type=int)
    parser.add_argument("--num_seqs", default=1, type=int)
    parser.add_argument("--num_skips", default=0, type=int)
    parser.add_argument("--input_col", default="question", type=str)
    parser.add_argument("--output_col", default="answer", type=str)
    parser.add_argument("--max_iter", default=3, type=int)
    parser.add_argument("--num_process", default=1, type=int)
    parser.add_argument("--output_path", default=None, type=str)
    parser.add_argument("--batch_size", default=10, type=int)
    parser.add_argument("--do_recover", action="store_true")
    args = parser.parse_args()
    return args


def load_prompt(prompt_path):
    with open(prompt_path, "r", encoding="utf-8") as fp:
        prompt = fp.read().strip()
    return prompt


def stop_tora(result):
    if "\\boxed" in result:
        return True
    return False


def clean_strategy(strategy):
    strategy = strategy.split("## Strategy")[-1].strip()
    strategy = strategy.split("##")[0].strip()
    return strategy


def get_max_tokens_model(batch_messages, max_tokens, model):
    for m in batch_messages:
        num_tokens = num_tokens_from_messages(m)
        while num_tokens + max_tokens > 4000 and max_tokens > 512:
            max_tokens -= 100
        if max_tokens < 512:
            model = "gpt-3.5-turbo-16k"
            max_tokens = 1024
            break
    return max_tokens, model


def parse_output(output):
    num_sols = output.count("## Solution")
    if num_sols != 1:
        return None, None
    strategy, solution = output.split("## Solution")
    strategy = clean_strategy(strategy)
    solution = solution.strip()
    return strategy, solution


def batch_tora(llm, batch_messages_base):
    max_iters = 2
    itr = 0
    while itr < max_iters:
        if itr == 0:
            max_tokens, model = get_max_tokens_model(
                batch_messages_base, args.max_tokens, args.model
            )
            print(max_tokens)
            batch_outputs = asyncio.run(
                llm.achat(
                    batch_messages_base,
                    model=model,
                    stop=["```output", "---"],
                    max_tokens=max_tokens,
                    temperature=args.temperature,
                    num_beams=args.num_seqs,
                )
            )
            batch_full_outputs = [["" for o in outputs] for outputs in batch_outputs]
        else:
            max_tokens, model = get_max_tokens_model(
                batch_messages, args.max_tokens, args.model
            )
            print(max_tokens)
            batch_outputs = asyncio.run(
                llm.achat(
                    batch_messages,
                    model=model,
                    stop=["```output", "---"],
                    max_tokens=max_tokens,
                    temperature=0,
                    num_beams=1,
                )
            )
            temp_outputs = [[None for _ in outputs] for outputs in batch_full_outputs]
            for outputs, idx in zip(batch_outputs, remain_ids):
                temp_outputs[idx[0]][idx[1]] = outputs[0]
            batch_outputs = temp_outputs
        # print(batch_outputs)
        remain_ids = []
        for i, outputs in enumerate(batch_outputs):
            for j, output in enumerate(outputs):
                if output is not None:
                    batch_full_outputs[i][j] += output
                    if not stop_tora(output):
                        if output.strip().endswith("```"):
                            code_output = execute_code_interactive(output)
                            print("code output", code_output)
                            if (
                                len(code_output) > 0
                                and "SyntaxError" not in code_output
                            ):
                                code_output = f"```output\n{code_output}\n```\n"
                                batch_full_outputs[i][j] += code_output
                        elif output == "" and max_iters < 5:
                            max_iters += 1
                        remain_ids.append((i, j))
        if len(remain_ids) == 0:
            break
        user_messages = [
            batch_messages_base[idx[0]][1]["content"]
            + batch_full_outputs[idx[0]][idx[1]]
            for idx in remain_ids
        ]
        batch_messages = [
            [
                {
                    "role": "system",
                    "content": "You are a helpful expert for math problem solving.",
                },
                {
                    "role": "user",
                    "content": user_message,
                },
            ]
            for user_message in user_messages
        ]
        itr += 1
    return batch_full_outputs


def need_recover(sample):
    for generation in sample["generation"]:
        if generation.startswith("```") or "SyntaxError" in generation:
            return True
    return False


def main(args, samples, idx):
    # load prompt
    prompt = load_prompt(args.prompt_path)
    if idx <= 0:
        print(prompt)
    os.makedirs(f"result/{args.model}/{args.dataset}", exist_ok=True)
    if args.output_path is None:
        output_path = f"result/{args.model}/{args.dataset}/t{args.temperature}_n{args.num_seqs}-train_specific.jsonl"
    else:
        output_path = args.output_path
    print("%" * 30, "Tora", "%" * 30)
    print("Start PID %d and save to %s" % (os.getpid(), output_path))

    if idx != -1:
        output_path = output_path.replace(".jsonl", f"_{idx}.jsonl")
    save_samples, scores = [], []
    # samples = samples[args.num_skips :]
    if args.do_recover:
        ref_samples = load_jsonl_ml(output_path)
        with open(output_path.replace(".jsonl", ".cache.jsonl"), "w") as f:
            for s in ref_samples:
                f.write(json.dumps(s, ensure_ascii=False, indent=4) + "\n")
    else:
        ref_samples = []
    llm = LLM()
    batch_size = args.batch_size
    print(samples[0]["question"])
    with open(output_path, "w" if args.num_skips == 0 else "a") as f:
        for i in tqdm(range(0, len(samples), batch_size)):
            start = i
            batch_samples = samples[start : start + batch_size]
            run_samples = []
            run_ids = []
            batch_save_samples = [{} for _ in range(len(batch_samples))]
            print(len(ref_samples))
            for j in range(len(batch_samples)):
                if i + j < len(ref_samples):
                    if need_recover(ref_samples[i + j]):
                        print("Need", j)
                        run_samples.append(samples[i + j])
                        run_ids.append(j)
                    else:
                        batch_save_samples[j] = ref_samples[i + j]
                else:
                    run_samples.append(samples[i + j])
                    run_ids.append(j)
            print(run_ids)
            batch_messages = [
                [
                    {
                        "role": "system",
                        "content": "You are a helpful expert for math problem solving.",
                    },
                    {
                        "role": "user",
                        "content": prompt.replace("{question}", s["question"]),
                    },
                ]
                for s in run_samples
            ]
            batch_outputs = batch_tora(llm, batch_messages)
            # print(batch_outputs)
            for s, outputs, run_id in zip(run_samples, batch_outputs, run_ids):
                strategies, solutions, pred_anss = [], [], []
                for o in outputs:
                    solutions.append(o.strip())
                    pred_anss.append(extract_answer_math(o.strip()))
                label_ans = s["answer"]
                score = 0
                for p in pred_anss:
                    if compare_ans(p, label_ans):
                        score = 1
                        break
                scores.append(score)
                save_sample = s
                save_sample["generation"] = solutions
                save_sample["strategy"] = strategies
                save_sample["pred_answers"] = pred_anss
                save_sample["label_answer"] = label_ans
                save_sample["score"] = score
                batch_save_samples[run_id] = save_sample
            print("save", len(batch_save_samples))
            for save_sample in batch_save_samples:
                f.write(json.dumps(save_sample, ensure_ascii=False, indent=4) + "\n")
                f.flush()
    print(f"Accuracy - {sum(scores) / len(scores)}")


if __name__ == "__main__":
    args = parse_args()
    samples = load_data(args.dataset, args.data_path)
    if args.num_process == 1:
        main(args, samples, idx=-1)
    else:
        num_each_split = int(len(samples) / args.num_process)
        p = mp.Pool(args.num_process)
        for idx in range(args.num_process):
            start = idx * num_each_split
            if idx == args.num_process - 1:
                end = max((idx + 1) * num_each_split, len(samples))
            else:
                end = (idx + 1) * num_each_split
            split_data = samples[start:end]
            p.apply_async(
                main,
                args=(
                    args,
                    split_data,
                    idx,
                ),
            )
        p.close()
        p.join()
        print("All of the child processes over!")
