from __future__ import annotations

"""
ACE FlexAttention training-style benchmark (separate script).

- Initializes ACE (AmortizedConditioningEngine) with a MixtureGaussian head (2 components).
- Builds the ACE training mask (prefix | localized causal ctx+buf | chunked target-buffer), with
  the diagonal component disabled by default (no_diag=True).
- Times forward+backward (loss.backward) exactly like the other timing scripts.

Defaults: B=64,128,256; Nt=128,256,512; Nc=128,256,512,1024; runs=10.
"""

import os
import sys
import time
import math
import argparse
from pathlib import Path
from typing import Dict, List, Tuple, Any

import numpy as np
import torch
from torch.nn.attention.flex_attention import BlockMask, or_masks, create_block_mask
from src.models.masks import (
    generate_training_mask_mod_runtime,  # repo training mask generator (uncompiled)
)

# Repo import setup
if __package__ is None or __package__ == "":
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    ROOT = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            ROOT = cand
            break
    ROOT = ROOT or base
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

from scripts.fast_times.common import (
    configure_torch_env,
    _ensure_dir,
    _atomic_write_json,
)
from src.models.ace import AmortizedConditioningEngine
from src.models.modules.embedders import Embedder
from src.models.modules.backbones import Transformer
from src.models.modules.heads import NeuralProcessHead, MixtureGaussian
from src.utils import DataAttr


class SimpleGaussianHead(NeuralProcessHead):
    def __init__(self, dim_model: int, dim_feedforward: int, dim_y: int = 1):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim_model, dim_feedforward),
            torch.nn.GELU(),
            torch.nn.Linear(dim_feedforward, 2 * dim_y),
        )
        self.dim_y = dim_y

    def forward(self, zt: torch.Tensor, yt: torch.Tensor | None = None, *, loss_mask=None, num_samples: int = 0):
        out = self.net(zt)
        mu, log_sigma = torch.split(out, self.dim_y, dim=-1)
        log_sigma = log_sigma.clamp(-7.0, 3.0)
        sigma = torch.exp(log_sigma)
        if yt is not None:
            ll = -0.5 * (((yt - mu) / sigma) ** 2 + 2 * log_sigma + math.log(2 * math.pi))
            ll = ll.sum(dim=-1)
            loss = -ll.mean()
        else:
            ll = None
            loss = None
        return {"loss": loss, "mean": mu, "std": sigma, "log_likelihood": ll}

    def sample(self, zt: torch.Tensor, num_samples: int = 1) -> torch.Tensor:
        out = self.net(zt)
        mu, log_sigma = torch.split(out, self.dim_y, dim=-1)
        sigma = torch.exp(log_sigma)
        eps = torch.randn(zt.shape[0], zt.shape[1], num_samples, self.dim_y, device=zt.device, dtype=zt.dtype)
        return mu.unsqueeze(2) + sigma.unsqueeze(2) * eps


def build_training_mask_no_diag(total_len: int, *, context_len: int, buffer_len: int,
                                attending_chunks: int, q_block: int, kv_block: int,
                                device: str, no_diag: bool) -> BlockMask:
    """Use the repo training mask generator, drop diagonal at runtime if requested.

    Layout: [context | buffer | targets], total_len == context_len + buffer_len + target_len
    """
    base = generate_training_mask_mod_runtime(
        current_context_len=context_len,
        current_buffer_len=buffer_len,
        attending_chunks=attending_chunks,
    )

    if no_diag:
        def mask_mod(b, h, q_idx, kv_idx):
            return base(b, h, q_idx, kv_idx) & (q_idx != kv_idx)
        mask_mod.__name__ = f"train_nodiag_{context_len}_{buffer_len}_{attending_chunks}"
    else:
        mask_mod = base

    # Use the upstream (uncompiled) create_block_mask to avoid Dynamo/Inductor lowering on mask
    return create_block_mask(
        mask_mod,
        Q_LEN=total_len,
        KV_LEN=total_len,
        B=None,
        H=None,
        BLOCK_SIZE=(q_block, kv_block),
        device=device,
    )


def build_ace(dim_x: int, dim_y: int, d_model: int, n_heads: int, n_layers: int, d_ff: int, device: str):
    emb = Embedder(dim_x, dim_y, hidden_dim=d_ff, out_dim=d_model, depth=2, pos_emb_init=False).to(device)
    backbone = Transformer(num_layers=n_layers, dim_model=d_model, num_head=n_heads, dim_feedforward=d_ff).to(device)
    head = MixtureGaussian(dim_y=dim_y, dim_model=d_model, dim_feedforward=d_ff, num_components=2).to(device)
    ace = AmortizedConditioningEngine(
        embedder=emb, backbone=backbone, head=head,
        max_buffer_size=16, num_target_points=512,
        targets_block_size_for_buffer_attend=4,
    ).to(device)
    ace.train()
    return ace


def time_fwd_bwd(ace: AmortizedConditioningEngine, Nc: int, B: int, *, T: int, runs: int, device: str,
                 d_model: int, dx: int, dy: int, use_amp: bool, K: int, bm: BlockMask) -> Tuple[np.ndarray, float, float]:
    g = torch.Generator(device=device).manual_seed(0)
    dtype = torch.float16 if use_amp else torch.float32
    xc = torch.randn(B, Nc, dx, device=device, dtype=dtype, generator=g)
    yc = torch.randn(B, Nc, dy, device=device, dtype=dtype, generator=g)
    xb = torch.randn(B, K, dx, device=device, dtype=dtype, generator=g)
    yb = torch.randn(B, K, dy, device=device, dtype=dtype, generator=g)
    xt = torch.randn(B, T, dx, device=device, dtype=dtype, generator=g)
    yt = torch.randn(B, T, dy, device=device, dtype=dtype, generator=g)
    batch = DataAttr(xc=xc, yc=yc, xb=xb, yb=yb, xt=xt, yt=yt)
    scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
    # Warmup
    for _ in range(2):
        ace.zero_grad(set_to_none=True)
        if use_amp:
            with torch.amp.autocast('cuda', dtype=torch.float16):
                out = ace.forward(batch, bm)
            scaler.scale(out["loss"]).backward()
        else:
            out = ace.forward(batch, bm)
            out["loss"].backward()
        torch.cuda.synchronize()
        ace.zero_grad(set_to_none=True)
    torch.cuda.empty_cache()
    timings: List[float] = []
    for i in range(runs):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        ace.zero_grad(set_to_none=True)
        start.record()
        if use_amp:
            with torch.amp.autocast('cuda', dtype=torch.float16):
                out = ace.forward(batch, bm)
            scaler.scale(out["loss"]).backward()
        else:
            out = ace.forward(batch, bm)
            out["loss"].backward()
        end.record()
        torch.cuda.synchronize()
        timings.append(start.elapsed_time(end) / 1000.0)
        torch.cuda.empty_cache()
    t = np.array(timings, dtype=np.float64)
    return t, float(t.mean()), float(t.std())


def main() -> None:
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required for ACE flex mask benchmark.")
    configure_torch_env()

    ap = argparse.ArgumentParser(description="ACE FlexAttention fwd+bwd benchmark (Gaussian head)")
    ap.add_argument("--Nc", type=str, default="128,256,512,1024")
    ap.add_argument("--B", type=str, default="64,128,256")
    ap.add_argument("--T", type=str, default="128,256,512")
    ap.add_argument("--d_model", type=int, default=128)
    ap.add_argument("--n_heads", type=int, default=4)
    ap.add_argument("--n_layers", type=int, default=6)
    ap.add_argument("--d_ff", type=int, default=256)
    ap.add_argument("--q_block", type=int, default=128)
    ap.add_argument("--kv_block", type=int, default=128)
    ap.add_argument("--runs", type=int, default=10)
    ap.add_argument("--no_diag", dest="no_diag", action="store_true", help="Disallow diagonal (q==kv) in mask")
    ap.add_argument("--diag", dest="no_diag", action="store_false", help="Allow diagonal (q==kv) in mask")
    ap.add_argument("--no_amp", action="store_true", help="Disable autocast fp16 for timing")
    ap.add_argument("--out", type=str, default="outputs/fast_times/ace_flex_mask.json")
    ap.set_defaults(no_diag=True)
    args, _ = ap.parse_known_args()

    def _parse_list(s: str) -> Tuple[int, ...]:
        return tuple(int(x) for x in s.split(",") if x)

    Nc_vals = _parse_list(args.Nc)
    B_vals = _parse_list(args.B)
    T_vals = _parse_list(args.T)
    device = "cuda"
    use_amp = not args.no_amp

    # Helper to reset compilation state between Nc groups
    def _reset_and_build_model() -> AmortizedConditioningEngine:
        try:
            import torch._dynamo as dynamo
            dynamo.reset()
        except Exception:
            pass
        torch.cuda.empty_cache()
        import gc
        gc.collect()
        m = build_ace(dim_x=1, dim_y=1, d_model=args.d_model, n_heads=args.n_heads, n_layers=args.n_layers, d_ff=args.d_ff, device=device)
        try:
            import src.models.modules.backbones as backbones
            backbones.flex_attention = torch.compile(backbones.flex_attention, fullgraph=True)
        except Exception:
            pass
        return m

    methods: Dict[str, Dict[str, List[Any]]] = {"ACE (MoG2, no-diag train mask)": {"Nc": [], "B": [], "Nt": [], "mean_time": [], "std_time": [], "all_times": []}}

    total = len(Nc_vals) * len(B_vals) * len(T_vals)
    idx = 0
    K = 16
    for Nc in Nc_vals:
        # Full reset per Nc to avoid any stale compiled mask/artifacts
        ace = _reset_and_build_model()
        mask_cache: Dict[Tuple[int, int], BlockMask] = {}
        for T in T_vals:
            for B in B_vals:
                idx += 1
                print(f"[{idx}/{total}] Nc={Nc} Nt={T} B={B} (no_diag={bool(args.no_diag)})")
                key = (Nc, T)
                bm = mask_cache.get(key)
                if bm is None:
                    L = Nc + K + T
                    attending_chunks = {128: 4, 256: 8, 512: 16}.get(T, max(1, T // 32))
                    bm = build_training_mask_no_diag(
                        L, context_len=Nc, buffer_len=K, attending_chunks=attending_chunks,
                        q_block=args.q_block, kv_block=args.kv_block, device=device, no_diag=bool(args.no_diag)
                    )
                    mask_cache[key] = bm
                try:
                    times, mean_t, std_t = time_fwd_bwd(
                        ace, Nc=Nc, B=B, T=T, runs=args.runs, device=device,
                        d_model=args.d_model, dx=1, dy=1, use_amp=use_amp, K=K, bm=bm,
                    )
                    methods["ACE (MoG2, no-diag train mask)"]["Nc"].append(Nc)
                    methods["ACE (MoG2, no-diag train mask)"]["B"].append(B)
                    methods["ACE (MoG2, no-diag train mask)"]["Nt"].append(T)
                    methods["ACE (MoG2, no-diag train mask)"]["mean_time"].append(mean_t)
                    methods["ACE (MoG2, no-diag train mask)"]["std_time"].append(std_t)
                    methods["ACE (MoG2, no-diag train mask)"]["all_times"].append(times.tolist())
                    print(f"   mean={mean_t:.6f}s ± {std_t:.6f}s")
                except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
                    print(f"   skip due to error: {str(e)[:120]}")
                    methods["ACE (MoG2, no-diag train mask)"]["Nc"].append(Nc)
                    methods["ACE (MoG2, no-diag train mask)"]["B"].append(B)
                    methods["ACE (MoG2, no-diag train mask)"]["Nt"].append(T)
                    methods["ACE (MoG2, no-diag train mask)"]["mean_time"].append(-1.0)
                    methods["ACE (MoG2, no-diag train mask)"]["std_time"].append(-1.0)
                    methods["ACE (MoG2, no-diag train mask)"]["all_times"].append([])
                    torch.cuda.empty_cache()

    meta: Dict[str, Any] = {
        "script": "run_ace_flex_mask_bench",
        "config": {
            "Nc": list(Nc_vals), "B": list(B_vals), "Nt": list(T_vals),
            "d_model": args.d_model, "n_heads": args.n_heads, "n_layers": args.n_layers, "d_ff": args.d_ff,
            "q_block": args.q_block, "kv_block": args.kv_block, "runs": args.runs, "no_diag": bool(args.no_diag),
            "amp": bool(use_amp)
        }
    }

    out = {"metadata": meta, "methods": methods}
    _ensure_dir(os.path.join(Path(args.out).parent, "_"))
    _atomic_write_json(args.out, out)
    print(f"Saved: {args.out}")


if __name__ == "__main__":
    main()
