import os
from generation_models import message_construct_func, GPT_response, extract_and_check, extract_code
import re
import pandas as pd
import subprocess
import time
from post_process_path_plan_continuous_func import check_trajectory
import string
import json
from argparse import ArgumentParser

def evaluate_response(word, target_letter, llm_response):
    """
    Evaluate the LLM's response for correctness.

    :param word: The test word
    :param target_letter: The letter that was counted
    :param llm_response: The response from the LLM
    :return: A tuple (is_correct, explanation)
    """
    # Extract count and positions from LLM response
    match = re.search(r"Count: (\d+), Positions: \[([\d, ]+)\]", llm_response)
    if not match:
        return False, "Response format is incorrect"

    llm_count = int(match.group(1))
    llm_positions = [int(pos) for pos in match.group(2).split(',')]

    # Calculate correct count and positions
    correct_count = word.count(target_letter)
    correct_positions = [i + 1 for i, letter in enumerate(word) if letter == target_letter]

    if llm_count != correct_count:
        return False, f"Incorrect count. Expected {correct_count}, got {llm_count}"

    if set(llm_positions) != set(correct_positions):
        return False, f"Incorrect positions. Expected {correct_positions}, got {llm_positions}"

    return True, "Correct response"

def extract_equation_with_GPT4(response):
    prompt = 'Your task is to extract the final answer from the given answer by another LLM:\n' \
             'Note that the final answer should follow strictly the format like Count: 5, Positions: [2, 4, 13, 17, 22], Count: 1, Positions: [5], ' \
             'Count: 4, Positions: [3, 11, 18, 24] \n' \
             'Here is the response, return your answer with the format <<<final answer>>>, like <<<Count: 4, Positions: [3, 11, 18, 24]>>>.\n' \
             'If the input text does not have <<<>>> and is already the pure answer, add <<<>>> and return your answer.\n' \
             'Note that if you find no final answer are answered, then directly answer <<<Count: 0, Positions: []>>>.\n' \
             'Input text: ' \

    extract_equation = GPT_response('', prompt + response, model_name='gpt-4o', code_interpreter=False, user_prompt_list = [prompt + response], response_total_list = [])
    return extract_equation

def read_words_from_file(filename):
    """
    Read a list of words from a JSON file.

    :param filename: Name of the file to read from
    :return: List of words
    """
    with open(filename, 'r') as f:
        words = json.load(f)
    #print(f"Words read from {filename}")
    return words

def post_process_letters(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
               AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter, all_code_with_COT,
               all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num, method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first):
    print('\n' + '*'*30)
    print(f'Model_name: {model_name}, AutoGen_prompt_system_message: {AutoGen_prompt_system_message}, '
          f'AutoGen_prompt_concatenate: {AutoGen_prompt_concatenate}, code_interpreter: {code_interpreter}, Encourage_code_execution_interpreter: {Encourage_code_execution_interpreter}, '
          f'all_code_with_COT: {all_code_with_COT}, all_code_without_COT: {all_code_without_COT}, all_text: {all_text}, multi_turn_planning: {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}\n')

    if method_9_all_text_all_code_summarizer == False and method_10_LLM_estimates_scores_first == False:
        base_save_code_dir = save_input_dir + f'/result_letters_{model_name}_interpreter_{code_interpreter}_AutoGen_system_{AutoGen_prompt_system_message}_AutoGen_concatenate_{AutoGen_prompt_concatenate}_all_code_with_COT_{all_code_with_COT}_all_code_without_COT_{all_code_without_COT}_Encourage_code_execution_interpreter_{Encourage_code_execution_interpreter}_all_text_{all_text}_multi_turn_planning_{multi_turn_planning}_round_num_{multi_turn_planning_round_num}'
    elif method_9_all_text_all_code_summarizer == True and method_10_LLM_estimates_scores_first == False:
        base_save_code_dir = save_input_dir + f'/result_letters_{model_name}_method_9_all_text_all_code_summarizer'
    elif method_9_all_text_all_code_summarizer == False and method_10_LLM_estimates_scores_first == True:
        base_save_code_dir = save_input_dir + f'/result_letters_{model_name}_method_10_LLM_estimates_scores_first'
    else:
        raise ValueError(
            "method_9_all_text_all_code_summarizer and method_10_LLM_estimates_scores_first can't be both True")

    total_sample_num = 0
    answer_with_code_wrong = 0
    answer_with_code_correct = 0
    answer_without_code_wrong = 0
    answer_without_code_correct = 0
    total_test_num = 0

    if multi_turn_planning:
        round_number = multi_turn_planning_round_num
    else:
        round_number = 1

    #for min_length, max_length in [(5, 10), (10, 15), (15, 20), (20, 25), (25, 30)]:
    for min_length, max_length in [(10, 15), (15, 20), (20, 25)]:
        base_dir = dataset_input_dir + f'Letters_dataset_min_length_{min_length}_max_length_{max_length}/'

        for letter in string.ascii_lowercase:
            for letter_freq in range(1, 6):
                #for index in range(3):
                for index in range(1):
                    with_code = False
                    total_test_num += 1

                    base_save_code_dir_2 = os.path.join(base_save_code_dir, f"min_length_{min_length}_max_length_{max_length}/")
                    save_code_dir = os.path.join(base_save_code_dir_2, f"{letter}_{letter_freq}_{index}/")
                    saving_dir = base_dir + f"{letter}_{letter_freq}_{index}/"
                    word = read_words_from_file(saving_dir + 'test_words.json')

                    with open(save_code_dir + "/response_code_1.txt", "r") as f:
                        response = f.read()
                    original_response = response
                    itertools_present = 'code_interpreter' in response or '```python' in response
                    print('-------###-------###-------###-------')
                    print(
                        f"\nMin_length: {min_length}, Max_length: {max_length}, Letter: {letter}, Letter_freq: {letter_freq}, Test word: {word}")

                    code_block_list = extract_code(response)
                    for index, code_string in enumerate(code_block_list):
                        with open(save_code_dir + f"/code_1_{index}.py", "w") as f:
                            f.write(code_string)

                    # Test the generated code
                    if not os.path.exists(save_code_dir + f"/code_1_0.py"):
                        pass
                    else:
                        try:
                            result = subprocess.run(["python3", "-c", f"exec(open('{save_code_dir}/code_1_0.py').read()); print(result)"], capture_output=True, text=True, timeout=15)
                            if result.stdout == '':
                                result = subprocess.run(
                                    ["python3", "-c", f"exec(open('{save_code_dir}/code_1_0.py').read()); print(Waypoints)"],
                                    capture_output=True, text=True, timeout=15)
                            if result.stdout == '':
                                result = subprocess.run(
                                    ["python3", "-c", f"exec(open('{save_code_dir}/code_1_0.py').read()); print(waypoints)"],
                                    capture_output=True, text=True, timeout=15)
                            if result.stdout == '':
                                result = subprocess.run(
                                    ["python3", "-c", f"exec(open('{save_code_dir}/code_1_0.py').read()); print(trajectory)"],
                                    capture_output=True, text=True, timeout=15)

                            response = result.stdout
                            errors = result.stderr
                        except Exception as e:
                            pass

                    output_1 = None;
                    iteration_num_1 = 0
                    while output_1 == None and iteration_num_1 < 3:
                        iteration_num_1 += 1
                        output_1 = extract_equation_with_GPT4(response)
                    output_2 = None;
                    iteration_num_2 = 0
                    while output_2 == None and iteration_num_2 < 3:
                        iteration_num_2 += 1
                        output_2 = extract_equation_with_GPT4(original_response)

                    extracted_position_count_str_1, _ = extract_and_check(output_1)
                    is_correct_1, explanation_1 = evaluate_response(word, letter, extracted_position_count_str_1)
                    extracted_position_count_str_2, _ = extract_and_check(output_2)
                    is_correct_2, explanation_2 = evaluate_response(word, letter, extracted_position_count_str_2)

                    print(f'Position_count from response: {extracted_position_count_str_1}')
                    print(f'Position_count from original response: {extracted_position_count_str_2}')
                    correct_count = word.count(letter)
                    correct_positions = [i + 1 for i, test_letter in enumerate(word) if test_letter == letter]
                    print(f'Correct_count: {correct_count}, Correct_positions: {correct_positions}')

                    with open(save_code_dir + f"/position_count_1.txt", "w") as f:
                        f.write(extracted_position_count_str_1)
                    with open(save_code_dir + f"/position_count_2.txt", "w") as f:
                        f.write(extracted_position_count_str_2)
                    with open(save_code_dir + f"/feedback_1.txt", "w") as f:
                        f.write(explanation_1)
                    with open(save_code_dir + f"/feedback_2.txt", "w") as f:
                        f.write(explanation_2)

                    if is_correct_1 == False and is_correct_2 == False:
                        print('False')
                        print(f'Feedback_1: {explanation_1}')
                        print(f'Feedback_2: {explanation_2}')
                        print(f'Original response: {original_response}')
                        if itertools_present == True:
                            answer_with_code_wrong += 1
                        elif itertools_present == False:
                            answer_without_code_wrong += 1
                    else:
                        print('True')
                        if itertools_present == True:
                            answer_with_code_correct += 1
                        elif itertools_present == False:
                            answer_without_code_correct += 1

    print(f'\ntotal_sample_num: {total_sample_num}')
    print(f'answer_with_code_wrong: {answer_with_code_wrong}')
    print(f'answer_without_code_wrong: {answer_without_code_wrong}')
    print(f'answer_with_code_correct: {answer_with_code_correct}')
    print(f'answer_without_code_correct: {answer_without_code_correct}')

    with open(base_save_code_dir + f"/answer_with_code_wrong.txt", "w") as f:
        f.write(str(answer_with_code_wrong))
    with open(base_save_code_dir + f"/answer_without_code_wrong.txt", "w") as f:
        f.write(str(answer_without_code_wrong))
    with open(base_save_code_dir + f"/answer_with_code_correct.txt", "w") as f:
        f.write(str(answer_with_code_correct))
    with open(base_save_code_dir + f"/answer_without_code_correct.txt", "w") as f:
        f.write(str(answer_without_code_correct))

    print(
        f'\nModel_name: {model_name}, AutoGen_prompt_system_message: {AutoGen_prompt_system_message}, AutoGen_prompt_concatenate: {AutoGen_prompt_concatenate}, code_interpreter: {code_interpreter}, Encourage_code_execution_interpreter: {Encourage_code_execution_interpreter}, all_code_with_COT: {all_code_with_COT}, all_code_without_COT: {all_code_without_COT}, all_text: {all_text}')
    print('*' * 30)

    run_info = f"{model_name}, round_num={multi_turn_planning_round_num}, " \
               f"{code_interpreter}, {AutoGen_prompt_system_message}, {AutoGen_prompt_concatenate}, " \
               f"{Encourage_code_execution_interpreter}, {all_code_with_COT}, {all_code_without_COT}, {all_text}, {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}" \
               f"\nmethod_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n"

    run_info_result = f'answer_with_code_wrong:{answer_with_code_wrong}, answer_without_code_wrong: {answer_without_code_wrong}, answer_with_code_correct: {answer_with_code_correct}, answer_without_code_correct: {answer_without_code_correct}\n'
    log_file_result = os.path.join(save_input_dir, f"acc_result_log_{model_name}.txt")
    log_run_info(log_file_result, run_info + run_info_result)

if __name__ == '__main__':
    # gpt-4o, gpt-4o-mini, gpt-3.5-turbo for OpenAi API
    # gpt-4o, gpt-35-turbo for Azure API
    # multi_turn_planning_round_num = [1, 2, 4, 8, 16]

    parser = ArgumentParser()
    parser.add_argument('-model_name', '--model_name', default='gpt-35-turbo-16k-0613')
    args = parser.parse_args()
    model_name = args.model_name

    dataset_input_dir = '../dataset_gather/'
    save_input_dir = '../results_gather/Letters'

    def log_run_info(log_file, run_info):
        with open(log_file, 'a') as f:
            f.write(run_info + "\n")

    # for multi_turn_planning_round_num in [1, 2, 4, 8, 16]:
    for multi_turn_planning_round_num in [1]:
        # for model_name in ['gpt-4o', 'gpt-4o-mini', 'gpt-3.5-turbo', 'gpt-35-turbo-16k-0613']:
        # for model_name in ['gpt-4o']:
        log_file = os.path.join(save_input_dir, f"run_log_{model_name}.txt")

        for dataset_input_dir, save_input_dir, model_name, code_interpreter, AutoGen_prompt_system_message, AutoGen_prompt_concatenate, \
            Encourage_code_execution_interpreter, all_code_with_COT, all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num \
                , method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first in [
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, True, False, False, False, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, True, False, False, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, True, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, True, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, True, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, True, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, False, True),
        ]:
            post_process_letters(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
                                 AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter,
                                 all_code_with_COT,
                                 all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num,
                                 method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first)

            # Log the completed run
            run_info = f"Completed run: {model_name}, round_num={multi_turn_planning_round_num}, " \
                       f"{code_interpreter}, {AutoGen_prompt_system_message}, {AutoGen_prompt_concatenate}, " \
                       f"{Encourage_code_execution_interpreter}, {all_code_with_COT}, {all_code_without_COT}, {all_text}, {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}" \
                       f"\nmethod_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n"
            log_run_info(log_file, run_info)


    #for model_name in ['gpt-4o', 'gpt-4o-mini', 'gpt-3.5-turbo', 'gpt-35-turbo-16k-0613']:
    #for model_name in ['gpt-4o']:
    log_file = os.path.join(save_input_dir, f"run_log_{model_name}.txt")

    for dataset_input_dir, save_input_dir, model_name, code_interpreter, AutoGen_prompt_system_message, AutoGen_prompt_concatenate, \
        Encourage_code_execution_interpreter, all_code_with_COT, all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num\
        ,method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first in [
        (dataset_input_dir, save_input_dir, model_name, True, False, False, False, False, False, False, False, 1, False, False),
        (dataset_input_dir, save_input_dir, model_name, True, False, False, True, False, False, False, False, 1, False, False)
    ]:
        post_process_letters(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
                      AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter,
                      all_code_with_COT,
                      all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num, method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first)

        # Log the completed run
        run_info = f"Completed run: {model_name}, round_num={multi_turn_planning_round_num}, " \
                   f"{code_interpreter}, {AutoGen_prompt_system_message}, {AutoGen_prompt_concatenate}, " \
                   f"{Encourage_code_execution_interpreter}, {all_code_with_COT}, {all_code_without_COT}, {all_text}, {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}" \
                   f"\nmethod_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n"
        log_run_info(log_file, run_info)