"""
Optionally use react trajs from the library as react trajs; optionally update the library with current ReAct trajs.
If not using the library, ReAct Agent will be run.
"""
import os
import time
import json
from copy import deepcopy

from absl.flags import FLAGS

from UIAgents.Agent_RPA.agent_rpa import Agent_RPA
from UIAgents.Agent_RPA.utils.rpa_bank_utils import RPABank
from .agent_rpa_builder import RPA_Builder_Agent
from .env_operation import EnvOperation
from .episode_runner import run_episode
from .run_tasks_react import run_task
from .utils import JSON_models
from .utils.llm_client import get_llm_wrapper
from .utils.JSON_models import ExpResultLine
from .utils.agent_utils import print_with_color, record_exp_result
from .utils.traj_utils import ReactTrajBank, RPAExecTrajBank


def run_rpa_verification(
  env_op: EnvOperation,
  agent: Agent_RPA,
  log_path: str
) -> JSON_models.RPAExecTraj:
  """Runs a task.

  Args:
    task: The task.
    env: Environment that will be run on.
    agent: The agent to run the task.
    log_path: The path to save the log file.

  Returns:
    Episode data and associated success signals.

  Raises:
    ValueError: If step data was not as expected.
  """
  agent.rpa_mode = True
  episode_length = 0
  rpa_exec_traj = None
  
  time.sleep(3)
  # Mode: RPA Verification
  print("========================================================================")
  print(f'Running rpa_verification for {env_op.task_type} with goal "{env_op.task}"')
  
  # try to complete the task by RPAExec(RPAVerification)
  episode_result = run_episode(env_op, agent, log_path)
  print(f'{"Task Successful ✅" if episode_result.env_success else "Task Failed ❌"}; {env_op.task}\n')
  
  rpa_exec_traj = episode_result.agent_traj
  episode_length = len(rpa_exec_traj.env_op_traj)
  
  # get rpa exec trajs for RPABuilder
  if episode_result.task_successful == 1.0:
    if episode_result.agent_done:
      env_and_exec_feedback = "The benchmark indicates task success, and the code executed successfully."
    else:
      env_and_exec_feedback = "The benchmark indicates task success, but the code did not execute smoothly and did not terminate normally. Please carefully review the execution history to identify which part of the code did not perform as expected."
  elif episode_result.task_successful == 0.0:
    if episode_result.agent_done:
      env_and_exec_feedback = "The benchmark task failed, but the code output an end marker. Please carefully review the execution history to identify which part of the code did not execute as expected."
    else:
      env_and_exec_feedback = "Benchmark task failed without an end marker in the code output. Please review the execution history to identify which part of the code caused the failure."
  else:
    benchmark_feedback = f"The benchmark judged the task as partially completed (approximately {episode_result.task_successful * 100:.0f}%), "
    if episode_result.agent_done:
      env_and_exec_feedback = benchmark_feedback + "but the code output an end marker."
    else:
      env_and_exec_feedback = benchmark_feedback + "and the code did not execute smoothly and did not terminate normally."
  rpa_exec_traj.exec_result.exec_feedback = env_and_exec_feedback if rpa_exec_traj.exec_result.exec_feedback is None else rpa_exec_traj.exec_result.exec_feedback + "\n" + env_and_exec_feedback

  rpa_exec_traj.success = episode_result.env_success
  
  # conclusion for the task
  if (not FLAGS.test_rpa_mode) and (not FLAGS.agent_name == 'agent_react'):
    agent.reflection_history = []
    concluder_result = agent.Concluder_Agent(goal=env_op.task, log_task_path=log_path, episode_results=episode_result)
    rpa_exec_traj.conclusion = concluder_result.episode_conclusion
  
  return rpa_exec_traj


def run(
  task_list: list[int],
  env_op: EnvOperation,
  agent: Agent_RPA,
  rpa_bank: RPABank,
  explore_rpa_banks: list[RPABank],
  react_traj_bank: ReactTrajBank,
  task_templates: dict,
  cnt_task_type: int
):
  """Runs e2e system on task suite.

  Args:
    env_op: The environment to run it on.
    agent: The agent to run the task.
    rpa_bank: The rpa bank to use.
    process_episodes_fn: The function to process episode data. Usually to
      compute metrics. Deafaults to process_episodes from this file.

  Returns:
    Metadata for each episode, including the scripted reward.
  """
  
  task_type = env_op.task_type
  task_type_log_path = os.path.join(FLAGS.log_folder_exp, task_type)
  os.makedirs(task_type_log_path, exist_ok=True)
  
  max_task_num = FLAGS.num_tasks_to_explore  # Number of tasks to explore per task_type
  max_attempts = FLAGS.max_attempts_per_task  # Max attempts to build RPA per task
  
  rpa_exec_traj_bank = RPAExecTrajBank()
  
  if FLAGS.test_rpa_mode:  # Employment phase
    print_with_color("\n======================== Employment Phase ==============================", 'blue')
    exp_result_csv = os.path.join(FLAGS.log_folder_exp, "test_result.csv")
    test_result = {"Num": cnt_task_type, "Task Type": task_type, "Task Template": task_type}
    test_success = {}
    action_info = {}
    for task_num in range(len(task_list)):
      task_id = task_list[task_num]
      log_path = os.path.join(task_type_log_path, f'Task_{task_num}')
      env_op.reset(task_id, log_path)
      # Run rpa verification for the task
      agent.rpa_bank = rpa_bank
      agent.record_token = JSON_models.RecordToken(file_path=FLAGS.log_folder_exp, task_type=task_type,
                                                   task_num=f'Task {task_num}', stage='Test')
      rpa_exec_traj = run_rpa_verification(env_op=env_op, agent=agent, log_path=log_path)
      rpa_exec_traj_bank.save_temp(save_path=log_path, rpa_exec_traj=rpa_exec_traj)
      test_result.update({f"Test_{task_num}": int(rpa_exec_traj.success)})
      
      if 'env_op.ask_mllm' in rpa_bank.rpa_dict[task_type]['rpa_code']:
        action_info.update({f"{task_num}_has_ask_mllm": 1})
      else:
        action_info.update({f"{task_num}_has_ask_mllm": 0})
      if 'env_op.get_ui_content' in rpa_bank.rpa_dict[task_type]['rpa_code']:
        action_info.update({f"{task_num}_has_get_ui_info": 1})
      else:
        action_info.update({f"{task_num}_has_get_ui_info": 1})
    
    test_result.update(test_success | action_info)
    record_exp_result(exp_result_csv, test_result)
  else:
    print_with_color("\n===================== Building Phase ============================", 'blue')
    exploration_path = os.path.join(task_type_log_path, 'Building')  # TaskType/Building
    exp_result_csv = os.path.join(FLAGS.log_folder_exp, "building_result.csv")
    
    # abandoned_tasks = []  # Tasks React could not complete or verify failed repeatedly
    rpa_bank_candidate = RPABank(load_local_bank=False)  # get an empty rpa_bank to store candidate rpa
    RPABuilder_Agent = RPA_Builder_Agent(get_llm_wrapper(FLAGS.builder_llm, enable_logging=FLAGS.enable_llm_logging))
    
    cnt_generate_rpa = 0
    cnt_fetch_info = 0
    
    for task_num in range(1, max_task_num + 1):
      task_id = task_list[task_num]
      cur_task_num = task_num
      task_num_path = os.path.join(exploration_path, f'task_{task_num}')  # TaskType/Exploration/task_1
      env_op.reset(task_id, task_num_path)
      task_goal = env_op.task
      print(task_type, task_goal)
      rpa_bank_candidate_temp = deepcopy(rpa_bank_candidate)
      rpa_exec_traj_temp = deepcopy(rpa_exec_traj_bank)
      exp_result_line = ExpResultLine(task_type=task_type, task_num=f'task_{task_num}', task_goal=task_goal)
      flag_all_success = False
      
      i = 0
      while i < max_attempts:  # Try up to max_attempts; exit when a verified RPA is obtained
        attempt_path = os.path.join(task_num_path, f'attempt_{i + 1}')
        flag_init = True  # initialize the env determine by exec_evaluator result
        rpa_exec_traj = None
        FLAGS.cur_attempt_cnt = i + 1
        
        print_with_color(f"\n======== Task_num: {task_num}, Attempt: {i + 1} ========\n", 'blue')
        # print(f'abandoned_tasks: {abandoned_tasks}\n')
        
        # verify all tasks
        # If rpa_bank_candidate has this task_type, verify RPA on all explored tasks
        # On first attempt, run ReAct to generate new code
        if i and task_type in rpa_bank_candidate.rpa_dict:
          print_with_color(f"\n======== Verify 1-to-{max_task_num} tasks ========\n", 'blue')
          
          agent.rpa_bank = rpa_bank_candidate  # use to run verification
          fit_path = os.path.join(attempt_path, 'RPAFitExec')  # task_1/attempt_1/RPAFitExec
          flag_all_success = True
          rpa_exec_traj = None
          to_verify_tasks = [task_num] + [i for i in range(1, max_task_num + 1) if i != task_num]
          cur_verified_tasks = []
          
          # for j in range(1, max_task_num + 1):
          for j in to_verify_tasks:
            print_with_color(f"\n======== Verify task {j}\n", 'blue')
            log_path = os.path.join(fit_path, f'task_{j}')
            task_id = task_list[j]
            env_op.reset(task_id, log_path)
            task_goal = env_op.task
            cur_task_num = j
            
            # if f'{task_goal}_{cur_task_num}' in abandoned_tasks:
            #   print_with_color(f'Task is abandoned: {task_goal}\nTry next task.\n', 'yellow')
            #   continue
            if f'{task_goal}_{cur_task_num}' in rpa_bank_candidate.rpa_dict[task_type]['verified_tasks']:
              print_with_color(f'Task has been verified: {task_goal}\nTry next task.\n', 'green')
              continue
            
            agent.record_token = JSON_models.RecordToken(file_path=FLAGS.log_folder_exp, task_type=task_type,
                                                         task_num=f'Task {cur_task_num}', stage='Verification')
            rpa_exec_traj = run_rpa_verification(env_op=env_op, agent=agent, log_path=log_path)
            cur_rpa = rpa_bank_candidate.rpa_dict[task_type]['rpa_code']
            rpa_exec_traj_bank.add_rpa_exec_traj(rpa_code=cur_rpa, rpa_exec_traj=rpa_exec_traj)
            rpa_exec_traj_bank.save_temp(save_path=log_path, rpa_exec_traj=rpa_exec_traj)
            if not rpa_exec_traj.success:
              flag_all_success = False
              break
            cur_verified_tasks.append(f'{task_goal}_{cur_task_num}')  # On full success, store first t goal names
          
          # cur_ve_tasks = []
          # for tt in cur_verified_tasks:
          #   if int(tt[-1]) <= max_task_num:
          #     cur_ve_tasks.append(tt)
          
          rpa_bank_candidate.update_verified_tasks(task_type=task_type, verified_tasks=cur_verified_tasks)
          # Compare verified task counts and keep the RPA code with the most verified tasks
          if task_type in rpa_bank_candidate_temp.rpa_dict:
            prev_verified_num = rpa_bank_candidate_temp.rpa_dict[task_type].get("verified_tasks_num", 0)
            curr_verified_num = rpa_bank_candidate.rpa_dict[task_type].get("verified_tasks_num", 0)
            if prev_verified_num > curr_verified_num:
              # Previous RPA has more verified tasks, rollback to previous version
              print_with_color(f"\n[Rollback] Previous RPA verified {prev_verified_num} tasks, current RPA verified {curr_verified_num} tasks. Keeping previous version.", 'yellow')
              rpa_bank_candidate = deepcopy(rpa_bank_candidate_temp)
              rpa_exec_traj_bank = deepcopy(rpa_exec_traj_temp)
            elif curr_verified_num > prev_verified_num:
              # Current RPA has more verified tasks, keep current version
              print_with_color(f"\n[Keep] Current RPA verified {curr_verified_num} tasks, previous RPA verified {prev_verified_num} tasks. Keeping current version.", 'green')
            else:
              # Same number of verified tasks, keep current version (last generated)
              print_with_color(f"\n[Keep] Both RPA versions verified {curr_verified_num} tasks. Keeping current version.", 'green')
          print_with_color('\n---------------\nVerified Tasks:', 'green')
          for m, task in enumerate(cur_verified_tasks):
            print_with_color(f"{m} {task}", 'green')
          
          if flag_all_success:
            print_with_color(f"\n======== Verify 1-to-{max_task_num} tasks all success ========\n", 'blue')
            # record exp result
            for tn, verified_task in enumerate(cur_verified_tasks):
              num = int(verified_task[-1])
              setattr(exp_result_line, f"task_{num}", '1')
            # for tn, abandoned_task in enumerate(abandoned_tasks):
            #   num = int(abandoned_task[-1])
            #   setattr(exp_result_line, f"task_{num}", 'abandon')
            rpa_bank_candidate.save_temp(task_type=task_type, save_path=fit_path)  # Save current RPA to log dir
            break  # exit attempt
          elif i == max_attempts - 1:
            # abandoned_tasks.append(f'{task_goal}_{cur_task_num}')
            # record exp result
            for tn, verified_task in enumerate(cur_verified_tasks):
              num = int(verified_task[-1])
              setattr(exp_result_line, f"task_{num}", '1')
            # for tn, abandoned_task in enumerate(abandoned_tasks):
            #   num = int(abandoned_task[-1])
            #   setattr(exp_result_line, f"task_{num}", 'abandon')
            # rpa_bank_candidate = rpa_bank_candidate_temp  # rpa bank rollback
            # rpa_exec_traj_bank = rpa_exec_traj_temp
            break  # t++
          ## -----end: verify t-to-1 tasks
          
          ## -----start: call MLLM to evaluate the current page
          exec_evaluator_path = os.path.join(attempt_path, 'ExecEvaluatorAgent')
          exec_evaluator_output = agent.Breakpoint_Analyzer_Agent(rpa_exec_traj=rpa_exec_traj,
                                                             log_path=exec_evaluator_path)
          exec_continue = exec_evaluator_output.to_continue  # exec_continue = y or n
          print_with_color(f"Breakpoint_Analyzer_Agent result: {exec_continue}", 'blue')
          ## -----end: call MLLM to evaluate the current page
          
          if exec_continue.lower() == 'y':
            flag_init = False
            agent.completed_tasks = [f'{exec_evaluator_output.completed_tasks}']
        
        ## -----start: ReAct + RPABuilder(with VerificationResult)
        log_react_path = os.path.join(attempt_path, 'ReAct')
        list_react_traj = react_traj_bank.get_react_traj(task_type, f'{env_op.task}_{cur_task_num}')
        flag_react_traj_exists = False if list_react_traj is None else True
        # As rpa verification failed, run Fix ReAct
        if not flag_init:
          print_with_color("Need to continue ReAct from cur_page.", 'yellow')
          env_op.done = False
          agent.record_token = JSON_models.RecordToken(file_path=FLAGS.log_folder_exp, task_type=task_type,
                                                       task_num=f'Task {cur_task_num}', stage='Fix ReAct')
          episode, list_fix_react_traj = run_task(env_op=env_op,
                                                  agent=agent, to_init_task=False, log_path=log_react_path, react_round=1)
          if list_fix_react_traj[-1].success:
            rpa_exec_traj.fix_evaluator_analysis = f"Observation: {exec_evaluator_output.observation}\nCode Diagnosis: {exec_evaluator_output.code_diagnosis}"
            
            # Batch translate fix_react_traj actions before passing to RPA Builder
            if FLAGS.use_action_translator:
              print_with_color("🔧 Batch translating fix_react_traj actions...", 'cyan')
              list_fix_react_traj = agent.batch_translate_actions(
                react_trajs=list_fix_react_traj,
                log_path=log_react_path
              )
            
            rpa_exec_traj.fix_react_traj = list_fix_react_traj
            
            
            flag_react_traj_exists = True
            list_react_traj = []
        
        if flag_react_traj_exists:
          print_with_color("ReAct Traj Exists, skip running ReAct Agent.",
                           'green')  # Also printed when react continues after rpa exec
          react_traj_bank.save_temp(list_react_traj=list_react_traj, save_path=log_react_path)
        else:
          print_with_color("ReAct Traj Doesn't Exist, running ReAct Agent.", 'yellow')
          print_with_color("Need to restart ReAct.", 'yellow')
          
          agent.record_token = JSON_models.RecordToken(file_path=FLAGS.log_folder_exp, task_type=task_type,
                                                       task_num=f'Task {cur_task_num}', stage='ReAct')
          episode, list_react_traj = run_task(env_op=env_op,
                                              agent=agent, to_init_task=True, log_path=log_react_path,
                                              react_round=FLAGS.reflection_rounds + 1)
          react_traj_bank.save_temp(list_react_traj=list_react_traj, save_path=log_react_path)
          
          # Batch translate actions before RPA Builder
          if FLAGS.use_action_translator:
            print_with_color("🔧 Batch translating actions before RPA Builder...", 'cyan')
            list_react_traj = agent.batch_translate_actions(
              react_trajs=list_react_traj,
              log_path=log_react_path
            )
            # Note: The trajectory will be added to the bank later in the code (line 359)
          
          # If React run failed for this task, abandon it
          if not list_react_traj[-1].success:
            # abandoned_tasks.append(f'{env_op.task}_{cur_task_num}')
            # for cnt in range(0, len(list_react_traj)):
            #   setattr(exp_result_line, f'round_{cnt}', '0')  # failed round
            
            # rpa_bank_candidate = rpa_bank_candidate_temp
            # rpa_exec_traj_bank = rpa_exec_traj_temp
            break  # t++, i=1
          
          exp_result_line.round_0 = '/'
          exp_result_line.round_1 = '/'
          exp_result_line.round_2 = '/'
          setattr(exp_result_line, f'round_{len(list_react_traj) - 1}', '1')  # successful round
          for cnt in range(0, len(list_react_traj) - 1):
            setattr(exp_result_line, f'round_{cnt}', '0')  # failed round
          
          # Store traj in react_trajs_dict only when ReAct succeeded
          # When task was started from scratch
          if flag_init:
            react_traj_bank.add_react_traj(task_type, f'{env_op.task}_{cur_task_num}', list_react_traj)
          if FLAGS.update_react_trajs_bank:
            react_traj_bank.save()
        
        # RPA Builder(with VerificationResult)
        print_with_color("\n\nRunning RPABuilder_Agent", 'blue')
        rpa_builder_path = os.path.join(attempt_path, 'RPABuilder')  # init_task/RPABuilder
        RPABuilder_Agent.record_token = JSON_models.RecordToken(file_path=FLAGS.log_folder_exp, task_type=task_type,
                                                                  task_num=f'Task {cur_task_num}',
                                                                  stage='RPA Builder')
        
        # Extract encountered task goals from all tasks we've seen so far (from task 1 to current task_num)
        encountered_task_goals = []
        for j in range(1, task_num + 1):
          task_id_seen = task_list[j]
          # Create a temporary env_op to get the task goal
          # We need to get the task goal without actually resetting the current env
          temp_config_file = f"config_files/{task_id_seen}.json"
          if os.path.exists(temp_config_file):
            with open(temp_config_file) as f:
              _c = json.load(f)
              task_goal_seen = _c["intent"]
              if task_goal_seen and task_goal_seen not in encountered_task_goals:
                encountered_task_goals.append(task_goal_seen)
        
        rpa_info, cur_cnt_fetch_info = RPABuilder_Agent.generate_rpa_code(log_task_path=rpa_builder_path,
                                                            task_type=task_type,
                                                            task_template=task_type,
                                                            list_react_traj=react_traj_bank.get_last_traj(task_type),
                                                            pre_rpa_exec_traj=rpa_exec_traj,
                                                            encountered_task_goals=encountered_task_goals if encountered_task_goals else None)
        cnt_generate_rpa += 1
        cnt_fetch_info += cur_cnt_fetch_info
        # save rpa_candidate
        rpa_bank_candidate.add_rpa(rpa_info)
        rpa_bank_candidate.update_based_on_task(task_type, task_num)
        rpa_bank_candidate.save_temp(task_type=task_type, save_path=rpa_builder_path)  # Save current RPA to log dir
        ## -----end: ReAct + RPABuilder
        i += 1
      
      if rpa_bank_candidate.rpa_dict.get(task_type):
        if 'env_op.ask_mllm' in rpa_bank_candidate.rpa_dict[task_type]['rpa_code']:
          exp_result_line.has_ask_mllm = 1
        if 'env_op.get_ui_content' in rpa_bank_candidate.rpa_dict[task_type]['rpa_code']:
          exp_result_line.has_get_ui_info = 1
        exp_result_line.based_on_task = rpa_bank_candidate.rpa_dict[task_type].get('based_on_task', '0')
      
      # save record to csv
      record_exp_result(exp_result_csv, exp_result_line.dict())
      
      cur_rpa_bank = explore_rpa_banks[task_num - 1]
      cur_rpa_bank.merge_from(rpa_bank_candidate)
      cur_rpa_bank.save_temp(save_path=FLAGS.log_folder_exp, file_name=f'temp_rpa_{task_num}.json')
      
      if flag_all_success:
        break  # exit building phase
    
    exp_result_line = ExpResultLine(task_type=task_type, task_num=f'Final', task_goal='-', task_1='abandon',
                                    task_2='abandon', task_3='abandon', task_4='abandon', task_5='abandon',
                                    cnt_fetch_info=(cnt_fetch_info / cnt_generate_rpa) if cnt_generate_rpa > 0 else 0,)
    # Debug: Check if rpa_bank_candidate has RPA for this task_type
    print_with_color(f"\n[DEBUG] Checking rpa_bank_candidate for task_type: {task_type}", 'yellow')
    print_with_color(f"[DEBUG] rpa_bank_candidate.rpa_dict.keys(): {list(rpa_bank_candidate.rpa_dict.keys())}", 'yellow')
    if task_type in rpa_bank_candidate.rpa_dict.keys():
      # -----start: recording
      exp_result_line.based_on_task = rpa_bank_candidate.rpa_dict[task_type]['based_on_task']
      for tn, verified_task in enumerate(rpa_bank_candidate.rpa_dict[task_type].get('verified_tasks', [])):
        num = int(verified_task[-1])
        setattr(exp_result_line, f"task_{num}", '1')
      # if 'env_op.one_step_actor' in rpa_bank_candidate.rpa_dict[task_type]['rpa_code']:
      #   exp_result_line.has_one_step_actor = 1
      if 'env_op.ask_mllm' in rpa_bank_candidate.rpa_dict[task_type]['rpa_code']:
        exp_result_line.has_ask_mllm = 1
      if 'env_op.get_ui_content' in rpa_bank_candidate.rpa_dict[task_type]['rpa_code']:
        exp_result_line.has_get_ui_info = 1
      # -----end: recording
      
      for m, task in enumerate(rpa_bank_candidate.rpa_dict[task_type]['verified_tasks']):
        print_with_color(f"{m} {task}", 'blue')
      
      rpa_bank.merge_from(rpa_bank_candidate)  # update rpa_bank
      rpa_bank.save_temp(save_path=FLAGS.log_folder_exp,
                           file_name='temp_rpa.json')  # Also save updated rpa_bank in local experiment folder
      
      if FLAGS.update_rpa_bank:  # Update the complete rpa bank
        rpa_bank.save()
      print_with_color(f"====================== End -- Task Type : {task_type} =========================", 'blue')
      print_with_color(f'Successfully build the RPA Code', 'blue')
    else:
      print_with_color(f"====================== End -- Task Type : {task_type} =========================", 'blue')
      print_with_color("Oh no, failed to create the RPA Code.", 'red')
    
    record_exp_result(exp_result_csv, exp_result_line.dict())
