from __future__ import annotations

import argparse
import os
import subprocess
import time
from typing import Dict, List, Tuple

import yaml

from phijax.experiments.sweeps import assemble_runs


ProcInfo = Tuple[subprocess.Popen, int]  # (process, run_idx)


def launch_one(
    *,
    gpu: int,
    run_idx: int,
    config_path: str,
    device: str | None,
    tm: bool,
    ns: bool,
) -> subprocess.Popen:
    env = os.environ.copy()
    env["CUDA_VISIBLE_DEVICES"] = str(gpu)

    cmd: List[str] = ["python", "main.py", "--config", config_path, "--run_idx", str(run_idx)]
    if device is not None:
        cmd += ["--device", device]
    if tm:
        cmd += ["--tm"]
    if ns:
        cmd += ["--ns"]

    # Inherit stdout/stderr so you see logs in real time. If you want per-run log files, change this.
    print(f"[start] gpu={gpu} run_idx={run_idx} :: {' '.join(cmd)}", flush=True)
    return subprocess.Popen(cmd, env=env)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--config", required=True)
    ap.add_argument("--gpus", default="0,1", help="Comma-separated GPU indices, e.g. 0,1")
    ap.add_argument("--poll_s", type=float, default=1.0, help="Polling interval in seconds")
    ap.add_argument("--tm", action="store_true", default=False)
    ap.add_argument("--ns", action="store_true", default=False)
    ap.add_argument("--device", type=str, default=None)
    ap.add_argument("--stop_on_fail", action="store_true", default=False)
    args = ap.parse_args()

    gpus = [int(x) for x in args.gpus.split(",") if x.strip() != ""]
    if not gpus:
        raise ValueError("No GPUs provided. Example: --gpus 0,1")

    base_cfg = yaml.safe_load(open(args.config, "r"))
    runs = assemble_runs(base_cfg)
    num_runs = len(runs)

    pending = list(range(num_runs))
    running: Dict[int, ProcInfo] = {}  # gpu -> (proc, run_idx)
    finished_ok: List[int] = []
    finished_fail: List[Tuple[int, int]] = []  # (run_idx, exit_code)

    print(f"Expanded runs: {num_runs}")
    print(f"GPU slots: {gpus}")

    # Initial fill
    for gpu in gpus:
        if not pending:
            break
        ridx = pending.pop(0)
        proc = launch_one(
            gpu=gpu,
            run_idx=ridx,
            config_path=args.config,
            device=args.device,
            tm=args.tm,
            ns=args.ns,
        )
        running[gpu] = (proc, ridx)

    # Main loop
    while pending or running:
        freed: List[int] = []

        for gpu, (proc, ridx) in list(running.items()):
            ret = proc.poll()
            if ret is None:
                continue

            # Process finished
            freed.append(gpu)
            del running[gpu]

            if ret == 0:
                print(f"[done]  gpu={gpu} run_idx={ridx} exit=0", flush=True)
                finished_ok.append(ridx)
            else:
                print(f"[fail] gpu={gpu} run_idx={ridx} exit={ret}", flush=True)
                finished_fail.append((ridx, ret))
                if args.stop_on_fail:
                    print("Stopping because --stop_on_fail was set.", flush=True)
                    # terminate anything still running
                    for g2, (p2, r2) in running.items():
                        print(f"[kill] gpu={g2} run_idx={r2}", flush=True)
                        p2.terminate()
                    raise SystemExit(1)

        # Refill freed GPUs
        for gpu in freed:
            if not pending:
                continue
            ridx = pending.pop(0)
            proc = launch_one(
                gpu=gpu,
                run_idx=ridx,
                config_path=args.config,
                device=args.device,
                tm=args.tm,
                ns=args.ns,
            )
            running[gpu] = (proc, ridx)

        time.sleep(args.poll_s)

    # Summary
    finished_ok.sort()
    finished_fail.sort(key=lambda x: x[0])
    print("\nAll jobs completed.")
    print(f"OK:   {len(finished_ok)} -> {finished_ok}")
    if finished_fail:
        print(f"FAIL: {len(finished_fail)} -> {finished_fail}")
        raise SystemExit(2)


if __name__ == "__main__":
    main()
