import asyncio
import time
from datetime import datetime
import random
import subprocess
import random
import json 
import autogen 
import argparse 
from color import slow_type_approximation, slow_type_target

from autogen import AssistantAgent, UserProxyAgent, config_list_from_json
from autogen import AssistantAgent, UserProxyAgent, config_list_from_json, GroupChat, GroupChatManager
from autogen.agentchat.contrib.agent_builder import AgentBuilder

from utils import Logger, cancel, register_async_handler
import config
import tiktoken

encoding = tiktoken.get_encoding("cl100k_base")

def concurrent_calls():
    tasks = asyncio.all_tasks()
    pending_tasks = [t for t in tasks if not t.done() and not t.cancelled()]
    return len(pending_tasks)

### functions that can be customized
def parse_user_input(user_input, s, t):
    if user_input == 'approximation answer':
        return s
    elif user_input == 'target answer':
        return t[0][1]
    elif user_input == '':
        return t[0][1]
    else:
        return user_input

### functions that can be customized
def interaction_function(sas, tas, to_print_id, logger, target_logger, prev_steps, target_tasks):
    s = sas[to_print_id]
    t = tas[to_print_id]
    target_logger.log(f'Target: Step {t[0][0] + len(prev_steps)+1} - {t[0][1]}')
    try:
        config.TOTAL_APPROXIMATION_CALLS += 1
        if judge_to_be_true(s, t[0][1]):
            config.TOTAL_CORRECT_APPROXIMATION_CALLS += 1
            logger.log(f'The target agent thinks step {len(prev_steps) + to_print_id+1} should be '+ str(t[0][1]) + ', which agrees with the approximation agent.')
            try:
                logger.log(f'The approximation agent thinks step {len(prev_steps) + to_print_id+2} should be ' + sas[to_print_id+1])
                config.HIL_INTERACTION = to_print_id+1
                register_async_handler(target_tasks=target_tasks)
            except:
                pass
        else:
            logger.log(f'The target agent thinks step {len(prev_steps) + to_print_id+1} should be '+ str(t[0][1]) + ', correcting what the approximation agent thinks which is ' + str(s) + '.')
        user_input = ''
        user_input = parse_user_input(user_input, s, t)
    except KeyboardInterrupt as e:
        user_input = input('')

        user_input = parse_user_input(user_input, s, t)
        if not judge_to_be_true(user_input, str(t[0][1])):
            logger.log(f'Since you think the action should be {user_input} ... we will follow your suggestion :)')
        logger.log('-------------------')

    return user_input

### functions that can be customized
def judge_to_be_true(s, t):
    if s == t:
        return True
    else:
        return False

############# autogen code for speculative planning #############
def load_data(args):
    with open("data/openagi_task_description.txt", "r") as f:
        data = f.read()
    data = [t.strip() for t in data.split("\n")]
    return data 

def parse_response(response):
    if '<' in response and '>' in response and '</' in response:
        # find the last tag
        all_starts = [i for i in range(len(response)-1) if response[i] == '<' and response[i+1] != '/']
        all_ends = [i for i in range(len(response)-1) if response[i:i+2] == '</']
        start = all_starts[-1]
        end = all_ends[-1]
        return response[start:end].replace('<tool>', '').replace('<', '').replace('>', '')
    elif '**' in response and response.count('**') >= 2:
        start = response.index('**') + len('**')
        response = response[start:]
        end = response.index('**')
        return response[:end].replace('<', '').replace('>', '')
    else:
        return ''

def ordinal(n):
    if 11 <= (n % 100) <= 13:
        suffix = 'th'
    else:
        suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)]
    return str(n) + suffix

def simulate_within_T_interaction(sas, tas, flatten_tas, printed_ids, logger, target_logger, prev_steps, target_tasks):
    # print('sas', sas)
    # print('tas', tas)
    ## conduct on-time printing out
    tas_ids = [t[0] for t in flatten_tas]
    flatten_printed_ids = []
    for ids in printed_ids:
        if ids:
            flatten_printed_ids += ids
    tas_length = len(tas_ids)
    for l in range(0, tas_length+1):
        if not(tas_ids[:l] == list(range(len(flatten_tas)))[:l] and len(sas) >= len(flatten_tas[:l])):
            tas_length = l-1
            break

    # if the printed ids contain wrong results
    contain_wrong_result = False
    if tas_length > 0:
        for printed_id in flatten_printed_ids:
            if not judge_to_be_true(sas[printed_id], tas[printed_id][0][1]):
                contain_wrong_result = True

    if tas_length > 0 and not contain_wrong_result:
        tas_ids = tas_ids[:tas_length]
        to_print_ids = list(set(tas_ids) - set(flatten_printed_ids))
        for order_id, to_print_id in enumerate(to_print_ids):
            if order_id > 0:
                if not judge_to_be_true(sas[to_print_ids[order_id-1]], tas[to_print_ids[order_id-1]][0][1]):
                    break
            printed_ids[to_print_id].append(to_print_id)
            user_input = interaction_function(sas, tas, to_print_id, logger, target_logger, prev_steps, target_tasks)
            # print('user_input', user_input.lower())
            if str(user_input) == str(tas[to_print_id][0][1]) and user_input.lower() != 'terminate':
                continue 
            elif str(user_input) == str(tas[to_print_id][0][1]) and user_input.lower() == 'terminate':
                return printed_ids, tas
            else:
                change_one_tas_position = list(tas[to_print_id][0])
                change_one_tas_position[1] = user_input
                tas[to_print_id][0] = change_one_tas_position
                return printed_ids, tas

    ## finish printing out
    return printed_ids, tas

def simulate_leftover_interaction(sas, tas, flatten_tas, printed_ids, logger, target_logger, prev_steps, target_tasks):
    # print('simulate_leftover_interaction')
    # when sas is slower than tas 
    # we need to print out the extra here
    flatten_printed_ids = []
    for ids in printed_ids:
        if ids:
            flatten_printed_ids += ids
    print_leftover = True
    if len(sas) > len(flatten_printed_ids) and len(flatten_tas) > len(flatten_printed_ids):
        for printed_id in flatten_printed_ids:
            if not judge_to_be_true(sas[printed_id], tas[printed_id][0][1]):
                print_leftover = False
    else:
        print_leftover = False

    # if the printed ids contain wrong results
    contain_wrong_result = False
    for printed_id in flatten_printed_ids:
        if not judge_to_be_true(sas[printed_id], tas[printed_id][0][1]):
            contain_wrong_result = True

    if print_leftover and not contain_wrong_result:
        for print_id in range(len(flatten_printed_ids), min(len(sas),len(tas))):
            if tas[print_id]:
                user_input = interaction_function(sas, tas, print_id, logger, target_logger, prev_steps, target_tasks)
                printed_ids[print_id].append(print_id)
                if str(user_input) == str(tas[print_id][0][1]):
                    continue 
                else:
                    change_one_tas_position = list(tas[print_id][0])
                    change_one_tas_position[1] = user_input
                    tas[print_id][0] = change_one_tas_position
            else:
                break
            if not judge_to_be_true(sas[print_id], tas[print_id][0][1]):
                break
            if user_input.lower() == 'terminate':
                break

    return printed_ids, tas

async def A_generate(assistant, prompt, total_step_number, logger, approximation_logger):
    start = time.time()
    prompt += f"\n\nDirectly tell me what **the ONE NEXT action step** based on the current action trajectory should be. (Remember to use xml tag <tool> and </tool> for formatting.)\nWhat should be the action in Step {total_step_number+1}?\n\nStep {total_step_number+1}:"
    n = 0
    while True:
        try:
            n += 1
            if n >= 10:
                result = ''
                return result
            response = await assistant.a_generate_reply(messages=[{'content':prompt, 'role':'user'}])
            config.TOTAL_TOKEN_GENERATION.append(response)
            result = parse_response(response)
            end = time.time()
            approximation_logger.log(f'Approximation: Step {total_step_number+1} - {result}')
            print(f'approximation: step time for {total_step_number}: '+ str(end-start))
            return result
        except:
            continue

async def ReAct(assistant, prompt, total_step_number):
    prompt += f"\n\nCarefully think about **the ONE NEXT action step** based on the current action trajectory."
    prompt += f"\nGenerate thought only.\nThought {total_step_number}:"
    thought = await assistant.a_generate_reply(messages=[{'content':prompt, 'role':'user'}])
    config.TOTAL_TOKEN_GENERATION.append(thought)
    prompt += " " + thought
    prompt += f"\nGenerate Action only based on thoughts. Remember to use xml tag <tool> and </tool> for formatting. \nAction {total_step_number}:"
    # generate action based on thought
    response = await assistant.a_generate_reply(messages=[{'content':prompt, 'role':'user'}])
    config.TOTAL_TOKEN_GENERATION.append(response)
    return response

async def postprocess_T_generation(result, total_step_number, tas, sas, target_tasks, printed_ids=[[]], current_step=0, logger=None, target_logger=None, prev_steps=[]): 
    in_step_number = total_step_number - current_step
    tas[in_step_number].append((in_step_number,result))

    # cancel next target_tasks after which we know is incorrect
    # but there maybe unknown result before this point, so we don't kill all processes
    flatten_tas = []
    for t in tas:
        if t:
            flatten_tas += t
    flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False)
    printed_ids, tas = simulate_within_T_interaction(sas, tas, flatten_tas, printed_ids, logger, target_logger, prev_steps, target_tasks)

    # if the target result is terminate, we break the loop
    flatten_ids = [t[0] for t in flatten_tas]
    if flatten_ids == list(range(len(flatten_ids))):
        for step_number, (s, t) in enumerate(zip(sas, flatten_tas)):
            if t[0] == step_number and t[1].lower() == 'terminate':
                # throw Exception here to halt everything
                end = time.time()
                raise Exception('terminate the whole process!')

    # if it is a wrong result
    # we break out the target processes and cancel processes that comes after it
    flatten_tas = []
    for t in tas:
        if t:
            flatten_tas += t
    flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False)
    for ta in flatten_tas:
        if len(sas) > ta[0]:
            if not sas[ta[0]] == ta[1]:
                end = time.time()
                pending_approximation_tasks = [t for t in asyncio.all_tasks() if not t.cancelled() and not t.done() and t not in target_tasks and t.get_name().startswith('approximation')]
                for pending_approximation_task in pending_approximation_tasks:
                    await cancel(pending_approximation_task)
                raise Exception(f'approximation error happen in step {total_step_number} for current step {current_step}, the target id is {ta[0]}')

    return tas, printed_ids

async def T_generate_without_handling_signal(args, assistant, prompt, total_step_number, tas, sas, target_tasks, printed_ids, current_step, logger, target_logger, prev_steps, suggestion=None):
    if suggestion:
        prompt += "\n\nFor this step, the user provides the following suggestion. Take this suggestion in mind when planning for this step:\nSuggestion:" + suggestion
    if args.target_type == 'direct':
        prompt += f"\n\nDirectly tell me what **the ONE NEXT action step** based on the current action trajectory should be. (Remember to use xml tag <tool> and </tool> for formatting.)\nWhat should be the action in Step {total_step_number+1}?\n\nStep {total_step_number+1}:"
    else:
        prompt += f"\n\nCarefully think about **the ONE NEXT action step** based on the current action trajectory, by first providing a clear reasoning chain. And then decide which tool to use for the current step. (Remember to use xml tag <tool> and </tool> for formatting.)\nWhat should be the action in Step {total_step_number+1}?\n\nStep {total_step_number+1}:"
    
    # call agent to generate the response
    n = 0
    while True:
        try:
            n += 1
            if n >= 10:
                result = ''
                break
            if args.target_type == 'react':
                response = await ReAct(assistant, prompt, total_step_number)
            else:
                response = await assistant.a_generate_reply(messages=[{'content':prompt, 'role':'user'}])
            config.TOTAL_TOKEN_GENERATION.append(response)
            result = parse_response(response)
            break
        except:
            await asyncio.sleep(0.1)
            continue
    tas, printed_ids = await postprocess_T_generation(result, total_step_number, tas, sas, target_tasks, printed_ids=printed_ids, current_step=current_step, logger=logger, target_logger=target_logger, prev_steps=prev_steps) 
    return tas, printed_ids

async def T_generate(args, assistant, prompt, total_step_number, tas, sas, target_tasks, printed_ids, current_step, logger, target_logger, prev_steps):
    async def handle_user_input_choices(total_step_number, tas, sas, target_tasks, printed_ids, current_step, logger, target_logger, prev_steps):
        result = input("Do you have an idea what to do or just some opinions? Enter 'idea' if you know what to do for next step. Enter 'opinion' if you just want to add some suggestions.\n")
        if result == 'idea':
            result = input("What do you think this step should be?\n")
            result = 'any tool' # just provide a placeholder here
            tas, printed_ids = await postprocess_T_generation(result, total_step_number, tas, sas, target_tasks, printed_ids=printed_ids, current_step=current_step, logger=logger, target_logger=target_logger, prev_steps=prev_steps)
        else:
            suggestion = input("What are your suggestions?\n")
            suggestion = 'Think very carefully.'
            tas, printed_ids = await T_generate_without_handling_signal(args, assistant, prompt, total_step_number, tas, sas, target_tasks, printed_ids, current_step, logger, target_logger, prev_steps, suggestion=suggestion)
        return tas, printed_ids
    
    start = time.time()
    if args.target_type == 'direct':
        prompt += f"\n\nDirectly tell me what **the ONE NEXT action step** based on the current action trajectory should be. (Remember to use xml tag <tool> and </tool> for formatting.)\nWhat should be the action in Step {total_step_number+1}?\n\nStep {total_step_number+1}:"
    else:
        prompt += f"\n\nCarefully think about **the ONE NEXT action step** based on the current action trajectory, by first providing a clear reasoning chain. And then decide which tool to use for the current step. (Remember to use xml tag <tool> and </tool> for formatting.)\nWhat should be the action in Step {total_step_number+1}?\n\nStep {total_step_number+1}:"
    
    try:
        # call agent to generate the response
        n = 0
        while True:
            try:
                n += 1
                if n >= 10:
                    result = ''
                    break
                if args.target_type == 'react':
                    response = await ReAct(assistant, prompt, total_step_number)
                else:
                    response = await assistant.a_generate_reply(messages=[{'content':prompt, 'role':'user'}])
                config.TOTAL_TOKEN_GENERATION.append(response)
                result = parse_response(response)
                break
            except:
                await asyncio.sleep(0.1)
                continue
        tas, printed_ids = await postprocess_T_generation(result, total_step_number, tas, sas, target_tasks, printed_ids=printed_ids, current_step=current_step, logger=logger, target_logger=target_logger, prev_steps=prev_steps)
    except asyncio.CancelledError as e:
        if config.USERINPUT:
            config.USERINPUT=False
            user_input_task = asyncio.create_task(handle_user_input_choices(total_step_number, tas, sas, target_tasks, printed_ids, current_step, logger, target_logger, prev_steps))
            target_tasks.append(user_input_task)
    except asyncio.exceptions.TimeoutError:
        if config.USERINPUT:
            config.USERINPUT=False
            user_input_task = asyncio.create_task(handle_user_input_choices(total_step_number, tas, sas, target_tasks, printed_ids, current_step, logger, target_logger, prev_steps))
            target_tasks.append(user_input_task)
    except Exception as e:
        if config.USERINPUT:
            config.USERINPUT=False
            user_input_task = asyncio.create_task(handle_user_input_choices(total_step_number, tas, sas, target_tasks, printed_ids, current_step, logger, target_logger, prev_steps))
            target_tasks.append(user_input_task)
            
    end = time.time()
    print(f'target: step time for {total_step_number}: '+ str(end-start))

    return tas, printed_ids

async def onebreakingpoint_speculative_planning(args, app_assistant, tar_assistant, prompt, current_step, logger, target_logger, approximation_logger, prev_steps):
    # approximation
    # target
    sas = []
    tas = []
    target_tasks = []
    printed_ids = []
    for i in range(args.k):
        break_out_approximation = False

        tas.append([])
        printed_ids.append([])
        approximation = asyncio.create_task(A_generate(app_assistant, prompt, current_step+i, logger=logger, approximation_logger=approximation_logger), name=f"approximation_{current_step+i}")
        target = asyncio.create_task(T_generate(args, tar_assistant, prompt, current_step+i, tas, sas, target_tasks=target_tasks, printed_ids=printed_ids, current_step=current_step, logger=logger, target_logger=target_logger, prev_steps=prev_steps), name=f"target_{i}")
        target_tasks.append(target)

        concurrent_api_calls = concurrent_calls()
        if concurrent_api_calls >= config.MAX_CONCURRENT_CALLS:
            config.MAX_CONCURRENT_CALLS = concurrent_api_calls

        try:
            sa = await approximation
            sas.append(sa)

            # check if we need to print out the approximation result here
            # which is when previous steps of target agent are all done
            flatten_tas = []
            for t in tas[:len(sas)-1]:
                if t:
                    flatten_tas += t
            flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False)
            flatten_ids = [t[0] for t in flatten_tas]
            flatten_tas_action = [t[1] for t in flatten_tas]
            if flatten_ids == list(range(len(flatten_ids))) and len(flatten_ids) == len(sas)-1 and all([judge_to_be_true(s, t) for s, t in zip(sas[:-1], flatten_tas_action)]):
                flattened_printed_ids = [printed_id[0] for printed_id in printed_ids if printed_id != []]
                
                if flattened_printed_ids:
                    if len(sas) == len(flattened_printed_ids)+1:
                        logger.log(f'in breaking, The approximation agent thinks step {current_step+i+1} should be ' + sa)
                        config.HIL_INTERACTION = len(sas)-1
                        register_async_handler(target_tasks=target_tasks)
                else:
                    logger.log(f'in breaking, The approximation agent thinks step {current_step+i+1} should be ' + sa)
                    config.HIL_INTERACTION = len(sas)-1
                    register_async_handler(target_tasks=target_tasks)

            # modify the prompt based on latest approximation result
            if '## Current Action Trajectory:' not in prompt:
                prompt += f'\n\n## Current Action Trajectory:\n'
            prompt += f'\nAction {current_step+i+1} in the decided action trajectory: {str(sa)}.'
        except asyncio.CancelledError as e:
            pass

        # if sa == terminate, and ta == terminate, we break the loop
        # if sa == terminate, and ta != terminate, we also break the loop
        # thus as long as sa == terminate, we break the loop
        if sa.lower() == 'terminate':
            break_out_approximation = True

        # tas is now a list of lists, so we need to flatten it in order to compare with sas
        flatten_tas = []
        for t in tas:
            if t:
                flatten_tas += t
        flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False)
        # halt the ongoing approximation loop
        for t in flatten_tas:
            if len(sas) > t[0]:
                if not judge_to_be_true(sas[t[0]], t[1]):
                    # print('break out of the approximation loop')
                    # print('sas', sas)
                    # print('tas', tas)
                    break_out_approximation = True
                    for process_id, one_task in enumerate(target_tasks):
                        if not one_task.cancelled() and not one_task.done() and process_id > t[0]:
                            # print('start cancel tasks within approximation loop with task id with', process_id)
                            await cancel(one_task)
                            # print('finish cancel tasks within approximation loop with task id with', process_id)
                    break

        if break_out_approximation:
            break

    # print('====finish approximation loop====')
    # after halting the approximation loop
    # we need to collect the target results
    # organize to sas, see how much we want to preserve
    # SHOULD NOT exclude finished tasks, because exceptions are only thrown when tasks are finished
    pending_tasks = [t for t in target_tasks if not t.cancelled()]
    while pending_tasks:
        break_while_loop = False
        try:
            # print('[pending_task.done() for pending_task in pending_tasks]', str([pending_task.done() for pending_task in pending_tasks]))
            if [pending_task.done() for pending_task in pending_tasks] == [True]*len(pending_tasks):
                break_while_loop = True
                try:
                    await asyncio.gather(*pending_tasks, return_exceptions=False)
                    break
                except:
                    break
            # should not await cancelled tasks
            # return_exceptions=False is also the default value
            await asyncio.gather(*pending_tasks, return_exceptions=False)
            break_while_loop = True
            break
        except Exception as e:
            # print('get to the exception part')
            if str(e) == 'terminate the whole process!':
                # cancel all pending tasks because we have already got TERMINATE
                for pending_task in pending_tasks:
                    await cancel(pending_task) 
                # organize the results and return the final results
                flatten_tas = []
                for t in tas:
                    if t:
                        flatten_tas += t
                flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False)
                printed_ids, tas = simulate_leftover_interaction(sas, tas, flatten_tas, printed_ids, logger, target_logger, prev_steps, target_tasks)

                # get the final tas result
                flatten_tas = []
                for t in tas:
                    if t:
                        flatten_tas += t
                flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False)
                for step_number, (s, t) in enumerate(zip(sas, flatten_tas)):
                    if t[0] == step_number and not judge_to_be_true(s, t[1]):
                        sas = sas[:step_number]+[flatten_tas[step_number][1]]
                        break
                
                return sas
            else:
                # print('need to cancel some tasks')
                # cancel t_j for j > i if t_i != s_i
                if [pending_task.done() for pending_task in pending_tasks] == [True]*len(pending_tasks):
                    # print('no task left to be cancelled')
                    break_while_loop = True
                    # to handle all exceptions in finished tasks to avoid "exception not handled" error
                    try:
                        await asyncio.gather(*pending_tasks, return_exceptions=False)
                        break
                    except:
                        break
                if break_while_loop:
                    break
                mistaken_process_id = int(str(e)[-1])
                # print('mistaken_process_id', mistaken_process_id)
                # for process_id, one_task in enumerate(pending_tasks):
                #     if not one_task.cancelled() and not one_task.done() and process_id > mistaken_process_id:
                #         print('should cancel task id with', process_id)
                # only cancel t_j such that j > i
                for process_id, one_task in enumerate(pending_tasks):
                    if not one_task.cancelled() and not one_task.done() and process_id > mistaken_process_id:
                        # print('start cancel task id with', process_id)
                        await cancel(one_task)
                        # print('finish cancel task id with', process_id)
                
                pending_tasks = [t for process_id, t in enumerate(target_tasks) if not t.cancelled() and process_id != mistaken_process_id]

        if break_while_loop:
            break

    # get user input or interruption
    flatten_tas = []
    for t in tas:
        if t:
            flatten_tas += t
    flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False)
    printed_ids, tas = simulate_leftover_interaction(sas, tas, flatten_tas, printed_ids, logger, target_logger, prev_steps, target_tasks)

    # get the final tas result
    flatten_tas = []
    for t in tas:
        if t:
            flatten_tas += t
    flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False)
    for step_number, (s, t) in enumerate(zip(sas, flatten_tas)):
        if t[0] == step_number and not judge_to_be_true(s, t[1]):# t[1] != s:
            sas = sas[:step_number]+[flatten_tas[step_number][1]]
            break
    
    return sas

async def speculative_planning(args, app_assistant, tar_assistant, prompt, logger, target_logger, approximation_logger):
    begin_time = datetime.now()
    steps = []
    breaking_points = 0
    i = 0
    while True:
        result = await onebreakingpoint_speculative_planning(args, app_assistant, tar_assistant, prompt, len(steps), logger, target_logger, approximation_logger, prev_steps=steps)

        print(result)

        previous_action_trajectory = [f'\nAction {len(steps) + j+1} in the decided action trajectory: {result[j]}.' for j in range(len(result))]
        if '## Current Action Trajectory:' not in prompt:
            prompt += f'\n\n## Current Action Trajectory:\n'
        prompt += ''.join(previous_action_trajectory)

        steps += result
        breaking_points += 1
        # logger.log('the number of breaking points: ' + str(breaking_points))
        i += len(result)

        # if the last action is terminate, we break the generation process
        if result[-1].lower() == 'terminate':
            break

    end_time = datetime.now()
    logger.log(f'{end_time} - {begin_time} = {end_time - begin_time}')

    return steps

if __name__ == '__main__':
    ## gloabl variables
    config.MAX_CONCURRENT_CALLS = 0
    config.TOTAL_APPROXIMATION_CALLS = 0
    config.TOTAL_CORRECT_APPROXIMATION_CALLS = 0
    config.TOTAL_TOKEN_GENERATION = []
    config.USERINPUT=False

    random.seed(2)
    parser = argparse.ArgumentParser(description='OpenAGI')
    parser.add_argument('--data', type=str, default='data/openagi_task_descrition.txt', help='data directory')
    parser.add_argument('--task_id', type=int, default=29, help='task id')
    parser.add_argument('--k', type=int, default=4, help='number of approximation steps to generate everytime')
    parser.add_argument('--target_type', type=str, default='react', help='cot, react, multi-agent, direct')
    args = parser.parse_args()

    logger = Logger(f'logs/{args.target_type}/simulation_datapoint{args.task_id}_k{args.k}.log', on=True)
    target_logger = Logger(f'logs/{args.target_type}/target_datapoint{args.task_id}_k{args.k}.log', on=True)
    approximation_logger = Logger(f'logs/{args.target_type}/approximation_datapoint{args.task_id}_k{args.k}.log', on=True)

    tasks = load_data(args)
    task_description = tasks[args.task_id]
    logger.log('task description: ' + task_description)
    target_logger.log('task description: ' + task_description)
    approximation_logger.log('task description: ' + task_description)
    tools = """
Available tools are as follows:

(1) <tool>Sentiment Analysis</tool>
(2) <tool>Text Summarization</tool>
(3) <tool>Machine Translation</tool>
(4) <tool>Fill Mask</tool>
(5) <tool>Question Answering</tool>
(6) <tool>Image Classification</tool>
(7) <tool>Object Detection</tool>
(8) <tool>Colorization</tool>
(9) <tool>Image Super-Resolution</tool>
(10) <tool>Image Denoising</tool>
(11) <tool>Image Deblurring</tool>
(12) <tool>Visual Question Answering</tool>
(13) <tool>Image Captioning</tool>
(14) <tool>Text-to-Image Generation</tool>
(15) <tool>TERMINATE</tool>

For each step of the plan, please specify the tool you would like to use. But if you think the task is completed, please use <tool>TERMINATE</tool> to end the conversation.

Please use xml tags to specify the tool when responsing. For example, <tool>Sentiment Analysis</tool> for Sentiment Analysis.
"""

    prompt = "## Problem: " + task_description + "\nPlease solve this problem using the following tools step by step:\n" + tools

    if args.target_type == 'direct':
        app_config_list = [{
            "model": "gpt-3.5-turbo",
            "api_key": "",
            "api_type": "openai",
            "cache_seed": None, 
            "seed":0
        },]
    else:
        app_config_list = [{
            "model": "gpt-4-turbo",
            "api_key": "",
            "api_type": "openai",
            "cache_seed": None, 
            "seed":0
        },]
    app_assistant = AssistantAgent("assistant", llm_config={"config_list": app_config_list}, human_input_mode='NEVER')

    tar_config_list = [{
        "model": "gpt-4-turbo",
        "api_key": "",
        "api_type": "openai",
        "cache_seed": None, 
        "seed":0
    },]
    tar_assistant = AssistantAgent("assistant", llm_config={"config_list": tar_config_list}, human_input_mode='NEVER')

    # run the speculative planning
    steps = asyncio.run(speculative_planning(args, app_assistant, tar_assistant, prompt, logger, target_logger, approximation_logger))
    print(steps)
    # record the metrics
    logger.log('final result for the speculative planning ' + str(steps))
    logger.log('max concurrent calls: ' + str(config.MAX_CONCURRENT_CALLS-1)) # speculative_planning will add one more call
    logger.log('accuracy of approximation agent: ' + str(config.TOTAL_CORRECT_APPROXIMATION_CALLS/config.TOTAL_APPROXIMATION_CALLS))
    token_number = sum([len(encoding.encode(response)) for response in config.TOTAL_TOKEN_GENERATION])
    logger.log('total token generation: ' + str(token_number))
        
