from __future__ import annotations

import argparse
import glob
import os
import shlex
import subprocess
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import yaml

from phijax.experiments.sweeps import assemble_runs


@dataclass
class RunRec:
    cfg_path: str
    run_idx: int
    gpu: int
    win_name: str
    ok_path: str
    fail_path: str


@dataclass(frozen=True)
class Job:
    cfg_path: str
    run_idx: int


def require_tmux() -> str:
    if not os.environ.get("TMUX"):
        raise SystemExit(
            "Error: not inside tmux (TMUX env var not set). Run this from within a tmux session."
        )
    return subprocess.check_output(["tmux", "display-message", "-p", "#S"], text=True).strip()


def tmux_window_exists(session: str, win_name: str) -> bool:
    out = subprocess.check_output(
        ["tmux", "list-windows", "-t", session, "-F", "#{window_name}"],
        text=True,
    )
    names = {line.strip() for line in out.splitlines() if line.strip()}
    return win_name in names


def tmux_new_window(session: str, win_name: str) -> None:
    subprocess.check_call(["tmux", "new-window", "-d", "-t", session, "-n", win_name])


def tmux_send(session: str, win_name: str, cmd: str) -> None:
    subprocess.check_call(["tmux", "send-keys", "-t", f"{session}:{win_name}", cmd, "C-m"])


def sanitize_window_name(s: str) -> str:
    # tmux window names: keep it simple and readable
    out = []
    for ch in s:
        if ch.isalnum() or ch in ("_", "-", "."):
            out.append(ch)
        else:
            out.append("_")
    return "".join(out)[:60]


def expand_config_paths(items: List[str]) -> List[str]:
    """
    Accepts:
      - file paths
      - directories (recursively picks *.yml, *.yaml)
      - glob patterns
    Returns sorted unique list of files.
    """
    files: List[str] = []
    for it in items:
        it = os.path.expanduser(it)

        # Glob pattern
        matches = glob.glob(it)
        if matches and any(os.path.isfile(m) for m in matches):
            files.extend([m for m in matches if os.path.isfile(m)])
            continue

        # Directory
        if os.path.isdir(it):
            for ext in ("*.yml", "*.yaml"):
                files.extend(glob.glob(os.path.join(it, "**", ext), recursive=True))
            continue

        # Plain file
        if os.path.isfile(it):
            files.append(it)
            continue

        raise FileNotFoundError(f"Config path/pattern not found: {it}")

    # normalize + unique + stable order
    uniq = sorted({os.path.abspath(f) for f in files})
    return uniq


def build_jobs(config_files: List[str]) -> List[Job]:
    jobs: List[Job] = []
    for cfg_path in config_files:
        cfg = yaml.safe_load(open(cfg_path, "r"))
        runs = assemble_runs(cfg)
        for ridx in range(len(runs)):
            jobs.append(Job(cfg_path=cfg_path, run_idx=ridx))
    return jobs


def launch_in_tmux(
    *,
    session: str,
    win_name: str,
    gpu: int,
    job: Job,
    device: Optional[str],
    tm: bool,
    ns: bool,
    conda_env: str,
    status_dir: str,
) -> RunRec:
    os.makedirs(status_dir, exist_ok=True)

    # status files include both config stem and run index
    cfg_stem = Path(job.cfg_path).stem
    ok_path = os.path.join(status_dir, f"{cfg_stem}__run_{job.run_idx:05d}.ok")
    fail_path = os.path.join(status_dir, f"{cfg_stem}__run_{job.run_idx:05d}.fail")

    for p in (ok_path, fail_path):
        if os.path.exists(p):
            os.remove(p)

    # Ensure unique window name
    base = win_name
    k = 1
    while tmux_window_exists(session, win_name):
        k += 1
        win_name = f"{base}_{k}"

    tmux_new_window(session, win_name)

    cmd: List[str] = ["python", "main_queue.py", "--config", job.cfg_path, "--run_idx", str(job.run_idx)]
    if device is not None:
        cmd += ["--device", device]
    if tm:
        cmd += ["--tm"]
    if ns:
        cmd += ["--ns"]

    shell = (
        f"echo '[tmux-run] gpu={gpu} cfg={shlex.quote(job.cfg_path)} run_idx={job.run_idx} window={win_name}'; "
        f"conda activate {shlex.quote(conda_env)}; "
        f"export CUDA_VISIBLE_DEVICES={gpu}; "
        f"{' '.join(shlex.quote(x) for x in cmd)}; "
        f"rc=$?; "
        f"if [ $rc -eq 0 ]; then "
        f"  echo OK > {shlex.quote(ok_path)}; "
        f"  echo '[tmux-run] done OK'; "
        f"else "
        f"  echo $rc > {shlex.quote(fail_path)}; "
        f"  echo '[tmux-run] done FAIL rc=' $rc; "
        f"fi; "
        f"echo '[tmux-run] leaving shell open'; "
        f"exec $SHELL"
    )

    tmux_send(session, win_name, shell)

    return RunRec(
        cfg_path=job.cfg_path,
        run_idx=job.run_idx,
        gpu=gpu,
        win_name=win_name,
        ok_path=ok_path,
        fail_path=fail_path,
    )


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument(
        "--configs",
        nargs="+",
        required=True,
        help="One or more config files, directories, or glob patterns. Example: --configs config/runs/*.yaml config/more/",
    )
    ap.add_argument("--gpus", default="0,1", help="Comma-separated GPU indices, e.g. 0,1")
    ap.add_argument("--poll_s", type=float, default=2.0)
    ap.add_argument("--tm", action="store_true", default=False)
    ap.add_argument("--ns", action="store_true", default=False)
    ap.add_argument("--device", type=str, default=None)
    ap.add_argument("--conda_env", type=str, required=True, help="Conda env to activate in each tmux window")
    ap.add_argument("--status_dir", type=str, default=".queue_status")
    ap.add_argument("--stop_on_fail", action="store_true", default=False)
    args = ap.parse_args()

    session = require_tmux()

    gpus = [int(x) for x in args.gpus.split(",") if x.strip() != ""]
    if not gpus:
        raise ValueError("No GPUs provided. Example: --gpus 0,1")

    config_files = expand_config_paths(args.configs)
    jobs = build_jobs(config_files)

    pending: List[Job] = list(jobs)
    running: Dict[int, RunRec] = {}  # gpu -> RunRec
    ok: List[Tuple[str, int]] = []
    fail: List[Tuple[str, int, int]] = []  # (cfg_path, run_idx, rc)

    print(f"tmux session: {session}")
    print(f"GPU slots: {gpus}")
    print(f"configs ({len(config_files)}):")
    for p in config_files:
        print(f"  - {p}")
    print(f"Total pooled jobs: {len(jobs)}")
    print(f"status_dir: {args.status_dir}")
    print(f"conda_env: {args.conda_env}")

    def fill_free_gpus() -> None:
        for gpu in gpus:
            if gpu in running:
                continue
            if not pending:
                continue

            job = pending.pop(0)
            cfg_stem = Path(job.cfg_path).stem
            win = sanitize_window_name(f"{cfg_stem}_r{job.run_idx:03d}_gpu{gpu}")
            rec = launch_in_tmux(
                session=session,
                win_name=win,
                gpu=gpu,
                job=job,
                device=args.device,
                tm=args.tm,
                ns=args.ns,
                conda_env=args.conda_env,
                status_dir=args.status_dir,
            )
            running[gpu] = rec
            print(f"[start] gpu={gpu} cfg={cfg_stem} run_idx={job.run_idx} window={rec.win_name}", flush=True)

    fill_free_gpus()

    while pending or running:
        freed: List[int] = []

        for gpu, rec in list(running.items()):
            if os.path.exists(rec.ok_path):
                ok.append((rec.cfg_path, rec.run_idx))
                freed.append(gpu)
                del running[gpu]
                print(f"[done]  gpu={gpu} cfg={Path(rec.cfg_path).stem} run_idx={rec.run_idx} OK window={rec.win_name}", flush=True)

            elif os.path.exists(rec.fail_path):
                with open(rec.fail_path, "r") as f:
                    rc_s = f.read().strip()
                rc = int(rc_s) if rc_s.isdigit() else 1

                fail.append((rec.cfg_path, rec.run_idx, rc))
                freed.append(gpu)
                del running[gpu]
                print(f"[fail] gpu={gpu} cfg={Path(rec.cfg_path).stem} run_idx={rec.run_idx} rc={rc} window={rec.win_name}", flush=True)

                if args.stop_on_fail:
                    print("Stopping because --stop_on_fail was set.", flush=True)
                    raise SystemExit(2)

            # window existence detection
            elif not tmux_window_exists(session, rec.win_name):
                rc = 137
                try:
                    with open(rec.fail_path, "w") as f:
                        f.write(str(rc))
                except Exception:
                    pass

                fail.append((rec.cfg_path, rec.run_idx, rc))
                freed.append(gpu)
                del running[gpu]
                print(f"[killed] gpu={gpu} cfg={Path(rec.cfg_path).stem} run_idx={rec.run_idx} rc={rc} window={rec.win_name}", flush=True)

                if args.stop_on_fail:
                    print("Stopping because --stop_on_fail was set.", flush=True)
                    raise SystemExit(2)

        if freed:
            fill_free_gpus()

        time.sleep(args.poll_s)

    ok.sort(key=lambda x: (x[0], x[1]))
    fail.sort(key=lambda x: (x[0], x[1]))

    print("\nAll jobs completed.")
    print(f"OK:   {len(ok)}")
    if fail:
        print(f"FAIL: {len(fail)}")
        for cfg_path, ridx, rc in fail[:50]:
            print(f"  - {Path(cfg_path).name} run_idx={ridx} rc={rc}")
        raise SystemExit(3)


if __name__ == "__main__":
    main()

    #python queue_tmux.py --configs config/runs/ --gpus 0,1,2,3 --poll_s 20 --conda_env /raid/work/abijuru/envs/jax
