#!/usr/bin/env python3
from __future__ import annotations

import argparse
import os
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import queue
import re
import signal
import subprocess
import sys
import threading
from typing import Iterable


@dataclass(frozen=True)
class Job:
    only_file: str
    log_path: Path


def _repo_root() -> Path:
    return Path(__file__).resolve().parents[1]


def _sanitize_filename(s: str) -> str:
    s = s.strip().lstrip("./")
    s = s.replace(os.sep, "__")
    s = re.sub(r"[^A-Za-z0-9._-]+", "_", s)
    return s or "unknown"


def _now_run_id() -> str:
    return datetime.now().strftime("%Y%m%d_%H%M%S")


def _iter_unique(items: Iterable[str]) -> list[str]:
    seen: set[str] = set()
    out: list[str] = []
    for it in items:
        if it not in seen:
            seen.add(it)
            out.append(it)
    return out


def _print_prefixed(prefix: str, line: str) -> None:
    sys.stdout.write(f"[{prefix}] {line}")
    sys.stdout.flush()


def _run_one(
    job: Job,
    cmd: list[str],
    *,
    print_to_terminal: bool,
    stop_event: threading.Event,
    running_procs: "queue.Queue[subprocess.Popen[bytes]]",
) -> int:
    job.log_path.parent.mkdir(parents=True, exist_ok=True)
    with job.log_path.open("wb") as log_fp:
        log_fp.write(f"$ {' '.join(cmd)}\n\n".encode("utf-8", errors="replace"))
        log_fp.flush()

        proc = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            cwd=str(_repo_root()),
        )
        running_procs.put(proc)
        assert proc.stdout is not None

        try:
            for raw in iter(proc.stdout.readline, b""):
                if stop_event.is_set():
                    break
                log_fp.write(raw)
                log_fp.flush()
                if print_to_terminal:
                    _print_prefixed(Path(job.only_file).name, raw.decode("utf-8", errors="replace"))
        finally:
            try:
                proc.stdout.close()
            except Exception:
                pass

        if stop_event.is_set():
            try:
                proc.terminate()
            except Exception:
                pass

        return proc.wait()


def main() -> int:
    parser = argparse.ArgumentParser(
        description=(
            "Run scripts/run_final_pipeline.sh on multiple Lean files in parallel, "
            "writing each job's live output to a fixed log directory."
        )
    )
    parser.add_argument(
        "--project",
        type=str,
        default=os.getenv("FORMAL_PROJECT", ""),
        help="Optional formalization project name (sets FORMAL_PROJECT for each job).",
    )
    parser.add_argument(
        "--only-file",
        action="append",
        default=[],
        help="Lean file path (repeatable). Example: Question_bench/FateH/93_new.lean",
    )
    parser.add_argument(
        "--jobs",
        type=int,
        default=4,
        help="Max parallel jobs (default: 4).",
    )
    parser.add_argument(
        "--log-dir",
        type=Path,
        default=None,
        help="Directory to write per-file logs (default: log/default/parallel_logs/final_pipeline).",
    )
    parser.add_argument(
        "--run-id",
        default="",
        help="Optional run id to include in log filenames (default: timestamp).",
    )
    parser.add_argument(
        "--no-terminal",
        action="store_true",
        help="Do not stream prefixed output to terminal (still writes logs).",
    )
    parser.add_argument(
        "--no-isolate-progress-history",
        action="store_true",
        help="Do not pass --job-name; allows sharing FINAL progress/history across jobs (not recommended).",
    )
    parser.add_argument(
        "--script",
        default=str(_repo_root() / "scripts" / "run_final_pipeline.sh"),
        help="Pipeline script to run (default: scripts/run_final_pipeline.sh).",
    )
    parser.add_argument(
        "extra_args",
        nargs=argparse.REMAINDER,
        help="Extra args passed to run_final_pipeline.sh (prefix with `--`).",
    )

    args = parser.parse_args()
    only_files = _iter_unique(args.only_file)
    if not only_files:
        print("No --only-file provided.", file=sys.stderr)
        return 2

    jobs = max(1, int(args.jobs))
    run_id = args.run_id.strip() or _now_run_id()
    project = (args.project or "").strip()
    if args.log_dir is None:
        base = _repo_root() / "log"
        if project:
            safe = re.sub(r"[^A-Za-z0-9._-]+", "_", project).strip("_") or "project"
            log_dir = base / safe / "parallel_logs" / "final_pipeline"
        else:
            log_dir = base / "default" / "parallel_logs" / "final_pipeline"
    else:
        log_dir = args.log_dir.expanduser().resolve()
    log_dir.mkdir(parents=True, exist_ok=True)

    extra_args: list[str] = list(args.extra_args)
    if extra_args and extra_args[0] == "--":
        extra_args = extra_args[1:]

    stop_event = threading.Event()
    running_procs: "queue.Queue[subprocess.Popen[bytes]]" = queue.Queue()

    def _handle_sigint(_signum: int, _frame) -> None:  # type: ignore[override]
        stop_event.set()
        while True:
            try:
                p = running_procs.get_nowait()
            except queue.Empty:
                break
            try:
                p.terminate()
            except Exception:
                pass

    signal.signal(signal.SIGINT, _handle_sigint)
    signal.signal(signal.SIGTERM, _handle_sigint)

    work_q: "queue.Queue[Job]" = queue.Queue()
    for f in only_files:
        log_name = f"{_sanitize_filename(f)}__{run_id}.log"
        work_q.put(Job(only_file=f, log_path=log_dir / log_name))

    results_lock = threading.Lock()
    results: dict[str, int] = {}

    def worker() -> None:
        while not stop_event.is_set():
            try:
                job = work_q.get_nowait()
            except queue.Empty:
                return
            cmd = [args.script]
            if project:
                cmd += ["--project", project]
            if not args.no_isolate_progress_history:
                cmd += ["--job-name", _sanitize_filename(job.only_file)]
            cmd += ["--only-file", job.only_file, *extra_args]
            rc = _run_one(
                job,
                cmd,
                print_to_terminal=not args.no_terminal,
                stop_event=stop_event,
                running_procs=running_procs,
            )
            with results_lock:
                results[job.only_file] = rc
            work_q.task_done()

    threads = [threading.Thread(target=worker, daemon=True) for _ in range(min(jobs, len(only_files)))]
    for t in threads:
        t.start()
    for t in threads:
        t.join()

    failed = {f: rc for f, rc in results.items() if rc != 0}
    ok = len(results) - len(failed)
    print(f"\nLogs: {log_dir}")
    print(f"Done: {ok}/{len(only_files)} succeeded.")
    if failed:
        print("Failed:")
        for f, rc in sorted(failed.items()):
            print(f"  - {f}: exit {rc}")
        return 1
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
