"""Runs an agent on the environment."""

from .agent_rpa import Agent_RPA
from .env_operation import EnvOperation
from .utils.JSON_models import EpisodeResult
from .utils.agent_utils import print_with_color

MAX_ACTION_STEP = 20


def run_episode(
  env_op: EnvOperation,
  agent: Agent_RPA,
  log_task_path: str,
  to_init_task: bool = True
):
  """Runs an agent on task, e.g., "turn off wifi".

  An agent will start from whatever state the provided environment is in and
  run until it determines a task is complete, if the max number of
  steps is reached, of if the termination_fn is True.

  Args:
    task: The task to run on the environment.
    agent: The agent to run on the environment.

  Returns:
    Data collected during running agent on task.
  """
  
  agent.reset(env_op.task_type, log_task_path, to_init_task)
  
  react_max_steps = MAX_ACTION_STEP
  if agent.rpa_mode or not to_init_task:
    react_max_steps += 10  # add additional steps for rpa testing
  
  remaining_steps = react_max_steps - len(agent.action_history)
  print(f'remaining_steps: {remaining_steps}')
  remaining_steps = max(remaining_steps, 1)
  # assert remaining_steps > 0, "n_steps must be greater than 0"
  
  env_op.reset(env_op.task_idx, log_task_path, to_init_task, react_max_steps)
  agent.cur_task = env_op.task  # make task.goal if set after MobileMiniWoB task is initialized

  if not agent.rpa_mode:  # run react
    for step_n in range(remaining_steps):
      agent.step()
      if env_op.done:
        print('Environment ends episode.')
        break
    output = agent.agent_traj
  else:  # run testing
    output = agent.rpa_testing()
  
  if not env_op.agent_done and len(agent.action_history) == react_max_steps:
    print_with_color('Agent did not indicate task is done. Reached max number of steps.', 'red')
  
  task_successful = (env_op.reward > 0)
  agent_done = env_op.agent_done
  agent_successful = task_successful if agent_done else 0.0
  env_success = True if agent_successful > 0.5 else False
  
  ep_result = EpisodeResult(
    task_goal=env_op.task,
    log_task_path=log_task_path,
    agent_traj=output,
    action_history=agent.action_history,
    env_success=env_success,
    task_successful=task_successful,
    agent_done=agent_done,
    agent_successful=agent_successful
  )
  return ep_result
