from __future__ import annotations

"""
Tiny smoke-run for baseline forward+backward timings. Runs a 1x1x1 grid on CPU.
Outputs JSON under outputs/fast_times/test.
"""

import sys
from pathlib import Path

# Allow running as a file or pasted into a notebook: add repo root to sys.path
if __package__ is None or __package__ == "":
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    ROOT = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            ROOT = cand
            break
    ROOT = ROOT or base
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

import os
# Allow CPU import path for fast_times config (which normally requires CUDA)
os.environ.setdefault("FAST_TIMES_CPU_OK", "1")

from scripts.fast_times.run_backward_baselines import (
    BackwardConfig,
    BackwardGrid,
    build_baseline_models,
    run_grid,
)
from scripts.fast_times.config import ModelDims, Runtime


def tiny_config() -> BackwardConfig:
    return BackwardConfig(
        grid=BackwardGrid(B_values=(1,), Nc_values=(16,), Nt_values=(4,), num_runs=1),
        dims=ModelDims(dx=1, dy=1, d_model=32, n_heads=2, n_layers_enc=1, n_layers_dec=1, d_ff=64),
        runtime=Runtime(device="cpu", dtype="auto"),
        out_dir="outputs/fast_times/test",
        name="backward_baselines_tiny",
    )


if __name__ == "__main__":
    cfg = tiny_config()
    # Build models on CPU and run the tiny grid, printing times only (no file save)
    m_tnpd, m_tnpa, m_tnpnd = build_baseline_models(cfg)
    models = {"TNPD": m_tnpd, "TNPA": m_tnpa, "TNP-ND": m_tnpnd}
    results = run_grid(models, cfg)
    print("\nTiny CPU forward+backward timings (prints only):")
    for label, d in results.items():
        print(f"\n[{label}]")
        for B, Nc, Nt, mean in zip(d["B"], d["Nc"], d["Nt"], d["mean_fwd_bwd"]):
            print(f"  B={B:>2} Nc={Nc:>4} Nt={Nt:>3}  fwd+bwd={mean:.6f}s")
