from __future__ import annotations

import argparse
import os
import subprocess
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable


@dataclass(frozen=True)
class Experiment:
    name: str
    script: Path
    extra_args: list[str]


def _stream_process(cmd: list[str], cwd: Path, log_path: Path) -> None:
    log_path.parent.mkdir(parents=True, exist_ok=True)
    print(f"\n[RUN] {' '.join(cmd)}")
    with log_path.open("w", encoding="utf-8", errors="replace") as log_file:
        proc = subprocess.Popen(
            cmd,
            cwd=str(cwd),
            env=os.environ.copy(),
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
        )
        assert proc.stdout is not None
        for line in proc.stdout:
            sys.stdout.write(line)
            log_file.write(line)
        rc = proc.wait()
        if rc != 0:
            raise RuntimeError(f"Command failed (exit={rc}): {' '.join(cmd)}")


def _experiment_suite(args: argparse.Namespace, repo_root: Path) -> list[Experiment]:
    epochs = str(args.epochs)
    scan_epochs = str(args.scan_epochs)
    batch_size = str(args.batch_size)
    hidden = str(args.hidden)
    seed = str(args.seed)

    common = [
        "--epochs",
        epochs,
        "--scan-epochs",
        scan_epochs,
        "--batch-size",
        batch_size,
        "--hidden",
        hidden,
        "--seed",
        seed,
    ]

    exps: list[Experiment] = []
    if args.suite in {"rnn", "all"}:
        rnn_limits: list[str] = []
        if args.rnn_train_limit is not None:
            rnn_limits += ["--train-limit", str(args.rnn_train_limit)]
        if args.rnn_test_limit is not None:
            rnn_limits += ["--test-limit", str(args.rnn_test_limit)]
        exps.extend(
            [
                Experiment(
                    name="rnn_row_mnist",
                    script=repo_root / "Compare_RNN" / "task" / "classification" / "row_mnist.py",
                    extra_args=common + rnn_limits,
                ),
                Experiment(
                    name="rnn_row_cifar10",
                    script=repo_root / "Compare_RNN" / "task" / "classification" / "row_cifar10.py",
                    extra_args=common + rnn_limits,
                ),
                Experiment(
                    name="rnn_wikitext2_char",
                    script=repo_root / "Compare_RNN" / "task" / "lm" / "wikitext2_char.py",
                    extra_args=common
                    + [
                        "--ptb-block-size",
                        str(args.lm_block_size),
                        "--ptb-steps-per-epoch",
                        str(args.lm_steps_per_epoch),
                        "--ptb-val-steps",
                        str(args.lm_val_steps),
                        "--eval-every",
                        str(args.lm_eval_every),
                    ],
                ),
            ]
        )
        if args.lm_max_chars is not None:
            exps[-1].extra_args.extend(["--ptb-max-chars", str(args.lm_max_chars)])
        if args.include_ptb_char:
            ptb_char_args = common + [
                "--ptb-block-size",
                str(args.lm_block_size),
                "--ptb-steps-per-epoch",
                str(args.lm_steps_per_epoch),
                "--ptb-val-steps",
                str(args.lm_val_steps),
                "--eval-every",
                str(args.lm_eval_every),
            ]
            if args.lm_max_chars is not None:
                ptb_char_args += ["--ptb-max-chars", str(args.lm_max_chars)]
            exps.append(
                Experiment(
                    name="rnn_ptb_char",
                    script=repo_root / "Compare_RNN" / "task" / "lm" / "ptb_char.py",
                    extra_args=ptb_char_args,
                )
            )
        if args.include_shd:
            shd_limits: list[str] = []
            if args.shd_train_limit is not None:
                shd_limits += ["--train-limit", str(args.shd_train_limit)]
            if args.shd_test_limit is not None:
                shd_limits += ["--test-limit", str(args.shd_test_limit)]
            exps.append(
                Experiment(
                    name="rnn_shd",
                    script=repo_root / "Compare_RNN" / "task" / "classification" / "shd.py",
                    extra_args=common
                    + shd_limits
                    + [
                        "--shd-time-bins",
                        str(args.shd_time_bins),
                    ],
                )
            )

    if args.suite in {"convrnn", "all"}:
        cnn_common = common + ([] if args.include_eprop else ["--no-eprop"])
        cnn_limits: list[str] = []
        if args.cnn_train_limit is not None:
            cnn_limits += ["--train-limit", str(args.cnn_train_limit)]
        if args.cnn_test_limit is not None:
            cnn_limits += ["--test-limit", str(args.cnn_test_limit)]
        dvs_limits: list[str] = []
        if args.dvs_train_limit is not None:
            dvs_limits += ["--train-limit", str(args.dvs_train_limit)]
        if args.dvs_test_limit is not None:
            dvs_limits += ["--test-limit", str(args.dvs_test_limit)]
        exps.extend(
            [
                Experiment(
                    name="convrnn_fashion_mnist",
                    script=repo_root / "Compare_CNN" / "task" / "classification" / "fashion_mnist.py",
                    extra_args=cnn_common + cnn_limits,
                ),
                Experiment(
                    name="convrnn_dvs_gesture",
                    script=repo_root / "Compare_CNN" / "task" / "classification" / "dvs_gesture.py",
                    extra_args=cnn_common
                    + dvs_limits
                    + ["--gesture-time-bins", str(args.dvs_time_bins)],
                ),
            ]
        )
        if args.include_dvs_cifar10:
            exps.append(
                Experiment(
                    name="convrnn_dvs_cifar10",
                    script=repo_root / "Compare_CNN" / "task" / "classification" / "dvs_cifar10.py",
                    extra_args=cnn_common
                    + dvs_limits
                    + ["--dvs-time-bins", str(args.dvs_time_bins)],
                )
            )

    return exps


def main(argv: Iterable[str] | None = None) -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--suite", choices=["rnn", "convrnn", "all"], default="all")
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--scan-epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--hidden", type=int, default=128)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--out-dir", type=str, default=str(Path("plots") / "runner"))
    parser.add_argument("--tag", type=str, default=None)
    parser.add_argument("--include-dvs-cifar10", action="store_true")
    parser.add_argument("--include-eprop", action="store_true", help="Include E-Prop runs for ConvRNN suite.")
    parser.add_argument("--include-ptb-char", action="store_true")
    parser.add_argument("--include-shd", action="store_true")

    parser.add_argument("--rnn-train-limit", type=int, default=None)
    parser.add_argument("--rnn-test-limit", type=int, default=None)

    parser.add_argument("--lm-block-size", type=int, default=80)
    parser.add_argument("--lm-steps-per-epoch", type=int, default=200)
    parser.add_argument("--lm-val-steps", type=int, default=60)
    parser.add_argument("--lm-max-chars", type=int, default=None)
    parser.add_argument("--lm-eval-every", type=int, default=5)

    parser.add_argument("--cnn-train-limit", type=int, default=None)
    parser.add_argument("--cnn-test-limit", type=int, default=None)
    parser.add_argument("--dvs-train-limit", type=int, default=None)
    parser.add_argument("--dvs-test-limit", type=int, default=None)
    parser.add_argument("--dvs-time-bins", type=int, default=10)
    parser.add_argument("--shd-train-limit", type=int, default=2000)
    parser.add_argument("--shd-test-limit", type=int, default=1000)
    parser.add_argument("--shd-time-bins", type=int, default=100)

    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args(list(argv) if argv is not None else None)

    repo_root = Path(__file__).resolve().parents[1]
    out_dir = Path(args.out_dir)
    tag = args.tag or time.strftime("%Y%m%d_%H%M%S")

    experiments = _experiment_suite(args, repo_root)
    if not experiments:
        print("[INFO] No experiments selected.")
        return 0

    for exp in experiments:
        plot_dir = out_dir / exp.name
        plot_path = plot_dir / f"{tag}.png"
        log_path = plot_dir / f"{tag}.log"
        cmd = [sys.executable, "-u", str(exp.script)] + exp.extra_args + ["--plot-path", str(plot_path)]
        if args.dry_run:
            print(f"[DRY] {exp.name}: {' '.join(cmd)}")
            continue
        _stream_process(cmd, cwd=repo_root, log_path=log_path)

    print(f"\n[DONE] Plots/logs saved under: {out_dir}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
