# speed_bench_mamba.py
import os, json, math, time
import torch
import torch.utils.benchmark as benchmark
from einops import rearrange, repeat

# import the fused kernel -----------------------------------------------
# please install if you don't have it.
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined

mamba_chunk_scan_combined = torch.compile(mamba_chunk_scan_combined)


# -----------------------------------------------------------------------
def make_inputs(
    B,
    L,
    model_dim=512,
    v_dim=512,
    nheads=1,
    ngroups=1,
    device="cuda",
    dtype=torch.float16,
):
    """
    Produce random but shape-correct tensors for mamba_chunk_scan_combined.
    Returns tuple (x, dt, A, B, C, extra_kw)

    Output shapes:
    - x: [B, L, nheads, model_dim] - Input tensor
    - dt: [B, L, nheads] - Time delta parameters
    - A: [nheads] - SSM state matrix diagonal
    - Bm: [B, L, ngroups, v_dim] - SSM input projection
    - Cm: [B, L, ngroups, v_dim] - SSM output projection
    - D: [nheads] - Skip connection parameter
    """

    # branch X  ----------------------------------------------------------
    # [b, l, nh, hd]
    x = torch.randn(B, L, nheads, model_dim, device=device, dtype=dtype)

    # continuous time steps ---------------------------------------------
    dt = torch.rand(B, L, nheads, device=device, dtype=dtype).abs() + 1e-3

    # SSM parameters -----------------------------------------------------
    A = -torch.exp(
        torch.empty(nheads, device=device, dtype=dtype)
        .uniform_(math.log(1), math.log(16))
        .exp()
    )  # negative exp

    # [b, l, ng, v_dim]
    Bm = torch.randn(B, L, ngroups, v_dim, device=device, dtype=dtype)
    # [b, l, ng, v_dim]
    Cm = torch.randn_like(Bm)

    D = torch.ones(nheads, device=device, dtype=dtype)  # skip param
    return x, dt, A, Bm, Cm, D


# -----------------------------------------------------------------------
def bench_single(config, n_repeat=20, chunk_size=256):
    x, dt, A, Bm, Cm, D = make_inputs(**config)

    # warm-up to compile Triton kernel
    y = mamba_chunk_scan_combined(x, dt, A, Bm, Cm, chunk_size=chunk_size, D=D)
    torch.cuda.synchronize()

    # measure ------------------------------------------------------------
    t = benchmark.Timer(
        stmt="mamba_chunk_scan_combined(x, dt, A, Bm, Cm, chunk_size=chunk_size, D=D);"
        "torch.cuda.synchronize()",
        globals=dict(
            x=x,
            dt=dt,
            A=A,
            Bm=Bm,
            Cm=Cm,
            D=D,
            chunk_size=chunk_size,
            mamba_chunk_scan_combined=mamba_chunk_scan_combined,
        ),
    )
    m = t.timeit(n_repeat)
    tok_per_s = (config["B"] * config["L"]) / m.mean
    return m.mean, tok_per_s


# -----------------------------------------------------------------------
def sweep():
    B_range, L_range = [4], [16384]

    Q_DIM_RANGE = [128, 256, 512, 1024, 2048, 4096]  # for

    V_DIM_RANGE = [512, 1024, 2048, 4096]
    chunk_size = 2048  # 1024 # 512
    results = []
    for B in B_range:
        for L in L_range:
            for d_model in Q_DIM_RANGE:
                for v_dim in V_DIM_RANGE:
                    cfg = dict(B=B, L=L, model_dim=d_model, v_dim=v_dim)
                    t, tps = bench_single(cfg, chunk_size=chunk_size)

                    FLOPS = 4 * B * L * d_model * v_dim
                    gflops = FLOPS / t / 1e9
                    tflops = gflops / 1e3

                    print(
                        f"B={B:2d}  L={L:6d}  q_dim={d_model:4d}  v_dim={v_dim:4d} "
                        f"{gflops:>8.1f} gflops   {tflops:>8.1f} tflops"
                    )
                    results.append({**cfg, "tok_per_sec": tps, "time": t})

    os.makedirs("results", exist_ok=True)
    with open("results/mamba_scan_bench.json", "w") as f:
        json.dump(results, f, indent=2)


def test_flops():
    B = 4
    L = 65536
    d_model = 4096
    v_dim = 4096

    Q_DIM_RANGE = [512, 1024, 2048]

    Q_DIM_RANGE = [64, 128, 256]

    V_DIM_RANGE = [512, 1024, 2048, 3072]
    chunk_size = 2048  # 1024 # 512
    results = []

    for d_model in Q_DIM_RANGE:
        for v_dim in V_DIM_RANGE:
            cfg = dict(B=B, L=L, model_dim=d_model, v_dim=v_dim)
            t, tps = bench_single(cfg, chunk_size=chunk_size)

            FLOPS = 4 * B * L * d_model * v_dim
            gflops = FLOPS / t / 1e9
            tflops = gflops / 1e3

            print(
                f"B={B:2d}  L={L:6d}  q_dim={d_model:4d}  v_dim={v_dim:4d} "
                f"{gflops:>8.1f} gflops   {tflops:>8.1f} tflops"
            )
            results.append({**cfg, "tok_per_sec": tps, "time": t})


def compute_fwd_iters_per_second():
    B = 1
    L = 65536
    chunk_size = 2048

    n_repeat = 10

    # Q_DIM_RANGE = [128, 256, 512]
    # V_DIM_RANGE = [512, 1024, 2048, 4096]

    Q_DIM_RANGE = [128, 256, 512, 1024]
    V_DIM_RANGE = [2048, 4096, 8192, 16384, 32768]

    for d_model in Q_DIM_RANGE:
        for v_dim in V_DIM_RANGE:
            cfg = dict(B=B, L=L, model_dim=d_model, v_dim=v_dim)
            t, tps = bench_single(cfg, chunk_size=chunk_size)

            state_size = d_model * v_dim / 1e6  # MB

            print(
                f"Configuration: B: {B}, L: {L}, D: {d_model}, H: {v_dim}, chunk_size: {chunk_size}, state_size: {state_size} MB"
            )
            print(f"Time per call: {t*1e3:.2f} ms")


if __name__ == "__main__":
    # compute_fwd_iters_per_second()

    test_flops()
