from dotenv import load_dotenv
from openai import AzureOpenAI
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
import json
import re
import pandas as pd
import os
import subprocess
import io
import sys
from openai import OpenAI
from typing_extensions import override
from openai import AssistantEventHandler
from generation_models import message_construct_func, GPT_response, count_total_tokens, extract_and_check
import random
import math
import json
from typing import List, Tuple, Dict
import time
import numpy as np
import ast
from prompt import *
from argparse import ArgumentParser

def extract_code(text):
    # Regular expression to match code blocks enclosed in triple backticks
    code_block_pattern = re.compile(r'```python\n(.*?)\n```', re.DOTALL)

    # Find all matches in the text
    code_blocks = code_block_pattern.findall(text)

    # If no code blocks are found, try to find indented code blocks
    if not code_blocks:
        return []

    return code_blocks

def extract_equation_with_GPT4(response):
    prompt = 'Your task is to extract the final answer of the given answer by another LLM:\n' \
             'Here is the response, return your answer with the format <<<list>>>, like <<<Yes>>>, <<<No>>>.\n' \
             'If the input text does not have <<<>>> and is already the pure answer, add <<<>>> and return your answer.\n' \
             'Note that if you find no final answer is answered, then directly answer <<<No answer found>>>.\n' \
             'Input text: ' \

    extract_equation = GPT_response('', prompt + response, model_name='gpt-4o', code_interpreter=False, user_prompt_list = [prompt + response], response_total_list = [])
    return extract_equation

def create_prompt(boxes: List[int], lifters: List[int], estimated_steps) -> str:
    prompt = f"""Task: BoxLift

You are given a list of boxes with the following weights: {boxes}
And a list of lifters with the following maximum lifting capacities: {lifters}

Your task is to assign the lifters to lift all the boxes in multiple steps, following these rules:
1. Multiple boxes can be lifted in each step.
2. Each lifter can only lift one box at a time.
3. Each lifting agent can be used only once in each step.
4. Multiple lifters can combine together to lift one box if the box is too heavy for a single lifter.
5. Try to lift all the boxes using the minimum number of steps possible.
6. You need to lift all the boxes in less than or equal to {estimated_steps} steps.

Please provide your solution in the following format:
Step 1: [(Box weight, [Lifter indices]), (Box weight, [Lifter indices]), ...]
Step 2: [(Box weight, [Lifter indices]), (Box weight, [Lifter indices]), ...]
...

For example:
Step 1: [(50, [0, 2]), (30, [1]), (20, [3])]
This means in Step 1, lifters 0 and 2 are lifting a box weighing 50, lifter 1 is lifting a box weighing 30, and lifter 3 is lifting a box weighing 20.

Surround the answer with <<<content>>>.

For example, <<<Step 1: [(50, [0, 2]), (30, [1]), (20, [3])]\nStep 2: [(40, [0, 1]), (20, [2]), (20, [3])]\nStep 3:...>>>

Ensure all boxes are lifted and provide the most efficient solution possible.

Your answer:\n
"""
    return prompt


def verify_solution(boxes: List[int], lifters: List[int], solution: str, estimated_steps) -> Tuple[bool, List[int]]:
    remaining_boxes = boxes.copy()
    success_failure_list = []

    steps = solution.split("Step")[1:]  # Split the solution into steps
    if len(steps) > estimated_steps:
        success_failure = 'Too many steps'
        success_failure_list.append(success_failure)
        #return False, remaining_boxes, success_failure

    for index in range(min(estimated_steps, len(steps))):
        step = steps[index]
        used_lifters = set()
        try:
            assignments = eval(step.split(":")[1].strip())
            for box_weight, lifter_indices in assignments:
                # Check if the box weight is valid
                if box_weight not in remaining_boxes:
                    success_failure = 'Invalid box weight'
                    success_failure_list.append(success_failure)
                    #return False, remaining_boxes, success_failure

                elif any(index >= len(lifters) for index in lifter_indices):
                    success_failure = 'Invalid lifter index'
                    success_failure_list.append(success_failure)
                    #return False, remaining_boxes, success_failure

                # Check if lifters are used only once per step
                elif any(index in used_lifters for index in lifter_indices):
                    success_failure = 'Lifter used more than once'
                    success_failure_list.append(success_failure)
                    #return False, remaining_boxes, success_failure

                # Check if lifters can lift the box
                elif sum(lifters[i] for i in lifter_indices) < box_weight:
                    success_failure = 'Insufficient lifter strength'
                    success_failure_list.append(success_failure)
                    # return False, remaining_boxes, success_failure
                    #pass
                else:
                    remaining_boxes.remove(box_weight)
                    used_lifters.update(lifter_indices)
        except:
            success_failure = 'Invalid format'
            success_failure_list.append(success_failure)
            return False, remaining_boxes, success_failure

    return len(remaining_boxes) == 0, remaining_boxes, success_failure_list


def estimate_steps(boxes: List[int], lifters: List[int]) -> int:
    remaining_boxes = sorted(boxes, reverse=True)  # Sort boxes in descending order
    steps = 0

    while remaining_boxes:
        steps += 1
        available_lifters = lifters.copy()

        i = 0
        while i < len(remaining_boxes) and available_lifters:
            box = remaining_boxes[i]
            combined_strength = sum(available_lifters)

            if combined_strength >= box:
                # Lift the box using as many lifters as needed
                lift_strength = 0
                used_lifters = []
                for j, lifter in enumerate(available_lifters):
                    lift_strength += lifter
                    used_lifters.append(j)
                    if lift_strength >= box:
                        break

                # Remove the used lifters and the lifted box
                for j in reversed(used_lifters):
                    available_lifters.pop(j)
                remaining_boxes.pop(i)
            else:
                i += 1  # Move to the next box if we can't lift this one

    return steps


def read_test_case(filename: str) -> Tuple[List[int], List[int]]:
    """
    Read the test case (boxes and lifters) from a JSON file.

    :param filename: Name of the file to read from.
    :return: A tuple containing a list of box weights and a list of lifter capacities.
    """
    with open(filename, 'r') as f:
        data = json.load(f)
    return data["boxes"], data["lifters"]

def run_boxlift(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
               AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter, all_code_with_COT,
               all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num, method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first):
    print('\n' + '*'*30)
    print(f'Model_name: {model_name}, AutoGen_prompt_system_message: {AutoGen_prompt_system_message}, '
          f'AutoGen_prompt_concatenate: {AutoGen_prompt_concatenate}, code_interpreter: {code_interpreter}, Encourage_code_execution_interpreter: {Encourage_code_execution_interpreter}, '
          f'all_code_with_COT: {all_code_with_COT}, all_code_without_COT: {all_code_without_COT}, all_text: {all_text}, multi_turn_planning: {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}, '
          f'method_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n')

    total_test_num = 0
    num_with_code = 0
    num_without_code = 0
    time_with_code = 0
    time_without_code = 0

    if method_9_all_text_all_code_summarizer == False and method_10_LLM_estimates_scores_first == False:
        base_save_code_dir = save_input_dir + f'/result_boxlift_{model_name}_interpreter_{code_interpreter}_AutoGen_system_{AutoGen_prompt_system_message}_AutoGen_concatenate_{AutoGen_prompt_concatenate}_all_code_with_COT_{all_code_with_COT}_all_code_without_COT_{all_code_without_COT}_Encourage_code_execution_interpreter_{Encourage_code_execution_interpreter}_all_text_{all_text}_multi_turn_planning_{multi_turn_planning}_round_num_{multi_turn_planning_round_num}'
    elif method_9_all_text_all_code_summarizer == True and method_10_LLM_estimates_scores_first == False:
        base_save_code_dir = save_input_dir + f'/result_boxlift_{model_name}_method_9_all_text_all_code_summarizer'
    elif method_9_all_text_all_code_summarizer == False and method_10_LLM_estimates_scores_first == True:
        base_save_code_dir = save_input_dir + f'/result_boxlift_{model_name}_method_10_LLM_estimates_scores_first'
    else:
        raise ValueError(
            "method_9_all_text_all_code_summarizer and method_10_LLM_estimates_scores_first can't be both True")

    if not os.path.exists(base_save_code_dir):
        os.makedirs(base_save_code_dir)

    for num_boxes, num_lifters, min_box_weight, max_box_weight, min_lifter_capacity, max_lifter_capacity in \
            [(10, 3, 10, 100, 40, 80), (15, 4, 20, 200, 30, 120), (20, 5, 30, 300, 40, 160), (25, 6, 40, 400, 50, 200)]:
        for iteration_num in range(10):
            with_code = False
            total_test_num += 1
            print(f'\n\nNum_boxes = {num_boxes}, Num_lifters = {num_lifters}, Iteration_num = {iteration_num}')
            save_code_dir = os.path.join(base_save_code_dir, f"{num_boxes}_{num_lifters}_{iteration_num}/")
            if not os.path.exists(save_code_dir):
                os.makedirs(save_code_dir)

            boxes, lifters = read_test_case(dataset_input_dir + f'/BoxLift_{num_boxes}_{num_lifters}/BoxLift{iteration_num}/BoxLift.json')
            print(f"Initial boxes: {boxes}")
            print(f"Initial lifters: {lifters}")

            estimated_steps = estimate_steps(boxes, lifters)
            print(f"Estimated number of steps: {estimated_steps}")
            question = create_prompt(boxes, lifters, estimated_steps)

            response_total_list = [];
            system_message = ""

            if method_9_all_text_all_code_summarizer == True and method_10_LLM_estimates_scores_first == False:
                base_save_code_dir_all_text = save_input_dir + f'/result_boxlift_{model_name}_interpreter_False_AutoGen_system_False_AutoGen_concatenate_False_all_code_with_COT_False_all_code_without_COT_False_Encourage_code_execution_interpreter_False_all_text_True_multi_turn_planning_True_round_num_{multi_turn_planning_round_num}'
                base_save_code_dir_all_code = save_input_dir + f'/result_boxlift_{model_name}_interpreter_False_AutoGen_system_False_AutoGen_concatenate_False_all_code_with_COT_False_all_code_without_COT_True_Encourage_code_execution_interpreter_False_all_text_False_multi_turn_planning_True_round_num_{multi_turn_planning_round_num}'
                save_code_dir_all_text = os.path.join(base_save_code_dir_all_text, f"{num_boxes}_{num_lifters}_{iteration_num}/")
                save_code_dir_all_code = os.path.join(base_save_code_dir_all_code, f"{num_boxes}_{num_lifters}_{iteration_num}/")
                with open(save_code_dir_all_text + f"/response_code_1.txt", "r") as f:
                    response_all_text = f.read()
                with open(save_code_dir_all_code + f"/response_code_1.txt", "r") as f:
                    response_all_code = f.read()
                input_prompt = combined_agent_prompt + '###The input question is: \n' + question + f'\n\n'
                input_prompt = input_prompt + f'\nThe response from Agent_1 is: {response_all_text}'

                if os.path.exists(save_code_dir_all_text + f"/code_1_0.py"):
                    with open(save_code_dir_all_text + f"/response_answer.txt", "r") as f:
                        extracted_all_text = f.read()
                    input_prompt = input_prompt + f'\nThe execution result from the Agent_1 code is: {extracted_all_text}'

                input_prompt = input_prompt + f'\n\nThe response from Agent_2 is: {response_all_code}'
                if os.path.exists(save_code_dir_all_code + f"/code_1_0.py"):
                    with open(save_code_dir_all_code + f"/response_answer.txt", "r") as f:
                        extracted_all_code = f.read()
                    input_prompt = input_prompt + f'\nThe execution result from the Agent_2 code is: {extracted_all_code}'
                input_prompt = input_prompt + f'\n\nNow you need to analyze the problem based on their answers and output final answer with the required format in the original question. Your analysis and answer:\n'

                # print(f'\n#########input_prompt: \n{input_prompt}')
                user_prompt_list = [input_prompt]

            elif method_9_all_text_all_code_summarizer == False and method_10_LLM_estimates_scores_first == True:
                user_prompt_list = [method_10_self_estimate_score_prompt + question]

            elif Encourage_code_execution_interpreter == True and code_interpreter == True and all_code_with_COT == False and all_code_without_COT == False and all_text == False:
                user_prompt_list = [Encourage_code_execution_prompt_for_code_interpreter + question]
            elif AutoGen_prompt_system_message == True and AutoGen_prompt_concatenate == False and all_code_with_COT == False and all_code_without_COT == False and all_text == False:
                system_message = AutoGen_prompt
                user_prompt_list = [question]
            elif AutoGen_prompt_system_message == False and AutoGen_prompt_concatenate == True and all_code_with_COT == False and all_code_without_COT == False and all_text == False:
                user_prompt_list = [AutoGen_prompt + question]
            elif AutoGen_prompt_system_message == False and AutoGen_prompt_concatenate == False and all_code_with_COT == False and all_code_without_COT == False and all_text == False:
                user_prompt_list = [question]
            elif code_interpreter == False and all_code_with_COT == True and all_code_without_COT == False and all_text == False:
                user_prompt_list = [with_COT_all_code_prompt_2 + question]
            elif code_interpreter == False and all_code_with_COT == False and all_code_without_COT == True and all_text == False:
                user_prompt_list = [without_COT_all_code_prompt_1 + question]
            elif code_interpreter == False and all_code_with_COT == False and all_code_without_COT == False and all_text == True:
                user_prompt_list = [text_output_prompt + question]
            with open(save_code_dir + f"/system_message.txt", "w") as f:
                f.write(system_message)

            boxes_initial = {
                "initial_boxes": boxes
            }
            with open(save_code_dir + f'/boxes_initial_{num_boxes}_{num_lifters}_{iteration_num}.json',
                      'w') as f:
                json.dump(boxes_initial, f)
            f.close()

            if multi_turn_planning:
                round_number = multi_turn_planning_round_num
            else:
                round_number = 1

            execution_time_total = 0
            for round_index in range(round_number):
                with open(save_code_dir + f"/input_prompt_{round_index + 1}.txt", "w") as f:
                    f.write(user_prompt_list[round_index])

                print(f'Round {round_index + 1}')
                start_time = time.time()

                # 15000 tokens limit for gpt-3.5-turbo
                if count_total_tokens(user_prompt_list, response_total_list) > 15000 and model_name in ['gpt-3.5-turbo',
                                                                                                        'gpt-35-turbo-16k-0613']:
                    break

                response_code = GPT_response("", user_prompt_list[0], model_name=model_name,
                                             code_interpreter=code_interpreter, user_prompt_list=user_prompt_list,
                                             response_total_list=response_total_list)
                code_block_list = extract_code(response_code)
                for index_code, code_string in enumerate(code_block_list):
                    with open(save_code_dir + f"/code_{round_index + 1}_{index_code}.txt", "w") as f:
                        f.write(code_string)
                with open(save_code_dir + f"/response_code_{round_index + 1}.txt", "w") as f:
                    f.write(response_code)
                if 'TERMINATE' in response_code:
                    print(f'Terminate in round {round_index + 1}. Completed!')
                    end_time = time.time()
                    execution_time = end_time - start_time
                    execution_time_total += execution_time
                    with open(save_code_dir + f"/execution_time_{round_index + 1}.txt", "w") as f:
                        f.write(str(execution_time))
                    break
                elif os.path.exists(save_code_dir + f"/code_{round_index + 1}_0.py"):
                    with_code = True
                    try:
                        result = subprocess.run(
                            ["python3", save_code_dir + f"/code_{round_index + 1}_0.py"],
                            capture_output=True, text=True, timeout=10
                        )
                        output = result.stdout
                        errors = result.stderr
                    except subprocess.TimeoutExpired as e:
                        output = e.stdout if e.stdout else ""
                        errors = e.stderr if e.stderr else ""
                        errors += f"\nTimeoutExpired: Command '{e.cmd}' timed out after {e.timeout} seconds"

                    multi_turn_question = multi_turn_planning_prompt_with_code + f'The execution result from the code is:\noutput: {output}, errors: {errors}'
                    user_prompt_list.append(multi_turn_question)
                    response_total_list.append(response_code)
                    end_time = time.time()
                    execution_time = end_time - start_time
                    execution_time_total += execution_time
                    with open(save_code_dir + f"/execution_time_{round_index + 1}.txt", "w") as f:
                        f.write(str(execution_time))
                else:
                    # print(f'No code generated in round {round_index + 1}. Completed!')
                    multi_turn_question = multi_turn_planning_prompt_without_code
                    user_prompt_list.append(multi_turn_question)
                    response_total_list.append(response_code)

                    end_time = time.time()
                    execution_time = end_time - start_time
                    execution_time_total += execution_time
                    with open(save_code_dir + f"/execution_time_{round_index + 1}.txt", "w") as f:
                        f.write(str(execution_time))

            if with_code:
                num_with_code += 1
                time_with_code += execution_time_total
            else:
                num_without_code += 1
                time_without_code += execution_time_total

    print(
        f'ratio_with_code: {num_with_code / total_test_num}, ratio_without_code: {num_without_code / total_test_num}')
    if num_with_code != 0:
        print(f'average_time_with_code: {time_with_code / num_with_code}')
        with open(base_save_code_dir + f"/average_time_with_code.txt", "w") as f:
            f.write(str(time_with_code / num_with_code))
    if num_without_code != 0:
        print(f'average_time_without_code: {time_without_code / num_without_code}')
        with open(base_save_code_dir + f"/average_time_without_code.txt", "w") as f:
            f.write(str(time_without_code / num_without_code))

    print(f'Model_name: {model_name}, AutoGen_prompt_system_message: {AutoGen_prompt_system_message}, '
          f'AutoGen_prompt_concatenate: {AutoGen_prompt_concatenate}, code_interpreter: {code_interpreter}, Encourage_code_execution_interpreter: {Encourage_code_execution_interpreter}, '
          f'all_code_with_COT: {all_code_with_COT}, all_code_without_COT: {all_code_without_COT}, all_text: {all_text}, multi_turn_planning: {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}, '
          f'method_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n')
    print('*' * 30)

if __name__ == '__main__':
    # gpt-4o, gpt-4o-mini, gpt-3.5-turbo for OpenAi API
    # gpt-4o, gpt-35-turbo for Azure API
    # multi_turn_planning_round_num = [1, 2, 4, 8, 16]

    parser = ArgumentParser()
    parser.add_argument('-model_name', '--model_name', default='gpt-35-turbo-16k-0613')
    args = parser.parse_args()
    model_name = args.model_name

    dataset_input_dir = '../dataset_gather/BoxLift_dataset'
    save_input_dir = '../results_gather/BoxLift'

    if not os.path.exists(save_input_dir):
        os.makedirs(save_input_dir)

    def log_run_info(log_file, run_info):
        with open(log_file, 'a') as f:
            f.write(run_info + "\n")

    # for multi_turn_planning_round_num in [1, 2, 4, 8, 16]:
    for multi_turn_planning_round_num in [1]:
        # for model_name in ['gpt-4o', 'gpt-4o-mini', 'gpt-3.5-turbo', 'gpt-35-turbo-16k-0613']:
        # for model_name in ['gpt-4o']:
        log_file = os.path.join(save_input_dir, f"run_log_{model_name}.txt")

        for dataset_input_dir, save_input_dir, model_name, code_interpreter, AutoGen_prompt_system_message, AutoGen_prompt_concatenate, \
            Encourage_code_execution_interpreter, all_code_with_COT, all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num \
                , method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first in [
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, True, False, False, False, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, True, False, False, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, True, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, True, False, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, True, False, True, multi_turn_planning_round_num, False, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, True, False),
            (dataset_input_dir, save_input_dir, model_name, False, False, False, False, False, False, False, True, multi_turn_planning_round_num, False, True),
        ]:
            run_boxlift(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
                                 AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter,
                                 all_code_with_COT,
                                 all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num,
                                 method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first)

            # Log the completed run
            run_info = f"Completed run: {model_name}, round_num={multi_turn_planning_round_num}, " \
                       f"{code_interpreter}, {AutoGen_prompt_system_message}, {AutoGen_prompt_concatenate}, " \
                       f"{Encourage_code_execution_interpreter}, {all_code_with_COT}, {all_code_without_COT}, {all_text}, {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}" \
                       f"\nmethod_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n"
            log_run_info(log_file, run_info)

    #for model_name in ['gpt-4o', 'gpt-4o-mini', 'gpt-3.5-turbo', 'gpt-35-turbo-16k-0613']:
    #for model_name in ['gpt-4o']:
    log_file = os.path.join(save_input_dir, f"run_log_{model_name}.txt")

    for dataset_input_dir, save_input_dir, model_name, code_interpreter, AutoGen_prompt_system_message, AutoGen_prompt_concatenate, \
        Encourage_code_execution_interpreter, all_code_with_COT, all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num\
        ,method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first in [
        (dataset_input_dir, save_input_dir, model_name, True, False, False, False, False, False, False, False, 1, False, False),
        (dataset_input_dir, save_input_dir, model_name, True, False, False, True, False, False, False, False, 1, False, False)
    ]:
        run_boxlift(dataset_input_dir, save_input_dir, model_name, AutoGen_prompt_system_message,
                      AutoGen_prompt_concatenate, code_interpreter, Encourage_code_execution_interpreter,
                      all_code_with_COT,
                      all_code_without_COT, all_text, multi_turn_planning, multi_turn_planning_round_num, method_9_all_text_all_code_summarizer, method_10_LLM_estimates_scores_first)

        # Log the completed run
        run_info = f"Completed run: {model_name}, round_num={multi_turn_planning_round_num}, " \
                   f"{code_interpreter}, {AutoGen_prompt_system_message}, {AutoGen_prompt_concatenate}, " \
                   f"{Encourage_code_execution_interpreter}, {all_code_with_COT}, {all_code_without_COT}, {all_text}, {multi_turn_planning}, multi_turn_planning_round_num: {multi_turn_planning_round_num}" \
                   f"\nmethod_9_{method_9_all_text_all_code_summarizer}, method_10_{method_10_LLM_estimates_scores_first}\n"
        log_run_info(log_file, run_info)