import json
import os
import pprint
from datetime import datetime
from time import sleep
from tqdm import tqdm
import argparse
from collections import Counter
from distutils.util import strtobool

from src.evol.data_utils import load_data
from src.evol.openai_backend import call_chatgpt
from src.utils.data_utils import extract_answer_math
from src.utils.code_utils import execute_tora
from src.utils.math_utils import compare_ans


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--prompt_path", default=None, type=str)
    parser.add_argument("--dataset", default="gsm", type=str)
    parser.add_argument("--data_path", default=None, type=str)
    parser.add_argument("--model", default="gpt-3.5-turbo", type=str)
    parser.add_argument("--temperature", default=0.0, type=float)
    parser.add_argument("--top_p", default=1.0, type=float)
    parser.add_argument("--max_tokens", default=1024, type=int)
    parser.add_argument("--num_seqs", default=1, type=int)
    parser.add_argument("--num_skips", default=0, type=int)
    parser.add_argument("--input_col", default="question", type=str)
    parser.add_argument("--output_col", default="answer", type=str)
    parser.add_argument("--max_iter", default=3, type=int)
    args = parser.parse_args()
    return args

def load_prompt(prompt_path):
    with open(prompt_path, 'r', encoding='utf-8') as fp:
        prompt = fp.read().strip() + "\n\n"
    return prompt

def stop_tora(result):
    if "\\boxed" in result:
        return True
    return False

def main(args):
    # load prompt
    prompt = load_prompt(args.prompt_path)
    print(prompt)
    print("%" * 30, "CoT", "%" * 30)

    os.makedirs(f"result/{args.model}/{args.dataset}", exist_ok=True)
    output_path = f"result/{args.model}/{args.dataset}/t{args.temperature}_n{args.num_seqs}.jsonl"
    samples = load_data(args.dataset, args.data_path)
    save_samples, scores = [], []

    with open(output_path, "w" if args.num_skips == 0 else "a") as f:
        pbar = tqdm(
            samples[args.num_skips:], initial=args.num_skips, total=len(samples)
        )
        for i, x in enumerate(pbar):
            question = x[args.input_col]
            tora_prompt = prompt.replace("{question}", question)
            full_output = ""
            for itr in range(1, args.max_iter + 2):
                output = call_chatgpt(messages=[{"role": "system", "content": "You are a helpful expert for math problem solving."}, {"role": "user", "content": tora_prompt}], stop=["```output", "---"], max_tokens=args.max_tokens, temperature=args.temperature, num_beams=args.num_seqs)[0]
                print(output)
                tora_prompt += output
                full_output += output
                if stop_tora(output) or itr == args.max_iter + 1:
                    break 
                code_output, code_report = execute_tora(output)
                print("code output", code_output)
                if code_output != "":
                    code_output = f"```output\n{code_output if code_output is not None else code_report}\n```\n"
                    tora_prompt += code_output
                    full_output += code_output
            pred_ans = extract_answer_math(full_output)
            label_ans = x[args.output_col]
            score = int(compare_ans(pred_ans, label_ans))
            scores.append(score)
            save_sample = x
            save_sample["generation"] = [full_output]
            save_sample["pred_answer"] = pred_ans
            save_sample["label_answer"] = label_ans
            save_sample["score"] = score
            save_samples.append(save_sample)
            f.write(json.dumps(save_sample, ensure_ascii=False, indent=4) + "\n")
            f.flush()
    print(f"Accuracy - {sum(scores) / len(scores)}")

if __name__ == "__main__":
    args = parse_args()
    main(args)