###################################################
# This version is currently the most optimzed one
###################################################

import asyncio
import time
from datetime import datetime
import random
import subprocess
import random
import json
import autogen
import argparse
from color import slow_type_approximation, slow_type_target
from tool_agents_sp import DirectAgent, ReactAgent
import argparse
from datasets import load_dataset
import nltk

from autogen import AssistantAgent, UserProxyAgent, config_list_from_json
from autogen import (
    AssistantAgent,
    UserProxyAgent,
    config_list_from_json,
    GroupChat,
    GroupChatManager,
)
from autogen.agentchat.contrib.agent_builder import AgentBuilder

from util import Logger, cancel
import config
import tiktoken

encoding = tiktoken.get_encoding("cl100k_base")


def concurrent_calls():
    tasks = asyncio.all_tasks()
    pending_tasks = [t for t in tasks if not t.done() and not t.cancelled()]
    return len(pending_tasks)


### functions that can be customized
def parse_user_input(user_input, s, t):
    if user_input == "approximation answer":
        return s[0]
    elif user_input == "target answer":
        return t[0][1][0]
    elif user_input == "":
        return t[0][1][0]
    else:
        return user_input


### functions that can be customized
def interaction_function(s, t, logger):
    try:
        config.TOTAL_APPROXIMATION_CALLS += 1
        if judge_to_be_true(s[0], t[0][1][0]):
            config.TOTAL_CORRECT_APPROXIMATION_CALLS += 1
            print("The approximation agent thinks it is: " + str(s[0]))
            print("The target agent thinks it is " + str(t[0][1][0]) + " as well.")
        else:
            print("The approximation agent thinks it is: " + str(s[0]))
            print("But the target agent thinks it is " + str(t[0][1][0]) + " instead.")
        user_input = ""
        user_input = parse_user_input(user_input, s, t)
    except KeyboardInterrupt as e:
        user_input = input("")

        user_input = parse_user_input(user_input, s, t)
        if not judge_to_be_true(user_input, str(t[0][1][0])):
            print(
                f"Since you think the action should be {user_input} ... we will follow your suggestion :)"
            )
        print("-------------------")

    return user_input


### functions that can be customized
# def judge_to_be_true(s, t):
#     if s == t:
#         return True
#     else:
#         return False


def judge_to_be_true(s, t):
    try:
        approximation_function_name = s.split("[")[0].strip()
        target_function_name = t.split("[")[0].strip()

        approximation_function_arg = s[s.index("[") : s.index("]")].strip()
        target_function_arg = t[t.index("[") : t.index("]")].strip()

        def token_edit_levenstein_similarity_normalized(
            text1: str, text2: str
        ) -> float:
            """
            Compute the normalized levenstein distance between two texts.
            """
            return 1 - nltk.edit_distance(text1, text2) / max(len(text1), len(text2))

        if approximation_function_name == target_function_name:
            if (
                token_edit_levenstein_similarity_normalized(
                    approximation_function_arg, target_function_arg
                )
                > 0.5
            ):
                return True

        return False
    except:
        if s == t:
            return True
        else:
            return False


############# autogen code for speculative planning #############
def ordinal(n):
    if 11 <= (n % 100) <= 13:
        suffix = "th"
    else:
        suffix = ["th", "st", "nd", "rd", "th"][min(n % 10, 4)]
    return str(n) + suffix


def simulate_within_T_interaction(
    sas, tas, flatten_tas, printed_ids, logger, tar_assistant
):
    # print('simulate_within_T_interaction')
    # print('sas', sas)
    # print('tas', tas)
    ## conduct on-time printing out
    tas_ids = [t[0] for t in flatten_tas]
    flatten_printed_ids = []
    for ids in printed_ids:
        if ids:
            flatten_printed_ids += ids
    tas_length = len(tas_ids)
    for l in range(0, tas_length + 1):
        if not (
            tas_ids[:l] == list(range(len(flatten_tas)))[:l]
            and len(sas) >= len(flatten_tas[:l])
        ):
            tas_length = l - 1
            break

    # if the printed ids contain wrong results
    contain_wrong_result = False
    if tas_length > 0:
        for printed_id in flatten_printed_ids:
            if not judge_to_be_true(sas[printed_id][0], tas[printed_id][0][1][0]):
                contain_wrong_result = True

    if tas_length > 0 and not contain_wrong_result:
        tas_ids = tas_ids[:tas_length]
        to_print_ids = list(set(tas_ids) - set(flatten_printed_ids))
        for order_id, to_print_id in enumerate(to_print_ids):
            if order_id > 0:
                if not judge_to_be_true(
                    sas[to_print_ids[order_id - 1]][0],
                    tas[to_print_ids[order_id - 1]][0][1][0],
                ):
                    break
            printed_ids[to_print_id].append(to_print_id)
            user_input = interaction_function(
                sas[to_print_id], tas[to_print_id], logger
            )
            # print('user_input', user_input.lower())
            if (
                str(user_input) == str(tas[to_print_id][0][1][0])
                and user_input.lower() != "terminate"
            ):
                continue
            elif (
                str(user_input) == str(tas[to_print_id][0][1][0])
                and user_input.lower() == "terminate"
            ):
                return printed_ids, tas
            else:
                change_one_tas_position = list(tas[to_print_id][0][0])
                tar_assistant.execute(user_input)
                change_one_tas_position[1] = [
                    user_input,
                    tar_assistant.current_observation,
                ]
                tas[to_print_id][0] = change_one_tas_position
                return printed_ids, tas

    ## finish printing out
    return printed_ids, tas


def simulate_leftover_interaction(
    sas, tas, flatten_tas, printed_ids, logger, tar_assistant
):
    # print('simulate_leftover_interaction')
    # when sas is slower than tas
    # we need to print out the extra here
    flatten_printed_ids = []
    for ids in printed_ids:
        if ids:
            flatten_printed_ids += ids
    print_leftover = True
    if len(sas) > len(flatten_printed_ids) and len(flatten_tas) > len(
        flatten_printed_ids
    ):
        for printed_id in flatten_printed_ids:
            if not judge_to_be_true(sas[printed_id][0], tas[printed_id][0][1][0]):
                print_leftover = False
    else:
        print_leftover = False

    # if the printed ids contain wrong results
    contain_wrong_result = False
    for printed_id in flatten_printed_ids:
        if not judge_to_be_true(sas[printed_id][0], tas[printed_id][0][1][0]):
            contain_wrong_result = True

    if print_leftover and not contain_wrong_result:
        for print_id in range(len(flatten_printed_ids), min(len(sas), len(tas))):
            if tas[print_id]:
                user_input = interaction_function(sas[print_id], tas[print_id], logger)
                printed_ids[print_id].append(print_id)
                if str(user_input) == str(tas[print_id][0][1][0]):
                    continue
                else:
                    change_one_tas_position = list(tas[print_id][0][0])
                    tar_assistant.execute(user_input)
                    change_one_tas_position[1] = [
                        user_input,
                        tar_assistant.current_observation,
                    ]
                    tas[print_id][0] = change_one_tas_position
            else:
                break
            if not judge_to_be_true(sas[print_id][0], tas[print_id][0][1][0]):
                break
            if user_input.lower() == "terminate":
                break

    return printed_ids, tas


async def A_generate(assistant, total_step_number, logger):
    start = time.time()

    # utilize the scratch pad
    action, finished = await assistant.direct_act()
    config.TOTAL_TOKEN_GENERATION.append(action)

    config.TOTAL_TOKEN_GENERATION.append(action)
    end = time.time()
    print(f"approximation: step time for {total_step_number}: " + str(end - start))

    # find action
    if not finished:
        assistant.execute(action)
        observation = assistant.current_observation
    else:
        observation = "terminate"

    return action, observation


async def T_generate(
    assistant,
    total_step_number,
    tas,
    sas,
    previous_steps,
    target_tasks,
    printed_ids,
    current_step,
    logger,
):
    start = time.time()

    # add approximation result to the prompt
    scratchpad = ""
    scratchpad = assistant.create_scratchpad(scratchpad, previous_steps + sas)
    # call agent to generate the response
    thought, action, finished = await assistant.think_and_act(
        scratchpad, total_step_number
    )
    config.TOTAL_TOKEN_GENERATION.append(thought)
    config.TOTAL_TOKEN_GENERATION.append(action)
    if action == "cancelled":
        return tas, printed_ids

    in_step_number = total_step_number - current_step
    tas[in_step_number] = [[in_step_number, (action, finished)]]

    # cancel next target_tasks after which we know is incorrect
    # but there maybe unknown result before this point, so we don't kill all processes
    flatten_tas = []
    for t in tas:
        if t:
            flatten_tas += t
    flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False)
    printed_ids, tas = simulate_within_T_interaction(
        sas, tas, flatten_tas, printed_ids, logger, assistant
    )

    # if the target result is terminate, we break the loop
    flatten_ids = [t[0] for t in flatten_tas]
    if flatten_ids == list(range(len(flatten_ids))):
        for step_number, (s, t) in enumerate(zip(sas, flatten_tas)):
            if t[0] == step_number and t[1][1]:  # terminate
                # throw Exception here to halt everything
                end = time.time()
                print(f"target: step time for {total_step_number}: " + str(end - start))
                raise Exception("terminate the whole process!")

    # if it is a wrong result
    # we break out the target processes and cancel processes that comes after it
    flatten_tas = []
    for t in tas:
        if t:
            flatten_tas += t
    flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False)
    for ta in flatten_tas:
        if len(sas) > ta[0]:
            # if approximation action != target action
            if not sas[ta[0]][0] == ta[1][0]:
                # print('raise exception here')
                end = time.time()
                pending_approximation_tasks = [
                    t
                    for t in asyncio.all_tasks()
                    if not t.cancelled()
                    and not t.done()
                    and t not in target_tasks
                    and t.get_name().startswith("approximation")
                ]
                for pending_approximation_task in pending_approximation_tasks:
                    print("cancel approximation process")
                    await cancel(pending_approximation_task)
                print(f"target: step time for {total_step_number}: " + str(end - start))
                raise Exception(
                    f"approximation error happen in step {total_step_number} for current step {current_step}, the target id is {ta[0]}"
                )

    end = time.time()
    print(f"target: step time for {total_step_number}: " + str(end - start))
    return tas, printed_ids


async def onebreakingpoint_speculative_planning(
    args, app_assistant, tar_assistant, previous_steps, current_step, logger
):
    # approximation
    # target
    sas = []
    tas = []
    target_tasks = []
    printed_ids = []
    for i in range(args.k):
        break_out_approximation = False

        tas.append([])
        printed_ids.append([])
        # return action, observation
        approximation = asyncio.create_task(
            A_generate(app_assistant, current_step + i, logger=logger),
            name=f"approximation_{current_step+i}",
        )
        # return thought, action, finished
        target = asyncio.create_task(
            T_generate(
                tar_assistant,
                current_step + i,
                tas,
                sas,
                previous_steps,
                target_tasks=target_tasks,
                printed_ids=printed_ids,
                current_step=current_step,
                logger=logger,
            )
        )
        target_tasks.append(target)

        concurrent_api_calls = concurrent_calls()
        if concurrent_api_calls >= config.MAX_CONCURRENT_CALLS:
            config.MAX_CONCURRENT_CALLS = concurrent_api_calls

        try:
            action, observation = await approximation
            sa = [action, observation]
            sas.append(sa)
        except asyncio.CancelledError as e:
            pass

        # if sa == terminate, and ta == terminate, we break the loop
        # if sa == terminate, and ta != terminate, we also break the loop
        # thus as long as sa == terminate, we break the loop
        if sa[1] == True or sa[0].lower == 'terminate': ## terminate
            break_out_approximation = True

        # tas is now a list of lists, so we need to flatten it in order to compare with sas
        flatten_tas = []
        for t in tas:
            if t:
                flatten_tas += t
        flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False)
        # halt the ongoing approximation loop
        for t in flatten_tas:
            if len(sas) > t[0]:
                # if approximation action != target action
                if not judge_to_be_true(sas[t[0]][0], t[1][0]):
                    # print('break out of the approximation loop')
                    # print('sas', sas)
                    # print('tas', tas)
                    break_out_approximation = True
                    for process_id, one_task in enumerate(target_tasks):
                        if (
                            not one_task.cancelled()
                            and not one_task.done()
                            and process_id > t[0]
                        ):
                            # print('start cancel tasks within approximation loop with task id with', process_id)
                            await cancel(one_task)
                            # print('finish cancel tasks within approximation loop with task id with', process_id)
                    break

        if break_out_approximation:
            break

    # print('====finish approximation loop====')
    # after halting the approximation loop
    # we need to collect the target results
    # organize to sas, see how much we want to preserve
    # SHOULD NOT exclude finished tasks, because exceptions are only thrown when tasks are finished
    pending_tasks = [t for t in target_tasks if not t.cancelled()]
    while pending_tasks:
        break_while_loop = False
        try:
            # print('[pending_task.done() for pending_task in pending_tasks]', str([pending_task.done() for pending_task in pending_tasks]))
            if [pending_task.done() for pending_task in pending_tasks] == [True] * len(
                pending_tasks
            ):
                break_while_loop = True
                try:
                    await asyncio.gather(*pending_tasks, return_exceptions=False)
                    break
                except:
                    break
            # should not await cancelled tasks
            # return_exceptions=False is also the default value
            await asyncio.gather(*pending_tasks, return_exceptions=False)
            break_while_loop = True
            break
        except Exception as e:
            # print('get to the exception part')
            if str(e) == "terminate the whole process!":
                # cancel all pending tasks because we have already got TERMINATE
                for pending_task in pending_tasks:
                    await cancel(pending_task)
                # organize the results and return the final results
                flatten_tas = []
                for t in tas:
                    if t:
                        flatten_tas += t
                flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False)
                printed_ids, tas = simulate_leftover_interaction(
                    sas, tas, flatten_tas, printed_ids, logger, tar_assistant
                )

                # get the final tas result
                flatten_tas = []
                for t in tas:
                    if t:
                        flatten_tas += t
                flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False)
                for step_number, (s, t) in enumerate(zip(sas, flatten_tas)):
                    if t[0] == step_number and not judge_to_be_true(s[0], t[1][0]):
                        tar_assistant.execute(t[1][0])
                        to_replace_action = [
                            flatten_tas[step_number][1][0],
                            tar_assistant.current_observation,
                        ]
                        sas = sas[:step_number] + [to_replace_action]
                        app_assistant.update_scratchpad(
                            sas[-1][0], sas[-1][1], len(previous_steps) + step_number
                        )
                        break

                return sas
            else:
                # print('need to cancel some tasks')
                # cancel t_j for j > i if t_i != s_i
                if [pending_task.done() for pending_task in pending_tasks] == [
                    True
                ] * len(pending_tasks):
                    # print('no task left to be cancelled')
                    break_while_loop = True
                    # to handle all exceptions in finished tasks to avoid "exception not handled" error
                    try:
                        await asyncio.gather(*pending_tasks, return_exceptions=False)
                        break
                    except:
                        break
                if break_while_loop:
                    break
                mistaken_process_id = int(str(e)[-1])
                # print('mistaken_process_id', mistaken_process_id)
                # for process_id, one_task in enumerate(pending_tasks):
                #     if not one_task.cancelled() and not one_task.done() and process_id > mistaken_process_id:
                #         print('should cancel task id with', process_id)
                # only cancel t_j such that j > i
                for process_id, one_task in enumerate(pending_tasks):
                    if (
                        not one_task.cancelled()
                        and not one_task.done()
                        and process_id > mistaken_process_id
                    ):
                        # print('start cancel task id with', process_id)
                        await cancel(one_task)
                        # print('finish cancel task id with', process_id)

                pending_tasks = [
                    t
                    for process_id, t in enumerate(target_tasks)
                    if not t.cancelled() and process_id != mistaken_process_id
                ]

        if break_while_loop:
            break

    # get user input or interruption
    flatten_tas = []
    for t in tas:
        if t:
            flatten_tas += t
    flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False)
    printed_ids, tas = simulate_leftover_interaction(
        sas, tas, flatten_tas, printed_ids, logger, tar_assistant
    )

    # get the final tas result
    flatten_tas = []
    for t in tas:
        if t:
            flatten_tas += t
    flatten_tas = sorted(flatten_tas, key=lambda x: x[0], reverse=False)
    for step_number, (s, t) in enumerate(zip(sas, flatten_tas)):
        if t[0] == step_number and not judge_to_be_true(s[0], t[1][0]):  # t[1] != s:
            tar_assistant.execute(t[1][0])
            to_replace_action = [
                flatten_tas[step_number][1][0],
                tar_assistant.current_observation,
            ]
            sas = sas[:step_number] + [to_replace_action]
            app_assistant.update_scratchpad(
                sas[-1][0], sas[-1][1], len(previous_steps) + step_number
            )
            break

    return sas


async def speculative_planning(args, app_assistant, tar_assistant, logger):
    begin_time = datetime.now()
    steps = []
    breaking_points = 0
    i = 0
    while True:
        result = await onebreakingpoint_speculative_planning(
            args, app_assistant, tar_assistant, steps, len(steps), logger
        )

        steps += result
        breaking_points += 1
        # logger.log('the number of breaking points: ' + str(breaking_points))
        i += len(result)

        # if the last action is terminate, we break the generation process
        if result[-1][0].lower() == "terminate" or result[-1][1] == True:
            break

    end_time = datetime.now()
    print(f"{end_time} - {begin_time} = {end_time - begin_time}")

    return steps


if __name__ == "__main__":
    ## gloabl variables
    config.MAX_CONCURRENT_CALLS = 0
    config.TOTAL_APPROXIMATION_CALLS = 0
    config.TOTAL_CORRECT_APPROXIMATION_CALLS = 0
    config.TOTAL_TOKEN_GENERATION = []
    random.seed(2)
    tools_list = [
        "notebook",
        "flights",
        "attractions",
        "accommodations",
        "restaurants",
        "googleDistanceMatrix",
        "planner",
        "cities",
    ]

    # model_name = ['gpt-3.5-turbo-1106','gpt-4-1106-preview','gemini','mistral-7B-32K','mixtral','ChatGLM3-6B-32K'][2]
    parser = argparse.ArgumentParser()
    parser.add_argument("--set_type", type=str, default="validation")
    parser.add_argument("--model_name", type=str, default="gpt-4-turbo")
    parser.add_argument("--output_dir", type=str, default="./")
    parser.add_argument('--task_id', type=int, default=0)
    parser.add_argument(
        "--k",
        type=int,
        default=4,
        help="number of approximation steps to generate everytime",
    )

    args = parser.parse_args()

    if args.set_type == "validation":
        query_data_list = load_dataset("osunlp/TravelPlanner", "validation")[
            "validation"
        ]
    elif args.set_type == "test":
        query_data_list = load_dataset("osunlp/TravelPlanner", "test")["test"]

    numbers = [i for i in range(1, len(query_data_list) + 1)]
    task_id = numbers[args.task_id] - 1
    # select query
    query = query_data_list[task_id]["query"]

    logger = Logger(
        f"simulation_datapoint{task_id}_k{args.k}.log",
        on=False,
    )

    # setup target agent
    target_agent = ReactAgent(
        None,
        tools=tools_list,
        max_steps=30,
        react_llm_name=args.model_name,
        planner_llm_name=args.model_name,
    )
    target_agent.query = query

    # setup approximation agent
    approximation_agent = DirectAgent(
        None,
        tools=tools_list,
        max_steps=30,
        react_llm_name=args.model_name,
        planner_llm_name=args.model_name,
    )
    approximation_agent.query = query

    # run the speculative planning
    steps = asyncio.run(
        speculative_planning(args, approximation_agent, target_agent, logger)
    )

    # record the metrics
    logger.log("final result for the speculative planning " + str([s[0] for s in steps]))
    logger.log(
        "max concurrent calls: " + str(config.MAX_CONCURRENT_CALLS - 1)
    )  # speculative_planning will add one more call
    logger.log(
        "accuracy of approximation agent: "
        + str(
            config.TOTAL_CORRECT_APPROXIMATION_CALLS / config.TOTAL_APPROXIMATION_CALLS
        )
    )
    token_number = sum(
        [len(encoding.encode(response)) for response in config.TOTAL_TOKEN_GENERATION]
    )
    logger.log("total token generation: " + str(token_number))
    logger.log('total number of steps' + str(len(steps)))
