import os
from generation_models import message_construct_func, GPT_response, extract_and_check, extract_code
import re
import pandas as pd
import subprocess
import time
from post_process_path_plan_continuous_func import check_trajectory
import string
import json
from argparse import ArgumentParser

def is_equiv_func(target_answer, extracted_text):
    input_prompt_equiv_func = r'Evaluate whether the following numerical pair has the same values.' \
                              r'Neglect the format difference and the extra text like units and names and equations.' \
                              r'The value can be regarded as the same if they are < 1e-3 relative difference.' \
                              r'The examples are: ("12", "12.0", True), ("5*sqrt(13)", "15.97112779602377", False),' \
                              r'("10\text{ inches}", "10.0", True), ("42", "41.99999999999998", True), ("frac{63}{64}", "0.984375", True),' \
                              r'("frac{5\sqrt{5}}{3}", "5\sqrt{5}/3", True), (\tfrac34, "3/4", True), ("frac{1033}{4}+30\sqrt{3}", "169.0", False), ("AB=12+12\sqrt{3}", "12(\sqrt{3} + 1)", True),' \
                              r'((18, -18), (18, -18), True). ' \
                              r'In the end of your response, answer <<<True>>> or <<<False>>>'
    input_prompt_equiv_func = input_prompt_equiv_func + f'\n({target_answer}, {extracted_text}), Your answer:'
    response = GPT_response('Your are a helpful checker for math expressions.', input_prompt_equiv_func, model_name='gpt-4o',
                            code_interpreter=False, user_prompt_list = [input_prompt_equiv_func], response_total_list = [])
    return response

def extract_equation_with_GPT4(response):
    prompt = 'Your task is to extract the final numerical answer of the given answer by another LLM:\n' \
             'Here is the response, return your answer with the format <<<list>>>, like <<<43243.4>>>.\n' \
             'If the input text does not have <<<>>> and is already the pure answer, add <<<>>> and return your answer.\n' \
             'Note that if you find no final answer is answered, then directly answer <<<No answer found>>>.\n' \
             'Input text: ' \

    extract_equation = GPT_response('', prompt + response, model_name='gpt-4o', code_interpreter=False, user_prompt_list = [prompt + response], response_total_list = [])
    return extract_equation

def post_process_gsm(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
               AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter, all_code_with_COT,
               all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num, method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first):
    print('\n' + '*'*30)
    print(f'Model_name: {model_name}, AutoGen_prompt_system_message: {AutoGen_prompt_system_message}, '
          f'AutoGen_prompt_concatenate: {AutoGen_prompt_concatenate}, code_interpreter: {code_interpreter}, Encourage_code_execution_interpreter: {Encourage_code_execution_interpreter}, '
          f'all_code_with_COT: {all_code_with_COT}, all_code_without_COT: {all_code_without_COT}, all_text: {all_text}, multi_turn_planning: {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}\n')

    if method_9_all_text_all_code_summarizer == False and method_10_LLM_estimates_scores_first == False:
        base_save_code_dir = save_input_dir + f'/result_gsm_{model_name}_interpreter_{code_interpreter}_AutoGen_system_{AutoGen_prompt_system_message}_AutoGen_concatenate_{AutoGen_prompt_concatenate}_all_code_with_COT_{all_code_with_COT}_all_code_without_COT_{all_code_without_COT}_Encourage_code_execution_interpreter_{Encourage_code_execution_interpreter}_all_text_{all_text}_multi_turn_planning_{multi_turn_planning}_round_num_{multi_turn_planning_round_num}'
    elif method_9_all_text_all_code_summarizer == True and method_10_LLM_estimates_scores_first == False:
        base_save_code_dir = save_input_dir + f'/result_gsm_{model_name}_method_9_all_text_all_code_summarizer'
    elif method_9_all_text_all_code_summarizer == False and method_10_LLM_estimates_scores_first == True:
        base_save_code_dir = save_input_dir + f'/result_gsm_{model_name}_method_10_LLM_estimates_scores_first'
    else:
        raise ValueError(
            "method_9_all_text_all_code_summarizer and method_10_LLM_estimates_scores_first can't be both True")

    total_sample_num = 0
    answer_with_code_wrong = 0
    answer_with_code_correct = 0
    answer_without_code_wrong = 0
    answer_without_code_correct = 0
    total_test_num = 0

    if multi_turn_planning:
        round_number = multi_turn_planning_round_num
    else:
        round_number = 1

    DATA_PATH = dataset_input_dir + f'/gsmhardv2.jsonl'

    question_json_list = []
    with open(DATA_PATH, 'r') as file:
        for line in file:
            question_json_list.append(json.loads(line))

    print(f'len(question_json_list): {len(question_json_list)}')

    for i in range(0, len(question_json_list), 5):
            with_code = False
            total_test_num += 1
            print('-------###-------###-------###-------')
            print(f'Current num is: {total_test_num}, Total num is: {int(len(question_json_list)/5)}\n')

            save_code_dir = os.path.join(base_save_code_dir, f"Sample_{i}/")
            if not os.path.exists(save_code_dir):
                os.makedirs(save_code_dir)

            data = question_json_list[i]
            target_answer = data['target']
            with open(save_code_dir + "/response_code_1.txt", "r") as f:
                response = f.read()
            original_response = response
            itertools_present = 'code_interpreter' in response or '```python' in response

            code_block_list = extract_code(response)
            for index, code_string in enumerate(code_block_list):
                with open(save_code_dir + f"/code_1_{index}.py", "w") as f:
                    f.write(code_string)

            # Test the generated code
            if not os.path.exists(save_code_dir + f"/code_1_0.py"):
                pass
            else:
                try:
                    result = subprocess.run(["python3", "-c", f"exec(open('{save_code_dir}/code_1_0.py').read()); print(result)"], capture_output=True, text=True, timeout=15)
                    if result.stdout == '':
                        result = subprocess.run(
                            ["python3", "-c", f"exec(open('{save_code_dir}/code_1_0.py').read()); print(Answer)"],
                            capture_output=True, text=True, timeout=15)
                    if result.stdout == '':
                        result = subprocess.run(
                            ["python3", "-c", f"exec(open('{save_code_dir}/code_1_0.py').read()); print(answer)"],
                            capture_output=True, text=True, timeout=15)

                    #if '<<<' in result.stdout and '>>>' in result.stdout:
                    response = result.stdout
                    errors = result.stderr
                except Exception as e:
                    pass

            output_1 = None;
            iteration_num_1 = 0
            while output_1 == None and iteration_num_1 < 3:
                iteration_num_1 += 1
                output_1 = extract_equation_with_GPT4(response)
            extracted_text_1, _ = extract_and_check(output_1)

            output_2 = None;
            iteration_num_2 = 0
            while output_2 == None and iteration_num_2 < 3:
                iteration_num_2 += 1
                output_2 = extract_equation_with_GPT4(original_response)
            extracted_text_2, _ = extract_and_check(output_2)

            True_false_result_1 = is_equiv_func(target_answer, extracted_text_1)
            True_false_result_1, _ = extract_and_check(True_false_result_1)
            True_false_result_2 = is_equiv_func(target_answer, extracted_text_2)
            True_false_result_2, _ = extract_and_check(True_false_result_2)

            print(f'True_false_result from response: {True_false_result_1}')
            print(f'True_false_result from original_response: {True_false_result_2}')
            print(f'target_answer: {target_answer}')
            print(f'extracted_text from response: {extracted_text_1}')
            print(f'extracted_text from original_response: {extracted_text_2}')
            with open(save_code_dir + f"/True_false_result_1.txt", "w") as f:
                f.write(True_false_result_1)
            with open(save_code_dir + f"/True_false_result_2.txt", "w") as f:
                f.write(True_false_result_2)
            with open(save_code_dir + f"/extracted_answer_1.txt", "w") as f:
                f.write(extracted_text_1)
            with open(save_code_dir + f"/extracted_answer_2.txt", "w") as f:
                f.write(extracted_text_2)

            if True_false_result_1 == 'False' and True_false_result_2 == 'False':
                print('False')
                if itertools_present == True:
                    answer_with_code_wrong += 1
                elif itertools_present == False:
                    answer_without_code_wrong += 1
            else:
                print('True')
                if itertools_present == True:
                    answer_with_code_correct += 1
                elif itertools_present == False:
                    answer_without_code_correct += 1

    print(f'\ntotal_sample_num: {total_sample_num}')
    print(f'answer_with_code_wrong: {answer_with_code_wrong}')
    print(f'answer_without_code_wrong: {answer_without_code_wrong}')
    print(f'answer_with_code_correct: {answer_with_code_correct}')
    print(f'answer_without_code_correct: {answer_without_code_correct}')

    with open(base_save_code_dir + f"/answer_with_code_wrong.txt", "w") as f:
        f.write(str(answer_with_code_wrong))
    with open(base_save_code_dir + f"/answer_without_code_wrong.txt", "w") as f:
        f.write(str(answer_without_code_wrong))
    with open(base_save_code_dir + f"/answer_with_code_correct.txt", "w") as f:
        f.write(str(answer_with_code_correct))
    with open(base_save_code_dir + f"/answer_without_code_correct.txt", "w") as f:
        f.write(str(answer_without_code_correct))

    print(
        f'\nModel_name: {model_name}, AutoGen_prompt_system_message: {AutoGen_prompt_system_message}, AutoGen_prompt_concatenate: {AutoGen_prompt_concatenate}, code_interpreter: {code_interpreter}, Encourage_code_execution_interpreter: {Encourage_code_execution_interpreter}, all_code_with_COT: {all_code_with_COT}, all_code_without_COT: {all_code_without_COT}, all_text: {all_text}')
    print('*' * 30)

    run_info = f"{model_name}, round_num={multi_turn_planning_round_num}, " \
               f"{code_interpreter}, {AutoGen_prompt_system_message}, {AutoGen_prompt_concatenate}, " \
               f"{Encourage_code_execution_interpreter}, {all_code_with_COT}, {all_code_without_COT}, {all_text}, {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}" \
               f"\nmethod_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n"

    run_info_result = f'answer_with_code_wrong: {answer_with_code_wrong}, answer_without_code_wrong: {answer_without_code_wrong}, answer_with_code_correct: {answer_with_code_correct}, answer_without_code_correct: {answer_without_code_correct}\n'
    log_file_result = os.path.join(save_input_dir, f"acc_result_log_{model_name}.txt")
    log_run_info(log_file_result, run_info + run_info_result)

if __name__ == '__main__':
    # gpt-4o, gpt-4o-mini, gpt-3.5-turbo for OpenAi API
    # gpt-4o, gpt-35-turbo for Azure API
    # multi_turn_planning_round_num = [1, 2, 4, 8, 16]

    parser = ArgumentParser()
    parser.add_argument('-model_name', '--model_name', default='gpt-35-turbo-16k-0613')
    args = parser.parse_args()
    model_name = args.model_name

    dataset_input_dir = '../dataset_gather'
    save_input_dir = '../results_gather/gsm'

    def log_run_info(log_file, run_info):
        with open(log_file, 'a') as f:
            f.write(run_info + "\n")

    #for multi_turn_planning_round_num in [1, 2, 4, 8, 16]:
    for multi_turn_planning_round_num in [1]:
        #for model_name in ['gpt-4o', 'gpt-4o-mini', 'gpt-3.5-turbo', 'gpt-35-turbo-16k-0613']:
        #for model_name in ['gpt-4o']:
            log_file = os.path.join(save_input_dir, f"run_log_{model_name}.txt")

            for dataset_input_dir, save_input_dir, model_name, code_interpreter, AutoGen_prompt_system_message, AutoGen_prompt_concatenate,\
                Encourage_code_execution_interpreter, all_code_with_COT, all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num\
                ,method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first in [
                (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, False, False),
                (dataset_input_dir, save_input_dir, model_name, False, True, False, False, False, False, False, True, multi_turn_planning_round_num, False, False),
                (dataset_input_dir, save_input_dir, model_name, False, False, True, False, False, False, False, True, multi_turn_planning_round_num, False, False),
                (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, True, True, multi_turn_planning_round_num, False, False),
                (dataset_input_dir, save_input_dir, model_name, False, False, False, False, True, False, False, True, multi_turn_planning_round_num, False, False),
                (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, True, False, True, multi_turn_planning_round_num, False, False),
                (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, True, False),
                (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, False, True),
            ]:

                post_process_gsm(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
                                    AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter, all_code_with_COT,
                                    all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num, method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first)

                # Log the completed run
                run_info = f"Completed run: {model_name}, round_num={multi_turn_planning_round_num}, " \
                           f"{code_interpreter}, {AutoGen_prompt_system_message}, {AutoGen_prompt_concatenate}, " \
                           f"{Encourage_code_execution_interpreter}, {all_code_with_COT}, {all_code_without_COT}, {all_text}, {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}" \
                           f"\nmethod_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n"
                log_run_info(log_file, run_info)

    #for model_name in ['gpt-4o', 'gpt-4o-mini', 'gpt-3.5-turbo', 'gpt-35-turbo-16k-0613']:
    #for model_name in ['gpt-4o']:
    log_file = os.path.join(save_input_dir, f"run_log_{model_name}.txt")

    for dataset_input_dir, save_input_dir, model_name, code_interpreter, AutoGen_prompt_system_message, AutoGen_prompt_concatenate, \
        Encourage_code_execution_interpreter, all_code_with_COT, all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num\
        ,method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first in [
        (dataset_input_dir, save_input_dir, model_name, True, False, False, False, False, False, False, False, 1, False, False),
        (dataset_input_dir, save_input_dir, model_name, True, False, False, True, False, False, False, False, 1, False, False)
    ]:
        post_process_gsm(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
                      AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter,
                      all_code_with_COT,
                      all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num, method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first)

        # Log the completed run
        run_info = f"Completed run: {model_name}, round_num={multi_turn_planning_round_num}, " \
                   f"{code_interpreter}, {AutoGen_prompt_system_message}, {AutoGen_prompt_concatenate}, " \
                   f"{Encourage_code_execution_interpreter}, {all_code_with_COT}, {all_code_without_COT}, {all_text}, {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}" \
                   f"\nmethod_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n"
        log_run_info(log_file, run_info)