# autopep8: off
# fmt: off
import datetime
import os
import shutil
import subprocess
import threading
import time
from typing import Any

from llms.constants import API_KEYS_PATH, API_KEYS_REPO
from utils.debug_utils import set_env_variables
from utils.signal_utils import signal_manager
from vwa_utils.captioner_utils import is_captioner_running

# Usage: python -u -m scripts.runs.p_run_many.py
# Run detached in the background: 
# nohup python -u -m offline_experiments.offline_verify > verify.log 2>&1 & disown
#===============================================
# Configs
#===============================================

#-----------------------------------------------
# Params for this script
#-----------------------------------------------
#TODO: when running final exper
# - Add back the reset for REDDIT
# - Add function to provide images when answer is URL

# Dict of agent config and the tasks to run on
task_subsets = [
    "evaluation_harness/task_subsets/shopping.txt",
    "evaluation_harness/task_subsets/reddit.txt",
    "evaluation_harness/task_subsets/classifieds.txt",
    # "results/gemini-2.5-flash-preview-04-17/p_run-reddit-2025-05-13-0631/unfinished_tasks.txt",
    # "results/gemini-2.5-flash-preview-04-17/p_run-shopping-2025-05-13-0631/unfinished_tasks.txt"
]
run_configs = {
    # "agent_config_t_100_2p_tri_webk.yaml": task_subsets,
    # "agent_config_t_100_2p_tri_ctrl_f.yaml": task_subsets,
    # "agent_config_t_100_2p_tri.yaml": task_subsets,
    "agent_config_t_100_2p_tri_expert_best.yaml": task_subsets,
    # "agent_config_t_100_2p_tri_nocot_expert.yaml": task_subsets,
}
reset_envs = True
reset_cookies = True
replace_keys = False
cookies_dir = "./.auth"

non_blocking = True  # If false, next launch of parallel_run.py will wait for the previous one to finish.


#-----------------------------------------------
# Params for parallel_run.py instances
#-----------------------------------------------
num_processes = 1                       # Num of processes executing `run.py` at the same time.
tasks_per_process = 5                 # Max num of tasks in each `run.py` execution (i.e., size of a 'batch' of tasks)
max_running_time = -1                   # Max runtime (mins) of each run.py execution: -1 => no limit; 0 => dynamically set based on `avg_run_time_per_task` * `# tasks in the batch`
avg_run_time_per_task = 30*60            # Avg runtime (secs) of each task. Used to compute the max running time if `max-running-time` is set to 0.

# Environment reset params
reset_after = -1                        # Reset environment after N tasks. -1 => no reset, 0 => estimate based on `tasks_per_worker`, `num workers` and `num_processes`
max_wait_reset = 5                      # When reset is triggered, wait for `max_wait_reset` secs before reset envs. Use this to ~wait a process to complete a running task before reset.
domains_to_reset = ""                   # Domains to reset. If empty, infer from `test_config_dir`

max_attempts_per_task = 3              # If a task execution fails, it will be retried up to `max_attempts_per_task` times.
                                          # Obs.: tasks stopped early due to env resets aren't included in the count.

captioner_device = "server-cuda:0"     # Device to hold the captioner

skip_initial_reset = True             # False => full start/reset of envs before parallel execution. 
                                            # Obs.: this is not the reset *during* the task dispatches.
skip_initial_cookies = True           # False => full cookies creation before parallel execution.
                                            # Obs.: this is not the cookies creation *during* the task dispatches.

domain_reset_on_init = "domain"            

shuffle_tasks = True                   # Shuffle the task list before parallel execution
seed = 42                              # Random seed
                                            

copy_api_keys = False                  # If true, copy api_keys.json to api_keys_copy.json; overwrite if exists

#===============================================
# Helpers
#===============================================
def build_command(agent_config: str, task_list: str) -> list[str]:
    """
    Build the command list to be executed for the given agent configuration
    and task list.
    """
    cmd = [
        "python",
        "-u",                # Unbuffered mode for parallel_run.py
        "-m", "scripts.runs.parallel_run",
        "-a", agent_config,
        "-t", task_list,
        "-n", str(num_processes),
        "-b", str(tasks_per_process),
        "-mrt", str(max_running_time),
        "-art", str(avg_run_time_per_task),
        "-ra", str(reset_after),
        "-d", domains_to_reset,
        "-mr", str(max_wait_reset),
        "-ma", str(max_attempts_per_task),
        "-cd", captioner_device,
        "-seed", str(seed),
    ]
    if shuffle_tasks:
        cmd.append("-st")
    if skip_initial_cookies:
        cmd.append("-sc")
    if skip_initial_reset:
        cmd.append("-sr")
    if copy_api_keys:
        cmd.append("-ck")
    if domain_reset_on_init:
        cmd.extend(["-di", domain_reset_on_init])
    return cmd

def get_header(agent_config: str, task_list: str) -> str:
    """
    Generate a header string for the per-process log file.
    """
    header = (
        f"\n{'-'*80}\n"
        f"Running agent config: {agent_config}\n"
        f"Task list: {task_list}\n"
        f"{'-'*80}\n\n"
    )
    return header

def stream_output(proc: subprocess.Popen[Any], log_file: str) -> None:
    """Stream subprocess output, writing to both global logger and process-specific log."""
    with open(log_file, "a") as lf:
        for line in proc.stdout:  # type: ignore
            print(line.strip())
            lf.write(line)
            lf.flush()

# Cleanup function and signal handler
def cleanup_processes() -> None:
    """
    Cleanup function to kill tmux server and any lingering 
    run.py/parallel_run.py processes.
    """
    # print("Killing tmux server...", flush=True)
    # os.system("tmux kill-server")
    # print("Killing run.py and parallel_run.py processes...", flush=True)
    os.system(
        r"ps aux | grep -E 'run\.py|parallel_run\.py' | grep -v grep | awk '{print $2}' | xargs -r kill -9"
    )
    print("Cleanup complete.", flush=True)

signal_manager.add_cleanup_function(cleanup_processes)

#===============================================
# Main
#===============================================
try:    
    os.makedirs("logs", exist_ok=True)

    # Start reset envs
    if reset_envs:
        print("Starting reset envs...", flush=True)
        subprocess.run(["scripts/environments/start_reset_envs.sh", "all_vwa"], check=True)

    # Set env variables
    print("Setting env variables...", flush=True)
    set_env_variables(arg1="local_vwebarena", arg2="localhost")
    print("Env variables set.", flush=True)
    
    # Reset cookies
    if reset_cookies:
        print("Running auto login...", flush=True)
        shutil.rmtree(cookies_dir, ignore_errors=True)
        subprocess.run(["python", "-m", "browser_env.auto_login", "--auth_folder", cookies_dir], check=True)

    # Start captioner
    if not is_captioner_running():
        print("Starting captioner...", flush=True)    
        result = subprocess.run([
            "python", "-m", "vwa_utils.captioner_utils", 
            "--model_name", "Salesforce/blip2-flan-t5-xl", 
            "--model_device", captioner_device, 
            "--port", "9555", 
            "--endpoint", "http://localhost:9555/caption/", 
            "--conda_env", "vwebarena", 
            "--tmux_session_name", "vwa_captioner"], check=True)
        print(result, flush=True)

    # Copy api_keys.json to api_keys_copy.json; overwrite if exists
    if replace_keys:
        if os.path.exists(API_KEYS_REPO):
            shutil.copy(API_KEYS_PATH, API_KEYS_PATH)

    # Launch parallel runs
    print(f"Launching parallel runs in {'non-blocking' if non_blocking else 'blocking'} mode.", flush=True)
    print(f"Configs: {run_configs}", flush=True)
    if non_blocking:
        # Non-blocking mode: launch each process in a separate thread.
        threads = []
        for agent_config, task_lists in run_configs.items():
            for task_list in task_lists:
                unique_log_file = f"logs/prun_many_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
                
                # Write header to the per-process log file.
                with open(unique_log_file, "a") as lf:
                    lf.write(get_header(agent_config, task_list))
                    lf.flush()

                cmd = build_command(agent_config, task_list)
                proc = subprocess.Popen(
                    cmd,
                    stdout=subprocess.PIPE,
                    stderr=subprocess.STDOUT,
                    universal_newlines=True,
                    cwd=os.getcwd(),
                )
                t = threading.Thread(target=stream_output, args=(proc, unique_log_file))
                t.start()
                threads.append(t)
                time.sleep(1)

        # Wait for all non-blocking process threads to finish.
        for t in threads:
            t.join()

        print("All processes finished.", flush=True)
    else:
        # Blocking mode: launch each process sequentially with its own unique log file.
        process_counter = 0
        for agent_config, task_lists in run_configs.items():
            for task_list in task_lists:
                unique_log_file = f"logs/prun_many_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
                with open(unique_log_file, "a") as lf:
                    lf.write(get_header(agent_config, task_list))
                    lf.flush()
                    cmd = build_command(agent_config, task_list)
                    proc = subprocess.Popen(
                        cmd,
                        stdout=subprocess.PIPE,
                        stderr=subprocess.STDOUT,
                        universal_newlines=True,
                        cwd=os.getcwd(),
                    )
                    for line in proc.stdout:  # type: ignore
                        print(line.strip())
                        lf.write(line)
                        lf.flush()
except Exception as e:
    print(f"Error: {e}", flush=True)
    cleanup_processes()
