import sys 
sys.path.append('../')

import asyncio
import time
from datetime import datetime
import random
import subprocess
import random
import numpy as np
import json 
import autogen 
import argparse 
from color import slow_type_approximation, slow_type_target
from utils import Logger, cancel, register_handler

import config

def concurrent_calls():
    tasks = asyncio.all_tasks()
    pending_tasks = [t for t in tasks if not t.done() and not t.cancelled()]
    return len(pending_tasks)

def turn_to_bool(user_input):
    if user_input == 'True':
        return True
    elif user_input == 'False':
        return False
    else:
        return user_input

def parse_user_input(user_input, s, t):
    if user_input == 'approximation answer':
        return s
    elif user_input == 'target answer':
        return t[0][1]
    elif user_input == '':
        return t[0][1]
    else:
        return user_input
    
async def simulate_execution(action_trajectory):
    a = 'The action trajectory is: ' + str(action_trajectory)
    await asyncio.sleep(0.1)

############# setted time simulation code for speculative planning #############
def pseudo_interaction_function(sas, tas, to_print_id, logger, target_logger, prev_steps, target_tasks):
    s = sas[to_print_id]
    t = tas[to_print_id]
    target_logger.log(f'Target: Step {t[0][0] + len(prev_steps)+1} - {t[0][1]}')
    try:
        config.TOTAL_APPROXIMATION_CALLS += 1
        if s == t[0][1]:
            config.TOTAL_CORRECT_APPROXIMATION_CALLS += 1
            logger.log(f'The target agent thinks step {len(prev_steps) + to_print_id+1} should be '+ str(t[0][1]) + ', which agrees with the approximation agent.')
            try:
                logger.log(f'The approximation agent thinks step {len(prev_steps) + to_print_id+2} should be ' + str(sas[to_print_id+1]))
                config.HIL_INTERACTION = to_print_id+1
                register_handler(target_tasks=target_tasks)
            except:
                pass
        else:
            logger.log(f'The target agent thinks step {len(prev_steps) + to_print_id+1} should be '+ str(t[0][1]) + ', correcting what the approximation agent thinks which is ' + str(s) + '.')
        user_input = ''
        user_input = parse_user_input(user_input, s, t)
    except KeyboardInterrupt as e:
        user_input = input('')

        user_input = parse_user_input(user_input, s, t)
        if not user_input == str(t[0][1]):
            logger.log(f'Since you think the action should be {user_input} ... we will follow your suggestion :)')
        logger.log('-------------------')

    return user_input

def simulate_within_T_interaction(sas, tas, flatten_tas, printed_ids, logger, target_logger, prev_steps, target_tasks):
    ## conduct on-time printing out
    tas_ids = [t[0] for t in flatten_tas]
    flatten_printed_ids = []
    for ids in printed_ids:
        if ids:
            flatten_printed_ids += ids
    tas_length = len(tas_ids)
    for l in range(0, tas_length+1):
        if not(tas_ids[:l] == list(range(len(flatten_tas)))[:l] and len(sas) >= len(flatten_tas[:l])):
            tas_length = l-1
            break

    # if the printed ids contain wrong results
    contain_wrong_result = False
    if tas_length > 0:
        for printed_id in flatten_printed_ids:
            if not (sas[printed_id] == tas[printed_id][0][1]):
                contain_wrong_result = True

    if tas_length > 0 and not contain_wrong_result:
        tas_ids = tas_ids[:tas_length]
        to_print_ids = list(set(tas_ids) - set(flatten_printed_ids))
        for order_id, to_print_id in enumerate(to_print_ids):
            if order_id > 0:
                if not (sas[to_print_ids[order_id-1]] == tas[to_print_ids[order_id-1]][0][1]):
                    break
            printed_ids[to_print_id].append(to_print_id)
            user_input = pseudo_interaction_function(sas, tas, to_print_id, logger, target_logger, prev_steps, target_tasks)
            if str(user_input) == str(tas[to_print_id][0][1]) and str(user_input).lower() != 'terminate':
                continue 
            elif str(user_input) == str(tas[to_print_id][0][1]) and str(user_input).lower() == 'terminate':
                return printed_ids, tas
            else:
                change_one_tas_position = list(tas[to_print_id][0])
                change_one_tas_position[1] = user_input
                tas[to_print_id][0] = change_one_tas_position
                return printed_ids, tas

    ## finish printing out
    return printed_ids, tas

def simulate_leftover_interaction(sas, tas, flatten_tas, printed_ids, logger, target_logger, prev_steps, target_tasks):
    # when sas is slower than tas 
    # we need to print out the extra here
    flatten_printed_ids = []
    for ids in printed_ids:
        if ids:
            flatten_printed_ids += ids
    print_leftover = True
    if len(sas) > len(flatten_printed_ids) and len(flatten_tas) > len(flatten_printed_ids):
        for printed_id in flatten_printed_ids:
            if not (sas[printed_id] == tas[printed_id][0][1]):
                print_leftover = False
    else:
        print_leftover = False

    # if the printed ids contain wrong results
    contain_wrong_result = False
    for printed_id in flatten_printed_ids:
        if not (sas[printed_id] == tas[printed_id][0][1]):
            contain_wrong_result = True

    if print_leftover and not contain_wrong_result:
        for print_id in range(len(flatten_printed_ids), min(len(sas),len(tas))):
            if tas[print_id]:
                user_input = pseudo_interaction_function(sas, tas, print_id, logger, target_logger, prev_steps, target_tasks)
                printed_ids[print_id].append(print_id)
                if str(user_input) == str(tas[print_id][0][1]):
                    continue 
                else:
                    change_one_tas_position = list(tas[print_id][0])
                    change_one_tas_position[1] = user_input
                    tas[print_id][0] = change_one_tas_position
            else:
                break
            if not (sas[print_id] == tas[print_id][0][1]):
                break

    return printed_ids, tas

############# setted time simulation code for speculative planning #############

async def simulate_random_time_A_generate(prompt, total_step_number, acc, logger, approximation_logger, sleep_time=2):
    start = time.time()
    sleep_time = sleep_time#random.randint(1)
    await asyncio.sleep(sleep_time)
    result = np.random.choice([True, False], p=[acc, 1-acc])
    config.TOTAL_TOKEN_GENERATION += 10
    end = time.time()
    approximation_logger.log(f'Approximation: Step {total_step_number+1} - {result}')

    await simulate_execution(prompt + str(result))

    return result

async def simulate_random_time_T_generate(prompt, total_step_number, tas, sas, target_tasks, printed_ids=[[]], current_step=0, logger=None, target_logger=None, prev_steps=[]):    
    try:
        sleep_time = 8#random.randint(2, 10)
        await asyncio.sleep(sleep_time)
        result = True
        config.TOTAL_TOKEN_GENERATION += 20
        tas, printed_ids = await postprocess_T_generation(result, total_step_number, tas, sas, target_tasks, printed_ids=printed_ids, current_step=current_step, logger=logger, target_logger=target_logger, prev_steps=prev_steps)
        return tas, printed_ids
    except asyncio.CancelledError as e:
        if config.USERINPUT:
            config.USERINPUT=False
            result = input("What do you think this step should be?\n")
            result = True
            user_input_task = asyncio.create_task(postprocess_T_generation(result, total_step_number, tas, sas, target_tasks, printed_ids=printed_ids, current_step=current_step, logger=logger, target_logger=target_logger, prev_steps=prev_steps))
            target_tasks.append(user_input_task)

async def postprocess_T_generation(result, total_step_number, tas, sas, target_tasks, printed_ids=[[]], current_step=0, logger=None, target_logger=None, prev_steps=[]): 
    in_step_number = total_step_number - current_step
    tas[in_step_number].append((in_step_number,result))

    # cancel next target_tasks after which we know is incorrect
    # but there maybe unknown result before this point, so we don't kill all processes
    flatten_tas = []
    for t in tas:
        if t:
            flatten_tas += t
    flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False)
    printed_ids, tas = simulate_within_T_interaction(sas, tas, flatten_tas, printed_ids, logger, target_logger, prev_steps, target_tasks)

    # if it is a wrong result
    # we break out the target processes and cancel processes that comes after it
    flatten_tas = []
    for t in tas:
        if t:
            flatten_tas += t
    flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False)
    for ta in flatten_tas:
        if len(sas) > ta[0]:
            if not sas[ta[0]] == ta[1]:
                # print('raise exception here')
                pending_approximation_tasks = [t for t in asyncio.all_tasks() if not t.cancelled() and not t.done() and t not in target_tasks and t.get_name().startswith('approximation')]
                for pending_approximation_task in pending_approximation_tasks:
                    await cancel(pending_approximation_task)
                raise Exception(f'approximation error happen in step {total_step_number} for current step {current_step}, the target id is {ta[0]}')

    return tas, printed_ids

async def simulate_random_time_onebreakingpoint_speculative_planning(prompt, current_step, k, acc, logger, target_logger, approximation_logger, approximation_sleep_time=2, prev_steps=[]):
    #print(f'{datetime.now()} - start')
    sas = []
    tas = []
    target_tasks = []
    printed_ids = []
    for i in range(k):
        break_out_approximation = False
        tas.append([])
        printed_ids.append([])
        approximation = asyncio.create_task(simulate_random_time_A_generate(prompt, current_step+i, acc, logger, approximation_logger=approximation_logger, sleep_time=approximation_sleep_time), name=f"approximation_{current_step+i}")
        target = asyncio.create_task(simulate_random_time_T_generate(prompt, current_step+i, tas, sas, target_tasks, printed_ids, current_step, logger, target_logger=target_logger, prev_steps=prev_steps), name=f"target_{i}")
        target_tasks.append(target)

        concurrent_api_calls = concurrent_calls()
        if concurrent_api_calls >= config.MAX_CONCURRENT_CALLS:
            config.MAX_CONCURRENT_CALLS = concurrent_api_calls

        try:
            sa = await approximation
            sas.append(sa)

            # check if we need to print out the approximation result here
            # which is when previous steps of target agent are all done
            flatten_tas = []
            for t in tas[:len(sas)-1]:
                if t:
                    flatten_tas += t
            flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False)
            flatten_ids = [t[0] for t in flatten_tas]
            flatten_tas_action = [t[1] for t in flatten_tas]
            if flatten_ids == list(range(len(flatten_ids))) and len(flatten_ids) == len(sas)-1 and all([s==t for s, t in zip(sas[:-1], flatten_tas_action)]):
                flattened_printed_ids = [printed_id[0] for printed_id in printed_ids if printed_id != []]
                if flattened_printed_ids:
                    if len(sas) == len(flattened_printed_ids)+1:
                        logger.log(f'The approximation agent thinks step {current_step+i+1} should be ' + str(sa))
                        config.HIL_INTERACTION = len(sas)-1
                        register_handler(target_tasks=target_tasks)
                else:
                    logger.log(f'The approximation agent thinks step {current_step+i+1} should be ' + str(sa))
                    config.HIL_INTERACTION = len(sas)-1
                    register_handler(target_tasks=target_tasks)

            # modify the prompt based on latest approximation result
            prompt += f' A step {current_step+i} {str(sa)}'
        except asyncio.CancelledError as e:
            pass

        # tas is now a list of lists, so we need to flatten it in order to compare with sas
        flatten_tas = []
        for t in tas:
            if t:
                flatten_tas += t
        flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False)
        # halt the ongoing approximation loop
        for t in flatten_tas:
            if len(sas) > t[0]:
                if not sas[t[0]] == t[1]:
                    break_out_approximation = True
                    for process_id, one_task in enumerate(target_tasks):
                        if not one_task.cancelled() and not one_task.done() and process_id > t[0]:
                            await cancel(one_task)
                    break

        if break_out_approximation:
            break

    # print('====finish approximation loop====')
    # after halting the approximation loop
    # we need to collect the target results
    # organize to sas, see how much we want to preserve
    # SHOULD NOT exclude finished tasks, because exceptions are only thrown when tasks are finished
    pending_tasks = [t for t in target_tasks if not t.cancelled()]
    while pending_tasks:
        break_while_loop = False
        try:
            # print('[pending_task.done() for pending_task in pending_tasks]', str([pending_task.done() for pending_task in pending_tasks]))
            if [pending_task.done() for pending_task in pending_tasks] == [True]*len(pending_tasks):
                break_while_loop = True
                try:
                    await asyncio.gather(*pending_tasks, return_exceptions=False)
                    break
                except:
                    break
            # should not await cancelled tasks
            # return_exceptions=False is also the default value
            await asyncio.gather(*pending_tasks, return_exceptions=False)
            break_while_loop = True
            break
        except Exception as e:
            # cancel t_j for j > i if t_i != s_i
            if [pending_task.done() for pending_task in pending_tasks] == [True]*len(pending_tasks):
                break_while_loop = True
                # to handle all exceptions in finished tasks to avoid "exception not handled" error
                try:
                    await asyncio.gather(*pending_tasks, return_exceptions=False)
                    break
                except:
                    break
            if break_while_loop:
                break
            mistaken_process_id = int(str(e)[-1])
            # only cancel t_j such that j > i
            for process_id, one_task in enumerate(pending_tasks):
                if not one_task.cancelled() and not one_task.done() and process_id > mistaken_process_id:
                    await cancel(one_task)
            
            pending_tasks = [t for process_id, t in enumerate(target_tasks) if not t.cancelled() and process_id != mistaken_process_id]

        if break_while_loop:
            break

    # get user input or interruption
    flatten_tas = []
    for t in tas:
        if t:
            flatten_tas += t
    flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False)
    printed_ids, tas = simulate_leftover_interaction(sas, tas, flatten_tas, printed_ids, logger, target_logger, prev_steps, target_tasks)

    # get the final tas result
    flatten_tas = []
    for t in tas:
        if t:
            flatten_tas += t
    flatten_tas = sorted(flatten_tas, key=lambda x: x[0],reverse=False)
    for step_number, (s, t) in enumerate(zip(sas, flatten_tas)):
        if t[0] == step_number and not s == t[1]:# t[1] != s:
            sas = sas[:step_number]+[flatten_tas[step_number][1]]
            await simulate_execution(prompt + str(sas))
            break

    return sas

# wait signal, one signal variable for each task
# https://stackoverflow.com/questions/59073556/how-to-cancel-all-remaining-tasks-in-gather-if-one-fails

async def simulate_random_time_speculative_planning(prompt, k, acc, logger, target_logger, approximation_logger, approximation_sleep_time=2):
    begin_time = datetime.now()
    steps = []
    breaking_points = 0
    i = 0
    while len(steps) < 10:
        result = await simulate_random_time_onebreakingpoint_speculative_planning(prompt, len(steps), k, acc, logger, target_logger, approximation_logger, approximation_sleep_time=approximation_sleep_time, prev_steps=steps)
        steps += result
        breaking_points += 1
        i += len(result)
        #print(f'{datetime.now()} - breaking_point number {breaking_points}')
        prompt += ' verified ' + ' '.join([str(s) for s in result])

    end_time = datetime.now()
    logger.log(f'{end_time} - {begin_time} = {end_time - begin_time}')

    return steps


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='OpenAGI')
    parser.add_argument('--k', type=int, default=10, help='number of approximation steps to generate everytime')
    parser.add_argument('--acc', type=float, default=0.5, help='accuracy of the approximation agent')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--approximation_sleep_time', type=int, default=5, help='sleep time for approximation agent')
    args = parser.parse_args()

    ## gloabl variables
    config.MAX_CONCURRENT_CALLS = 0
    config.TOTAL_APPROXIMATION_CALLS = 0
    config.TOTAL_CORRECT_APPROXIMATION_CALLS = 0
    config.TOTAL_TOKEN_GENERATION = 0
    config.USERINPUT=False

    # random.seed(args.seed)
    # np.random.seed(args.seed)
    logger = Logger(f'logs/simulate_human_inthe_loop/k{args.k}_accuracy{args.acc}_ast{args.approximation_sleep_time}.log', on=True)
    target_logger = Logger(f'logs/simulate_human_inthe_loop/target_k{args.k}_accuracy{args.acc}_ast{args.approximation_sleep_time}.log', on=True)
    approximation_logger = Logger(f'logs/simulate_human_inthe_loop/approximation_k{args.k}_accuracy{args.acc}_ast{args.approximation_sleep_time}.log', on=True)

    logger.log('random seed: ' + str(args.seed))
    target_logger.log('random seed: ' + str(args.seed))
    approximation_logger.log('random seed: ' + str(args.seed))
    # original times
    # simulate_original_planning('test')

    #asyncio.run(simulate_setted_time_onebreakingpoint_speculative_planning('test', current_step=0))
    # steps = asyncio.run(simulate_setted_time_speculative_planning('test'))
    # print(steps)

    #asyncio.run(simulate_random_time_onebreakingpoint_speculative_planning('test', current_step=0))
    steps = asyncio.run(simulate_random_time_speculative_planning('test', args.k, args.acc, logger, target_logger, approximation_logger, approximation_sleep_time=args.approximation_sleep_time))
    logger.log('final result simulation finished: ' + str(steps))
    logger.log('max concurrent calls: ' + str(config.MAX_CONCURRENT_CALLS-1))
    logger.log('total token generation: ' + str(config.TOTAL_TOKEN_GENERATION))
    assert steps == [True]*len(steps)
