from __future__ import annotations

"""
Compiled sampling-time benchmark for M3 (AR buffer) using Triton CA kernel.

This mirrors run_compiled_sampling.py but only runs the M3 method and
patches the decoder with Triton (Lq==1, no attn_mask), keeping embedders
and K/V handling identical to the SDPA path for an apples-to-apples kernel
comparison.
"""

import os
import sys
import copy
from pathlib import Path
from typing import Dict

import torch

# Allow running as a file or pasted into a notebook: add repo root to sys.path
if __package__ is None or __package__ == "":
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    ROOT = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            ROOT = cand
            break
    ROOT = ROOT or base
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

from scripts.fast_times.config import BenchConfig, DEFAULT
from scripts.fast_times.common import (
    configure_torch_env,
    _ensure_dir,
    _atomic_write_json,
    run_benchmark_grid,
    package_for_plot,
)
from scripts.fast_times.compiled_core import (
    ContextEmbedder, TargetEmbedder, BufferEmbedder,
    ContextEncoder, Decoder, GaussianHead, MvGaussianHead,
    CompiledFiveMethodAdapter, patch_decoder_with_triton,
)


def build_modules(cfg: BenchConfig):
    device = cfg.runtime.device
    dtype = torch.float16 if (device == "cuda" and torch.cuda.is_available()) else torch.float32
    D = cfg.dims
    emb_ctx = ContextEmbedder(D.dx + D.dy, D.d_model).to(device).to(dtype)
    emb_tgt = TargetEmbedder(D.dx, D.d_model).to(device).to(dtype)
    emb_buf = BufferEmbedder(D.dx + D.dy, D.d_model).to(device).to(dtype)
    head = GaussianHead(D.d_model, D.dy).to(device).to(dtype)
    mv_head = MvGaussianHead(d_model=D.d_model, dy=D.dy, n_heads=D.n_heads, d_ff=D.d_ff,
                             n_std_layers=2, prj_dim=8, bound_diag=True, min_diag=0.05).to(device).to(dtype)
    ctx_enc = ContextEncoder(D.d_model, D.n_heads, D.n_layers_enc, D.d_ff).to(device).to(dtype)
    dec = Decoder(D.d_model, D.n_heads, D.n_layers_dec, D.d_ff).to(device).to(dtype)

    # Optionally compile embedders/head/decoder for full-pipeline fairness
    def _compile_ok(mod: torch.nn.Module) -> torch.nn.Module:
        try:
            return torch.compile(mod, mode="default", fullgraph=False, dynamic=True)
        except Exception:
            return mod

    if os.environ.get("FAST_TIMES_COMPILE_MODULES", "1") not in {"0", "false", "False", "no", "NO"}:
        emb_ctx = _compile_ok(emb_ctx)
        emb_tgt = _compile_ok(emb_tgt)
        emb_buf = _compile_ok(emb_buf)
        head = _compile_ok(head)
    # Always compile decoder (surrounding FFN/LN), Triton CA runs inside
    dec = _compile_ok(dec)
    return emb_ctx, emb_tgt, emb_buf, head, mv_head, ctx_enc, dec


def main(cfg: BenchConfig = DEFAULT) -> None:
    configure_torch_env()
    emb_ctx, emb_tgt, emb_buf, head, mv_head, ctx_enc, dec = build_modules(cfg)

    # Triton-patched decoder for M3 only
    dec_triton = copy.deepcopy(dec)
    patch_decoder_with_triton(dec_triton, verbose=False, force_triton=True)

    # Build only the M3 adapter with Triton decoder
    m3_triton = CompiledFiveMethodAdapter("m3", ctx_enc, dec_triton, emb_ctx, emb_tgt, emb_buf, head)

    Nc_vals = tuple(cfg.grid.Nc_values)
    ns_vals = tuple(cfg.grid.num_samples_values)
    Nt = cfg.grid.Nt
    runs = cfg.grid.num_runs
    dx, dy = cfg.dims.dx, cfg.dims.dy

    methods: Dict[str, dict] = {}
    print("\n=== TRITON M3 (AR buffer; compiled_core) ===")
    methods["TR M3 Ours AR Buffer"] = run_benchmark_grid(
        m3_triton, Nt=Nt, num_runs=runs, method="autoregressive",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )

    meta = {
        "script": "compiled_sampling_m3_triton",
        "config": cfg.to_dict(),
    }
    out = package_for_plot(methods, meta=meta)
    out_dir = cfg.out_dir
    _ensure_dir(os.path.join(out_dir, "_"))
    out_path = os.path.join(out_dir, "compiled_sampling_m3_triton.json")
    _atomic_write_json(out_path, out)
    print(f"Saved: {out_path}")


if __name__ == "__main__":
    main()
