import sys

DEBUG_ENABLED = sys.gettrace() is not None
# DEBUG_ENABLED = True
if DEBUG_ENABLED:
    print("\n====================\nINTERACTIVE DEBUG MODE ACTIVATED\n====================\n")
    from utils.debug_utils import set_api_keys, set_env_variables

    # set_api_keys()
    set_env_variables(bash_script="scripts/environments/set_env_variables.sh", arg1="local_vwebarena")

import argparse
import json
import logging
import os
import random
import re
import shutil
import time
import traceback
import warnings
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Any

from beartype.roar import BeartypeDecorHintPep585DeprecationWarning

from agent import to_json
from agent.agent import Agent
from agent.modular_agent import ModularAgent, construct_modular_agent
from agent.prompt_constructor import LMParsingError
from agent.teacher_forcing import TeacherForcingAgent
from browser_env import ActionTypes, ScriptBrowserEnv, StateInfo, Trajectory, create_stop_action
from browser_env.helper_functions import RenderHelper, get_action_description

# Constants
from constants import DEFAULT_RESULTS_DIR, HTMLS_SUBDIR, LM_LOGS_SUBDIR
from constants.constants import AGENTS_CONFIG_DIR, RESULT_DIR_TEMPLATE
from evaluation_harness import evaluator_router
from llms.setup_utils import restore_api_keys_to_file
from scripts import generate_test_data
from utils import timing_utils as timer

# Utilities
from utils.data_recorder import DataRecorder
from utils.eval_utils import get_agent_config, log_error, set_seed
from utils.file_utils import get_config_base_dir_from_txt, get_ids_from_tst_config_list
from utils.logger_utils import logger, save_log_file, save_log_file_path
from utils.signal_utils import signal_manager
from utils.string_utils import safe_format
from utils.timing_utils import dump_timings, set_timings_global_id, time_block, timeit
from utils.trajectory_view import TrajectoryView
from vwa_utils.captioner_utils import define_captioning_fn
from vwa_utils.extract_trajectory_html import process_html_file, rebuild_trajectory_vwa_format
from vwa_utils.vwa_utils import (
    auto_login,
    early_stop,
    get_domain_from_test_config_dir,
    get_tasks_with_trajectory,
    load_task_config,
)

# Filter warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
warnings.filterwarnings("ignore", category=BeartypeDecorHintPep585DeprecationWarning)
warnings.filterwarnings(
    "ignore", category=FutureWarning, message="`resume_download` is deprecated and will be removed in version 1.0.0."
)
warnings.filterwarnings(
    "ignore", category=UserWarning, message='Field "model_id" has conflict with protected namespace "model_".'
)

# Handler for termination signals
signal_manager.add_cleanup_function(restore_api_keys_to_file)

# ===============================================================================
# CMD arguments
# ===============================================================================


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Run end-to-end evaluation on the benchmark")

    # --- Agent configuration ---
    parser.add_argument(
        "--agent_config_file",
        type=str,
        default="agent_config_base.yaml",
        help="Filename for the YAML config file.",
    )

    parser.add_argument(
        "--active_modules",
        type=str,
        default="",
        help=(
            "If provided, activates certain modules defined in `agent_config_file`. "
            "Leave empty to activate all (default)."
            "Use <executor>:<text_refiner> to activate all modules containing the strings 'executor' and 'text_refiner'"
        ),
    )

    # --- Prompt configuration ---
    parser.add_argument(
        "--path_raw_prompts",
        type=str,
        default="./agent/prompts/raw/base",
        help="Path to raw files to build json prompts.",
    )

    # --- Subdir to store outputs ---
    parser.add_argument("--result_dir", type=str, nargs="?", const="", default="")

    # --- Execution configuration ---
    parser.add_argument("--manual_input", action="store_true", help="Low level action input comes from the user.")

    parser.add_argument(
        "--deployment_mode",
        choices=["tgi", "vllm", "automodel"],
        default="automodel",
        help="Deployment mode for the hugging-face models. If TGI, tgi_model_endpoint should be provided.",
    )

    parser.add_argument("--sleep_after_execution", type=float, default=0.0)

    parser.add_argument(
        "--tgi_model_endpoint",
        type=str,
        default="http://127.0.0.1:8080",  # if using TGI, must deploy in the address provided here
        help="Endpoint where the model is being deployed. Defaults to localhost:8080.",
    )

    # TGI execution
    parser.add_argument(
        "--local",
        action="store_true",
        help="Deploy a local TGI server running on --tgi_model_endpoint. Requires text-generation-launcher installation.",
    )

    # vllm execution
    parser.add_argument("--num_gpus", type=int, default=1)

    parser.add_argument(
        "--max_model_len", type=int, default=-1, help="Model context length in vLLM. Use if needs to reduce GPU memory."
    )

    parser.add_argument("--flash_attn", action="store_true", help="Uses flash attention in AutoModel engine.")

    parser.add_argument(
        "--eager", action="store_true", help="Force eager mode if using vLLM. Uses less GPU memory, but less efficient."
    )

    parser.add_argument("--debugging", action="store_true", help="Set debugging inputs.")

    parser.add_argument("--trajectory_html_path", type=str, nargs="?", const=None, default=None)

    # --- Task configuration ---
    parser.add_argument(
        "--test_config_base_dir", type=str, default="", nargs="?", help="Path to the test configuration base directory."
    )
    parser.add_argument("--test_start_idx", type=int, default=0)
    parser.add_argument("--test_end_idx", type=int, default=10000)
    parser.add_argument("--task_list", nargs="?", const=None, type=str, default=None)
    parser.add_argument("--max_tasks", type=int, default=None, help="Maximum number of tasks to evaluate.")
    parser.add_argument("--shuffle_tasks", action="store_true", help="Shuffle the task list.")
    parser.add_argument("--seed", type=int, default=42, help="Set seed for testing.")

    # ===========================================================================
    # (V)WA specific
    # ===========================================================================
    # --- Captioning configuration ---
    parser.add_argument(
        "--agent_captioning_model_device",
        type=str,
        default="cuda",
        help="Device to run captioning model on. By default, runs it on CUDA.",
    )

    parser.add_argument(
        "--eval_captioning_model_device",
        type=str,
        default="cpu",
        help="Device to run eval captioning model on. By default, runs it on CPU.",
    )

    parser.add_argument(
        "--eval_captioning_model",
        type=str,
        default="Salesforce/blip2-flan-t5-xl",
        choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"],
        help="Captioning backbone for VQA-type evals.",
    )

    parser.add_argument(
        "--no_caption_text_obs",
        action="store_true",
        help="If true, does not caption the text observations of the webpage, even if captioner is provided.",
    )

    # --- Observation configuration ---
    parser.add_argument(
        "--observation_type",
        choices=[
            "accessibility_tree",  # text: actree of the webpages
            "accessibility_tree_with_captioner",  # text: actree of the webpages; caption model generates description to images and add to actree
            "image_som",  # image: SOM-marked screenshot of webpage; text: ID of items marked in the screenshot
            "html",  # TODO: codebase dont have action parser for this
            "image",  # TODO: codebase dont to have action parser for this
        ],
        default="accessibility_tree",
        help="Observation type",
    )

    parser.add_argument(
        "--show_scroll_bar",
        action="store_true",
        help="Show the scroll bar in the observation",
    )

    parser.add_argument(
        "--current_viewport_only",
        action="store_true",
        help="Only use the current viewport for the observation",
    )
    parser.add_argument("--viewport_width", type=int, default=1280)
    parser.add_argument("--viewport_height", type=int, default=2048)

    # --- Evaluation configuration ---
    parser.add_argument(
        "--parsing_failure_th",
        type=int,
        default=3,
        help="When consecutive parsing failure exceeds this threshold, the agent will stop",
    )

    parser.add_argument(
        "--repeating_action_failure_th",
        type=int,
        default=5,
        help="When consecutive repeated actions exceed this threshold, the agent will terminate early.",
    )

    parser.add_argument(
        "--fuzzy_match_provider",
        type=str,
        default="openai",
        choices=["openai", "google", "huggingface"],
        help="LLM provider for fuzzy matching evaluation. If not provided, uses GPT4-Turbo. If GPT not available, \
            uses Gemini Pro 1.5 or 1.0, whichever is available. If 'huggingface', uses LLAMA-3-Instruct-8b.",
    )

    parser.add_argument(
        "--max_steps", type=int, default=30, help="Max number of environment steps allowed. If exceeded, FAIL."
    )

    # --- Outputs/results configuration ---
    parser.add_argument("--render_screenshot", action="store_true")
    parser.add_argument("--log_obs_lens", action="store_true", help="Log observation lengths for analysis.")

    # real-time browser rendering
    parser.add_argument("--render", action="store_true", help="Shows the browser in the screen.")
    parser.add_argument(
        "--slow_mo",
        type=int,
        default=0,
        help="Slow down the browser by the specified amount",
    )

    parser.add_argument("--save_trace_enabled", action="store_true")

    args = parser.parse_args()

    return args


def build_test_file_list(args: argparse.Namespace) -> list[str]:
    test_file_list = []

    if args.task_list:
        logger.info(f"\nExecuting tasks from: {args.task_list}")
        # Read task file and extract JSON paths (only for numeric lines)
        with open(args.task_list, "r") as f:
            test_file_list.extend(
                [
                    os.path.join(args.test_config_base_dir, f"{line.strip()}.json")
                    for line in f.readlines()
                    if line.strip().isdigit()
                ]
            )
        # Copy the task list file to the result directory if not already there
        src_path = os.path.abspath(args.task_list)
        dest_path = os.path.abspath(os.path.join(args.result_dir, os.path.basename(args.task_list)))
        if src_path != dest_path:
            shutil.copyfile(args.task_list, dest_path)
    else:
        logger.info(f"\nExecuting tasks in the range: {args.test_start_idx} to {args.test_end_idx}")
        test_file_list = [
            os.path.join(args.test_config_base_dir, f"{i}.json")
            for i in range(args.test_start_idx, args.test_end_idx + 1)
        ]

    # If a trajectory HTML path is provided, update the task list using the intersection with tasks that have trajectories
    if args.trajectory_html_path:
        logger.info(f"Set of tasks updated to tasks with previous trajectories in {args.trajectory_html_path}")
        test_file_list = get_tasks_with_trajectory(args.trajectory_html_path, args.test_config_base_dir, test_file_list)

    # Remove finished tasks # TODO: utilize get_unfinished from experiments_utils
    # test_file_list = get_unfinished(test_file_list, args.result_dir)

    if num_tasks := len(test_file_list) == 0:
        logger.info("No task left to run")
        sys.exit(0)

    if args.max_tasks is not None and args.max_tasks > 0 and num_tasks > args.max_tasks:
        logger.info(f"`max_tasks` set to {args.max_tasks}. Evaluating the first {args.max_tasks} tasks of {num_tasks}.")
        test_file_list = test_file_list[: args.max_tasks]

    if args.shuffle_tasks:
        logger.info(f"Shuffling tasks. Seed: {args.seed}")
        random.shuffle(test_file_list)

    return test_file_list


# Pre-process args
def preprocess_args(args: argparse.Namespace) -> None:
    # check the whether the action space is compatible with the observation space
    args.action_set_tag = args.agents_configs["executor_agent"]["action_set_tag"]

    # removed `som_image`` because `id_accessibility_tree` with `som_image` gives wrong hint in get_action_description(..) (see the function for more details)
    if args.action_set_tag == "id_accessibility_tree" and args.observation_type not in [
        "image",
        "accessibility_tree",
        "accessibility_tree_with_captioner",
    ]:
        raise ValueError(
            f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}"
        )

    if args.action_set_tag == "som" and args.observation_type not in ["image_som"]:  # image is not supported yet
        raise ValueError(
            f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}"
        )

    if args.observation_type in ["accessibility_tree_with_captioner"] and args.captioning_model is None:
        args.captioning_model = "Salesforce/blip2-flan-t5-xl"
        logging.warning(
            f"Observation type `{args.observation_type}` requires captioner but none provided. \
            captioning model set to {args.captioning_model} (default)"
        )

    # Convert prompt python files to json
    to_json.run(path_raw_prompts=args.path_raw_prompts)

    # Create json task files (0.json, 1.json, ...)
    generate_test_data.main()

    # If trajectory_html_path provided, verify it exists
    if args.trajectory_html_path:
        if not os.path.exists(args.trajectory_html_path):
            raise ValueError(f"Trajectory HTML path {args.trajectory_html_path} does not exist.")

    # If task_list provided, get the test_config_base_dir if available
    if args.task_list:
        test_config_base_dir = get_config_base_dir_from_txt(args.task_list)
        if test_config_base_dir:
            args.test_config_base_dir = test_config_base_dir

    # Verify existent tasks in config folder and regularize task indices
    if args.task_list is None:
        test_files = os.listdir(args.test_config_base_dir)
        test_files = [f for f in test_files if f.endswith(".json")]
        if len(test_files) == 0:
            raise ValueError(f"No tasks found in {args.test_config_base_dir}")
        task_ids = [int(os.path.basename(f).split(".")[0]) for f in test_files]
        min_idx, max_idx = min(task_ids), max(task_ids)
        args.test_start_idx = min_idx if args.test_start_idx < min_idx else args.test_start_idx
        args.test_end_idx = max_idx if args.test_end_idx > max_idx else args.test_end_idx

    # ---------------------------------------------------------------------------
    # Results and log directory paths
    # ---------------------------------------------------------------------------

    # --- Construct result directory path ----
    # Creates a default `args.result_dir` if not provided
    if not args.result_dir:
        top_dir = DEFAULT_RESULTS_DIR

        # Extract the domain name to annotate the result directory
        args.domain = domain = get_domain_from_test_config_dir(args.test_config_base_dir)

        # Annotate model and current date and time
        date_annotate = datetime.now().strftime("%Y-%m-%d-%H%M")
        model = args.agents_configs["executor_agent"]["lm_config"]["model"]  # TODO: change if using multiple models
        args.result_dir = safe_format(
            RESULT_DIR_TEMPLATE, results_dir=top_dir, model=model, domain=domain, annotation=date_annotate
        )
        # Example: results/gpt-4o-mini-2024-07-18/reddit_2025-02-22-1430

    # Create a subdir in `args.result_dir` to hold results for a range of tasks or a task list
    # NOTE: This is helpful if parallel runs; runs will all be saved at the initial `args.result_dir`, but separated by their corresponding tasks.
    if args.task_list is None:
        args.result_dir = f"{args.result_dir}/{args.test_start_idx}-{args.test_end_idx}"
    else:
        args.result_dir = f"{args.result_dir}/{Path(args.task_list).stem}"

    # Create the result directory, if not exists
    os.makedirs(args.result_dir, exist_ok=True)

    # Create the traces directory, if not exists
    if args.save_trace_enabled:
        os.makedirs(os.path.join(args.result_dir, "traces"), exist_ok=True)

    # Save logger files paths to a text file in results folder
    save_log_file_path(args.result_dir)

    # Base name of files to save execution HTML traces and LM call logs.
    args.htmls_dir = os.path.join(args.result_dir, HTMLS_SUBDIR)
    os.makedirs(args.htmls_dir, exist_ok=True)
    args.lm_logs_dir = os.path.join(args.result_dir, LM_LOGS_SUBDIR)
    os.makedirs(args.lm_logs_dir, exist_ok=True)

    # Create the conversation and usage directories, if not exists
    for agent, agent_config in args.agents_configs.items():
        if "lm_config" in agent_config:
            agent_lm_log = f"{args.lm_logs_dir}/{agent}"
            agent_config["conversation_dir"] = f"{agent_lm_log}/conversation"
            agent_config["usage_dir"] = f"{agent_lm_log}/usage"


def set_debug_inputs(args: argparse.Namespace) -> None:
    # FIXME @debug interactive
    global DEBUG_ENABLED
    if not DEBUG_ENABLED:
        return

    args.max_steps = 30
    args.manual_input = False
    # args.trajectory_html_path = "/home/mashalimay/webarena/WebGUIAgents/experiments/gemini-2.0-exp-flash/base_prev_utterances_critique_not_vague/shopping/htmls/"
    # args.task_list = "tasks.txt"
    args.test_start_idx = 122  # inclusive
    args.test_end_idx = 122  # inclusive
    args.test_config_base_dir = "config_files/vwa_not_vague/test_reddit"

    args.agent_config_file = "agent_config_t_100_2p_tri_nocot_expert.yaml"
    args.render = False

    args.agent_captioning_model_device = "server-cuda"  # 'cuda'
    # args.eval_captioning_model_device = "server-cuda"  # 'cuda'
    args.no_caption_text_obs = False

    args.viewport_width = 1280  # Default: 1280
    args.viewport_height = 2048  # Default: 720 for small context window models | 2048 for large context window models

    args.fuzzy_match_provider = "google"
    args.render_screenshot = True
    args.observation_type = "image_som"  # accessibility_tree, image_som, accessibility_tree_with_captioner, html, image

    # less used
    args.path_raw_prompts = "agent/prompts/raw/base"
    args.log_obs_lens = False
    args.result_dir = "results/debug"
    args.local = False
    args.save_images = False
    args.deployment_mode = "automodel"
    args.slow_mo = 0
    args.tgi_model_endpoint = "http://127.0.0.1:8080"

    if "debug" in args.result_dir and os.path.exists(os.path.abspath(args.result_dir)):
        shutil.rmtree(os.path.abspath(args.result_dir))
        os.makedirs(os.path.abspath(args.result_dir))


# ===============================================================================
# Evaluation
# ===============================================================================


def error_handler(error: Exception, config_file: str, result_dir: str, data_recorder: DataRecorder) -> bool:
    logger.info(f"[Error] {repr(error)}")
    log_error(error, config_file, result_dir)
    print(traceback.format_exc())
    data_recorder.num_failed_executions += 1
    return False


def init_environment(args: argparse.Namespace) -> ScriptBrowserEnv:
    captioning_fn = None if args.no_caption_text_obs else args.caption_img_fn
    return ScriptBrowserEnv(
        headless=not args.render,
        slow_mo=args.slow_mo,
        observation_type=args.observation_type,
        current_viewport_only=args.current_viewport_only,
        viewport_size={
            "width": args.viewport_width,
            "height": args.viewport_height,
        },
        save_trace_enabled=args.save_trace_enabled,
        sleep_after_execution=args.sleep_after_execution,
        captioning_fn=captioning_fn,
        show_scroll_bar=args.show_scroll_bar,
    )


# LINK: Evaluation
def test(args: argparse.Namespace, agent: Agent, config_file_list: list[str]) -> None:
    test_start_time = time.time()  # TODO: remove this
    domain = get_domain_from_test_config_dir(args.test_config_base_dir)

    # Initialize data recorder to record experiment data
    data_recorder = DataRecorder(
        args.result_dir, config_file_list, args.test_config_base_dir, agent.get_action_splitter()
    )

    # Set early stop thresholds
    max_steps = args.max_steps
    early_stop_thresholds = {
        "parsing_failure": args.parsing_failure_th,
        "repeating_action": args.repeating_action_failure_th,
    }

    # Initialize environment
    env = init_environment(args)

    # LINK: Iterate each task (defined in `config_file` jsons)
    for config_file in config_file_list:
        try:
            render_helper = RenderHelper(config_file, args.htmls_dir, args.action_set_tag)

            # Load task config
            config_dict, intent, task_id, intent_images = load_task_config(config_file)

            set_timings_global_id(f"{domain}-{task_id}")
            timer.start("RUN:test")

            # REVIEW: Commented out automatic login. Login after every task creates issues with p_run:
            # TODO: add cookie logic by environment for parallel process
            # process A renovates login cookies, process B cannot modify some websites.
            config_file = auto_login(config_dict, config_file, auth_folder="./.auth", renew=False)
            evaluator = evaluator_router(
                config_file, captioning_fn=args.eval_caption_img_fn, fuzzy_match_prov=args.fuzzy_match_provider
            )  # type: ignore

            # Log info
            logger.info(f"[Config file]: {config_file}")
            logger.info(f"[Intent]: {intent}")
            data_recorder.initialize_task(task_id, config_dict["sites"])
            task_start_time = time.time()

            # Prepare agent and environment
            agent.reset(config_file)
            trajectory: Trajectory = []
            trajectory_view = TrajectoryView(trajectory)
            obs, info = env.reset(options={"config_file": config_file})  # type: ignore
            state_info: StateInfo = {"observation": obs, "info": info}
            trajectory.append(state_info)
            meta_data: dict[str, Any] = {
                "action_str_history": ["None"],
                "task_id": task_id,
                "data_recorder": data_recorder,
                "manual_input": args.manual_input,
                "env": env,
                "trajectory": trajectory_view,
                "args": args,
                "config_file": config_file,
                "evaluator": partial(evaluator, config_file=config_file),  # type: ignore
            }

            # log observation len for landing page
            if args.log_obs_lens:
                data_recorder.log_observation_len(
                    trajectory,
                    task_id,
                    agent.get_model("executor"),
                    agent.get_provider("executor"),
                    agent.get_tokenizer("executor"),
                )

            # LINK: Loop: execute current task
            while True:
                save_log_file(args.result_dir)  # Write log content to `args.result_dir`
                test_execution_success = True  # Track if any errors during execution of current task
                try:
                    # Check if stop threshold reached and stop
                    early_stop_flag, stop_info = early_stop(trajectory, max_steps, early_stop_thresholds)
                    if early_stop_flag:
                        action = create_stop_action(f"Early stop: {stop_info}")
                        logger.info(f"Early stop: {stop_info}")
                        if agent.score_logger:
                            agent.log_scores_per_round(TrajectoryView(trajectory + [action]), intent, meta_data)

                    # Else, Get action on environment
                    else:
                        action = agent.next_action(
                            trajectory_view,
                            intent,
                            intent_images=intent_images,
                            meta_data=meta_data,
                        )

                        if "early_stop" in action:
                            early_stop_reason = action["early_stop"]  # type: ignore
                            logger.info(f"Early stop: {early_stop_reason}")
                            action = create_stop_action(action["raw_prediction"])
                            action["early_stop"] = early_stop_reason  # type: ignore
                            if agent.score_logger:
                                agent.log_scores_per_round(TrajectoryView(trajectory + [action]), intent, meta_data)

                except ValueError as e:
                    action = create_stop_action(f"ERROR: {str(e)}")
                    logger.info(f"[Error] {str(e)}")
                    print(traceback.format_exc())

                # Append action to trajectory
                trajectory.append(action)

                # Convert action to string to append to history to serve as input in next generations
                action_str = get_action_description(
                    action=action,
                    observation_metadata=state_info["info"]["observation_metadata"],
                    action_set_tag=args.action_set_tag,
                    prompt_constructor=agent.get_prompt_constructor(),
                )
                action_str = re.sub(r"\s+", " ", action_str.replace("\n", "")).strip()
                meta_data["action_str_history"].append(action_str)

                # Code inside 'get_action_description' appends invalid flag to action; below is another way.
                # In VWA, we can only check this with the logic as in get_action_description.
                # Issue: verification is environment dependent.
                # if "Attempt to perfom" in action_str or "The previous prediction you issued was" in action_str:
                #    <code to add invalid flag to meta_data>
                timer.start("RUN:render")
                render_helper.render(action, state_info, meta_data, args.render_screenshot)
                timer.end("RUN:render")

                # Record action data
                data_recorder.record_action(task_id, action)

                dump_timings(args.result_dir)

                # If STOP action, end task loop.
                if action["action_type"] == ActionTypes.STOP:
                    break

                # Act on environment
                obs, _, terminated, _, info = env.step(action)
                state_info = {"observation": obs, "info": info}
                trajectory.append(state_info)

                # Record observation length of current state # REFACTOR

                if args.log_obs_lens:
                    data_recorder.log_observation_len(
                        trajectory,
                        task_id,
                        agent.get_model("executor"),
                        agent.get_provider("executor"),
                        agent.get_tokenizer("executor"),
                    )

                if terminated:
                    trajectory.append(create_stop_action(""))  # Add action placeholder
                    break
            # END OF TASK

            # Evalute and print score.
            # Obs: eval_caption_image_fn is used for running eval_vqa functions.
            score = evaluator(trajectory=trajectory, config_file=config_file, page=env.page)  # type: ignore
            logger.info(f"[Result] ({'PASS' if score == 1 else 'FAIL'}) {config_file}")
            elapsed_time = time.time() - task_start_time

            # Save trace
            if args.save_trace_enabled:
                env.save_trace(Path(args.result_dir) / "traces" / f"{task_id}.zip")

            # Record and save task stats
            data_recorder.update_save_data(task_id, score, elapsed_time, num_actions=len(trajectory[1::2]))

        except Exception as e:
            test_execution_success = error_handler(e, config_file, args.result_dir, data_recorder)

        finally:
            timer.end("RUN:test")
            data_recorder.update_unfinished_failed_tasks(task_id, test_execution_success)
            render_helper.close()
        # End of one task

    # End of experiment
    env.close()
    test_end_time = time.time()
    scores = data_recorder.get_scores()
    logger.info(f"Average score: {sum(scores) / len(scores)}; {len(scores)} tasks")
    logger.info(
        f"NOTE: {len(config_file_list) - data_recorder.num_failed_executions} of {len(config_file_list)} completed without error."
    )
    logger.info(f"Total test time (min): {(test_end_time - test_start_time) / 60}")

    # Save summary execution stats
    data_recorder.save_execution_summary(test_end_time - test_start_time, agent.get_provider("executor"))


# LINK: Main
if __name__ == "__main__":
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    args = parse_args()

    # Set seed for numpy, random, torch
    set_seed(args.seed)

    # Set debugging inputs, if any
    if DEBUG_ENABLED:
        set_debug_inputs(args)

    # ------------------------------------------------------------------------------
    # Parse and validate configs, args
    # ------------------------------------------------------------------------------
    # Set Agent configurations
    if AGENTS_CONFIG_DIR in args.agent_config_file:
        args.agent_config_file = Path(args.agent_config_file).name

    args.agent_config_path = str(Path(AGENTS_CONFIG_DIR) / args.agent_config_file)
    args.agents_configs = get_agent_config(args.agent_config_path, active_modules_str=args.active_modules)

    # Validate and preprocess args
    preprocess_args(args)

    # Build test file list
    test_file_list = build_test_file_list(args)

    # ------------------------------------------------------------------------------
    # Dump test_file_list and args to results folder
    # ------------------------------------------------------------------------------
    with open(os.path.join(args.result_dir, "attempted_tasks.txt"), "w") as f:
        for test_file in sorted(test_file_list):
            f.write(f"{test_file}\n")

    args.task_list = sorted(test_file_list)  # updated for logging purposes
    with open(f"{args.result_dir}/args.json", "w") as f:
        json.dump(vars(args), f, indent=4)

    # ------------------------------------------------------------------------------
    # Load captioner
    # ------------------------------------------------------------------------------
    logger.info(f"\nTotal {len(test_file_list)} tasks left.\n")

    args.agent_captioning_model = None
    if "request_refiner" in args.agents_configs:
        args.agent_captioning_model = args.agents_configs["request_refiner"]["captioning_model"]
    args.caption_img_fn, args.eval_caption_img_fn = define_captioning_fn(args)

    # ------------------------------------------------------------------------------
    # Build Agent
    # ------------------------------------------------------------------------------
    agent: Agent
    if args.trajectory_html_path:
        # Get all HTMLs in the trajectory_html_path
        agent = TeacherForcingAgent(trajectory_html_path=args.trajectory_html_path, agents_configs=args.agents_configs)
    else:
        agent = construct_modular_agent(args.agents_configs, args.caption_img_fn)

    # ------------------------------------------------------------------------------
    # Evaluate
    # ------------------------------------------------------------------------------
    # print full path to result_dir
    logger.info(f"\nResults will be saved in {os.path.abspath(args.result_dir)}")

    try:
        test(args, agent, test_file_list)
        # Finishing up
        save_log_file(args.result_dir)
        logger.info(f"Results in {os.path.abspath(args.result_dir)}")

    except Exception as e:
        logger.error(f"Error in test: {e}")
        raise e

    finally:
        restore_api_keys_to_file()
