# TODO: remove this. Placed here to remember
import logging
from typing import Any

from filelock import FileLock

from utils.debug_utils import set_env_variables

set_env_variables()

import argparse
import concurrent.futures
import json
import os
import random
import re
import shutil
import subprocess
import sys
import time
from concurrent.futures import Future
from datetime import datetime
from multiprocessing import Manager
from pathlib import Path

import numpy as np

# from vwa_utils.vwa_utils import get_expired_sites
from browser_env.auto_login import is_expired_for_sites
from constants.constants import AGENTS_CONFIG_DIR, DEFAULT_RESULTS_DIR, RESET_ENV_SCRIPT
from llms.constants import API_KEYS_PATH, API_KEYS_REPO
from utils.concurrency_utils import get_file_lock, single_instance_lock
from utils.eval_utils import get_agent_attribute, set_seed
from utils.file_utils import (
    find_files,
    get_config_base_dir_from_txt,
    get_ids_from_tst_config_list,
    get_task_ids_from_csv,
    resolve_path_conflict,
)
from utils.signal_utils import signal_manager
from vwa_utils.captioner_utils import start_captioner

TEMP_FILES_DIR = ".temp_files"
EXC_COOKIE_SITE_COMB = True

# Global variable to store the current tmux session id for cleanup in case of a kill command.
CURRENT_TMUX_SESSION: str | None = None

LOCK_TIMEOUT = 120

# ===============================================================================
# Argument parsing
# ===============================================================================


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Run tasks in parallel.")
    parser.add_argument(
        "-t",
        "--tasks-file",
        default="",
        help="Path to file with tasks (first line is config_dir, rest are numeric IDs)",
    )

    parser.add_argument(
        "-n",
        "--num-processes",
        type=int,
        default=4,
        help="Number of parallel worker processes to keep busy.",
    )

    parser.add_argument(
        "-b",
        "--tasks-per-process",
        type=int,
        default=10,
        help="Number of tasks per worker process.",
    )

    parser.add_argument(
        "-c",
        "--config-dir",
        default="",
        help="If no tasks file, use this directory to find numeric tasks (min..max).",
    )

    parser.add_argument(
        "-a",
        "--agent-config",
        default="",
        help="YAML configuration for the Agent",
    )

    parser.add_argument(
        "-ra",
        "--reset-after",
        type=int,
        default=0,
        help="Reset environments after N tasks are finished. Set -1 for no reset, 0 for estimate based on num_processes and tasks_per_process.",
    )

    parser.add_argument(
        "-d",
        "--domains-to-reset",
        default="",
        help="Domains to reset. If none, will reset only the domain inferred from `config_dir`.",
    )

    parser.add_argument(
        "-di",
        "--domains-to-reset-on-init",
        default="all",
        help="Domains to reset on initialization.",
    )

    # Max running time
    parser.add_argument(
        "-mrt",
        "--max-running-time",
        type=int,
        default=0,
        help="Max time a process can run in minutes. Set -1 for no limit. If set to 0 will be estimated as `avg-running-time-per-task` *  `tasks-per-process`.",
    )

    parser.add_argument(
        "-art",
        "--avg-running-time-per-task",
        type=float,
        default=180,
        help="Average running time per task in seconds. Used to compute the max running time if `max-running-time` is set to 0.",
    )

    parser.add_argument(
        "-ma",
        "--max-attempts-per-task",
        type=int,
        default=2,
        help="Maximum number of attempts per task.",
    )

    parser.add_argument(
        "-sc",
        "--skip-cookies",
        action="store_true",
        help="Skip creating autologin cookies at initialization.",
    )

    parser.add_argument(
        "-sr",
        "--skip-reset",
        action="store_true",
        help="Skip start/reset of environments at initialization.",
    )

    parser.add_argument(
        "-cd",
        "--captioner-device",
        default="server-cuda",
        help="Device to host captioner on.",
    )
    parser.add_argument(
        "-seed",
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility.",
    )

    parser.add_argument(
        "-st",
        "--shuffle-tasks",
        action="store_true",
        help="Shuffle the task list.",
    )

    parser.add_argument(
        "-mr",
        "--max-wait-reset",
        type=int,
        default=120,
        help="Maximum time to wait for workers to finish remaining tasks after a reset is triggered.",
    )

    parser.add_argument(
        "-ck",
        "--copy_api_keys",
        action="store_true",
        help="Make a copy of the api keys file at initialization.",
    )

    return parser.parse_args()


# ===============================================================================
# LINK Logger helpers
# ===============================================================================


# TODO: add more fine-grained timer to also print times ex env reset
class Printer:
    def __init__(self, print_every: int = 30) -> None:
        self.running_time = 0
        self.print_every = print_every

    def start_timer(self) -> None:
        self.start_time = time.time()
        self.last_print_time = self.start_time

    def update_running_time(self) -> None:
        self.running_time = time.time() - self.start_time

    def print_stats(self, tasks_finished: set[int], task_list: list[int] | set[int], config_dir: str) -> None:
        template = "[Orchestrator] [{time}]: {message}"
        if time.time() - self.last_print_time > self.print_every:
            self.last_print_time = time.time()
            self.update_running_time()
            percent_finished = len(tasks_finished) / len(task_list) * 100
            if len(tasks_finished) > 0:
                avg_time_per_task = self.running_time / len(tasks_finished)
            else:
                avg_time_per_task = -1

            msg = f"Finished {percent_finished:.2f}% ({len(tasks_finished)}/{len(task_list)}) tasks for {config_dir}."
            print(template.format(time=datetime.now().strftime("%Y-%d-%m-%H:%M"), message=msg))
            print(
                f"Total running time: {int(self.running_time / 60)} minutes and {int(self.running_time % 60)} seconds."
            )
            if avg_time_per_task != -1:
                print(
                    f"Avg time per task: {int(avg_time_per_task / 60)} minutes and {int(avg_time_per_task % 60)} seconds."
                )

    def print_batch_completion(self, batch_id: int, run_time: float, future: Future) -> None:
        """
        Prints the completion status for a batch.
        """
        error_msg = ""
        try:
            ret_code = future.result()
            if ret_code == 0:
                error_msg = "finished with no errors."
            elif ret_code == 1:
                error_msg = "finished with a timeout."
            else:
                error_msg = f"finished with error code {ret_code}."
        except Exception as e:
            error_msg = f"encountered an exception: {e}"

        print(
            f"[Orchestrator]: Batch {batch_id} finished after {int(run_time / 60)} minutes and {int(run_time % 60)} seconds.\n"
            f"[Orchestrator]: Batch {batch_id} {error_msg}"
        )

    def print_run_info(
        self,
        agent_config: str,
        config_dir: str,
        total_tasks: int,
        total_batches: int,
        tasks_per_process: int,
        num_workers: int,
        max_attempts_per_task: int,
        shuffle_tasks: bool,
        max_running_time: float,
        first_batch_max_run_time: float,
        avg_running_time_per_task: float,
        reset_after: float,
        max_wait_reset: int,
        domains_to_reset: str | list[str],
    ) -> None:
        """
        Prints an initial overview of the parallel run settings.
        """
        print("\n-------------- Starting parallel run --------------")
        print(f"Agent config: [{agent_config}] | test_config_dir: [{config_dir}]")
        print(f"Total tasks: [{total_tasks}] | Total batches: [{total_batches}]")
        print(f"Tasks per process: [{tasks_per_process}] | Num processes: [{num_workers}]")
        print(f"Max attempts per task: [{max_attempts_per_task}] | Shuffle tasks: [{shuffle_tasks}]")
        max_time_str = (
            "NO LIMIT"
            if max_running_time < 0
            else (f"{first_batch_max_run_time} (auto)" if max_running_time == 0 else f"{max_running_time}")
        )
        print(f"Max run time per process: [{max_time_str}] | Avg run time per task: [{avg_running_time_per_task}]")
        print(
            f"Reset after: [{reset_after if reset_after != np.inf else 'NO RESET'}] | Max wait reset: [{max_wait_reset}] | domains_to_reset [{domains_to_reset}]"
        )
        print("-----------------------------------------------------")

    def print_message(self, message: str) -> None:
        print(f"[Orchestrator] {message}")

    def print_dispatched_batch(self, batch_id: int, task_list: list[int]) -> None:
        # List comprehension is to enforce native python ints for cleaner logging. (to prevent `np.int64(2)` instead of `2`)
        print(f"[Orchestrator] Dispatched batch {batch_id}: {[int(task) for task in task_list]}")

    def print_waiting_workers(self, waiting_count: int) -> None:
        print(f"[Orchestrator] No batches to dispatch. Waiting {waiting_count} workers to finish...")

    def print_all_completed(self, config_dir: str, agent_config: str, result_dir: str) -> None:
        print(f"[Orchestrator] All tasks completed for {config_dir}, agent config {agent_config}.")
        print(f"Results saved in {result_dir}")


# ===============================================================================
# LINK Task I/O helpers
# ===============================================================================


def get_range_from_config_dir(config_dir: str) -> tuple[int, int]:
    """
    Looks in config_dir for filenames starting with digits, extracts the min and max.
    Returns (start_id, end_id).
    """
    if not os.path.isdir(config_dir):
        raise FileNotFoundError(f"Config dir not found: {config_dir}")

    numeric_ids = []
    for fname in os.listdir(config_dir):
        match = re.match(r"^(\d+)", fname)
        if match:
            numeric_ids.append(int(match.group(1)))

    if not numeric_ids:
        raise ValueError(f"No numeric files found in {config_dir}")

    numeric_ids.sort()
    return numeric_ids[0], numeric_ids[-1]


def get_task_list(
    tasks_file: str = "",
    config_dir: str = "",
) -> tuple[str, list[int]]:
    if tasks_file:
        # Get test config dir and task ids from tasks file
        test_config_dir = get_config_base_dir_from_txt(txt_path=tasks_file)
        task_ids = sorted(get_ids_from_tst_config_list(txt_path=tasks_file))

        # Check if test config dir is valid
        if not test_config_dir or not os.path.isdir(test_config_dir):
            raise ValueError(f"[ERROR] No config directory found in file: {tasks_file}")

    else:
        # numeric range mode from --config-dir
        if not config_dir:
            raise ValueError("[ERROR] No config directory given.")

        start_id, end_id = get_range_from_config_dir(config_dir)
        task_ids = list(range(start_id, end_id + 1))

    return test_config_dir, task_ids


def get_tasks_success_failed(dir: str, return_failed: bool = False) -> tuple[set[int], set[int]]:
    """
    Get the number of tasks completed in a given directory.
    """
    args_files = find_files(dir, "args.json", upwards=False, downwards=True)
    successful_tasks = set()
    failed_tasks = set()
    for args_file in args_files:
        parent_dir = os.path.dirname(args_file)
        if os.path.exists(os.path.join(parent_dir, "summary_data.csv")):
            successful_tasks.update(get_task_ids_from_csv(os.path.join(parent_dir, "summary_data.csv")))

        if os.path.exists(os.path.join(parent_dir, "failed_tasks.txt")) and return_failed:
            failed_tasks.update(get_ids_from_tst_config_list(txt_path=os.path.join(parent_dir, "failed_tasks.txt")))

    return successful_tasks, failed_tasks


def create_task_txt_file(
    test_config_dir: str,
    task_list: list[int] | set[int],
    out_dir: str = TEMP_FILES_DIR,
    filename: str = "tasks.txt",
    overwrite: bool = False,
) -> str:
    """
    Creates a text file with the task list.
    """
    if isinstance(task_list, set):
        task_list = list(task_list)

    # Create a name if path exists
    file_path = f"{out_dir}/{filename}"

    if not overwrite:
        final_file_path = resolve_path_conflict(file_path, int_suffix=True)
    else:
        final_file_path = file_path

    os.makedirs(out_dir, exist_ok=True)
    with open(final_file_path, "w") as f:
        f.write(test_config_dir + "\n")
        for task_id in task_list:
            f.write(f"{task_id}\n")
    return str(final_file_path)


def write_unfinished_tasks(
    original_task_list: list[int] | set[int],
    test_config_dir: str,
    out_dir: str,
    tasks_finished: set[int] | list[int] | None = None,
    result_dir: str = "",
    filename: str = "unfinished_tasks.txt",
    overwrite: bool = True,
) -> str:
    if tasks_finished is None and result_dir:
        tasks_finished, _ = get_tasks_success_failed(result_dir, return_failed=False)

    if tasks_finished is None:
        raise ValueError("No tasks finished or not able to retrieve from 'result_dir'.")

    if isinstance(tasks_finished, list):
        tasks_finished = set(tasks_finished)

    unfinished_tasks = set(original_task_list) - tasks_finished

    final_file_path = create_task_txt_file(
        test_config_dir=test_config_dir,
        task_list=unfinished_tasks,
        out_dir=out_dir,
        filename=filename,
        overwrite=overwrite,
    )
    return final_file_path


# ===============================================================================
# LINK Tmux helpers
# ===============================================================================


def create_tmux_session() -> str:
    """
    Create a new tmux session in detached mode and return its session id.
    """
    try:
        result = subprocess.run(
            ["tmux", "new-session", "-d", "-P", "-F", "#{session_id}"],
            capture_output=True,
            text=True,
            check=True,
        )
        session_id = result.stdout.strip()
        print(f"Created tmux session with id: {session_id}")
        return session_id
    except subprocess.CalledProcessError as e:
        print(f"Failed to create tmux session: {e.stderr}")
        raise e


def is_pane_running(pane_id: str, tmux_session_id: str) -> tuple[bool, list[str]]:
    try:
        pane_list = subprocess.run(
            ["tmux", "list-panes", "-t", tmux_session_id, "-a", "-F", "#{pane_id}"],
            capture_output=True,
            text=True,
            check=True,
        ).stdout.splitlines()
        return pane_id in pane_list, pane_list
    except subprocess.CalledProcessError as e:
        raise e


def run_tmux_pane(tmux_session_id: str, cmd: str, new_window: bool = False) -> subprocess.CompletedProcess:
    tmux_cmd = [
        "tmux",
        "split-window" if not new_window else "new-window",
        "-t",
        tmux_session_id,
        "-P",  # Print pane information.
        "-F",
        "#{pane_id}",
        cmd,
    ]
    result = subprocess.run(tmux_cmd, check=False, capture_output=True, text=True)
    return result


def run_tmux_pane_fallback(tmux_session_id: str, cmd: str) -> subprocess.CompletedProcess:
    """
    Tries splitting the tmux session's active window first.
    If it fails with "no space for new pane", it creates a new window instead.
    Returns the final subprocess.CompletedProcess.
    """
    result = run_tmux_pane(tmux_session_id, cmd, new_window=False)
    if result.returncode != 0 and "no space for new pane" in result.stderr.lower():
        result = run_tmux_pane(tmux_session_id, cmd, new_window=True)
    return result


# ===============================================================================
# LINK Environment helpers
# ===============================================================================


def reset_environments(
    reset_env_script: str = RESET_ENV_SCRIPT,
    domains: list[str] | str = "all_vwa",
) -> None:
    domains_str = domains if isinstance(domains, str) else " ".join(domains)

    try:
        cmd = [reset_env_script, domains_str]
        subprocess.run(cmd, check=True)
    except subprocess.CalledProcessError as e:
        print(f"[ERROR] Failed to reset environments: {e.stderr}")


def reset_cookies(
    sites: list[str] | str = [],
    exc_comb: bool = EXC_COOKIE_SITE_COMB,
    max_wait: float = 60 * 1.5,
    auth_folder: str = "./.auth",
) -> None:
    """
    Reset cookies by calling the auto_login module.
    This function forces mutual exclusion by generating a lock based on the auto_login script file.
    """
    if isinstance(sites, str):
        sites = [sites]

    print("[INFO] Starting cookies reset.")

    # Identifier to create a lock on
    auto_login_script = "browser_env.auto_login.py"

    # Command to run
    cmd = ["python", "-m", "browser_env.auto_login", "--auth_folder", auth_folder]
    cmd.extend(["--site_list", *sites]) if sites else None
    cmd.append("--exc_comb") if exc_comb else None

    try:
        with single_instance_lock(identifier=auto_login_script, retry_interval=0.5, max_wait=max_wait):
            print(f"[INFO] Lock acquired based on {auto_login_script}. Running auto_login exclusively.")
            try:
                subprocess.run(cmd, check=True)
            except subprocess.CalledProcessError as e:
                print(f"[ERROR] Failed to reset cookies: {e.stderr}")
    except TimeoutError as _:
        print(
            f"[ERROR] Timeout reached ({max_wait} sec) while waiting for lock on {auto_login_script}. Skipping cookies reset."
        )
        return


def check_cookies_expired(sites: list[str] | str = []):
    if isinstance(sites, str):
        sites = [sites]

    expired_sites = is_expired_for_sites(sites)
    return any(len(cookies) > 0 for cookies in expired_sites.values())


# ===============================================================================
# Workers
# ===============================================================================


def run_sh_script_tmux(
    task_id: int,
    test_config_dir: str,
    agent_config: str,
    tmux_session_id: str,
    stop_dict: dict[int, bool],
    task_list: list[int] = [],
    start_id: int | None = None,
    end_id: int | None = None,
    result_dir: str = "",
    captioner_device: str = "server-cuda",
    wait_before_check: int = 30,
    max_retry_ls_tmux: int = 3,
    wait_until_dead: int = 30,
) -> int:
    cmd = ["./scripts/runs/run.sh"]

    if task_list:
        # Create a temporary text file with the task list.
        task_list_file = create_task_txt_file(test_config_dir, task_list, filename=f"tasks_{task_id}.txt")
        cmd += ["-t", task_list_file]
    elif start_id is not None and end_id is not None:
        cmd += ["-s", str(start_id), "-e", str(end_id)]

    cmd += ["-c", test_config_dir]
    cmd += ["-a", agent_config]

    if result_dir:
        cmd += ["-d", result_dir]

    if captioner_device:
        cmd += ["-m", captioner_device]

    # Skip cookies in run.sh.
    cmd.append("-r")

    # If a tmux session id is provided, run the command in a new tmux pane.
    # Build the command to be executed in the tmux pane.
    inner_cmd = f" ".join(cmd) + "; exit"
    # If required, use this to activate the env first
    # inner_cmd = f"conda activate {conda_env}; " + inner_cmd

    # Create a new tmux pane and capture its pane id.
    # print(f"[Worker {worker_id}] TMUX CMD: {' '.join(cmd)}")
    result = run_tmux_pane_fallback(tmux_session_id, inner_cmd)

    if result.returncode != 0:
        print(f"[Worker {task_id}] TMUX ERROR {result.returncode}:\n{result.stderr}")
        return result.returncode

    # Get the pane id from stdout.
    pane_id = result.stdout.strip()
    print(f"[Worker {task_id}] Launched tmux pane with id: {pane_id}")

    # Block until the pane no longer exists
    retry_count = 0
    while True:
        # LOOP: block until task is finished or stop signal received.
        try:
            # If stop signal received, stop early.
            if stop_dict.get(task_id, False):
                print(f"[Worker {task_id}]: Stop signal received. Attempting graceful stop...")

                # Try 'ctrl + c' to stop `run.py` gracefully.
                ret_ctrl_c = subprocess.run(["tmux", "send-keys", "-t", pane_id, "C-c"], check=False)
                kill_pane = True

                # Wait for the pane to die.
                if ret_ctrl_c.returncode == 0:
                    start_time = time.time()
                    while time.time() - start_time < wait_until_dead:
                        if not is_pane_running(pane_id, tmux_session_id)[0]:
                            kill_pane = False
                            break
                        time.sleep(0.1)
                # If the pane is still running, forcefully kill it.
                if kill_pane:
                    # Forcefully kill the tmux pane (obs.: panes are unique within the same server, regardless of the session/window).
                    ret = subprocess.run(["tmux", "kill-pane", "-t", pane_id], check=False)
                    if ret.returncode != 0:
                        print(
                            f"[Worker {task_id}] Stopping early, but failed to kill tmux pane {pane_id}: {ret.stderr}"
                        )
                # Remove itself from the stop dict.
                del stop_dict[task_id]
                return 1

            # If pane is not running -> batch finished -> return 0.
            if not is_pane_running(pane_id, tmux_session_id)[0]:
                print(f"[Worker {task_id}] FINISHED: TMUX pane {pane_id} closed.")
                return 0
            time.sleep(wait_before_check)

        except subprocess.CalledProcessError as e:
            if retry_count < max_retry_ls_tmux:
                # If error listing panes, retry up to `max_retry_ls_tmux` times.
                print(f"[Worker {task_id}] Error listing panes: {e.stderr}. Retrying...")
                retry_count += 1
                time.sleep(5)
            else:
                # If error listing panes more than `max_retry_ls_tmux` times, stop with return code -1.
                print(f"[Worker {task_id}] Failed to check if tmux pane is running. Stopping.")
                return -1


# ===============================================================================
# Dispatcher
# ===============================================================================


class TaskBatch:
    def __init__(
        self,
        task_list: list[int],
        hard_reset: bool = False,
        start_run_time: float = -1,
        batch_id: int = -1,
        max_run_time: float = 0,
        avg_run_time_per_task: float | None = None,
    ):
        self.task_list = task_list  # List of task ids.
        self.hard_reset = hard_reset  # Marks if batch processing suffered hard reset. Useful to skip updating attempts for tasks in the batch.
        self.start_run_time = start_run_time  # Time when the batch started processing.
        self.batch_id = batch_id  # ID of the batch.
        self.max_run_time = self._set_max_run_time(max_run_time, avg_run_time_per_task)

    def _set_max_run_time(self, max_run_time: float, avg_run_time_per_task: float | None) -> float:
        if max_run_time < 0:
            return np.inf
        elif max_run_time == 0:
            if avg_run_time_per_task is None:
                raise ValueError("avg_run_time_per_task must be provided if max_run_time is 0.")
            else:
                return avg_run_time_per_task * len(self.task_list)
        else:
            return max_run_time


def create_task_batches(
    task_list: list[int] | set[int],
    tasks_per_worker: int,
    max_run_time: float = 0,
    avg_run_time_per_task: float | None = None,
) -> list[TaskBatch]:
    if len(task_list) == 0:
        return []

    if isinstance(task_list, set):
        task_list = list(task_list)

    if max_run_time == 0 and avg_run_time_per_task is None:
        raise ValueError("avg_run_time_per_task must be provided if `max_run_time` is 0.")

    # Determine the number of batches required
    num_batches = (len(task_list) + tasks_per_worker - 1) // tasks_per_worker

    # Evenly split the tasks into num_batches arrays
    batches = np.array_split(task_list, num_batches)

    # Convert numpy arrays back to lists
    return [
        TaskBatch(list(batch), max_run_time=max_run_time, avg_run_time_per_task=avg_run_time_per_task)
        for batch in batches
    ]


def shuffle_task_list(task_list: list[int] | set[int]) -> list[int]:
    if isinstance(task_list, set):
        task_list = list(task_list)
    random.shuffle(task_list)
    return task_list


def recompute_task_batches(
    waiting_batches: list[TaskBatch],
    submitted_batches: list[TaskBatch],
    attempts_per_task: dict[int, int],
    tasks_finished: set[int],
    tasks_failed: set[int],
    max_attempts: int = 2,
    tasks_per_worker: int = 10,
    max_run_time: float = 0,
    avg_run_time_per_task: float | None = None,
) -> list[TaskBatch]:
    """
    Recomputes the task batches for re-assignment from two sources:
    - waiting_batches: tasks that haven't yet been dispatched.
    - submitted_batches: tasks that were already dispatched but may have failed.

    The function updates the attempt counts only for tasks in submitted_batches
    (i.e. tasks already in process), and it only includes tasks that are not finished.

    Returns:
        New task batches created using the union of waiting tasks (filtered) and
        unfinished tasks from submitted batches.
    """
    # Get tasks from the waiting batches that haven't been finished yet (obs.: should give the same set, but just in case)
    waiting_tasks = set(task for batch in waiting_batches for task in batch.task_list if task not in tasks_finished)

    # For tasks that failed: (i) update attempts, (ii) add to unfinished tasks if attempts < max_attempts.
    unfinished_tasks = []
    for batch in submitted_batches:
        for task in batch.task_list:
            if task in tasks_finished:
                continue

            # If there is a failed task set, and the task is in it, update attempts.
            if task in tasks_failed:
                attempts_per_task[task] += 1

            # If not identified as failed, update attempts only if the batch did not suffer a hard reset.
            elif not batch.hard_reset:
                attempts_per_task[task] += 1

            # If the task has reached the maximum number of attempts, skip it.
            if attempts_per_task[task] < max_attempts:
                unfinished_tasks.append(task)
            else:
                print(f"[Info] Task {task} reached maximum attempts, skipping.")

    # Combine the waiting tasks with the unfinished tasks from submitted batches.
    new_task_set = waiting_tasks.union(unfinished_tasks)
    return create_task_batches(new_task_set, tasks_per_worker, max_run_time, avg_run_time_per_task)


# LINK Reset environments and cookies helpers
def wait_until_all_workers_dead(
    futures: list[Future],
    max_wait_time: int = 120,
) -> bool:
    start_time = time.time()
    while time.time() - start_time < max_wait_time:
        if all(future.done() for future in futures):
            return True
        time.sleep(0.5)
    return False


def send_stop_signals(
    future_to_batch: dict[Future, TaskBatch],
    stop_dict: dict[int, bool],
    max_wait_dead_workers: int = 120,
) -> None:
    cur_time = time.time()
    killed_workers = []
    for future in future_to_batch.keys():
        batch = future_to_batch[future]
        if cur_time - batch.start_run_time > batch.max_run_time:
            print(
                f"[Orchestrator] Stopping batch {batch.batch_id}. Running for more than {batch.max_run_time} seconds."
            )
            stop_dict[batch.batch_id] = True
            killed_workers.append(future)
    # Wait until workers are dead
    if killed_workers:
        wait_until_all_workers_dead(killed_workers, max_wait_dead_workers)


def wait_reset_restart(
    future_to_batch: dict[Future, TaskBatch],
    tmux_session_id: str,
    max_wait_reset: float,
    domains_to_reset: list[str] | str,
    stop_dict: dict[int, bool],
    max_wait_dead_workers: int = 120,
    reset_env_flag: bool = False,
    reset_cookies_flag: bool = False,
    cookies_expired: bool = False,
) -> str:
    # (i) wait until all workers time out,
    # (ii) kill tmux session,
    # (iii) reset environments and cookies
    # (iv) create a new tmux session.

    print(f"Environment reset triggered. Waiting for {max_wait_reset} seconds...")

    if cookies_expired:
        # If cookie is expired, immediately reset
        max_wait_reset = 0.5
    else:
        start_time = time.time()
        while time.time() - start_time < max_wait_reset:
            time.sleep(0.5)

    # Send stop signal to all workers and wait until they are dead.
    for batch in future_to_batch.values():
        stop_dict[batch.batch_id] = True
    workers_dead = wait_until_all_workers_dead(list(future_to_batch.keys()), max_wait_dead_workers)

    # Kill the tmux session. This will forcefully stop all workers; if they are not dead yet, wait again.
    subprocess.run(["tmux", "kill-session", "-t", tmux_session_id], check=True)
    if not workers_dead:
        workers_dead = wait_until_all_workers_dead(list(future_to_batch.keys()), max_wait_dead_workers)

    if not workers_dead:
        raise RuntimeError("Failed to kill workers during environment reset.")

    # Reset environment if flag is true
    if reset_env_flag:
        print(f"Resetting environments...")
        reset_environments(domains=domains_to_reset)

    # Reset when expired cookies OR if the env was reset (obligatory)
    if reset_cookies_flag or reset_env_flag:
        print(f"Resetting cookies...")
        reset_cookies(sites=domains_to_reset, exc_comb=EXC_COOKIE_SITE_COMB)

    # Mark all batches as suffering a hard reset.
    for batch in future_to_batch.values():
        batch.hard_reset = True

    # Start new tmux session
    print(f"Creating new tmux session...")
    try:
        tmux_session_id = create_tmux_session()
    except Exception as e:
        print(f"[ERROR] Failed to create tmux session on environment reset: {e}")
        raise e

    return tmux_session_id


def restore_api_keys_file(
    src_file: str = API_KEYS_REPO,
    dest_file: str = API_KEYS_PATH,
    min_keys: dict[str, int] = {"google": 1, "openai": 0},
    logger: logging.Logger | None = None,
) -> None:
    if not os.path.exists(src_file):
        return

    all_api_keys = json.load(open(src_file))
    # safe read the API_KEYS_PATH
    try:
        lock = get_file_lock(API_KEYS_PATH)
        with lock, open(dest_file, "r+") as f:
            api_keys = json.load(f)

            for provider in api_keys.keys():
                if len(api_keys[provider]) <= min_keys[provider]:
                    print(
                        "\n--------------WARNING--------------:\n"
                        f"{dest_file} has <= {min_keys[provider]} keys for {provider}. Restoring from {src_file}.\n"
                        "------------------------------------"
                    )
                    if provider in all_api_keys:
                        api_keys[provider] = all_api_keys[provider]
            f.seek(0)
            json.dump(api_keys, f, indent=2)
            f.truncate()
    except Exception as e:
        print(f"[ERROR] Failed to restore API keys file: {e}")


def preprocess_args(
    num_workers: int,
    task_list: list[int] | set[int],
    tasks_per_worker: int,
    max_running_time: float,
    avg_running_time_per_task: float,
    reset_after: int,
) -> tuple[list[int] | set[int], int, float, int]:
    # Remove any duplicates in the task list.
    if isinstance(task_list, list):
        task_list = list(set(task_list))

    if tasks_per_worker == 0:
        tasks_per_worker = max(1, len(task_list) // num_workers)

    # Convert to seconds
    max_running_time = max_running_time * 60 if max_running_time > 0 else max_running_time

    if reset_after < 0:
        reset_after = np.inf  # type: ignore

    return task_list, tasks_per_worker, max_running_time, reset_after


# LINK TaskDispatcher
def run_tasks_dispatcher(
    task_list: list[int] | set[int],
    num_workers: int,
    config_dir: str,
    agent_config: str,
    result_dir: str,
    tasks_per_worker: int = 0,
    captioner_device: str = "server-cuda",
    reset_after: int = 0,
    domains_to_reset: list[str] | str = "",
    avg_running_time_per_task: float = 3 * 60,  # Average X minutes per task
    max_attempts_per_task: int = 2,
    shuffle_tasks: bool = True,
    max_running_time: float = 0,
    max_wait_reset: int = 2 * 60,  # Wait X minutes for workers to time out.
    print_every: int = 60,
    check_cookies_every: float = 2 * 60,  # Check cookies every X minutes
) -> None:
    """
    Runs each integer task in 'task_list' with up to 'num_workers' processes,
    scheduling a new batch as soon as a worker finishes one.

    Args:
        task_list (list[int] | set[int]): List of task ids to run.
        num_workers (int): Number of concurent processes allowed at each time.
        config_dir (str): path to the test configuration directory.
        agent_config (str): path to the agent configuration file.
        result_dir (str): path to the results directory.
        captioner_device (str, optional): device to run the captioner on. Defaults to "server-cuda".

        tasks_per_worker (int, optional): number of tasks to assign to each process.
            If 0, divide `task_list` evenly among `num_workers`.
            Defaults to 0.

        reset_after (int, optional): Environments will reset after this many tasks.
            If -1, no reset.
            If 0, estimate based on how many `tasks_per_worker` and number of workers.
            If >0, reset after this many tasks.

        domains_to_reset (list[str], optional): List of domains to reset. Defaults to [].

        max_running_time (float, optional): Max minutes for each process to finish `tasks_per_worker` tasks.
            If -1, no limit.
            If 0, estimate as: `num of tasks in the batch` * `avg_running_time_per_task`.
            If >0, maximum running time in minutes.

        avg_running_time_per_task (float, optional): Average running time per task. Defaults to 180.

        max_wait_reset_restart (int, optional): Maximum time to wait for workers to finish remaining tasks after a reset is triggered. Defaults to 120.
    """
    global CURRENT_TMUX_SESSION

    # Preprocess arguments
    task_list, tasks_per_worker, max_running_time, reset_after = preprocess_args(
        num_workers, task_list, tasks_per_worker, max_running_time, avg_running_time_per_task, reset_after
    )

    # Shuffle the task list
    task_set = set(task_list)
    task_list = shuffle_task_list(task_list) if shuffle_tasks else task_list

    # Write task list to file
    create_task_txt_file(config_dir, sorted(task_list), out_dir=result_dir, filename="tasks.txt", overwrite=True)

    # Create batches of tasks
    task_batches = create_task_batches(
        task_list,
        tasks_per_worker,
        max_run_time=max_running_time,
        avg_run_time_per_task=avg_running_time_per_task,
    )
    # Double check that there are tasks to run.
    if len(task_batches) == 0:
        raise ValueError("No tasks to run.")

    # Additional normalization if reset_after is 0.
    if reset_after == 0:
        n = min(num_workers, len(task_batches))
        reset_after = max(tasks_per_worker, int((n - 1) * tasks_per_worker))

    # Helper to print stats and run info.
    printer = Printer(print_every=print_every)

    # Print run configuration info.
    printer.print_run_info(
        agent_config=agent_config,
        config_dir=config_dir,
        total_tasks=len(task_list),
        total_batches=len(task_batches),
        tasks_per_process=tasks_per_worker,
        num_workers=num_workers,
        max_attempts_per_task=max_attempts_per_task,
        shuffle_tasks=shuffle_tasks,
        max_running_time=max_running_time,
        first_batch_max_run_time=task_batches[0].max_run_time,
        avg_running_time_per_task=avg_running_time_per_task,
        reset_after=reset_after,
        max_wait_reset=max_wait_reset,
        domains_to_reset=domains_to_reset,
    )

    # Create a tmux session that will be used for running tasks.
    tmux_session_id = create_tmux_session()
    CURRENT_TMUX_SESSION = tmux_session_id

    # Auxiliary variables
    batch: TaskBatch
    future_to_batch: dict[Future, TaskBatch] = {}  # Maps futures to their associated batch
    tasks_finished_prev: int = 0  # Number of tasks finished in the previous reset
    reset_flag: bool = False  # Flag to indicate if the environments should be reset
    reset_cookies_flag: bool = False  # Flag to indicate if the cookies should be reset
    cookies_expired: bool = False  # Flag to indicate if the cookies are expired
    recent_reset: bool  # Flag to avoid consecutive resets (can happen in edge cases).
    batch_id_counter: int = 0  # Keeps track of IDs to assign unique IDs to batches.
    last_cookies_check: float = time.time()  # Time of last cookies check
    # Maps tasks to num retries the task has suffered.
    attempts_per_task: dict[int, int] = {task_id: 0 for task_id in task_list}

    # Thread-safe dictionary to send stop signals.
    manager = Manager()
    stop_dict = manager.dict()

    # Check if any sites in domains_to_reset are expired
    if check_cookies_expired(domains_to_reset):
        print(f"[INFO] SITES EXPIRED. Cookies will be renovated for {domains_to_reset}")
        reset_cookies(sites=domains_to_reset, exc_comb=EXC_COOKIE_SITE_COMB)
        last_cookies_check = time.time()

    # Start parallel run
    with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
        # Submit initial tasks.
        for _ in range(min(num_workers, len(task_batches))):
            # Get task batch and assign ID.
            batch = task_batches.pop()
            batch_id_counter += 1
            batch.batch_id = batch_id_counter

            # Submit task batch.
            future = executor.submit(
                run_sh_script_tmux,
                task_id=batch_id_counter,  # Use batch ID for debug and stop signals.
                test_config_dir=config_dir,
                stop_dict=stop_dict,
                agent_config=agent_config,
                tmux_session_id=tmux_session_id,  # Now running in tmux mode.
                task_list=batch.task_list,
                result_dir=result_dir,
                captioner_device=captioner_device,
            )
            batch.start_run_time = time.time()
            future_to_batch[future] = batch
            printer.print_dispatched_batch(batch_id_counter, batch.task_list)

        # ----------------------------------------------------------------------
        # LOOP: Wait, process, dispatch jobs
        # ----------------------------------------------------------------------
        printer.start_timer()
        while future_to_batch:
            # Send stop signals if workers have been running for too long.
            send_stop_signals(future_to_batch, stop_dict)

            # Get finished workers.
            if reset_flag or reset_cookies_flag:
                if future_to_batch:
                    tmux_session_id = wait_reset_restart(
                        future_to_batch=future_to_batch,
                        tmux_session_id=tmux_session_id,
                        max_wait_reset=max_wait_reset,
                        domains_to_reset=domains_to_reset,
                        stop_dict=stop_dict,
                        reset_env_flag=reset_flag,
                        reset_cookies_flag=reset_cookies_flag,
                        cookies_expired=cookies_expired,
                    )
                    reset_flag, reset_cookies_flag, recent_reset = False, False, True
                    cookies_expired = False
                    CURRENT_TMUX_SESSION = tmux_session_id
                    done = set(future_to_batch.keys())
            else:
                done, _ = concurrent.futures.wait(
                    future_to_batch, timeout=0.5, return_when=concurrent.futures.FIRST_COMPLETED
                )
                recent_reset = False

            # -------------------------------------------------------------------
            # Orchestrator do work.
            # ------------------------------------------------------------------
            tasks_finished, _ = get_tasks_success_failed(result_dir, return_failed=False)
            printer.print_stats(tasks_finished, task_list, config_dir)

            # Check if it is time to trigger a reset.
            if len(tasks_finished) - tasks_finished_prev >= reset_after and not recent_reset:
                # Set flag for reset in next iteration
                reset_flag = True
                tasks_finished_prev = len(tasks_finished)

            # Check if needs to renovate cookies
            if time.time() - last_cookies_check >= check_cookies_every:
                printer.print_message(f"Checking cookies for {domains_to_reset}")
                if cookies_expired := check_cookies_expired(domains_to_reset):
                    printer.print_message(f"SITES EXPIRED. Cookies will be renovated for {domains_to_reset}")
                    reset_cookies_flag = True
                else:
                    printer.print_message(f"Cookies still valid for {domains_to_reset}")
                last_cookies_check = time.time()

            # Save unfinished tasks for future reference
            write_unfinished_tasks(
                original_task_list=task_set,
                tasks_finished=tasks_finished,
                test_config_dir=config_dir,
                out_dir=result_dir,
            )

            # Restore API keys file if it gets close to empty.
            # This is a workaround for cases where workers finish but aren't able to write back to the .json file the API keys they were using.
            # Obs.: not the best solution, but it works for now and scenario is pretty rare.
            restore_api_keys_file(min_keys={"google": 1, "openai": 0})

            # If no worker has finished, loop.
            if not done:
                continue

            # Process finished workers.
            for future in done:
                # Get the worker's task batch.
                completed_batch = future_to_batch.pop(future)
                run_time = time.time() - completed_batch.start_run_time

                # Print finish reason for debugging purposes.
                printer.print_batch_completion(completed_batch.batch_id, run_time, future)

                # Update task_batches removing finished tasks so far and adding failed tasks.
                finished_tasks, failed_tasks = get_tasks_success_failed(result_dir, return_failed=True)

                task_batches = recompute_task_batches(
                    waiting_batches=task_batches,
                    submitted_batches=[completed_batch],
                    tasks_finished=finished_tasks,
                    tasks_failed=failed_tasks,
                    attempts_per_task=attempts_per_task,
                    tasks_per_worker=tasks_per_worker,
                    max_attempts=max_attempts_per_task,
                    max_run_time=max_running_time,
                    avg_run_time_per_task=avg_running_time_per_task,
                )

                if task_batches:
                    # Submit new batch if available.
                    # Get task batch and assign ID.
                    next_batch = task_batches.pop()
                    batch_id_counter += 1
                    next_batch.batch_id = batch_id_counter

                    # Submit new task batch.
                    future_new = executor.submit(
                        run_sh_script_tmux,
                        task_id=batch_id_counter,
                        test_config_dir=config_dir,
                        stop_dict=stop_dict,
                        agent_config=agent_config,
                        tmux_session_id=tmux_session_id,
                        task_list=next_batch.task_list,
                        result_dir=result_dir,
                        captioner_device=captioner_device,
                    )
                    future_to_batch[future_new] = next_batch
                    next_batch.start_run_time = time.time()
                    printer.print_dispatched_batch(batch_id_counter, next_batch.task_list)
                else:
                    # If no new batch is available, print message and wait for workers to finish.
                    # Obs.: a new batch can be available in future iterations if tasks fail and are retried.
                    printer.print_waiting_workers(len(future_to_batch))

    # Orchestration done.
    printer.print_all_completed(config_dir, agent_config, result_dir)


def cleanup(tmux_session_id: str | None, temp_files_dir: str = TEMP_FILES_DIR) -> None:
    if tmux_session_id:
        subprocess.run(["tmux", "kill-session", "-t", tmux_session_id], check=True)
    subprocess.run(["rm", "-rf", temp_files_dir], check=True)


def main() -> None:
    args = parse_args()

    # Override for debugging purposes
    if sys.gettrace():
        # Basic params
        args.shuffle_tasks = True
        args.agent_config = "agent_config_base.yaml"
        args.seed = 42
        args.captioner_device = "server-cuda:0"
        args.skip_reset = True
        args.skip_cookies = True
        args.max_attempts_per_task = 2
        args.tasks_file = "evaluation_harness/task_subsets/reddit.txt"

        # # Parallelization params
        args.num_processes = 4
        args.tasks_per_process = 5
        args.max_running_time = 0
        args.avg_running_time_per_task = 180

        # # Environment reset params
        # args.reset_after = 10
        # args.max_wait_reset = 120
        args.domains_to_reset = ""

    # Get config dir and task list
    test_config_dir, task_ids = get_task_list(tasks_file=args.tasks_file, config_dir=args.config_dir)

    if not args.agent_config or not os.path.isfile(f"{AGENTS_CONFIG_DIR}/{args.agent_config}"):
        raise ValueError(f"Agent config {args.agent_config} not found in {AGENTS_CONFIG_DIR}.")

    domain = Path(test_config_dir).name.replace("test_", "")

    # Force domain recognition for reset, cookie renovation, other purposes.
    if not domain:
        raise ValueError(f"Could not infer domain from {test_config_dir}.")

    args.domains_to_reset = args.domains_to_reset or domain

    # Copy any existing api_keys.json to api_keys_copy.json
    if os.path.isfile(API_KEYS_REPO) and args.copy_api_keys:
        # Copy to api_keys_copy.json
        shutil.copy(API_KEYS_REPO, API_KEYS_PATH)

    # Skip initial reset of all envs if requested
    if not args.skip_reset:
        print("[INFO] Resetting environments...")
        # TODO fix the logic here
        if "all" in args.domains_to_reset_on_init:
            domains = "all_vwa"
        else:
            domains = domain
        reset_environments(domains=domains)

    # Skip cookies creation if requested
    if not args.skip_cookies:
        print("[INFO] Resetting cookies...")
        if "all" in args.domains_to_reset_on_init:
            reset_cookies(exc_comb=False)
        else:
            reset_cookies(sites=[domain], exc_comb=EXC_COOKIE_SITE_COMB)

    # Host captioner on tmux session, if not running
    print("[INFO] Checking / starting captioner...")
    start_captioner(
        model_name="Salesforce/blip2-flan-t5-xl",
        model_device=args.captioner_device,
        tmux_session_name="vwa_captioner",
    )

    # Build result dir
    results_dir = DEFAULT_RESULTS_DIR
    model = get_agent_attribute(f"{AGENTS_CONFIG_DIR}/{args.agent_config}", "executor_agent:lm_config:model")
    if model:
        results_dir = f"{results_dir}/{model}"
    date_ann = datetime.now().strftime("%Y-%m-%d-%H%M")
    results_dir = f"{results_dir}/p_run-{domain}-{date_ann}"
    # Example: results/gpt-4o-mini-2024-07-18/p_run-reddit-2025-02-24-1430
    os.makedirs(results_dir, exist_ok=True)
    with open(f"{results_dir}/prun_params.json", "w") as f:
        json.dump(args.__dict__, f, indent=2)

    print(f"[INFO] Running tasks from {test_config_dir} with {len(task_ids)} tasks.")

    print(f"\n[INFO] Results will be saved in {results_dir}")

    # Set seed
    set_seed(args.seed)

    # Run tasks dispatcher
    try:
        run_tasks_dispatcher(
            task_list=task_ids,
            num_workers=args.num_processes,
            config_dir=test_config_dir,
            agent_config=args.agent_config,
            result_dir=results_dir,
            captioner_device=args.captioner_device,
            tasks_per_worker=args.tasks_per_process,
            reset_after=args.reset_after,
            domains_to_reset=args.domains_to_reset,
            max_running_time=args.max_running_time,
            avg_running_time_per_task=args.avg_running_time_per_task,
            max_attempts_per_task=args.max_attempts_per_task,
            shuffle_tasks=args.shuffle_tasks,
            max_wait_reset=args.max_wait_reset,
        )

    except Exception as e:
        print(f"An error occurred during parallel run: {e}")
    finally:
        cleanup(CURRENT_TMUX_SESSION, temp_files_dir=TEMP_FILES_DIR)
        # Save unfinished tasks for future reference
        write_unfinished_tasks(
            original_task_list=task_ids,
            test_config_dir=test_config_dir,
            out_dir=results_dir,
            result_dir=results_dir,
        )


if __name__ == "__main__":
    main()
