import json
import os
import pprint
from datetime import datetime
from time import sleep
from tqdm import tqdm
import argparse
from collections import Counter
from distutils.util import strtobool
import multiprocessing as mp


from src.evol.data_utils import load_data
from src.evol.openai_backend import call_chatgpt
from src.utils.data_utils import extract_answer_math
from src.utils.code_utils import execute_tora
from src.utils.math_utils import compare_ans, vote


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--prompt_path", default=None, type=str)
    parser.add_argument("--dataset", default="gsm", type=str)
    parser.add_argument("--data_path", default=None, type=str)
    parser.add_argument("--model", default="gpt-3.5-turbo", type=str)
    parser.add_argument("--temperature", default=0.0, type=float)
    parser.add_argument("--top_p", default=1.0, type=float)
    parser.add_argument("--max_tokens", default=1024, type=int)
    parser.add_argument("--num_seqs", default=1, type=int)
    parser.add_argument("--num_skips", default=0, type=int)
    parser.add_argument("--input_col", default="question", type=str)
    parser.add_argument("--output_col", default="answer", type=str)
    parser.add_argument("--max_iter", default=3, type=int)
    parser.add_argument("--num_process", default=1, type=int)
    parser.add_argument("--output_path", default=None, type=str)
    args = parser.parse_args()
    return args


def load_prompt(prompt_path):
    with open(prompt_path, "r", encoding="utf-8") as fp:
        prompt = fp.read().strip() + "\n\n"
    return prompt


def stop_tora(result):
    if "\\boxed" in result:
        return True
    return False


def tora(args, prompt, x, temperature=0.0):
    question = x[args.input_col]
    tora_prompt = prompt.replace("{question}", question)
    full_output = ""
    for itr in range(1, args.max_iter + 2):
        output = call_chatgpt(
            messages=[
                {
                    "role": "system",
                    "content": "You are a helpful expert for math problem solving.",
                },
                {"role": "user", "content": tora_prompt},
            ],
            model=args.model,
            stop=["```output", "---"],
            max_tokens=args.max_tokens,
            temperature=temperature,
            num_beams=args.num_seqs,
        )[0]
        print(output)
        tora_prompt += output
        full_output += output
        if stop_tora(output) or itr == args.max_iter + 1:
            break
        code_output, code_report = execute_tora(output)
        print("code output", code_output)
        if code_output != "":
            code_output = f"```output\n{code_output if code_output is not None else code_report}\n```\n"
            tora_prompt += code_output
            full_output += code_output
    return full_output


def main(args, samples, idx):
    # load prompt
    prompt = load_prompt(args.prompt_path)
    if idx <= 0:
        print(prompt)
    os.makedirs(f"result/{args.model}/{args.dataset}", exist_ok=True)
    if args.output_path is None:
        output_path = f"result/{args.model}/{args.dataset}/t{args.temperature}_n{args.num_seqs}-train.jsonl"
    else:
        output_path = args.output_path
    print("%" * 30, "Tora", "%" * 30)
    print("Start PID %d and save to %s" % (os.getpid(), output_path))

    if idx != -1:
        output_path = output_path.replace(".jsonl", f"_{idx}.jsonl")
    save_samples, scores = [], []

    with open(output_path, "w" if args.num_skips == 0 else "a") as f:
        pbar = tqdm(
            samples[args.num_skips :], initial=args.num_skips, total=len(samples)
        )
        for i, x in enumerate(pbar):
            label_ans = x[args.output_col]
            batch_scores = []
            batch_outputs = []
            batch_pred_anss = []
            for i in range(10):
                full_output = tora(
                    args, prompt, x, temperature=args.temperature + i * 0.1
                )
                pred_ans = extract_answer_math(full_output)
                score = int(compare_ans(pred_ans, label_ans))
                batch_scores.append(score)
                batch_outputs.append(full_output)
                batch_pred_anss.append(pred_ans)
                pred_ans = extract_answer_math(full_output)
                score = int(compare_ans(pred_ans, label_ans))
                if score == 1:
                    break
            scores.append(score)
            save_sample = x
            save_sample["generation"] = batch_outputs
            save_sample["pred_answers"] = batch_pred_anss
            save_sample["pred_answer"] = pred_ans
            save_sample["label_answer"] = label_ans
            save_sample["score"] = score
            save_samples.append(save_sample)
            f.write(json.dumps(save_sample, ensure_ascii=False, indent=4) + "\n")
            f.flush()
    print(f"Accuracy - {sum(scores) / len(scores)}")


if __name__ == "__main__":
    args = parse_args()
    samples = load_data(args.dataset, args.data_path)
    if args.num_process == 1:
        main(args, samples, idx=-1)
    else:
        num_each_split = int(len(samples) / args.num_process)
        p = mp.Pool(args.num_process)
        for idx in range(args.num_process):
            start = idx * num_each_split
            if idx == args.num_process - 1:
                end = max((idx + 1) * num_each_split, len(samples))
            else:
                end = (idx + 1) * num_each_split
            split_data = samples[start:end]
            p.apply_async(
                main,
                args=(
                    args,
                    split_data,
                    idx,
                ),
            )
        p.close()
        p.join()
        print("All of the child processes over!")
        main(args)
