import os
from generation_models import message_construct_func, GPT_response, extract_and_check, extract_code
import re
import pandas as pd
import subprocess
from typing import List, Tuple, Dict
import time
import numpy as np
import ast
#from run_BoxNet1 import score_in_training_set
import json
import copy
from argparse import ArgumentParser

def action_from_response(pg_dict_input, original_response_dict_list):
    pg_dict_original = copy.deepcopy(pg_dict_input)
    for index, original_response_dict in enumerate(original_response_dict_list):
        #pg_dict_original = copy.deepcopy(pg_dict_input)
        transformed_dict = {}
        for key, value in original_response_dict.items():
            coordinates = tuple(map(float, re.findall(r"\d+\.?\d*", key)))

            # match the item and location in the value
            match = re.match(r"move\((.*?),\s(.*?)\)", value)
            if match:
                item, location = match.groups()
                if "square" in location:
                    location = tuple(map(float, re.findall(r"\d+\.?\d*", location)))
                transformed_dict[coordinates] = [item, location]

        for key, value in transformed_dict.items():
            #print(f"Key: {key}, Value1: {value[0]}, Value2: {value[1]}")
            try:
                if value[0] in pg_dict_original[str(key[0])+'_'+str(key[1])] and type(value[1]) == tuple and ((np.abs(key[0]-value[1][0])==0 and np.abs(key[1]-value[1][1])==1) or (np.abs(key[0]-value[1][0])==1 and np.abs(key[1]-value[1][1])==0)):
                    pg_dict_original[str(key[0])+'_'+str(key[1])].remove(value[0])
                    pg_dict_original[str(value[1][0])+'_'+str(value[1][1])].append(value[0])
                elif value[0] in pg_dict_original[str(key[0])+'_'+str(key[1])] and type(value[1]) == str and value[1] in pg_dict_original[str(key[0])+'_'+str(key[1])] and value[0][:4] == 'box_' and value[1][:7] == 'target_' and value[0][4:] == value[1][7:]:
                    pg_dict_original[str(key[0])+'_'+str(key[1])].remove(value[0])
                    pg_dict_original[str(key[0])+'_'+str(key[1])].remove(value[1])
                else:
                    #print(f"Error, Iteration Num: {iteration_num}, Key: {key}, Value1: {value[0]}, Value2: {value[1]}")
                    return pg_dict_original
            except:
                return pg_dict_original

    return pg_dict_original

def score_in_training_set(pg_dict, response):
    success_failure = ''

    try:
        original_response_dict_list = json.loads(response)
        for original_response_dict in original_response_dict_list:
            for key, value in original_response_dict.items():
                coordinates = tuple(map(float, re.findall(r"\d+\.?\d*", key)))
                # match the item and location in the value
                match = re.match(r"move\((.*?),\s(.*?)\)", value)
    except:
        success_failure = 'response in the wrong format'

    if success_failure == 'response in the wrong format':
        print('\nResponse in the wrong format!\n')
        return pg_dict, success_failure
    elif success_failure == '':
        pg_dict_returned = action_from_response(pg_dict, original_response_dict_list)

    #original_response_dict_list = json.loads(response)
    #pg_dict_returned = action_from_response(pg_dict, original_response_dict_list)

    count = 0
    for key, value in pg_dict.items():
        count += len(value)
    if count == 0:
        success_failure = 'success'
    elif success_failure == '':
        success_failure = 'failure after full execution'

    return pg_dict_returned, success_failure

def extract_equation_with_GPT4(response):
    prompt = 'Your task is to extract the final answer from the given answer by another LLM:\n' \
             'Note that the equation should be in the form like <<<answer>>>, <<<[{"Agent[0.5, 0.5]":"move(box_blue, square[0.5, 1.5])", "Agent[1.5, 0.5]":"move...}, {"Agent[0.5, 1.5]":"move(box_blue, target_blue)}, {...}...]>>>, \n' \
             'Here is the reponse, return your answer with the format <<<equation>>>, like <<<[{"Agent[0.5, 0.5]":"move(box_blue, square[0.5, 1.5])", "Agent[1.5, 0.5]":"move...}, {"Agent[0.5, 1.5]":"move(box_blue, target_blue)}, {...}...]>>>. ' \
             'Input text: ' \

    extract_equation = GPT_response('', prompt + response, model_name='gpt-4o', code_interpreter=False, user_prompt_list=[prompt + response], response_total_list=[])
    return extract_equation

#post_process_number_multiply
def post_process_boxnet1(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
               AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter, all_code_with_COT,
               all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num, method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first):
    print('\n' + '*'*30)
    print(f'Model_name: {model_name}, AutoGen_prompt_system_message: {AutoGen_prompt_system_message}, '
          f'AutoGen_prompt_concatenate: {AutoGen_prompt_concatenate}, code_interpreter: {code_interpreter}, Encourage_code_execution_interpreter: {Encourage_code_execution_interpreter}, '
          f'all_code_with_COT: {all_code_with_COT}, all_code_without_COT: {all_code_without_COT}, all_text: {all_text}, multi_turn_planning: {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}\n')

    total_sample_num = 0
    answer_with_code_wrong = 0
    answer_with_code_correct = 0
    answer_without_code_wrong = 0
    answer_without_code_correct = 0

    lifted_ratio_with_code = 0
    lifted_ratio_without_code = 0

    if method_9_all_text_all_code_summarizer == False and method_10_LLM_estimates_scores_first == False:
        base_save_code_dir = save_input_dir + f'/result_boxnet1_{model_name}_interpreter_{code_interpreter}_AutoGen_system_{AutoGen_prompt_system_message}_AutoGen_concatenate_{AutoGen_prompt_concatenate}_all_code_with_COT_{all_code_with_COT}_all_code_without_COT_{all_code_without_COT}_Encourage_code_execution_interpreter_{Encourage_code_execution_interpreter}_all_text_{all_text}_multi_turn_planning_{multi_turn_planning}_round_num_{multi_turn_planning_round_num}'
    elif method_9_all_text_all_code_summarizer == True and method_10_LLM_estimates_scores_first == False:
        base_save_code_dir = save_input_dir + f'/result_boxnet1_{model_name}_method_9_all_text_all_code_summarizer'
    elif method_9_all_text_all_code_summarizer == False and method_10_LLM_estimates_scores_first == True:
        base_save_code_dir = save_input_dir + f'/result_boxnet1_{model_name}_method_10_LLM_estimates_scores_first'
    else:
        raise ValueError(
            "method_9_all_text_all_code_summarizer and method_10_LLM_estimates_scores_first can't be both True")

    if multi_turn_planning:
        round_number = multi_turn_planning_round_num
    else:
        round_number = 1

    lifted_ratio_list = []
    for pg_row_num, pg_column_num in [(1, 2), (2, 2), (2, 4)]:
        for iteration_num in range(10):
            print('-------###-------###-------###-------')
            print(f'Row num is: {pg_row_num}, Column num is: {pg_column_num}, Iteration num is: {iteration_num}\n\n')

            save_code_dir = os.path.join(base_save_code_dir, f"{pg_row_num}_{pg_column_num}_{iteration_num}")

            with open(
                    dataset_input_dir + f'/env_pg_state_{pg_row_num}_{pg_column_num}/pg_state{iteration_num}/pg_state{iteration_num}.json',
                    'r') as file:
                pg_dict = json.load(file)
            file.close()

            pg_dict_initial = copy.deepcopy(pg_dict)

            with open(save_code_dir + "/response_code_1.txt", "r") as f:
                response = f.read()
            print(
                f'\nModel_name: {model_name}, AutoGen_prompt_system_message: {AutoGen_prompt_system_message}, AutoGen_prompt_concatenate: {AutoGen_prompt_concatenate}, code_interpreter: {code_interpreter}, Encourage_code_execution_interpreter: {Encourage_code_execution_interpreter}, all_code_with_COT: {all_code_with_COT}, all_code_without_COT: {all_code_without_COT}, all_text: {all_text}')

            itertools_present = 'code_interpreter' in response or '```python' in response

            code_block_list = extract_code(response)
            for index, code_string in enumerate(code_block_list):
                with open(save_code_dir + f"/code_1_{index}.py", "w") as f:
                    f.write(code_string)
                #print(f'code_{index}:\n {code_string}')

            # Test the generated code
            if not os.path.exists(save_code_dir + f"/code_1_0.py"):
                pass
            else:
                try:
                    result = subprocess.run(
                        ["python3", "-c", f"exec(open('{save_code_dir}/code_1_0.py').read()); print(result)"],
                        capture_output=True,
                        text=True,
                        timeout=15
                    )

                    response = result.stdout
                    errors = result.stderr
                except Exception as e:
                    pass

            response_answer, _ = extract_and_check(response)
            print(f"Response: {response_answer}")

            if response_answer == '':
                response_answer = extract_equation_with_GPT4(response)

            remaining_box_dict, success_failure = score_in_training_set(pg_dict, response_answer)

            boxes_all_list = [item for items in pg_dict_initial.values() for item in items if item.startswith('box')]
            boxes_remaining_list = [item for items in remaining_box_dict.values() for item in items if
                                    item.startswith('box')]
            lifted_ratio = (len(boxes_all_list) - len(boxes_remaining_list)) / len(boxes_all_list)
            lifted_ratio_list.append(lifted_ratio)

            print(f"Response: {response}")
            print(f"Response_answer: {response_answer}")
            print(f"Response is valid: {success_failure}")
            print(f'Initial boxes: {boxes_all_list}')
            print(f"Remaining boxes: {boxes_remaining_list}")
            print(f"Lifted ratio: {lifted_ratio}")

            with open(save_code_dir + "/Lifted_ratio_1.txt", "w") as f:
                f.write(str(lifted_ratio))

            with open(save_code_dir + "/response_answer.txt", "w") as f:
                f.write(response_answer)

            with open(save_code_dir + "/success_failure.txt", "w") as f:
                f.write(success_failure)

            if success_failure == 'success':
                print('True')
                #print(f'target_answer: {target_answer}')
                #print(f"response: {response}\n")
                if itertools_present == True:
                    answer_with_code_correct += 1
                elif itertools_present == False:
                    answer_without_code_correct += 1
            else:
                if itertools_present == True:
                    answer_with_code_wrong += 1
                elif itertools_present == False:
                    answer_without_code_wrong += 1
            if itertools_present == True:
                lifted_ratio_with_code += lifted_ratio
            elif itertools_present == False:
                lifted_ratio_without_code += lifted_ratio

    print(f'\ntotal_sample_num: {total_sample_num}')
    print(f'answer_with_code_wrong: {answer_with_code_wrong}')
    print(f'answer_without_code_wrong: {answer_without_code_wrong}')
    print(f'answer_with_code_correct: {answer_with_code_correct}')
    print(f'answer_without_code_correct: {answer_without_code_correct}')
    print(f'Total lifted ratio: {np.mean(lifted_ratio_list)}')

    with open(base_save_code_dir + f"/answer_with_code_wrong.txt", "w") as f:
        f.write(str(answer_with_code_wrong))
    with open(base_save_code_dir + f"/answer_without_code_wrong.txt", "w") as f:
        f.write(str(answer_without_code_wrong))
    with open(base_save_code_dir + f"/answer_with_code_correct.txt", "w") as f:
        f.write(str(answer_with_code_correct))
    with open(base_save_code_dir + f"/answer_without_code_correct.txt", "w") as f:
        f.write(str(answer_without_code_correct))
    with open(base_save_code_dir + f"/total_lifted_ratio.txt", "w") as f:
        f.write(str(np.mean(lifted_ratio_list)))
    with open(base_save_code_dir + f"/lifted_ratio_list.txt", "w") as f:
        f.write(str(lifted_ratio_list))

    if answer_with_code_wrong + answer_with_code_correct > 0:
        with_code_lifted_ratio = lifted_ratio_with_code / (answer_with_code_wrong + answer_with_code_correct)
        print(f'with_code_lifted_ratio: {lifted_ratio_with_code / (answer_with_code_wrong + answer_with_code_correct)}')
        with open(base_save_code_dir + f"/with_code_lifted_ratio.txt", "w") as f:
            f.write(str(lifted_ratio_with_code / (answer_with_code_wrong + answer_with_code_correct)))
    if answer_without_code_wrong + answer_without_code_correct > 0:
        without_code_lifted_ratio = lifted_ratio_without_code / (answer_without_code_wrong + answer_without_code_correct)
        print(f'without_code_lifted_ratio: {lifted_ratio_without_code / (answer_without_code_wrong + answer_without_code_correct)}')
        with open(base_save_code_dir + f"/without_code_lifted_ratio.txt", "w") as f:
            f.write(str(lifted_ratio_without_code / (answer_without_code_wrong + answer_without_code_correct)))

    print(f'\nModel_name: {model_name}, AutoGen_prompt_system_message: {AutoGen_prompt_system_message}, AutoGen_prompt_concatenate: {AutoGen_prompt_concatenate}, code_interpreter: {code_interpreter}, Encourage_code_execution_interpreter: {Encourage_code_execution_interpreter}, all_code_with_COT: {all_code_with_COT}, all_code_without_COT: {all_code_without_COT}, all_text: {all_text}')
    print('*' * 30)

    if method_9_all_text_all_code_summarizer:
        run_info = f"{model_name}, method_9, all_text_all_code_summarizer\n"
    elif method_10_LLM_estimates_scores_first:
        run_info = f"{model_name}, method_10, LLM_estimates_scores_first\n"
    elif all_text:
        run_info = f"{model_name}, method_2, all_text\n"
    elif all_code_without_COT:
        run_info = f"{model_name}, method_3, all_code_without_COT\n"
    elif all_code_with_COT:
        run_info = f"{model_name}, method_4, all_code_with_COT\n"
    elif AutoGen_prompt_concatenate:
        run_info = f"{model_name}, method_5, AutoGen_prompt_concatenate\n"
    elif AutoGen_prompt_system_message:
        run_info = f"{model_name}, method_6, AutoGen_prompt_system_message\n"
    elif code_interpreter == True and Encourage_code_execution_interpreter == False:
        run_info = f"{model_name}, method_7, code_interpreter\n"
    elif code_interpreter == True and Encourage_code_execution_interpreter == True:
        run_info = f"{model_name}, method_8, code_interpreter + Encourage_code\n"
    else:
        run_info = f"{model_name}, method_1, only question\n"

    run_info += f'answer_with_code_correct, answer_with_code_wrong, answer_without_code_correct, answer_without_code_wrong\n'
    run_info_result = f'[{answer_with_code_correct}, {answer_with_code_wrong}, {answer_without_code_correct}, {answer_without_code_wrong}]\n'
    run_info_result += f'total_lifted_ratio: {np.mean(lifted_ratio_list)}'

    if answer_with_code_wrong + answer_with_code_correct > 0:
        run_info_result += f', with_code_lifted_ratio: {with_code_lifted_ratio}'
    if answer_without_code_wrong + answer_without_code_correct > 0:
        run_info_result += f', without_code_lifted_ratio: {without_code_lifted_ratio}\n'
    log_file_result = os.path.join(save_input_dir, f"acc_result_log_{model_name}.txt")
    log_run_info(log_file_result, run_info + run_info_result)

if __name__ == '__main__':
    # gpt-4o, gpt-4o-mini, gpt-3.5-turbo for OpenAi API
    # gpt-4o, gpt-35-turbo for Azure API
    # multi_turn_planning_round_num = [1, 2, 4, 8, 16]

    parser = ArgumentParser()
    parser.add_argument('-model_name', '--model_name', default='gpt-35-turbo-16k-0613')
    args = parser.parse_args()
    model_name = args.model_name

    dataset_input_dir = '../dataset_gather/BoxNet1_dataset'
    save_input_dir = '../results_gather/BoxNet1'

    def log_run_info(log_file, run_info):
        with open(log_file, 'a') as f:
            f.write(run_info + "\n")

    # for multi_turn_planning_round_num in [1, 2, 4, 8, 16]:
    for multi_turn_planning_round_num in [1]:
        # for model_name in ['gpt-4o', 'gpt-4o-mini', 'gpt-3.5-turbo', 'gpt-35-turbo-16k-0613']:
        # for model_name in ['gpt-4o']:
        log_file = os.path.join(save_input_dir, f"run_log_{model_name}.txt")

        for dataset_input_dir, save_input_dir, model_name, code_interpreter, AutoGen_prompt_system_message, AutoGen_prompt_concatenate, \
            Encourage_code_execution_interpreter, all_code_with_COT, all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num \
                , method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first in [
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, True, False, False, False, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, True, False, False, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, True, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, True, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, True, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, True, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, False, True),
        ]:
            post_process_boxnet1(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
                             AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter,
                             all_code_with_COT,
                             all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num,
                             method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first)

            # Log the completed run
            run_info = f"Completed run: {model_name}, round_num={multi_turn_planning_round_num}, " \
                       f"{code_interpreter}, {AutoGen_prompt_system_message}, {AutoGen_prompt_concatenate}, " \
                       f"{Encourage_code_execution_interpreter}, {all_code_with_COT}, {all_code_without_COT}, {all_text}, {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}" \
                       f"\nmethod_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n"
            log_run_info(log_file, run_info)

'''
    #for model_name in ['gpt-4o', 'gpt-4o-mini', 'gpt-3.5-turbo', 'gpt-35-turbo-16k-0613']:
    #for model_name in ['gpt-4o']:
    log_file = os.path.join(save_input_dir, f"run_log_{model_name}.txt")

    for dataset_input_dir, save_input_dir, model_name, code_interpreter, AutoGen_prompt_system_message, AutoGen_prompt_concatenate, \
        Encourage_code_execution_interpreter, all_code_with_COT, all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num\
        ,method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first in [
        (dataset_input_dir, save_input_dir, model_name, True, False, False, False, False, False, False, False, 1, False, False),
        (dataset_input_dir, save_input_dir, model_name, True, False, False, True, False, False, False, False, 1, False, False)
    ]:
        post_process_boxnet1(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
                      AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter,
                      all_code_with_COT,
                      all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num, method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first)

        # Log the completed run
        run_info = f"Completed run: {model_name}, round_num={multi_turn_planning_round_num}, " \
                   f"{code_interpreter}, {AutoGen_prompt_system_message}, {AutoGen_prompt_concatenate}, " \
                   f"{Encourage_code_execution_interpreter}, {all_code_with_COT}, {all_code_without_COT}, {all_text}, {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}" \
                   f"\nmethod_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n"
        log_run_info(log_file, run_info)
'''