from __future__ import annotations

"""
Print-only forward+backward timings for TNPD/TNPA/TNP-ND on CPU.

Usage (defaults match the requested grid):
  uv run python -m scripts.fast_times.print_fwd_bwd_cpu \
      --B 32,64,128 --Nc 128,256,512 --Nt 64,128,256 --runs 3

No files are written; results are printed to stdout.
"""

import os
import sys
import argparse
from pathlib import Path

# Allow running as a file or pasted into a notebook: add repo root to sys.path
if __package__ is None or __package__ == "":
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    ROOT = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            ROOT = cand
            break
    ROOT = ROOT or base
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

# Force CPU OK for fast_times config import
os.environ.setdefault("FAST_TIMES_CPU_OK", "1")

from scripts.fast_times.run_backward_baselines import (
    BackwardConfig,
    BackwardGrid,
    build_baseline_models,
    run_grid,
)
from scripts.fast_times.config import ModelDims, Runtime


def _parse_int_list(s: str) -> tuple[int, ...]:
    s = s.strip()
    if not s:
        return tuple()
    return tuple(int(x) for x in s.split(","))


def main() -> None:
    ap = argparse.ArgumentParser(description="Print forward+backward times on CPU for TNPD/TNPA/TNP-ND")
    ap.add_argument("--B", type=str, default="32,64,128", help="Comma-separated batch sizes")
    ap.add_argument("--Nc", type=str, default="128,256,512", help="Comma-separated context sizes")
    ap.add_argument("--Nt", type=str, default="64,128,256", help="Comma-separated target sizes")
    ap.add_argument("--runs", type=int, default=3, help="Timed iterations per config")
    ap.add_argument("--ours-k", type=int, default=16, help="Add an 'OURS (TNPD +K ctx)' entry by using Nc+K for TNPD")
    ap.add_argument("--d_model", type=int, default=32)
    ap.add_argument("--n_heads", type=int, default=2)
    ap.add_argument("--n_layers", type=int, default=1, help="Encoder/decoder layers")
    ap.add_argument("--d_ff", type=int, default=64)
    # Accept unknown args (e.g., Jupyter passes "-f <kernel.json>")
    args, _unknown = ap.parse_known_args()

    B_vals = _parse_int_list(args.B)
    Nc_vals = _parse_int_list(args.Nc)
    Nt_vals = _parse_int_list(args.Nt)
    if not (B_vals and Nc_vals and Nt_vals):
        raise SystemExit("Please provide non-empty B, Nc, Nt lists")

    cfg = BackwardConfig(
        grid=BackwardGrid(B_values=B_vals, Nc_values=Nc_vals, Nt_values=Nt_vals, num_runs=args.runs),
        dims=ModelDims(dx=1, dy=1, d_model=args.d_model, n_heads=args.n_heads,
                       n_layers_enc=args.n_layers, n_layers_dec=args.n_layers, d_ff=args.d_ff),
        runtime=Runtime(device="cpu", dtype="auto"),
        out_dir="outputs/fast_times/test",
        name="print_only",
    )

    # Build models and run grid
    m_tnpd, m_tnpa, m_tnpnd = build_baseline_models(cfg)
    models = {"TNPD": m_tnpd, "TNPA": m_tnpa, "TNP-ND": m_tnpnd}
    results = run_grid(models, cfg)
    if args.ours_k and args.ours_k > 0:
        # Import helper locally to avoid circular import concerns
        from scripts.fast_times.run_backward_baselines import add_ours_tnpd
        add_ours_tnpd(results, models["TNPD"], cfg, K=args.ours_k)

    # Print results
    print("\nForward+Backward timings (CPU, print-only):")
    for label, d in results.items():
        print(f"\n[{label}]")
        rows = zip(d["B"], d["Nc"], d["Nt"], d["mean_fwd_bwd"])
        for B, Nc, Nt, mean in rows:
            print(f"  B={B:>4} Nc={Nc:>4} Nt={Nt:>4}  fwd+bwd={mean:.6f}s")


if __name__ == "__main__":
    main()
