import os
import json
import re

def equals(pred, label):
    try:
        if eval(f"abs({pred} - {label})") <= 1.0: # change from 1e-5 to 1
            return True
        else:
            return False
    except (TypeError, SyntaxError, NameError) as e:
        return False

def normalize_solution(solution, answer, ignore=False):
    if solution is None:
        return False, None
    text = solution.replace('####', 'The answer is').replace("the answer is", "The answer is")
    if "The answer is" not in text:
        return False, text
    sol, ans = text.split("The answer is", 1)
    pattern = r"[-+]?\d{1,3}(?:,\d{3})*\.?\d*"
    ans = re.findall(pattern, ans)
    if len(ans) == 0:
        return False, text
    ans = ans[-1].replace(",", "")
    if not equals(ans, answer) and not ignore:
        return False, text
    else:
        text = sol + "The answer is " + ans
    return True, text


def main(start_iter, stop_iter, shuffle=False, show_examples=False, ignore=False):

    merged_data = []
    invalid_data = []
    incorrect_data = []
    for filename in os.listdir("./data/train"):
        if filename.endswith('.json'):
            with open(os.path.join('./data/train', filename), 'r') as f:
                data = json.load(f)
            sample = {"question": data["problem"], "answer": data["solution"].replace('####', 'The answer is')}
            # merged_data.append(sample)
            for i in range(start_iter, stop_iter):
                if f"generation_{i}" in data:
                    informal_problem, formal_problem, solution, answer = data[f"generation_{i}"].get("informal problem", ""), data[f"generation_{i}"]["formal problem"], \
                                                                            data[f"generation_{i}"].get("solution", ""), data[f"generation_{i}"].get("answer", "")
                    if informal_problem == "" or "SMT-LIB" in informal_problem or solution == "": valid = False
                    else: valid = True
                    if valid:
                        correct, normalize_answer = normalize_solution(solution, answer)
                        if correct or ignore: # ignore means do not use any criteria
                            sample = {"question": informal_problem, "answer": normalize_answer}
                            merged_data.append(sample)
                        else:
                            incorrect_data.append({"question": informal_problem, "smt": formal_problem, "solution": solution, "answer": answer})
                    else:
                        invalid_data.append({"question": informal_problem, "solution": solution, "answer": answer})
                else:
                    break

    print('invalid data=====', len(invalid_data))
    print('incorrect data=====', len(incorrect_data))  
    print('total merged data=====', len(merged_data))
    if show_examples:
        print("examples of invalid data:", "="*100)
        for data in invalid_data[:5]:
            print(f"Question: {data['question']} \nSolution: {data['solution']} \nAnswer: {data['answer']}")
        print("examples of incorrect data:", "="*100)
        for data in incorrect_data[:100]:
            print(f"SMT: {data['smt']}\nQuestion: {data['question']}\nSolution: {data['solution']} \nAnswer: {data['answer']}")

    if shuffle:
        import random
        random.shuffle(merged_data)
    with open('../data/gsm8k_train.json', 'w') as output_file:
        json.dump(merged_data, output_file, indent=2)


    # merged_data = []
    # invalid_data = []
    # incorrect_data = []
    # for filename in os.listdir("./data/test"):
    #     if filename.endswith('.json'):
    #         with open(os.path.join('./data/test', filename), 'r') as f:
    #             data = json.load(f)
    #             sample = {"question": data["problem"], "answer": data["solution"].replace('####', 'The answer is')}
    #             merged_data.append(sample)
    #             for i in range(num_iter):
    #                 if f"generation_{i}" in data:
    #                     informal_problem = data[f"generation_{i}"].get("informal problem", "")
    #                     if informal_problem == "" or "SMT-LIB" in informal_problem: valid = False
    #                     else: valid = True
    #                     if valid:
    #                         correct, normalize_answer = normalize_solution(data[f"generation_{i}"]["solution"], data[f"generation_{i}"]["answer"])
    #                         if correct:
    #                             sample = {"question": data[f"generation_{i}"]["informal problem"], "answer": normalize_answer}
    #                             merged_data.append(sample)
    #                         else:
    #                             incorrect_data.append({"question": data[f"generation_{i}"]["informal problem"], "solution": data[f"generation_{i}"]["solution"], "answer": data[f"generation_{i}"]["answer"]})
    #                     else:
    #                         invalid_data.append({"question": data[f"generation_{i}"]["informal problem"], "solution": data[f"generation_{i}"]["solution"], "answer": data[f"generation_{i}"]["answer"]})
    #                 else:
    #                     break

    # print('invalid data=====', len(invalid_data))
    # print('incorrect data=====', len(incorrect_data))  
    # print('total merged data=====', len(merged_data))
    # if show_examples:
    #     print("examples of invalid data:", "="*100)
    #     for data in invalid_data[:5]:
    #         print(f"Question: {data['question']} \nSolution: {data['solution']} \nAnswer: {data['answer']}")
    #     print("examples of incorrect data:", "="*100)
    #     for data in incorrect_data[:5]:
    #         print(f"Question: {data['question']} \nSolution: {data['solution']} \nAnswer: {data['answer']}")

    # with open('../data/gsm8k_test.json', 'w') as output_file:
    #     json.dump(merged_data, output_file, indent=2)


if __name__ == "__main__":
    main(start_iter=50, stop_iter=65, shuffle=False, show_examples=False, ignore=True)