"""
Optionally use react trajs from the library; optionally update the library with current ReAct trajs.
If not using the library, ReAct Agent will be run.
"""
import os
import time
from copy import deepcopy

from absl.flags import FLAGS

from UIAgents.Agent_RPA.agent_rpa import Agent_RPA
from UIAgents.Agent_RPA.utils.rpa_bank_utils import RPABank
from .agent_rpa_builder import RPA_Builder_Agent
from .env_operation import EnvOperation
from .episode_runner import run_episode
from .run_tasks_react import run_task
from .utils import JSON_models
from .utils.JSON_API import get_llm_wrapper
from .utils.JSON_models import ExpResultLine
from .utils.agent_utils import print_with_color, record_exp_result
from .utils.traj_utils import ReactTrajBank, RPAExecTrajBank


def run_rpa_verification(
  env_op: EnvOperation,
  agent: Agent_RPA,
  log_path: str
) -> JSON_models.RPAExecTraj:
  """Runs a task.

  Args:
    task: The task.
    env: Environment that will be run on.
    agent: The agent to run the task.
    log_path: The path to save the log file.

  Returns:
    Episode data and associated success signals.

  Raises:
    ValueError: If step data was not as expected.
  """
  agent.rpa_mode = True
  
  time.sleep(3)
  # Mode: RPA Verification
  print("========================================================================")
  print(f'Running rpa_verification for {env_op.task_type} with goal "{env_op.task}"')
  
  # try to complete the task by RPAExec(RPAVerification)
  episode_result = run_episode(env_op, agent, log_path)
  print(f'{"Task Successful ✅" if episode_result.is_success else "Task Failed ❌"}; {env_op.task}\n')
  
  rpa_exec_traj = episode_result.agent_traj
  # episode_length = len(rpa_exec_traj.env_op_traj)
  
  # get rpa exec trajs for RPABuilder
  if episode_result.task_successful == 1.0:
    if episode_result.agent_done:
      exec_feedback = "The benchmark indicates task success, and the code executed successfully."
    else:
      exec_feedback = "The benchmark indicates task success, but the code did not execute smoothly and did not terminate normally. Please carefully review the execution history to identify which part of the code did not perform as expected."
  elif episode_result.task_successful == 0.0:
    if episode_result.agent_done:
      exec_feedback = "The benchmark task failed, but the code output an end marker. Please carefully review the execution history to identify which part of the code did not execute as expected."
    else:
      exec_feedback = "Benchmark task failed without an end marker in the code output. Please review the execution history to identify which part of the code caused the failure."
  else:
    benchmark_feedback = f"The benchmark judged the task as partially completed (approximately {episode_result.task_successful * 100:.0f}%), "
    if episode_result.agent_done:
      exec_feedback = benchmark_feedback + "but the code output an end marker."
    else:
      exec_feedback = benchmark_feedback + "and the code did not execute smoothly and did not terminate normally."
  rpa_exec_traj.exec_result.exec_feedback = exec_feedback if rpa_exec_traj.exec_result.exec_feedback is None else rpa_exec_traj.exec_result.exec_feedback + exec_feedback
  
  rpa_exec_traj.success_score = episode_result.success_score
  rpa_exec_traj.success = episode_result.is_success
  
  # conclusion for the task
  if (not FLAGS.test_rpa_mode) and (not FLAGS.agent_name == 'agent_react'):
    agent.reflection_history = []
    concluder_result = agent.Concluder_Agent(goal=env_op.task, log_task_path=log_path, episode_results=episode_result)
    rpa_exec_traj.conclusion = concluder_result.episode_conclusion
    rpa_exec_traj.reflection = concluder_result.reflection
  
  return rpa_exec_traj


def run(
  env_op: EnvOperation,
  agent: Agent_RPA,
  rpa_bank: RPABank,
  explore_rpa_banks: list[RPABank],
  react_traj_bank: ReactTrajBank,
  task_templates: dict,
  cnt_task_type: int
):
  """Runs e2e system on task suite.

  Args:
    env_op: The environment to run it on.
    agent: The agent to run the task.
    rpa_bank: The rpa bank to use.
    process_episodes_fn: The function to process episode data. Usually to
      compute metrics. Deafaults to process_episodes from this file.

  Returns:
    Metadata for each episode, including the scripted reward.
  """
  
  task_type = env_op.task_type
  task_type_log_path = os.path.join(FLAGS.log_folder_exp, task_type)
  os.makedirs(task_type_log_path, exist_ok=True)
  
  max_task_num = FLAGS.num_tasks_to_explore  # Number of tasks to explore per task_type
  max_attempts = FLAGS.max_attempts_per_task  # Max attempts to build RPA per task
  
  rpa_exec_traj_bank = RPAExecTrajBank()
  
  if FLAGS.test_rpa_mode:  # Employment phase
    print_with_color("\n======================== Employment Phase ==============================", 'blue')
    exp_result_csv = os.path.join(FLAGS.log_folder_exp, "test_result.csv")
    test_result = {"Num": cnt_task_type, "Task Type": task_type, "Task Template": task_templates[task_type]}
    test_success = {}
    action_info = {}
    for task_idx in FLAGS.to_test_tasks:
      log_path = os.path.join(task_type_log_path, f'Task_{task_idx}')
      env_op.reset(task_idx, log_path)
      # Run rpa verification for the task
      agent.rpa_bank = rpa_bank
      agent.record_token = JSON_models.RecordToken(file_path=FLAGS.log_folder_exp, task_type=task_type,
                                                   task_num=f'Task {task_idx}', stage='Test')
      rpa_exec_traj = run_rpa_verification(env_op=env_op, agent=agent, log_path=log_path)
      rpa_exec_traj_bank.save_temp(save_path=log_path, rpa_exec_traj=rpa_exec_traj)
      test_result.update({f"Test_{task_idx}": int(rpa_exec_traj.success)})
      
      if 'env_op.ask_mllm' in rpa_bank.rpa_dict[task_type]['rpa_code']:
        action_info.update({f"{task_idx}_has_ask_mllm": 1})
      else:
        action_info.update({f"{task_idx}_has_ask_mllm": 0})
      if 'env_op.get_ui_content' in rpa_bank.rpa_dict[task_type]['rpa_code']:
        action_info.update({f"{task_idx}_has_get_ui_info": 1})
      else:
        action_info.update({f"{task_idx}_has_get_ui_info": 0})
    
    test_result.update(test_success | action_info)
    record_exp_result(exp_result_csv, test_result)
  else:
    print_with_color("\n===================== Building Phase ============================", 'blue')
    exploration_path = os.path.join(task_type_log_path, 'Building')  # TaskType/Building
    exp_result_csv = os.path.join(FLAGS.log_folder_exp, "building_result.csv")
    
    # abandoned_tasks = []  # Tasks React could not complete or verify failed repeatedly
    rpa_bank_candidate = RPABank(load_local_bank=False)  # get an empty rpa_bank to store candidate rpa
    RPABuilder_Agent = RPA_Builder_Agent(get_llm_wrapper(FLAGS.builder_llm))
    
    cnt_generate_rpa = 0
    cnt_fetch_info = 0
    encountered_task_goals: list[str] = []
    
    def record_encountered_goal(goal: str):
      if goal and goal not in encountered_task_goals:
        encountered_task_goals.append(goal)
    
    for task_num in range(1, max_task_num + 1):
      cur_task_num = task_num
      task_num_path = os.path.join(exploration_path, f'task_{task_num}')  # TaskType/Exploration/task_1
      env_op.reset(task_num, task_num_path)
      task_goal = env_op.task
      record_encountered_goal(task_goal)
      print(task_type, task_goal)
      rpa_bank_candidate_temp = deepcopy(rpa_bank_candidate)
      rpa_exec_traj_temp = deepcopy(rpa_exec_traj_bank)
      exp_result_line = ExpResultLine(task_type=task_type, task_num=f'task_{task_num}', task_goal=task_goal)
      flag_all_success = False
      
      i = 0
      while i < max_attempts:  # Try up to max_attempts; exit when a verified RPA is obtained
        attempt_path = os.path.join(task_num_path, f'attempt_{i + 1}')
        flag_init = True  # initialize the env determine by exec_evaluator result
        rpa_exec_traj = None
        FLAGS.cur_attempt_cnt = i + 1
        
        print_with_color(f"\n======== Task_num: {task_num}, Attempt: {i + 1} ========\n", 'blue')
        # print(f'abandoned_tasks: {abandoned_tasks}\n')
        
        # verify all tasks
        # If rpa_bank_candidate has this task_type, verify RPA on all explored tasks
        # On first attempt, run ReAct to generate new code
        if i and task_type in rpa_bank_candidate.rpa_dict:
          print_with_color(f"\n======== Verify 1-to-{max_task_num} tasks ========\n", 'blue')
          
          agent.rpa_bank = rpa_bank_candidate  # use to run verification
          fit_path = os.path.join(attempt_path, 'RPAFitExec')  # task_1/attempt_1/RPAFitExec
          flag_all_success = True
          rpa_exec_traj = None
          to_verify_tasks = [task_num] + [i for i in range(1, max_task_num + 1) if i != task_num]
          cur_verified_tasks = []
          
          # for j in range(1, max_task_num + 1):
          for j in to_verify_tasks:
            print_with_color(f"\n======== Verify task {j}\n", 'blue')
            log_path = os.path.join(fit_path, f'task_{j}')
            env_op.reset(j, log_path)
            task_goal = env_op.task
            record_encountered_goal(task_goal)
            cur_task_num = j
            
            if f'{task_goal}_{cur_task_num}' in rpa_bank_candidate.rpa_dict[task_type]['verified_tasks']:
              print_with_color(f'Task has been verified: {task_goal}\nTry next task.\n', 'green')
              continue
            
            agent.record_token = JSON_models.RecordToken(file_path=FLAGS.log_folder_exp, task_type=task_type,
                                                         task_num=f'Task {cur_task_num}', stage='Verification')
            rpa_exec_traj = run_rpa_verification(env_op=env_op, agent=agent, log_path=log_path)
            cur_rpa = rpa_bank_candidate.rpa_dict[task_type]['rpa_code']
            rpa_exec_traj_bank.add_rpa_exec_traj(rpa_code=cur_rpa, rpa_exec_traj=rpa_exec_traj)
            rpa_exec_traj_bank.save_temp(save_path=log_path, rpa_exec_traj=rpa_exec_traj)
            if not rpa_exec_traj.success:
              flag_all_success = False
              break
            cur_verified_tasks.append(f'{task_goal}_{cur_task_num}')  # On full success, store first t goal names
          
          # cur_ve_tasks = []
          # for tt in cur_verified_tasks:
          #   if int(tt[-1]) <= max_task_num:
          #     cur_ve_tasks.append(tt)
          
          rpa_bank_candidate.update_verified_tasks(task_type=task_type, verified_tasks=cur_verified_tasks)
          if task_type in rpa_bank_candidate_temp.rpa_dict:
            if rpa_bank_candidate_temp.rpa_dict[task_type].get("verified_tasks_num") and \
              rpa_bank_candidate.rpa_dict[task_type]['verified_tasks_num'] < rpa_bank_candidate_temp.rpa_dict[task_type][
              'verified_tasks_num']:
              # Rollback
              rpa_bank_candidate = deepcopy(rpa_bank_candidate_temp)
              rpa_exec_traj_bank = deepcopy(rpa_exec_traj_temp)
          print_with_color('\n---------------\nVerified Tasks:', 'green')
          for m, task in enumerate(cur_verified_tasks):
            print_with_color(f"{m} {task}", 'green')
          
          if flag_all_success:
            print_with_color(f"\n======== Verify 1-to-{max_task_num} tasks all success ========\n", 'blue')
            # record exp result
            for tn, verified_task in enumerate(cur_verified_tasks):
              num = int(verified_task[-1])
              setattr(exp_result_line, f"task_{num}", '1')
            # for tn, abandoned_task in enumerate(abandoned_tasks):
            #   num = int(abandoned_task[-1])
            #   setattr(exp_result_line, f"task_{num}", 'abandon')
            rpa_bank_candidate.save_temp(task_type=task_type, save_path=fit_path)  # Save current RPA to log dir
            break  # exit attempt
          elif i == max_attempts - 1:
            # abandoned_tasks.append(f'{task_goal}_{cur_task_num}')
            # record exp result
            for tn, verified_task in enumerate(cur_verified_tasks):
              num = int(verified_task[-1])
              setattr(exp_result_line, f"task_{num}", '1')
            # for tn, abandoned_task in enumerate(abandoned_tasks):
            #   num = int(abandoned_task[-1])
            #   setattr(exp_result_line, f"task_{num}", 'abandon')
            # rpa_bank_candidate = rpa_bank_candidate_temp  # rpa bank rollback
            # rpa_exec_traj_bank = rpa_exec_traj_temp
            break  # t++
          ## -----end: verify t-to-1 tasks
          
          ## -----start: call MLLM to evaluate the current page
          exec_evaluator_path = os.path.join(attempt_path, 'ExecEvaluatorAgent')
          exec_evaluator_output = agent.Breakpoint_Analyzer_Agent(rpa_exec_traj=rpa_exec_traj,
                                                             log_path=exec_evaluator_path)
          exec_continue = exec_evaluator_output.to_continue  # exec_continue = y or n
          print_with_color(f"Breakpoint_Analyzer_Agent result: {exec_continue}", 'blue')
          ## -----end: call MLLM to evaluate the current page
          
          if exec_continue.lower() == 'y':
            flag_init = False
            agent.completed_tasks = [f'{exec_evaluator_output.completed_tasks}']
        
        ## -----start: ReAct + RPABuilder(with VerificationResult)
        log_react_path = os.path.join(attempt_path, 'ReAct')
        list_react_traj = react_traj_bank.get_react_traj(task_type, f'{env_op.task}_{cur_task_num}')
        flag_react_traj_exists = False if list_react_traj is None else True
        # As rpa verification failed, run Fix ReAct
        if not flag_init:
          print_with_color("Need to continue ReAct from cur_page.", 'yellow')
          env_op.done = False
          agent.record_token = JSON_models.RecordToken(file_path=FLAGS.log_folder_exp, task_type=task_type,
                                                       task_num=f'Task {cur_task_num}', stage='Fix ReAct')
          episode, fix_react_traj = run_task(env_op=env_op,
                                             agent=agent, to_init_task=False, log_path=log_react_path, react_round=1)
          fix_react_traj = fix_react_traj.pop(0)
          if fix_react_traj.success:
            rpa_exec_traj.fix_evaluator_analysis = exec_evaluator_output.observation
            rpa_exec_traj.fix_react_traj = fix_react_traj.traj
            flag_react_traj_exists = True
            list_react_traj = []
        
        if flag_react_traj_exists:
          print_with_color("ReAct Traj Exists, skip running ReAct Agent.",
                           'green')  # Also printed when react continues after rpa exec
          react_traj_bank.save_temp(list_react_traj=list_react_traj, save_path=log_react_path)
        else:
          print_with_color("ReAct Traj Doesn't Exist, running ReAct Agent.", 'yellow')
          print_with_color("Need to restart ReAct.", 'yellow')
          
          agent.record_token = JSON_models.RecordToken(file_path=FLAGS.log_folder_exp, task_type=task_type,
                                                       task_num=f'Task {cur_task_num}', stage='ReAct')
          episode, list_react_traj = run_task(env_op=env_op,
                                              agent=agent, to_init_task=True, log_path=log_react_path,
                                              react_round=FLAGS.reflection_rounds + 1)
          react_traj_bank.save_temp(list_react_traj=list_react_traj, save_path=log_react_path)
          # If React run failed for this task, abandon it
          if not list_react_traj[-1].success:
            # abandoned_tasks.append(f'{env_op.task}_{cur_task_num}')
            # for cnt in range(0, len(list_react_traj)):
            #   setattr(exp_result_line, f'round_{cnt}', '0')  # failed round
            
            # rpa_bank_candidate = rpa_bank_candidate_temp
            # rpa_exec_traj_bank = rpa_exec_traj_temp
            break  # t++, i=1
          
          exp_result_line.round_0 = '/'
          exp_result_line.round_1 = '/'
          exp_result_line.round_2 = '/'
          setattr(exp_result_line, f'round_{len(list_react_traj) - 1}', '1')  # successful round
          for cnt in range(0, len(list_react_traj) - 1):
            setattr(exp_result_line, f'round_{cnt}', '0')  # failed round
          
          # Store traj in react_trajs_dict only when ReAct succeeded
          # When task was started from scratch
          if flag_init:
            react_traj_bank.add_react_traj(task_type, f'{env_op.task}_{cur_task_num}', list_react_traj)
          if FLAGS.update_react_trajs_bank:
            react_traj_bank.save()
        
        # Batch Action Translation (before RPA Builder)
        print_with_color("\n\nPerforming Batch Action Translation", 'blue')
        react_trajs = react_traj_bank.get_last_traj(task_type)
        if react_trajs:
          translated_trajs = agent.batch_translate_actions(
            react_trajs=react_trajs,
            log_path=os.path.join(attempt_path, 'ActionTranslation')
          )
        else:
          translated_trajs = []
        
        # RPA Builder(with VerificationResult)
        print_with_color("\n\nRunning RPABuilder_Agent", 'blue')
        rpa_builder_path = os.path.join(attempt_path, 'RPABuilder')  # init_task/RPABuilder
        RPABuilder_Agent.record_token = JSON_models.RecordToken(file_path=FLAGS.log_folder_exp, task_type=task_type,
                                                                  task_num=f'Task {cur_task_num}',
                                                                  stage='RPA Builder')
        rpa_info, cur_cnt_fetch_info = RPABuilder_Agent.generate_rpa_code(log_task_path=rpa_builder_path,
                                                            task_type=task_type,
                                                            task_template=task_templates[task_type],
                                                            list_react_traj=translated_trajs,
                                                            pre_rpa_exec_traj=rpa_exec_traj,
                                                            encountered_task_goals=encountered_task_goals)
        cnt_generate_rpa += 1
        cnt_fetch_info += cur_cnt_fetch_info
        # save rpa_candidate
        rpa_bank_candidate.add_rpa(rpa_info)
        rpa_bank_candidate.update_based_on_task(task_type, task_num)
        rpa_bank_candidate.save_temp(task_type=task_type, save_path=rpa_builder_path)  # Save current RPA to log dir
        ## -----end: ReAct + RPABuilder
        i += 1
      
      if rpa_bank_candidate.rpa_dict.get(task_type):
        if 'env_op.ask_mllm' in rpa_bank_candidate.rpa_dict[task_type]['rpa_code']:
          exp_result_line.has_ask_mllm = 1
        if 'env_op.get_ui_content' in rpa_bank_candidate.rpa_dict[task_type]['rpa_code']:
          exp_result_line.has_get_ui_info = 1
        exp_result_line.based_on_task = rpa_bank_candidate.rpa_dict[task_type].get('based_on_task', '0')
      
      # save record to csv
      record_exp_result(exp_result_csv, exp_result_line.dict())
      
      cur_rpa_bank = explore_rpa_banks[task_num - 1]
      cur_rpa_bank.merge_from(rpa_bank_candidate)
      cur_rpa_bank.save_temp(save_path=FLAGS.log_folder_exp, file_name=f'temp_rpa_{task_num}.json')
      
      if flag_all_success:
        break  # exit building phase
    
    exp_result_line = ExpResultLine(task_type=task_type, task_num=f'Final', task_goal='-', task_1='abandon',
                                    task_2='abandon', task_3='abandon', task_4='abandon', task_5='abandon',
                                    cnt_fetch_info=(cnt_fetch_info / cnt_generate_rpa) if cnt_generate_rpa > 0 else 0,)
    if task_type in rpa_bank_candidate.rpa_dict.keys():
      # -----start: recording
      exp_result_line.based_on_task = rpa_bank_candidate.rpa_dict[task_type]['based_on_task']
      for tn, verified_task in enumerate(rpa_bank_candidate.rpa_dict[task_type].get('verified_tasks', [])):
        num = int(verified_task[-1])
        setattr(exp_result_line, f"task_{num}", '1')
      # if 'env_op.one_step_actor' in rpa_bank_candidate.rpa_dict[task_type]['rpa_code']:
      #   exp_result_line.has_one_step_actor = 1
      if 'env_op.ask_mllm' in rpa_bank_candidate.rpa_dict[task_type]['rpa_code']:
        exp_result_line.has_ask_mllm = 1
      if 'env_op.get_ui_content' in rpa_bank_candidate.rpa_dict[task_type]['rpa_code']:
        exp_result_line.has_get_ui_info = 1
      # -----end: recording
      
      for m, task in enumerate(rpa_bank_candidate.rpa_dict[task_type]['verified_tasks']):
        print_with_color(f"{m} {task}", 'blue')
      
      rpa_bank.merge_from(rpa_bank_candidate)  # update rpa_bank
      rpa_bank.save_temp(save_path=FLAGS.log_folder_exp,
                           file_name='temp_rpa.json')  # Also save updated rpa_bank in local experiment folder
      
      if FLAGS.update_rpa_bank:  # Update the complete rpa bank
        rpa_bank.save()
      print_with_color(f"====================== End -- Task Type : {task_type} =========================", 'blue')
      print_with_color(f'Successfully build the RPA Code', 'blue')
    else:
      print_with_color(f"====================== End -- Task Type : {task_type} =========================", 'blue')
      print_with_color("Oh no, failed to create the RPA Code.", 'red')
    
    record_exp_result(exp_result_csv, exp_result_line.dict())
