from __future__ import annotations

import argparse
import json
import os
import subprocess
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable


@dataclass(frozen=True)
class Job:
    name: str
    script: Path
    args: list[str]


def _timestamp() -> str:
    return time.strftime("%Y%m%d_%H%M%S")


def _repo_root() -> Path:
    return Path(__file__).resolve().parents[1]


def _build_jobs(repo_root: Path) -> list[Job]:
    seed = "42"

    jobs: list[Job] = [
        # -------------------------
        # RNN (regression)
        # -------------------------
        Job(
            name="rnn_add",
            script=repo_root / "Compare_RNN" / "task" / "regression" / "adding_task.py",
            args=[
                "--epochs",
                "30",
                "--scan-epochs",
                "5",
                "--batch-size",
                "128",
                "--hidden",
                "128",
                "--seed",
                seed,
                "--plot-path",
                "plots/rnn_adding_task",
            ],
        ),
        Job(
            name="rnn_lorenz_image",
            script=repo_root / "Compare_RNN" / "task" / "regression" / "lorenz_image.py",
            args=[
                "--epochs",
                "30",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--seed",
                seed,
                "--plot-path",
                "plots/rnn_lorenz_image",
            ],
        ),
        # -------------------------
        # RNN (classification)
        # -------------------------
        Job(
            name="rnn_seq_mnist",
            script=repo_root / "Compare_RNN" / "task" / "classification" / "row_mnist.py",
            args=[
                "--epochs",
                "50",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--hidden",
                "128",
                "--seed",
                seed,
                "--plot-path",
                "plots/rnn_row_mnist",
            ],
        ),
        Job(
            name="rnn_pixel_mnist",
            script=repo_root / "Compare_RNN" / "task" / "classification" / "pixel_mnist.py",
            args=[
                "--epochs",
                "50",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--hidden",
                "256",
                "--seed",
                seed,
                "--plot-path",
                "plots/rnn_pixel_mnist",
            ],
        ),
        Job(
            name="rnn_seq_cifar",
            script=repo_root / "Compare_RNN" / "task" / "classification" / "seq_cifar10.py",
            args=[
                "--epochs",
                "100",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--hidden",
                "256",
                "--seed",
                seed,
                "--plot-path",
                "plots/rnn_seq_cifar10",
            ],
        ),
        Job(
            name="rnn_uci",
            script=repo_root / "Compare_RNN" / "task" / "classification" / "uci_har.py",
            args=[
                "--epochs",
                "50",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--hidden",
                "128",
                "--seed",
                seed,
                "--har-root",
                "data/uci_har",
                "--plot-path",
                "plots/rnn_uci_har",
            ],
        ),
        # -------------------------
        # RNN (language modeling)
        # -------------------------
        Job(
            name="rnn_ptb",
            script=repo_root / "Compare_RNN" / "task" / "lm" / "ptb_char.py",
            args=[
                "--epochs",
                "50",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--hidden",
                "128",
                "--seed",
                seed,
                "--ptb-path",
                "data/ptb",
                "--ptb-block-size",
                "80",
                "--ptb-steps-per-epoch",
                "300",
                "--ptb-val-steps",
                "60",
                "--plot-path",
                "plots/rnn_ptb_char",
            ],
        ),
        Job(
            name="rnn_wikitext",
            script=repo_root / "Compare_RNN" / "task" / "lm" / "wikitext2_char.py",
            args=[
                "--epochs",
                "50",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--hidden",
                "128",
                "--seed",
                seed,
                "--wikitext-path",
                "data/wikitext2_raw",
                "--ptb-block-size",
                "80",
                "--ptb-steps-per-epoch",
                "200",
                "--ptb-val-steps",
                "60",
                "--plot-path",
                "plots/rnn_wikitext2_char",
            ],
        ),
        # -------------------------
        # CNN / ConvRNN (compare)
        # -------------------------
        Job(
            name="cnn_permuted_mnist",
            script=repo_root / "Compare_CNN" / "task" / "classification" / "permute_mnist.py",
            args=[
                "--epochs",
                "50",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--hidden",
                "128",
                "--enc-channels",
                "16,32",
                "--steps",
                "12",
                "--kernel-size",
                "3",
                "--seed",
                seed,
                "--permute-seed",
                "1234",
                "--plot-path",
                "plots/cnn_permute_mnist",
            ],
        ),
        Job(
            name="cnn_fashion_mnist",
            script=repo_root / "Compare_CNN" / "task" / "classification" / "fashion_mnist.py",
            args=[
                "--epochs",
                "50",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--hidden",
                "128",
                "--enc-channels",
                "16,32",
                "--steps",
                "12",
                "--kernel-size",
                "3",
                "--seed",
                seed,
                "--plot-path",
                "plots/cnn_fashion_mnist",
            ],
        ),
        Job(
            name="cnn_dvs_cifar10",
            script=repo_root / "Compare_CNN" / "task" / "classification" / "dvs_cifar10.py",
            args=[
                "--epochs",
                "50",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--hidden",
                "256",
                "--enc-channels",
                "32,64",
                "--steps",
                "12",
                "--kernel-size",
                "3",
                "--seed",
                seed,
                "--dvs-root",
                "data/dvs_cifar10",
                "--dvs-time-bins",
                "10",
                "--plot-path",
                "plots/cnn_dvs_cifar10",
            ],
        ),
        Job(
            name="cnn_dvs_gesture",
            script=repo_root / "Compare_CNN" / "task" / "classification" / "dvs_gesture.py",
            args=[
                "--epochs",
                "100",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--hidden",
                "256",
                "--enc-channels",
                "32,64",
                "--steps",
                "12",
                "--kernel-size",
                "3",
                "--seed",
                seed,
                "--gesture-root",
                "data/dvs_gesture",
                "--gesture-time-bins",
                "20",
                "--plot-path",
                "plots/cnn_dvs_gesture",
            ],
        ),
        # -------------------------
        # CNN / ConvRNN (E-Prop only)
        # -------------------------
        Job(
            name="cnn_eprop_permuted_mnist",
            script=repo_root / "run_CNN_eprop" / "permute_mnist.py",
            args=[
                "--epochs",
                "50",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--hidden",
                "128",
                "--enc-channels",
                "16,32",
                "--steps",
                "12",
                "--kernel-size",
                "3",
                "--seed",
                seed,
                "--permute-seed",
                "1234",
                "--plot-path",
                "plots/cnn_eprop_permute",
            ],
        ),
        Job(
            name="cnn_eprop_fashion_mnist",
            script=repo_root / "run_CNN_eprop" / "fashion_mnist.py",
            args=[
                "--epochs",
                "50",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--hidden",
                "128",
                "--enc-channels",
                "16,32",
                "--steps",
                "12",
                "--kernel-size",
                "3",
                "--seed",
                seed,
                "--plot-path",
                "plots/cnn_eprop_fashion",
            ],
        ),
        Job(
            name="cnn_eprop_dvs_cifar10",
            script=repo_root / "run_CNN_eprop" / "dvs_cifar10.py",
            args=[
                "--epochs",
                "50",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--hidden",
                "256",
                "--enc-channels",
                "32,64",
                "--steps",
                "12",
                "--kernel-size",
                "3",
                "--seed",
                seed,
                "--dvs-root",
                "data/dvs_cifar10",
                "--dvs-time-bins",
                "10",
                "--plot-path",
                "plots/cnn_eprop_dvs_cifar10",
            ],
        ),
        Job(
            name="cnn_eprop_dvs_gesture",
            script=repo_root / "run_CNN_eprop" / "dvs_gesture.py",
            args=[
                "--epochs",
                "100",
                "--scan-epochs",
                "5",
                "--batch-size",
                "64",
                "--hidden",
                "256",
                "--enc-channels",
                "32,64",
                "--steps",
                "12",
                "--kernel-size",
                "3",
                "--seed",
                seed,
                "--gesture-root",
                "data/dvs_gesture",
                "--gesture-time-bins",
                "20",
                "--plot-path",
                "plots/cnn_eprop_dvs_gesture",
            ],
        ),
    ]

    return jobs


def _stream_process(cmd: list[str], cwd: Path, log_path: Path) -> int:
    log_path.parent.mkdir(parents=True, exist_ok=True)
    print(f"\n[RUN] {' '.join(cmd)}")
    started = time.time()
    with log_path.open("w", encoding="utf-8", errors="replace") as log_file:
        proc = subprocess.Popen(
            cmd,
            cwd=str(cwd),
            env=os.environ.copy(),
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
        )
        assert proc.stdout is not None
        for line in proc.stdout:
            sys.stdout.write(line)
            log_file.write(line)
            log_file.flush()
        rc = proc.wait()
    elapsed = time.time() - started
    print(f"[DONE] exit={rc} elapsed_sec={elapsed:.1f} log={log_path}")
    return rc


def _select_jobs(jobs: list[Job], only: str | None) -> list[Job]:
    if not only:
        return jobs
    want = {x.strip() for x in only.split(",") if x.strip()}
    selected = [job for job in jobs if job.name in want]
    missing = sorted(want - {job.name for job in selected})
    if missing:
        raise SystemExit(f"Unknown --only job(s): {', '.join(missing)}")
    return selected


def main(argv: Iterable[str] | None = None) -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--tag", type=str, default=_timestamp())
    parser.add_argument("--only", type=str, default=None, help="Comma-separated job names to run.")
    parser.add_argument("--stop-on-failure", action="store_true", help="Stop at first failure (default: continue).")
    parser.add_argument("--log-dir", type=str, default=str(Path("logs") / "tonight"))
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args(list(argv) if argv is not None else None)

    repo_root = _repo_root()
    jobs = _select_jobs(_build_jobs(repo_root), args.only)

    log_root = Path(args.log_dir) / args.tag
    log_root.mkdir(parents=True, exist_ok=True)
    summary_path = log_root / "summary.json"

    meta = {
        "tag": args.tag,
        "repo_root": str(repo_root),
        "python": sys.version,
        "executable": sys.executable,
        "start_time": time.strftime("%Y-%m-%d %H:%M:%S"),
        "jobs": [],
    }
    summary_path.write_text(json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8")

    failures: list[dict] = []
    for idx, job in enumerate(jobs, start=1):
        cmd = [sys.executable, "-u", str(job.script)] + job.args
        job_log = log_root / f"{idx:02d}_{job.name}.log"

        record = {
            "index": idx,
            "name": job.name,
            "script": str(job.script),
            "args": job.args,
            "log": str(job_log),
            "start_ts": time.time(),
        }

        if args.dry_run:
            print(f"[DRY] {job.name}: {' '.join(cmd)}")
            record["exit_code"] = None
            record["elapsed_sec"] = None
            meta["jobs"].append(record)
            continue

        rc = _stream_process(cmd, cwd=repo_root, log_path=job_log)
        record["exit_code"] = rc
        record["elapsed_sec"] = float(max(0.0, time.time() - float(record["start_ts"])))
        meta["jobs"].append(record)

        summary_path.write_text(json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8")
        if rc != 0:
            failures.append(record)
            if args.stop_on_failure:
                break

    meta["end_time"] = time.strftime("%Y-%m-%d %H:%M:%S")
    meta["failures"] = failures
    summary_path.write_text(json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8")

    if failures:
        print(f"\n[WARN] Completed with failures: {len(failures)} (see {summary_path})")
        return 1
    print(f"\n[OK] All jobs completed successfully (see {summary_path})")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
