#!/usr/bin/env python3
"""Parallel launcher for AntFall with a frozen worker policy across a fixed set of seeds.

- Spawns one subprocess per seed.
- Shares a single GPU via CUDA_VISIBLE_DEVICES.
- Optionally pins each process to a disjoint CPU core range using `taskset` if available.
- Sets per-process OMP/MKL threads based on detected cores per job (overridable via env).
"""

import os
import shlex
import shutil
import sys
import subprocess
from typing import List, Tuple

# ----------------------------
# Config (env-overridable)
# ----------------------------
DEFAULT_SEEDS = [7, 19, 23]
GPU = os.environ.get("GPU", "0")

MAN_REW_SCALE = os.environ.get("MAN_REW_SCALE", "0.24")
MAN_NOISE_SIGMA = os.environ.get("MAN_NOISE_SIGMA", "0.6")

MODEL_DIR = os.environ.get("MODEL_DIR", "./saved_models")
LOG_DIR = os.environ.get("LOG_DIR", "./logs")
WORKER_MODEL_DIR = os.environ.get("WORKER_MODEL_DIR", "./models")

# Accept space-separated list in SEEDS env (e.g., SEEDS="1 2 3")
if "SEEDS" in os.environ and os.environ["SEEDS"].strip():
    SEEDS = [int(s) for s in os.environ["SEEDS"].split()]
else:
    SEEDS = DEFAULT_SEEDS

# Toggle CPU affinity pinning (auto if taskset exists and cores >= jobs)
FORCE_TASKSET = os.environ.get("FORCE_TASKSET", "").lower() in {"1", "true", "yes"}

# ----------------------------
# Static base args (match original)
# ----------------------------
BASE_ARGS: List[str] = [
    "--env_name", "AntFall",
    "--manager_propose_freq", "10",
    "--train_manager_freq", "10",
    "--man_ctrl_rew_balance_start", "0.1",
    "--man_ctrl_rew_balance_end", "0.32",
    "--man_ctrl_rew_balance_steps", "320000",
    "--man_rew_scale", MAN_REW_SCALE,
    "--man_noise_sigma", MAN_NOISE_SIGMA,
    "--reach_warmup_samples", "3000",
    "--reach_warmup_rounds", "1",
    "--freeze_worker",
    "--worker_model_dir", WORKER_MODEL_DIR,
    "--worker_env_name", "AntFall",
    "--worker_algo", "S3_AntFall_seed_7",
    "--model_dir", MODEL_DIR,
    "--log_dir", LOG_DIR,
    "--save_periodic",
]

def detect_cores() -> int:
    n = os.cpu_count() or 1
    return max(n, 1)

def compute_affinities(num_jobs: int, total_cores: int) -> List[Tuple[int, int]]:
    """Split [0..total_cores-1] into num_jobs contiguous ranges. Last job gets the remainder."""
    if total_cores < 1 or num_jobs < 1:
        return [(0, 0)] * num_jobs
    base = total_cores // num_jobs
    if base < 1:
        base = 1
    affinities = []
    next_core = 0
    for i in range(num_jobs):
        start = next_core
        end = start + base - 1
        if i == num_jobs - 1:
            end = total_cores - 1
        end = max(min(end, total_cores - 1), start)
        affinities.append((start, end))
        next_core = end + 1
    return affinities

def main() -> None:
    if not SEEDS:
        print("[ERROR] No seeds configured. Provide SEEDS=\"s1 s2 ...\".", file=sys.stderr)
        sys.exit(1)

    total_cores = detect_cores()
    num_jobs = len(SEEDS)
    cores_per_job = max(total_cores // num_jobs, 1)

    # Use taskset only if present and helpful
    have_taskset = shutil.which("taskset") is not None
    use_taskset = have_taskset and (total_cores >= num_jobs or FORCE_TASKSET)

    print(f"[INFO] Launching {num_jobs} AntFall runs on GPU {GPU}.")
    print(f"[INFO] System reports {total_cores} CPU cores; assigning ≈{cores_per_job} per job.")
    if use_taskset:
        print("[INFO] Using taskset for CPU affinity.")
    else:
        if FORCE_TASKSET and not have_taskset:
            print("[WARN] FORCE_TASKSET requested but `taskset` not found; proceeding without affinity.")
        else:
            print("[INFO] Not using taskset (either unavailable or not beneficial).")

    affinities = compute_affinities(num_jobs, total_cores) if use_taskset else [(0, total_cores - 1)] * num_jobs

    procs: List[subprocess.Popen] = []
    env_base = os.environ.copy()
    env_base["CUDA_VISIBLE_DEVICES"] = GPU

    # Respect OpenMP placement/binding defaults (overridable)
    env_base.setdefault("OMP_PROC_BIND", "TRUE")
    env_base.setdefault("OMP_PLACES", "cores")

    for idx, (seed, (start, end)) in enumerate(zip(SEEDS, affinities), start=1):
        algo_name = f"HAWK_AntFall_s{seed}_frozen"

        # Per-process env for thread counts
        env = env_base.copy()
        omp_threads = os.environ.get("OMP_NUM_THREADS_OVERRIDE") or str(cores_per_job)
        mkl_threads = os.environ.get("MKL_NUM_THREADS_OVERRIDE") or str(cores_per_job)
        env["OMP_NUM_THREADS"] = omp_threads
        env["MKL_NUM_THREADS"] = mkl_threads

        # Build command
        cmd: List[str] = [
            sys.executable, "main.py",
            "--algo", algo_name,
            "--seed", str(seed),
            "--gid", GPU,
            *BASE_ARGS,
        ]

        # Optional taskset prefix
        if use_taskset:
            affinity_str = f"{start}-{end}"
            launch = ["taskset", "-c", affinity_str] + cmd
            affinity_note = affinity_str
        else:
            launch = cmd
            affinity_note = "all"

        print(f"[RUN {idx}/{num_jobs}] Seed {seed} → man_rew_scale={MAN_REW_SCALE} cores={affinity_note} OMP={omp_threads}")

        # For debugging: show the exact command if needed
        if os.environ.get("DRY_RUN", "").lower() in {"1", "true", "yes"}:
            print(" ".join(shlex.quote(x) for x in launch))
            continue

        proc = subprocess.Popen(launch, env=env)
        procs.append(proc)

    if not procs:
        print("[WARN] Nothing was launched (DRY_RUN?). Exiting.")
        return

    # Wait for completion and aggregate status
    fail = False
    for proc in procs:
        ret = proc.wait()
        if ret != 0:
            fail = True

    if not fail:
        print("All AntFall runs finished successfully.")
        sys.exit(0)
    else:
        print("One or more AntFall runs failed.", file=sys.stderr)
        sys.exit(1)

if __name__ == "__main__":
    main()
