import os
from generation_models import message_construct_func, GPT_response, extract_and_check
import re
import pandas as pd
import subprocess
from typing import List, Tuple, Dict
import time
import numpy as np
import ast
from argparse import ArgumentParser

# Define types
State = List[List[str]]
Action = Tuple[str, str, str]  # (block, from, to)

def extract_code(text):
    # Regular expression to match code blocks enclosed in triple backticks
    code_block_pattern = re.compile(r'```python\n(.*?)\n```', re.DOTALL)

    # Find all matches in the text
    code_blocks = code_block_pattern.findall(text)

    # If no code blocks are found, try to find indented code blocks
    if not code_blocks:
        return []

    return code_blocks

def extract_equation_with_GPT4(response):
    prompt = 'Your task is to extract the final answer from the given answer by another LLM:\n' \
             'Note that the equation should be in the form like <<<answer>>>, <<<Move B from 2 to table\nMove A from 1 to 2\nMove C from 3 to 1\nMove D from 3 to 2\nMove B from 1 to 2>>>,\n' \
             'Here is the reponse, return your answer with the format <<<equation>>>, like <<<Move B from 2 to table\nMove A from 1 to 2>>>. ' \
             'Input text: ' \

    extract_equation = GPT_response('', prompt + response, model_name='gpt-4o', code_interpreter=False, user_prompt_list=[prompt + response], response_total_list=[])
    return extract_equation


def read_state_from_file(filename: str) -> Tuple[State, State]:
    """
    Read the initial and goal states from a text file.
    """
    initial_state = {}
    goal_state = {}
    current_state = initial_state

    with open(filename, 'r') as f:
        lines = f.readlines()

    for line in lines:
        line = line.strip()
        if line == "Initial State:":
            current_state = initial_state
        elif line == "Goal State:":
            current_state = goal_state
        elif line:
            parts = line.split(': ')
            stack = parts[0] if len(parts) > 1 else parts[0][:-1]
            blocks = parts[1].split() if len(parts) > 1 else []
            current_state[stack] = blocks

    return initial_state, goal_state

def validate_response(initial_state: State, goal_state: State, response: str) -> Tuple[bool, str]:
    """
    Validate the LLM's response and check if it reaches the goal state.
    """
    current_state = {stack: blocks.copy() for stack, blocks in initial_state.items()}
    #print('current_state:', current_state)
    moves = response.strip().split('\n')
    #print('goal state:', goal_state)
    for move in moves:
        parts = move.split()
        if len(parts) != 6 or parts[0] != "Move" or parts[2] != "from" or parts[4] != "to":
            return False, f"Invalid move format: {move}"

        block, source, destination = parts[1], parts[3], parts[5]
        if 'stack' not in source:
            source = 'stack' + source
        if 'stack' not in destination:
            destination = 'stack' + destination

        if source not in current_state or destination not in current_state:
            return False, f"Invalid source or destination stack: {move}"

        if not current_state[source] or current_state[source][-1] != block:
            return False, f"Invalid move: {move}. Block {block} is not at the top of the source stack."

        # Move the block
        moved_block = current_state[source].pop()
        current_state[destination].append(moved_block)
        #print('current_state:', current_state)

    def compare_states(state1, state2):
        state1_non_empty = {k: v for k, v in state1.items() if v}
        state2_non_empty = {k: v for k, v in state2.items() if v}
        return state1_non_empty == state2_non_empty

    # Check if the final state matches the goal state
    if compare_states(current_state, goal_state):
        return True, "Goal state reached successfully!"
    else:
        return False, "The final state does not match the goal state."

#post_process_number_multiply
def post_process_blocksworld(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
               AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter, all_code_with_COT,
               all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num, method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first):
    print('\n' + '*'*30)
    print(f'Model_name: {model_name}, AutoGen_prompt_system_message: {AutoGen_prompt_system_message}, '
          f'AutoGen_prompt_concatenate: {AutoGen_prompt_concatenate}, code_interpreter: {code_interpreter}, Encourage_code_execution_interpreter: {Encourage_code_execution_interpreter}, '
          f'all_code_with_COT: {all_code_with_COT}, all_code_without_COT: {all_code_without_COT}, all_text: {all_text}, multi_turn_planning: {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}\n')

    total_sample_num = 0
    answer_with_code_wrong = 0
    answer_with_code_correct = 0
    answer_without_code_wrong = 0
    answer_without_code_correct = 0

    if method_9_all_text_all_code_summarizer == False and method_10_LLM_estimates_scores_first == False:
        base_save_code_dir = save_input_dir + f'/result_blocksworld_{model_name}_interpreter_{code_interpreter}_AutoGen_system_{AutoGen_prompt_system_message}_AutoGen_concatenate_{AutoGen_prompt_concatenate}_all_code_with_COT_{all_code_with_COT}_all_code_without_COT_{all_code_without_COT}_Encourage_code_execution_interpreter_{Encourage_code_execution_interpreter}_all_text_{all_text}_multi_turn_planning_{multi_turn_planning}_round_num_{multi_turn_planning_round_num}'
    elif method_9_all_text_all_code_summarizer == True and method_10_LLM_estimates_scores_first == False:
        base_save_code_dir = save_input_dir + f'/result_blocksworld_{model_name}_method_9_all_text_all_code_summarizer'
    elif method_9_all_text_all_code_summarizer == False and method_10_LLM_estimates_scores_first == True:
        base_save_code_dir = save_input_dir + f'/result_blocksworld_{model_name}_method_10_LLM_estimates_scores_first'
    else:
        raise ValueError(
            "method_9_all_text_all_code_summarizer and method_10_LLM_estimates_scores_first can't be both True")

    if multi_turn_planning:
        round_number = multi_turn_planning_round_num
    else:
        round_number = 1

    for num_blocks, initial_stacks, goal_stacks in [
        (2, 3, 2), (2, 3, 3), (2, 4, 2), (2, 4, 3), (2, 4, 4), (2, 5, 2), (2, 5, 3), (2, 5, 4),
        (3, 3, 2), (3, 3, 3), (3, 4, 2), (3, 4, 3), (3, 4, 4), (3, 5, 2), (3, 5, 3), (3, 5, 4),
        (4, 3, 2), (4, 3, 3), (4, 4, 2), (4, 4, 3), (4, 4, 4), (4, 5, 2), (4, 5, 3), (4, 5, 4)
    ]:
        for index in range(5):
            dataset_base_dir_sample = os.path.join(dataset_input_dir, f"{num_blocks}_{initial_stacks}_{goal_stacks}_{index}/")
            # Read states from file
            initial_state, goal_state = read_state_from_file(dataset_base_dir_sample + f"blocksworld_task.txt")
            print(f'num_blocks: {num_blocks}, initial_stacks: {initial_stacks}, goal_stacks: {goal_stacks}, index: {index}')

            save_code_dir = os.path.join(base_save_code_dir, f"{num_blocks}_{initial_stacks}_{goal_stacks}_{index}/")

            with open(save_code_dir + "/response_code_1.txt", "r") as f:
                response = f.read()
            print(
                f'\nModel_name: {model_name}, AutoGen_prompt_system_message: {AutoGen_prompt_system_message}, AutoGen_prompt_concatenate: {AutoGen_prompt_concatenate}, code_interpreter: {code_interpreter}, Encourage_code_execution_interpreter: {Encourage_code_execution_interpreter}, all_code_with_COT: {all_code_with_COT}, all_code_without_COT: {all_code_without_COT}, all_text: {all_text}')

            itertools_present = 'code_interpreter' in response or '```python' in response

            code_block_list = extract_code(response)
            for index, code_string in enumerate(code_block_list):
                with open(save_code_dir + f"/code_1_{index}.py", "w") as f:
                    f.write(code_string)
                #print(f'code_{index}:\n {code_string}')

            # Test the generated code
            if not os.path.exists(save_code_dir + f"/code_1_0.py"):
                pass
            else:
                try:
                    result = subprocess.run(
                        ["python3", "-c", f"exec(open('{save_code_dir}/code_1_0.py').read()); print(result)"],
                        capture_output=True,
                        text=True,
                        timeout=15
                    )

                    '''
                    result = subprocess.run(
                        ["python3", save_code_dir + f"/code_1_0.py"],
                        capture_output=True, text=True, timeout=15)
                    '''
                    response = result.stdout
                    errors = result.stderr
                except Exception as e:
                    print('Code execution error')
                    response = ""
                    errors = str(e)

            response_answer, _ = extract_and_check(response)
            if response_answer == '':
                response_answer = extract_equation_with_GPT4(response)

            is_valid, message = validate_response(initial_state, goal_state, response_answer)

            print(f"Response: {response}")
            print(f"Response_answer: {response_answer}")
            print(f"Response is valid: {is_valid}")
            print(f"Message: {message}")

            with open(save_code_dir + "/response_answer.txt", "w") as f:
                f.write(response_answer)

            with open(save_code_dir + "/success_failure.txt", "w") as f:
                f.write(str(is_valid))

            if is_valid:
                print('True')
                #print(f'target_answer: {target_answer}')
                #print(f"response: {response}\n")
                if itertools_present == True:
                    answer_with_code_correct += 1
                elif itertools_present == False:
                    answer_without_code_correct += 1

            else:
                if itertools_present == True:
                    answer_with_code_wrong += 1
                elif itertools_present == False:
                    answer_without_code_wrong += 1

    print(f'\ntotal_sample_num: {total_sample_num}')
    print(f'answer_with_code_wrong: {answer_with_code_wrong}')
    print(f'answer_without_code_wrong: {answer_without_code_wrong}')
    print(f'answer_with_code_correct: {answer_with_code_correct}')
    print(f'answer_without_code_correct: {answer_without_code_correct}')

    with open(base_save_code_dir + f"/answer_with_code_wrong.txt", "w") as f:
        f.write(str(answer_with_code_wrong))
    with open(base_save_code_dir + f"/answer_without_code_wrong.txt", "w") as f:
        f.write(str(answer_without_code_wrong))
    with open(base_save_code_dir + f"/answer_with_code_correct.txt", "w") as f:
        f.write(str(answer_with_code_correct))
    with open(base_save_code_dir + f"/answer_without_code_correct.txt", "w") as f:
        f.write(str(answer_without_code_correct))

    print(f'\nModel_name: {model_name}, AutoGen_prompt_system_message: {AutoGen_prompt_system_message}, AutoGen_prompt_concatenate: {AutoGen_prompt_concatenate}, code_interpreter: {code_interpreter}, Encourage_code_execution_interpreter: {Encourage_code_execution_interpreter}, all_code_with_COT: {all_code_with_COT}, all_code_without_COT: {all_code_without_COT}, all_text: {all_text}')
    print('*' * 30)

    run_info = f"{model_name}, round_num={multi_turn_planning_round_num}, " \
               f"{code_interpreter}, {AutoGen_prompt_system_message}, {AutoGen_prompt_concatenate}, " \
               f"{Encourage_code_execution_interpreter}, {all_code_with_COT}, {all_code_without_COT}, {all_text}, {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}" \
               f"\nmethod_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n"

    run_info_result = f'answer_with_code_wrong:{answer_with_code_wrong}, answer_without_code_wrong: {answer_without_code_wrong}, answer_with_code_correct: {answer_with_code_correct}, answer_without_code_correct: {answer_without_code_correct}\n'
    log_file_result = os.path.join(save_input_dir, f"acc_result_log_{model_name}.txt")
    log_run_info(log_file_result, run_info + run_info_result)

if __name__ == '__main__':
    # gpt-4o, gpt-4o-mini, gpt-3.5-turbo for OpenAi API
    # gpt-4o, gpt-35-turbo for Azure API
    # multi_turn_planning_round_num = [1, 2, 4, 8, 16]

    parser = ArgumentParser()
    parser.add_argument('-model_name', '--model_name', default='gpt-35-turbo-16k-0613')
    args = parser.parse_args()
    model_name = args.model_name

    dataset_input_dir = '../dataset_gather/Blocksworld_dataset'
    save_input_dir = '../results_gather/blocksworld'

    def log_run_info(log_file, run_info):
        with open(log_file, 'a') as f:
            f.write(run_info + "\n")

    # for multi_turn_planning_round_num in [1, 2, 4, 8, 16]:
    for multi_turn_planning_round_num in [1]:
        # for model_name in ['gpt-4o', 'gpt-4o-mini', 'gpt-3.5-turbo', 'gpt-35-turbo-16k-0613']:
        # for model_name in ['gpt-4o']:
        log_file = os.path.join(save_input_dir, f"run_log_{model_name}.txt")

        for dataset_input_dir, save_input_dir, model_name, code_interpreter, AutoGen_prompt_system_message, AutoGen_prompt_concatenate, \
            Encourage_code_execution_interpreter, all_code_with_COT, all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num \
                , method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first in [
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, True, False, False, False, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, True, False, False, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, True, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, True, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, True, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, True, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, False, True),
        ]:
            post_process_blocksworld(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
                                 AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter,
                                 all_code_with_COT,
                                 all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num,
                                 method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first)

            # Log the completed run
            run_info = f"Completed run: {model_name}, round_num={multi_turn_planning_round_num}, " \
                       f"{code_interpreter}, {AutoGen_prompt_system_message}, {AutoGen_prompt_concatenate}, " \
                       f"{Encourage_code_execution_interpreter}, {all_code_with_COT}, {all_code_without_COT}, {all_text}, {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}" \
                       f"\nmethod_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n"
            log_run_info(log_file, run_info)


    #for model_name in ['gpt-4o', 'gpt-4o-mini', 'gpt-3.5-turbo', 'gpt-35-turbo-16k-0613']:
    #for model_name in ['gpt-4o']:
    log_file = os.path.join(save_input_dir, f"run_log_{model_name}.txt")

    for dataset_input_dir, save_input_dir, model_name, code_interpreter, AutoGen_prompt_system_message, AutoGen_prompt_concatenate, \
        Encourage_code_execution_interpreter, all_code_with_COT, all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num\
        ,method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first in [
        (dataset_input_dir, save_input_dir, model_name, True, False, False, False, False, False, False, False, 1, False, False),
        (dataset_input_dir, save_input_dir, model_name, True, False, False, True, False, False, False, False, 1, False, False)
    ]:
        post_process_blocksworld(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
                      AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter,
                      all_code_with_COT,
                      all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num, method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first)

        # Log the completed run
        run_info = f"Completed run: {model_name}, round_num={multi_turn_planning_round_num}, " \
                   f"{code_interpreter}, {AutoGen_prompt_system_message}, {AutoGen_prompt_concatenate}, " \
                   f"{Encourage_code_execution_interpreter}, {all_code_with_COT}, {all_code_without_COT}, {all_text}, {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}" \
                   f"\nmethod_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n"
        log_run_info(log_file, run_info)
