#!/usr/bin/env python3
"""
Launch AntFall training runs across the dispersion objectives described in the
estimation chapter. Each run is configured with tuned intrinsic reward scaling
so the dispersion bonus remains meaningful without overwhelming the external
task reward. All jobs are spawned in parallel and pinned to the provided GPU
list in a round-robin manner.

Example
-------
  python scripts/run_antfall_dispersion.py --gpus 0,1 --seeds 19
"""

from __future__ import annotations

import argparse
import os
import shlex
import subprocess
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List


@dataclass(frozen=True)
class DispersionRun:
    """Configuration for a single dispersion objective."""

    key: str
    label: str
    man_rew_scale: float = 1.0
    balance_end: float = 0.5
    extra_args: List[str] = field(default_factory=list)


# Intrinsic reward scales below are tuned so the dispersion term sits well below
# the task reward once shaping activates, preventing the bonus from dominating.
DISPERSION_RUNS: List[DispersionRun] = [
    DispersionRun(
        key="var",
        label="total_variance",
    ),
    DispersionRun(
        key="logdet",
        label="log_det",
    ),
    DispersionRun(
        key="logtrace",
        label="log_trace",
    ),
    DispersionRun(
        key="dir_perp",
        label="directional_variance",
    ),
    DispersionRun(
        key="maxeig",
        label="max_eigenvalue",
    ),
    DispersionRun(
        key="anisotropy",
        label="anisotropy",
    ),
    DispersionRun(
        key="w2",
        label="wasserstein",
    ),
    DispersionRun(
        key="chance",
        label="chance_constraint",
    ),
    DispersionRun(
        key="cvar",
        label="cvar_tail",
        extra_args=[],
    ),
]


def parse_list(raw: str) -> List[int]:
    """Convert a comma separated string into a list of ints."""
    if not raw:
        return []
    return [int(token) for token in raw.split(",")]


def launch_runs(args: argparse.Namespace) -> int:
    repo_root = Path(__file__).resolve().parents[1]
    log_dir = Path(args.log_dir).expanduser()
    model_dir = Path(args.model_dir).expanduser()
    log_dir.mkdir(parents=True, exist_ok=True)
    model_dir.mkdir(parents=True, exist_ok=True)

    seeds = parse_list(args.seeds)
    if not seeds:
        seeds = [args.seed]

    gpu_ids = [gpu.strip() for gpu in args.gpus.split(",") if gpu.strip()]
    if not gpu_ids:
        gpu_ids = ["0"]

    processes = []

    run_index = 0
    for run in DISPERSION_RUNS:
        extra_args = list(run.extra_args)
        if run.key == "cvar":
            extra_args.extend(
                [
                    "--disp-alpha",
                    f"{args.cvar_alpha}",
                    "--disp-samples",
                    f"{args.cvar_samples}",
                ]
            )
        elif run.key == "chance":
            extra_args.extend(
                [
                    "--disp-eps",
                    f"{args.chance_eps}",
                    "--disp-samples",
                    f"{args.cvar_samples}",
                ]
            )
        for seed in seeds:
            algo_tag = f"{args.experiment_prefix}_{run.label}_s{seed}"
            cmd = [
                sys.executable,
                "main.py",
                "--env_name",
                "AntFall",
                "--algo",
                algo_tag,
                "--dispersion",
                run.key,
                "--seed",
                f"{seed}",
                "--max_timesteps",
                f"{args.max_timesteps}",
                "--eval_freq",
                f"{args.eval_freq}",
                "--log_dir",
                str(log_dir),
                "--model_dir",
                str(model_dir),
                "--manager_propose_freq",
                "10",
                "--train_manager_freq",
                "10",
                "--man_rew_scale",
                f"{run.man_rew_scale}",
                "--man_ctrl_rew_balance_start",
                f"{args.man_ctrl_balance_start}",
                "--man_ctrl_rew_balance_end",
                f"{run.balance_end}",
                "--man_ctrl_rew_balance_steps",
                f"{args.man_ctrl_balance_steps}",
                "--reach_warmup_samples",
                f"{args.reach_warmup_samples}",
                "--reach_warmup_rounds",
                f"{args.reach_warmup_rounds}",
                "--log_dispersion_stats",
            ]

            if args.freeze_worker:
                cmd.append("--freeze_worker")

            cmd.extend(extra_args)

            gpu = gpu_ids[run_index % len(gpu_ids)]
            env = os.environ.copy()
            env["CUDA_VISIBLE_DEVICES"] = gpu

            printable = " ".join(shlex.quote(token) for token in cmd)
            print(f"[GPU {gpu}] Launching {algo_tag}: {printable}")
            if args.dry_run:
                run_index += 1
                continue

            proc = subprocess.Popen(cmd, cwd=str(repo_root), env=env)
            processes.append((algo_tag, proc))
            run_index += 1

    if args.dry_run:
        return 0

    return wait_for_processes(processes)


def wait_for_processes(processes: List[tuple[str, subprocess.Popen]]) -> int:
    """Wait for all subprocesses and report the first non-zero exit code."""
    try:
        exit_code = 0
        for name, proc in processes:
            ret = proc.wait()
            if ret != 0:
                print(f"[ERROR] Run '{name}' exited with code {ret}.")
                exit_code = ret if exit_code == 0 else exit_code
        return exit_code
    except KeyboardInterrupt:
        print("\n[WARN] KeyboardInterrupt received. Terminating child processes...")
        for _, proc in processes:
            proc.terminate()
        for _, proc in processes:
            proc.wait()
        return 1


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Parallel AntFall dispersion sweeps.")
    parser.add_argument("--seeds", default="", help="Comma separated list of seeds.")
    parser.add_argument("--seed", type=int, default=19, help="Fallback seed if --seeds is empty.")
    parser.add_argument("--gpus", default="0", help="Comma separated CUDA device ids.")
    parser.add_argument("--max_timesteps", type=int, default=5e6, help="Training horizon.")
    parser.add_argument("--eval_freq", type=int, default=5e3, help="Evaluation frequency.")
    parser.add_argument("--man_ctrl_balance_start", type=float, default=0.0,
                        help="Initial intrinsic weighting for the manager.")
    parser.add_argument("--man_ctrl_balance_steps", type=int, default=400_000,
                        help="Steps over which the intrinsic weighting anneals.")
    parser.add_argument("--cvar_alpha", type=float, default=0.1,
                        help="Tail mass for the CVaR dispersion objective.")
    parser.add_argument("--cvar_samples", type=int, default=128,
                        help="Number of MDN samples when estimating CVaR or chance.")
    parser.add_argument("--chance_eps", type=float, default=2.0,
                        help="Radius ε for the chance dispersion objective.")
    parser.add_argument("--reach_warmup_samples", type=int, default=8_000,
                        help="Reach buffer samples before enabling shaping.")
    parser.add_argument("--reach_warmup_rounds", type=int, default=2,
                        help="Reach-net optimisation rounds before enabling shaping.")
    parser.add_argument("--log_dir", default="./logs/antfall_dispersion",
                        help="Base directory for TensorBoard logs.")
    parser.add_argument("--model_dir", default="./models/antfall_dispersion",
                        help="Base directory for model checkpoints.")
    parser.add_argument("--experiment_prefix", default="S3_AntFall_oct27",
                        help="Prefix used for the --algo tag.")
    parser.add_argument("--freeze_worker", action="store_true",
                        help="Enable --freeze_worker for all runs.")
    parser.add_argument("--dry_run", action="store_true",
                        help="Print planned commands without executing them.")
    return parser


def main() -> int:
    parser = build_parser()
    args = parser.parse_args()
    return launch_runs(args)


if __name__ == "__main__":
    raise SystemExit(main())
