import os.path
import time
from typing import Any

from absl import flags

from UIAgents.Agent_RPA.agent_rpa import Agent_RPA
from UIAgents.Agent_RPA.utils.rpa_bank_utils import RPABank
from .env_operation import EnvOperation
from .episode_runner import run_episode
from .utils import JSON_models
from .utils.agent_utils import print_with_color, record_exp_result
from .utils.traj_utils import ReactTrajBank

FLAGS = flags.FLAGS


def run_task(
  env_op: EnvOperation,
  agent: Agent_RPA,
  log_path: str,
  to_init_task: bool = True,
  react_round: int = 1,
) -> tuple[dict[str, Any], list[JSON_models.ReActTraj]]:
  """Runs a task.

  Args:
    task: The task.
    env_op: Environment that will be run on.
    agent: The agent to run the task.
    log_path: The path to save the log file.

  Returns:
    Episode data and associated success signals.

  Raises:
    ValueError: If step data was not as expected.
  """
  agent.rpa_mode = None
  agent.reflection = None
  agent.reflection_history = []
  trajs = []
  episode_length = 0
  episode_result = JSON_models.EpisodeResult()
  goal = env_op.task
  
  time.sleep(3)
  # Mode: GUI Agent
  for round in range(react_round):
    agent.record_token.stage = agent.record_token.stage + f' Round {round}'
    print_with_color("========================================================================", 'blue')
    print_with_color(f'Running {env_op.task_type} with goal "{goal}"', 'blue')
    print_with_color(f"****** Round {round} ******", 'blue')
    file_suffix = f'round_{round}'
    log_file_path = os.path.join(log_path, file_suffix)
    
    # try to complete the task by ReAct
    episode_result = run_episode(env_op, agent, log_file_path, to_init_task)
    agent_traj = episode_result.agent_traj
    episode_length += len(agent_traj)
    
    env_op_traj = []
    for cnt, step_info in enumerate(agent_traj):
      env_op_traj.append(step_info.exec_step_info)
    
    # get ReAct Trajectory for RPABuilder
    react_traj = JSON_models.ReActTraj(
      task=goal,
      pre_reflection=agent.reflection,
      traj=agent_traj,
      action_history=episode_result.action_history,
      success=episode_result.env_success,
      env_op_traj=env_op_traj,
    )
    trajs.append(react_traj)
    
    print(f"\n********* task_successful_round_{round}: {episode_result.task_successful} *********\n")
    print(f'{"Task Successful ✅" if episode_result.env_success else "Task Failed ❌"}; {goal}\n')
    
    # conclusion for the task
    if (not FLAGS.test_rpa_mode) and (not FLAGS.agent_name == 'agent_react'):
      concluder_result = agent.Concluder_Agent(goal=goal, log_task_path=log_file_path, episode_results=episode_result)
      react_traj.conclusion = concluder_result.episode_conclusion
    
    if episode_result.env_success:
      break
    
    if round != react_round - 1:  # Avoid extra 'Restart with final-state-reflection'
      # restart with reflection
      print("\n\n*************** Restart with final-state-reflection *************")
      to_init_task = True
      time.sleep(2)
  
  return {}, trajs


def run(
  task_list: list[int],
  env_op: EnvOperation,
  agent: Agent_RPA,
  rpa_bank: RPABank,
  explore_rpa_banks: list[RPABank],
  react_traj_bank: ReactTrajBank,
  task_templates: dict,
  cnt_task_type: int
):
  """Runs e2e system on task suite.

  Args:
    env_op: The environment to run it on.
    agent: The agent to run the task.
    rpa_bank: The rpa bank to use.
    process_episodes_fn: The function to process episode data. Usually to
      compute metrics. Deafaults to process_episodes from this file.

  Returns:
    Metadata for each episode, including the scripted reward.
  """
  task_type = env_op.task_type
  task_type_log_path = os.path.join(FLAGS.log_folder_exp, task_type)
  os.makedirs(task_type_log_path, exist_ok=True)
  
  exp_result_csv = os.path.join(FLAGS.log_folder_exp, "react_result.csv")
  result = {"Num": cnt_task_type, "Task Type": task_type, "Task Template": task_type}
  
  for task_idx in task_list:
    print_with_color(f"\n========\nTask {task_idx}:", 'blue')
    
    log_path = os.path.join(task_type_log_path, f'Task_{task_idx}')
    os.makedirs(log_path)
    
    env_op.reset(task_idx, log_path)
    task_goal = env_op.task
    print(task_type, task_goal)
    
    agent.record_token = JSON_models.RecordToken(file_path=FLAGS.log_folder_exp, task_type=task_type,
                                                 task_num=f'Task {task_idx}', stage='ReAct Round 0')
    episode, list_react_traj = run_task(env_op=env_op, agent=agent, to_init_task=True, log_path=log_path, react_round=1)
    react_traj_bank.save_temp(list_react_traj=list_react_traj, save_path=log_path)
    result.update({f"ReAct_{task_idx}": int(list_react_traj[-1].success)})

  record_exp_result(exp_result_csv, result)
    
