import json
import os
import pprint
from datetime import datetime
from time import sleep
from tqdm import tqdm
import argparse
from collections import Counter
from distutils.util import strtobool
from code import InteractiveConsole

from src.evol.data_utils import load_data
from src.evol.openai_backend import call_chatgpt
from src.utils.data_utils import extract_answer_math
from src.utils.code_utils import execute_tora, execute_interactive
from src.utils.math_utils import compare_ans


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--prompt_path", default=None, type=str)
    parser.add_argument("--dataset", default="gsm", type=str)
    parser.add_argument("--data_path", default=None, type=str)
    parser.add_argument("--model", default="gpt-3.5-turbo", type=str)
    parser.add_argument("--temperature", default=0.0, type=float)
    parser.add_argument("--top_p", default=1.0, type=float)
    parser.add_argument("--max_tokens", default=1024, type=int)
    parser.add_argument("--num_seqs", default=1, type=int)
    parser.add_argument("--num_skips", default=0, type=int)
    parser.add_argument("--input_col", default="question", type=str)
    parser.add_argument("--output_col", default="answer", type=str)
    parser.add_argument("--max_iter", default=3, type=int)
    args = parser.parse_args()
    return args

def load_prompt(prompt_path):
    with open(prompt_path, 'r', encoding='utf-8') as fp:
        prompt = fp.read().strip() + "\n\n"
    return prompt

def stop_tora(result):
    if "\\boxed" in result:
        return True
    return False

def main(args):
    # load prompt
    prompt = load_prompt(args.prompt_path)
    print(prompt)
    print("%" * 30, "Tora", "%" * 30)

    # input and output file
    # now = datetime.now()
    # dt_string = now.strftime("%m-%d_%H-%M")
    # init_file = f'outputs/{args.model}/{args.data}/{args.split}_{args.prompt_type}_{args.num_test_sample}_seed{args.seed}.jsonl'
    # out_file = f'outputs/{args.model}/{args.data}/{args.split}_{args.critic_type}_{args.num_test_sample}_t{args.temperature}_seed{args.seed}_s{args.start}_e{args.end}_{dt_string}.jsonl'
    os.makedirs(f"result/csv_base/{args.model}/{args.dataset}", exist_ok=True)
    output_path = f"result/csv_base/{args.model}/{args.dataset}/t{args.temperature}_n{args.num_seqs}.jsonl"
    samples = load_data(args.dataset, args.data_path)
    save_samples, scores = [], []
    with open(output_path, "w" if args.num_skips == 0 else "a") as f:
        pbar = tqdm(
            samples[args.num_skips:], initial=args.num_skips, total=len(samples)
        )
        for i, x in enumerate(pbar):
            console = InteractiveConsole()
            question = x[args.input_col]
            tora_prompt = prompt.replace("{question}", question)
            full_output = ""
            for itr in range(1, args.max_iter + 2):
                output = call_chatgpt(messages=[{"role": "system", "content": "You are a helpful expert for math problem solving."}, {"role": "user", "content": tora_prompt}], stop=["```output", "---"], max_tokens=args.max_tokens, temperature=args.temperature, num_beams=args.num_seqs)[0]
                print(output)
                tora_prompt += output
                full_output += output
                if stop_tora(output) or itr == args.max_iter + 1:
                    break 
                # code_output, code_report = execute_tora(output)
                # code_output = f"```output\n{code_output if code_output is not None else code_report}\n```\n"
                code_output = execute_interactive(output, console)
                print("code output", code_output)
                if len(code_output) > 0:
                    code_output = f"```output\n{code_output}\n```\n"
                    tora_prompt += code_output
                    full_output += code_output
            pred_ans = extract_answer_math(full_output)
            label_ans = x[args.output_col]
            score = int(compare_ans(pred_ans, label_ans))
            scores.append(score)
            save_sample = x
            save_sample["generation"] = [full_output]
            save_sample["pred_answer"] = pred_ans
            save_sample["label_answer"] = label_ans
            save_sample["score"] = score
            save_samples.append(save_sample)
            f.write(json.dumps(save_sample, ensure_ascii=False, indent=4) + "\n")
            f.flush()
    print(f"Accuracy - {sum(scores) / len(scores)}")

if __name__ == "__main__":
    args = parse_args()
    main(args)