from __future__ import annotations

"""
Sampling-time benchmarks with Triton CA fast-path wiring for:
- M2: TNP-D AR (re-encode)
- M3: Ours AR buffer
- M4: TNPA AR (masked re-encode)

Note: The Triton class is wired; if Triton is unavailable (or kernels not compiled),
the code falls back to SDPA with identical call semantics.
"""

import os
import sys
from pathlib import Path
from typing import Dict

import torch

# Allow running as a file or pasted into a notebook: add repo root to sys.path
if __package__ is None or __package__ == "":
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    ROOT = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            ROOT = cand
            break
    ROOT = ROOT or base
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

from scripts.fast_times.config import BenchConfig, DEFAULT
from scripts.fast_times.common import (
    configure_torch_env,
    _ensure_dir,
    _atomic_write_json,
    run_benchmark_grid,
    package_for_plot,
)
from scripts.fast_times.triton_m3_core import (
    ContextEmbedder,
    TargetEmbedder,
    BufferEmbedder,
    ContextEncoder,
    Decoder,
    GaussianHead,
    CompiledM3Adapter,
    HAS_TRITON,
    build_modules_for_m3,
)


def build_modules(cfg: BenchConfig):
    D = cfg.dims
    return build_modules_for_m3(
        dx=D.dx,
        dy=D.dy,
        d_model=D.d_model,
        n_heads=D.n_heads,
        n_layers_enc=D.n_layers_enc,
        n_layers_dec=D.n_layers_dec,
        d_ff=D.d_ff,
        device=cfg.runtime.device,
    )


def main(cfg: BenchConfig = DEFAULT) -> None:
    if not HAS_TRITON:
        raise RuntimeError("Triton is required for run_triton_sampling; please install Triton on a CUDA system.")
    configure_torch_env()
    # Match notebook debugging toggles (safe no-ops if unset)
    os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", os.path.expanduser("~/.cache/torchinductor"))
    os.environ.setdefault("TORCHDYNAMO_VERBOSE", "1")
    os.environ.setdefault("TORCH_LOGS", "+dynamo")
    emb_ctx, emb_tgt, emb_buf, head, ctx_enc, dec = build_modules(cfg)

    # Our method (M3) compiled adapter with Triton CA path
    m3 = CompiledM3Adapter(ctx_enc, dec, emb_ctx, emb_tgt, emb_buf, head).to(cfg.runtime.device)

    Nc_vals = tuple(cfg.grid.Nc_values)
    ns_vals = tuple(cfg.grid.num_samples_values)
    Nt = cfg.grid.Nt
    runs = cfg.grid.num_runs
    dx, dy = cfg.dims.dx, cfg.dims.dy


    methods: Dict[str, dict] = {}
    print("\n=== TRITON M3 (AR buffer) ===")
    methods["TR M3 Ours AR Buffer"] = run_benchmark_grid(
        m3, Nt=Nt, num_runs=runs, method="autoregressive",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )

    meta = {
        "script": "triton_sampling",
        "config": cfg.to_dict(),
    }
    out = package_for_plot(methods, meta=meta)
    out_dir = cfg.out_dir
    _ensure_dir(os.path.join(out_dir, "_"))
    out_path = os.path.join(out_dir, "triton_sampling.json")
    _atomic_write_json(out_path, out)
    print(f"Saved: {out_path}")


if __name__ == "__main__":
    main()
