import json
import random
import time
import os
import re
import math
from fractions import Fraction
from typing import Iterable, Union, Sequence, List, Dict, Any, Tuple
from numbers import Real, Integral
from math import ceil
from numbers import Number
from decimal import Decimal, ROUND_HALF_UP, ROUND_FLOOR, ROUND_UP, ROUND_DOWN
from math import sqrt, isclose
import numbers
from itertools import permutations
from functools import reduce
import argparse
import calendar
import datetime


def find_number_in_question(question, modified_question):

    # print(f"question: {question}")
    # print(f"modified_question: {modified_question}")

    while question[0] == modified_question[0]:
        question = question[1:].strip()
        modified_question = modified_question[1:].strip()

    while question[-1] == modified_question[-1]:
        question = question[:-1].strip()
        modified_question = modified_question[:-1].strip()

    question_clear = re.sub(r'^[a-zA-Z\s]+', '', question).strip()
    question_clear = re.sub(r'[a-zA-Z\s]+$', '', question_clear).strip()
    if question_clear:
        question = question_clear

    # print(f"question: {question}")
    # print(f"modified_question: {modified_question}")

    if modified_question.count('<variable>') >= 2:
        question = question.split(' ')[-1]

    question = question.replace('$', '')
    question = question.replace(',', '')
    question = question.replace('%', '')

    # if question.endswith('th') or question.endswith('st') or question.endswith('nd') or question.endswith('rd'):
    #     question = question[:-2].strip()

    # if question.endswith('am') or question.endswith('pm'):
    #     question = question[:-2].strip()

    if '/' in question:
        question = question.split('/')
        if len(question) == 2:
            try:
                return float(question[0]) / float(question[1])
            except:
                print(f"Error in question: {question}")

    str_to_number = {
        'one': 1,
        'two': 2,
        'three': 3,
        'four': 4,
        'five': 5,
        'six': 6,
        'seven': 7,
        'eight': 8,
        'nine': 9,
        'ten': 10,
        'eleven': 11,
        'twelve': 12,
        'thirteen': 13,
        'fourteen': 14,
        'fifteen': 15,
        'sixteen': 16,
        'seventeen': 17,
        'eighteen': 18,
        'nineteen': 19,
        'twenty': 20,
        'thirty': 30,
        'forty': 40,
        'fifty': 50,
        'sixty': 60,
        'seventy': 70,
        'eighty': 80,
        'ninety': 90,
        'twice': 2,
        'half': 0.5,
    }

    try:
        return int(question)
    except:
        try:
            return float(question)
        except:
            return str_to_number[question.lower()]

def get_final_answer(code_str):
    
    safe_globals = {
        "__builtins__": __builtins__,
        "math": math,
        "Fraction": Fraction,
        "Iterable": Iterable,
        "Union": Union,
        "Sequence": Sequence,
        "Real": Real,
        "Integral": Integral,
        "Number": Number,
        "ceil": ceil,
        "Decimal": Decimal,
        "ROUND_HALF_UP": ROUND_HALF_UP,
        "sqrt": sqrt,
        "isclose": isclose,
        "numbers": numbers,
        "List": List,
        "Dict": Dict,
        "Any": Any,
        "permutations": permutations,
        "Tuple": Tuple,
        "ROUND_FLOOR": ROUND_FLOOR,
        "ROUND_UP": ROUND_UP,
        "ROUND_DOWN": ROUND_DOWN,
        "reduce": reduce,
        "re": re,
        "calendar": calendar,
        "datetime": datetime,
    }

    variable_name = 'previous_answer'
    
    local_vars = {}
    try:
        exec(code_str, safe_globals, local_vars)
        
        if variable_name in local_vars:
            return True, float(local_vars[variable_name])
        else:
            return False, f"Variable '{variable_name}' not found"
            
    except Exception as e:
        return False, str(e)

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="datasets/gsm8k_test_with_code_with_variables_in_questions_200_new.jsonl")
    parser.add_argument("--output_dataset", type=str)
    parser.add_argument("--num_subproblems", type=int)
    parser.add_argument("--num_repetitions", type=int, default=1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--only_int_answers", action="store_true", default=False)
    parser.add_argument("--only_small_int_answers", action="store_true", default=False)
    parser.add_argument("--prompt_version", type=int, default=0)
    args = parser.parse_args()

    if args.only_small_int_answers:
        args.only_int_answers = True

    with open(args.dataset, 'r') as infile:
        data = [json.loads(line) for line in infile]

    random.seed(args.seed)

    cnt_errors_in_find_number_in_question = 0
    cnt_results = 0

    for __ in range(args.num_repetitions):

        result = []

        prompt = ""
        python_code = ""
        final_answer = None
        
        num_subproblems = args.num_subproblems

        variable_names = ['X', 'Y', 'Z', 'U', 'V', 'W']

        available_indexes = []
        new_available_indexes = list(range(len(data)))

        for _ in range(10):

            available_indexes = new_available_indexes
            random.shuffle(available_indexes)
            new_available_indexes = []

            print(f"available_indexes: {len(available_indexes)}")

            for entry_id in range(len(available_indexes) + 1):

                problem_id = available_indexes[entry_id] if entry_id < len(available_indexes) else None
                entry = data[problem_id] if problem_id is not None else None

                if entry is not None and entry['function_call'] is None:
                    print(entry)
                    raise Exception("Function call is None")

                if entry is not None and entry['modified_question'] is None:
                    print(entry)
                    raise Exception("Modified question is None")

                step_id = entry_id % num_subproblems * 2

                if entry_id % num_subproblems == 0:

                    random.shuffle(variable_names)

                    if entry_id > 0:

                        ok, final_answer = get_final_answer(python_code)
                        if not ok or (args.only_int_answers and abs(final_answer - int(final_answer)) > 1e-10) or (args.only_small_int_answers and (final_answer > 1000 or final_answer < -1000)):
                            # print(f"Error in final answer for entry {entry_id}")
                            # print(f"Python code: {python_code}")
                            print(f"Error: {final_answer}")
                            print("================================================")
                            for j in range(num_subproblems):
                                new_available_indexes.append(available_indexes[entry_id - num_subproblems + j])
                            # raise Exception("Error in final answer")
                        else:
                            # print(f"Final answer: {final_answer}")
                            if args.only_int_answers:
                                final_answer = int(final_answer)
                            if args.prompt_version == 0:
                                prompt = prompt[:-1] + f" In the end, provide only the final numerical answer."
                            result.append({"prompt": prompt, "python_code": python_code, "final_answer": final_answer})

                    if entry_id == len(available_indexes):
                        break

                    if args.prompt_version == 0:
                        prompt = f"""
Step 1: Solve the following math problem step by step:

{entry['question']}
"""
                    elif args.prompt_version == 1 and args.num_subproblems == 2:
                        prompt = f"""
You have two math problems to solve. Solve the first math problem step by step. Take your final answer and substitute it for {variable_names[0]} in the second math problem. Then solve the updated version of the second math problem step by step. Do not use {variable_names[0]} in your calculations, use the exact numerical value instead. In the end, provide your final answer to the second math problem.

Problem 1: {entry['question']}
"""
                    else:
                        print(f"Incorrect prompt version specified: {args.prompt_version} (number of subproblems: {args.num_subproblems})")
                        return
                    # print(f"entry['question']: {entry['question']}")
                    # print(f"entry['modified_question']: {entry['modified_question']}")
                    try:
                        x = find_number_in_question(entry['question'].strip(), entry['modified_question'].strip())
                    except Exception as e:
                        cnt_errors_in_find_number_in_question += 1
                        python_code += "assert False\n"
                        continue
                    # print(f"x: {x}")
                    # print("================================================")

                    python_code = f"""
{entry['code_clean']}

previous_answer = {entry['function_call'].replace('previous_answer', str(x))}
"""

                elif entry_id < len(available_indexes):

                    variable_name = variable_names[entry_id % len(variable_names)]

                    if args.prompt_version == 0:
                        prompt += f"""
Step {step_id}: Take your final answer from Step {step_id - 1} and substitute it for {variable_name} in the following problem:

{entry['modified_question'].replace('<variable>', variable_name)}

Write out the updated version of the problem with the number from Step {step_id - 1} in place of {variable_name}.

Step {step_id + 1}: Solve the updated problem from Step {step_id} step by step.
"""
                    elif args.prompt_version == 1 and args.num_subproblems == 2:
                        prompt += f"""
Problem 2: {entry['modified_question'].replace('<variable>', variable_names[0])}
"""

                    python_code += f"""
{entry['code_clean']}

previous_answer = {entry['function_call']}
"""

                if entry_id == len(available_indexes):
                    for j in range(entry_id // num_subproblems * num_subproblems, len(available_indexes)):
                        new_available_indexes.append(available_indexes[j])

        print(f"new_available_indexes: {new_available_indexes}")

        random.shuffle(result)
        cnt_results += len(result)
                
        with open(args.output_dataset, 'a') as outfile:
            for entry in result:
                # print(f"Entry: {entry}")
                json.dump(entry, outfile, ensure_ascii=False)
                outfile.write("\n")

    print(f"cnt_errors_in_find_number_in_question: {cnt_errors_in_find_number_in_question}")
    print(f"cnt_results: {cnt_results}")

    # for entry in result:
    #     print(f"Prompt: {entry['prompt']}")
    #     print(f"Python code: {entry['python_code']}")
    #     print()
    #     print()
    #     print()
    #     print()

if __name__ == "__main__":
    main()