import jax
from jax import numpy as jnp, lax
from jax.numpy import einsum
from jax.scipy.special import logsumexp
from einshape import jax_einshape as einshape
from functools import partial
from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu
import dataclasses
from time import time
import math
import numpy as np

#from .pallas_official_flash_attention import mha as mha_acausal_official

DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.float32).max)

@dataclasses.dataclass(frozen=True)
class BlockSizes:
    block_q: int = 64
    block_kv: int = 64
    block_q_dkv: int = 64
    block_kv_dkv: int = 64
    block_q_dq: int = 128
    block_kv_dq: int = 64

def mha_acausal_cudnn(q, k, v):
    lse = jnp.zeros(q.shape[:-1], dtype=jnp.float32)
    out = jax.nn.dot_product_attention(q, k, v, scale=1.0, implementation="cudnn")
    #out = jax.nn.dot_product_attention(q, k, v, scale=1.0)
    return lse, out

def mha_windowed_cudnn(q, k, v, window):
    lse = jnp.zeros(q.shape[:-1], dtype=jnp.float32)
    out = jax.nn.dot_product_attention(q, k, v, local_window_size=window, scale=1.0, implementation="cudnn", is_causal=True)
    return lse, out

def reference_acausal(q, k, v):
    """
    Reference implementation of acausal attention.
    Args:
        q: Query tensor of shape (NQ, D).
        k: Key tensor of shape (NK, D).
        v: Value tensor of shape (NK, V).
    Returns:
        lse: Log-sum-exp of attention weights of shape (NQ,).
        output: Output tensor of shape (NQ, V).
    """
    attn_scores = einsum('qd,kd->qk', q, k)
    attn_weights = jax.nn.softmax(attn_scores, axis=-1)
    output = einsum('qk,kv->qv', attn_weights, v)
    lse = logsumexp(attn_scores, axis=-1)
    return lse, output

def reference_windowed(q, k, v, window):
    attn_scores = einsum('qd,kd->qk', q, k)
    qindex = jnp.arange(q.shape[0])[:, None]
    kindex = jnp.arange(k.shape[0])[None, :]
    mask = (qindex - kindex <= window[0]) & (kindex - qindex <= window[1])
    mask = (qindex - kindex <= window[0]) & (kindex - qindex <= window[1])
    attn_scores = jnp.where(mask, attn_scores, -jnp.inf)
    attn_weights = jax.nn.softmax(attn_scores, axis=-1)
    output = einsum('qk,kv->qv', attn_weights, v)
    lse = logsumexp(attn_scores, axis=-1)
    return lse, output

def manual_acausal_bwd(q, k, v, lse, o, g_lse, g_o):
    """
    Manual backward pass for acausal attention.
    Args:
        lse: Log-sum-exp of attention weights of shape (NQ,).
        o: Output tensor of shape (NQ, V).
        g_lse: Gradient of lse with respect to the loss.
        g_o: Gradient of output with respect to the loss.
    Returns:
        g_q: Gradient with respect to queries of shape (NQ, D).
        g_k: Gradient with respect to keys of shape (NK, D).
        g_v: Gradient with respect to values of shape (NK, V).
    """
    attn_scores = einsum('qd,kd->qk', q, k)
    attn_weights = jax.nn.softmax(attn_scores, axis=-1)
    
    # Gradient w.r.t. output
    d_v = einsum('qv,qk->kv', g_o, attn_weights)
    gv = einsum('qv,kv->qk', g_o, v)
    gv = gv + (g_lse - einsum('qv,qv->q', g_o, o))[:, None]
    gv = gv * attn_weights
    d_k = einsum('qd,qk->kd', q, gv)
    d_q = einsum('qk,kd->qd', gv, k)
    
    return d_q, d_k, d_v

def attn_kernel(q_ref, k_ref, v_ref, o_ref, lse_ref, *, window:tuple[int, int] = None, block_sizes: BlockSizes, sm_scale: float = 1.0):
    N, D = k_ref.shape
    _, V = v_ref.shape
    block_q = block_sizes.block_q
    block_kv = block_sizes.block_kv

    start_q = pl.program_id(axis=0)
    o = jnp.zeros((block_q, V), dtype=jnp.float32)
    m = jnp.zeros(block_q, dtype=jnp.float32) - jnp.inf
    l = jnp.zeros(block_q, dtype=jnp.float32)

    #curr_q_slice = pl.dslice(0, block_q)
    q = pl.load(q_ref, (slice(None), slice(None)))

    def body(start_k, carry):
        o_prev, m_prev, l_prev = carry
        curr_k_slice = pl.dslice(start_k * block_kv, block_kv)

        k = pl.load(k_ref, (curr_k_slice, slice(None)))
        #print(f"shapes: k {k.shape}, q {q.shape}, o_prev {o_prev.shape}, m_prev {m_prev.shape}, l_prev {l_prev.shape}")
        qk_scale = math.log2(math.e)
        if sm_scale != 1.0:
            qk_scale = qk_scale * sm_scale
        qk = pl.dot(q, k, trans_b=True, allow_tf32=True)
        qk = qk * qk_scale
        if window is not None:
            span_q = start_q * block_q + jnp.arange(block_q)
            span_k = start_k * block_kv + jnp.arange(block_kv)
            forward_index = span_q[:, None] - span_k[None, :]
            mask = (forward_index <= window[0]) & (forward_index >= -window[1])
            qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
        #qk = einsum('qd, kd -> qk', q, k)
        m_curr = jnp.max(qk, axis=-1)
        #pl.debug_print("m_curr:", m_curr)
        m_next = jnp.maximum(m_prev, m_curr)
        correction = jnp.exp2(m_prev - m_next)
        l_prev_corr = l_prev * correction
        s_curr = jnp.exp2(qk - m_next[:, None])
        l_curr = jnp.sum(s_curr, axis=-1)
        l_next = l_prev_corr + l_curr
        o_prev_corr = correction[:, None] * o_prev
        v = pl.load(v_ref, (curr_k_slice, slice(None)))
        o_curr = pl.dot(s_curr.astype(v.dtype), v, allow_tf32=True)
        o_next = o_prev_corr + o_curr
        return o_next, m_next, l_next

    if window is None:
        lower_bound = 0
        upper_bound = pl.cdiv(N, block_kv)
    else:
        lower_bound = lax.max(0, start_q - pl.cdiv(window[0], block_kv))
        upper_bound = lax.min(pl.cdiv(N, block_kv), start_q + pl.cdiv(window[1], block_kv) + 1)

    o, m, l = lax.fori_loop(lower_bound, upper_bound, body, (o, m, l))

    o = o / l[:, None]

    pl.store(o_ref, (slice(None), slice(None)), o.astype(o_ref.dtype)) # note can give a mask here
    lse_ref[...] = math.log(2)*m + jnp.log(l)

@partial(jax.custom_vjp, nondiff_argnums=(3,4,5))
@partial(jax.jit, static_argnames=("window", "block_sizes", "sm_scale"))
def attn(q, k, v,
         window: tuple[int, int] = None,
         block_sizes: BlockSizes = BlockSizes(),
         sm_scale: float = 1.0,
        ):
    NQ, D = q.shape
    NK, V = v.shape
    block_q = min(block_sizes.block_q, NQ)
    block_kv = min(block_sizes.block_kv, NK)
    kernel = partial(attn_kernel, window=window, block_sizes=block_sizes, sm_scale=sm_scale)
    grid = (pl.cdiv(NQ, block_q),)
    in_specs = [
        pl.BlockSpec((block_q, D), lambda i: (i, 0)),
        pl.BlockSpec((NK, D), lambda _: (0, 0)),
        pl.BlockSpec((NK, V), lambda _: (0, 0)),
    ]
    out_shape = [
        jax.ShapeDtypeStruct((NQ, V), v.dtype),
        jax.ShapeDtypeStruct((NQ,), jnp.float32),
    ]
    out_specs = [
        pl.BlockSpec((block_q, V), lambda i: (i, 0)),
        pl.BlockSpec((block_q,), lambda i: (i,)),
    ]
    out, lse = pl.pallas_call(
        kernel,
        out_shape=out_shape, # note can use a pytree here
        grid=grid,
        in_specs=in_specs,
        out_specs=out_specs,
        compiler_params=plgpu.CompilerParams(
            num_warps=4,
            num_stages=1,
        ),
        name="attn_forward",
    )(q, k, v)
    return lse, out

def attn_dq_kernel(q_ref, k_ref, v_ref, lse_ref, do_ref, dlse_ref, dq_ref, *,
                   window: tuple[int, int] = None, block_sizes: BlockSizes, sm_scale: float = 1.0):
    N, D = k_ref.shape
    _, V = v_ref.shape
    block_q = block_sizes.block_q_dq
    block_kv = block_sizes.block_kv_dq
    
    dq = jnp.zeros((block_q, D), dtype=jnp.float32)
    
    start_q = pl.program_id(axis=-1)
    curr_q_slice = pl.dslice(start_q * block_q, block_q)
    q = q_ref[curr_q_slice, :]
    do = do_ref[curr_q_slice, :]
    lse = lse_ref[curr_q_slice]
    dlse = dlse_ref[curr_q_slice]
    lse = math.log2(math.e) * lse  # Convert to log2 scale
    
    def body(start_k, dq_prev):
        curr_k_slice = pl.dslice(start_k * block_kv, block_kv)
        v = v_ref[curr_k_slice, :]
        dov = pl.dot(do, v, trans_b=True, allow_tf32=True)
        dov = dov + dlse[:, None]
        k = k_ref[curr_k_slice, :]
        qk_scale = math.log2(math.e)
        if sm_scale != 1.0:
            qk_scale = qk_scale * sm_scale
        qk = pl.dot(q, k, trans_b=True, allow_tf32=True)
        qk = qk * qk_scale
        s = jnp.exp2(qk - lse[:, None])
        if window is not None:
            span_q = start_q * block_q + jnp.arange(block_q)
            span_k = start_k * block_kv + jnp.arange(block_kv)
            forward_index = span_q[:, None] - span_k[None, :]
            mask = (forward_index <= window[0]) & (forward_index >= -window[1])
            s = jnp.where(mask, s, 0.0)
        dscore = s * dov
        dq_curr = pl.dot(dscore.astype(k.dtype), k, allow_tf32=True) * sm_scale
        dq_next = dq_prev + dq_curr
        return dq_next

    if window is None:
        upper_bound = pl.cdiv(N, block_kv)
        lower_bound = 0
    else:
        #lower_bound = lax.max(0, start_q - pl.cdiv(window[0], block_kv))
        lower_bound = lax.max(0, (start_q * block_q - window[0]) // block_kv)# - pl.cdiv(window[0], block_kv))
        #upper_bound = lax.min(pl.cdiv(N, block_kv), start_q + pl.cdiv(window[1], block_kv))
        upper_bound = lax.min(pl.cdiv(N, block_kv), pl.cdiv(((start_q+1) * block_q + window[1]), block_kv) + 1)# + pl.cdiv(window[1], block_kv))

    dq = lax.fori_loop(lower_bound, upper_bound, body, dq)
    dq_ref[curr_q_slice] = dq.astype(dq_ref.dtype)

def attn_dkv_kernel(q_ref, k_ref, v_ref, lse_ref, do_ref, dlse_ref, dk_ref, dv_ref, *,
                    window: tuple[int, int] = None, block_sizes: BlockSizes, sm_scale: float = 1.0):
    N, D = q_ref.shape
    _, V = v_ref.shape
    block_q = block_sizes.block_q_dkv
    block_kv = block_sizes.block_kv_dkv
    
    dk = jnp.zeros((block_kv, D), dtype=jnp.float32)
    dv = jnp.zeros((block_kv, V), dtype=jnp.float32)
    
    start_k = pl.program_id(axis=-1)
    curr_k_slice = pl.dslice(start_k * block_kv, block_kv)
    k = k_ref[curr_k_slice, :]
    v = v_ref[curr_k_slice, :]
    
    def body(start_q, carry):
        dk_prev, dv_prev = carry
        curr_q_slice = pl.dslice(start_q * block_q, block_q)
        do = do_ref[curr_q_slice, :]
        dov = pl.dot(do, v, trans_b=True, allow_tf32=True)
        dlse = dlse_ref[curr_q_slice]
        dov = dov + dlse[:, None]
        q = q_ref[curr_q_slice, :]
        qk_scale = math.log2(math.e) # Convert to log2 scale
        if sm_scale != 1.0:
            qk_scale = qk_scale * sm_scale
        qk = pl.dot(q, k, trans_b=True, allow_tf32=True)
        qk = qk * qk_scale
        lse = lse_ref[curr_q_slice] * math.log2(math.e)  # Convert to log2 scale
        s = jnp.exp2(qk - lse[:, None])
        if window is not None:
            span_q = start_q * block_q + jnp.arange(block_q)
            span_k = start_k * block_kv + jnp.arange(block_kv)
            forward_index = span_q[:, None] - span_k[None, :]
            mask = (forward_index <= window[0]) & (forward_index >= -window[1])
            s = jnp.where(mask, s, 0.0)
        dv_curr = pl.dot(s.astype(v.dtype), do, trans_a=True, allow_tf32=True)
        dv_next = dv_prev + dv_curr
        dscore = s * dov
        dk_curr = pl.dot(dscore.astype(q.dtype), q, trans_a=True, allow_tf32=True) * sm_scale
        dk_next = dk_prev + dk_curr
        return dk_next, dv_next

    if window is None:
        upper_bound = pl.cdiv(N, block_q)
        lower_bound = 0
    else:
        lower_bound = lax.max(0, (start_k * block_kv - window[1]) // block_q)
        upper_bound = lax.min(pl.cdiv(N, block_q), pl.cdiv(((start_k+1) * block_kv + window[0]), block_q) + 1)
    dk, dv = lax.fori_loop(lower_bound, upper_bound, body, (dk, dv))
    dk_ref[curr_k_slice] = dk.astype(dk_ref.dtype)
    dv_ref[curr_k_slice] = dv.astype(dv_ref.dtype)

def attn_dqkv(q, k, v, lse, o, dlse, do, *,
            window: tuple[int, int] = None,
            block_sizes: BlockSizes = BlockSizes(),
            sm_scale: float = 1.0,
        ):
    NQ, D = q.shape
    NK, V = v.shape
    dlse = dlse.astype(jnp.float32) - jnp.sum(o * do, axis=-1).astype(jnp.float32)  # Adjust dlse with the output gradient
    block_q = block_sizes.block_q_dq
    assert NQ % block_q == 0, f"NQ {NQ} must be divisible by block_q {block_q}"
    dq = pl.pallas_call(
        partial(attn_dq_kernel, window=window, block_sizes=block_sizes, sm_scale=sm_scale),
        out_shape=q,  # note can use a pytree here
        grid=(NQ // block_q,),
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="attn_backward_q",
    )(q, k, v, lse.astype(jnp.float32), do, dlse.astype(jnp.float32))
    block_kv = block_sizes.block_kv_dkv
    assert NK % block_kv == 0, f"NK {NK} must be divisible by block_kv {block_kv}"
    dk, dv = pl.pallas_call(
        partial(attn_dkv_kernel, window=window, block_sizes=block_sizes, sm_scale=sm_scale),
        out_shape=(k, v),
        grid=(NK // block_kv,),
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=4),
        name="attn_backward_kv",
    )(q, k, v, lse.astype(jnp.float32), do, dlse.astype(jnp.float32))
    return dq, dk, dv

def _attn_forward(q, k, v, window: tuple[int, int] = None, block_sizes: BlockSizes = BlockSizes(), sm_scale: float = 1.0):
    lse, out = attn(q, k, v, window=window, block_sizes=block_sizes, sm_scale=sm_scale)
    return (lse, out), (q, k, v, lse, out)

def _attn_backward(window: tuple[int, int] | None, block_sizes: BlockSizes, sm_scale: float, res, grads):
    dlse, do = grads
    q, k, v, lse, o = res
    #print(f"Backward pass with window {window}, block sizes {block_sizes}, sm_scale {sm_scale}")
    #print(f"Dtypes: q {q.dtype}, k {k.dtype}, v {v.dtype}, lse {lse.dtype}, o {o.dtype}, dlse {dlse.dtype}, do {do.dtype}")
    dq, dk, dv = attn_dqkv(q, k, v, lse, o, dlse, do, window=window, block_sizes=block_sizes, sm_scale=sm_scale)
    #print(f"Return Dtypes: dq {dq.dtype}, dk {dk.dtype}, dv {dv.dtype}")
    return dq, dk, dv

attn.defvjp(_attn_forward, _attn_backward, optimize_remat=True)





def sq_err(ref, x):
    ref_msq = jnp.mean(jnp.square(ref))
    err_msq = jnp.mean(jnp.square(x - ref))
    return jnp.where(ref_msq > 0., err_msq / ref_msq, jnp.inf)

def benchmark_fn(fn, args, ref_out, *, warmup_iters=100, iters=100):
    start_time = time()
    jit_fn = jax.jit(fn)
    fn_out = jax.block_until_ready(jit_fn(*args))
    end_time = time()
    sq_errs = jax.tree.map(lambda ref, x: float(sq_err(ref, x)), ref_out, fn_out)
    print(f"First call took {end_time - start_time:.4f} seconds.")
    print(f"Output errors: {sq_errs}")
    # Warmup
    for _ in range(warmup_iters):
        jax.block_until_ready(jit_fn(*args))
    # Benchmark
    start_time = time()
    for _ in range(iters):
        jax.block_until_ready(jit_fn(*args))
    end_time = time()
    print(f"Average time over {iters} iterations: {1e3 * (end_time - start_time) / iters:.4f} ms.")





def main():
    B = 4
    H = 8
    N = 2**12
    D = 64
    V = D
    window = (1024,0)

    dtype = jnp.bfloat16
    scale = jnp.array(1.0 / (D ** 0.25), dtype=dtype)

    q = jax.random.normal(jax.random.PRNGKey(0), (B, N, H, D), dtype=dtype) / scale
    k = jax.random.normal(jax.random.PRNGKey(1), (B, N, H, D), dtype=dtype) / scale
    v = jax.random.normal(jax.random.PRNGKey(2), (B, N, H, V), dtype=dtype)
    g_o = jax.random.normal(jax.random.PRNGKey(3), (B, N, H, V), dtype=dtype)
    g_lse = jax.random.normal(jax.random.PRNGKey(4), (B, N, H), dtype=dtype) * 0.

    mha_acausal_ref = jax.vmap(jax.vmap(reference_acausal, in_axes=1, out_axes=1))
    mha_windowed_ref = jax.vmap(jax.vmap(reference_windowed, in_axes=(1, 1, 1, None), out_axes=1), in_axes=(0, 0, 0, None))
    mha_acausal_attn = jax.vmap(jax.vmap(attn, in_axes=1, out_axes=1))
    mha_windowed_attn = jax.vmap(jax.vmap(partial(attn, window=window), in_axes=(1, 1, 1), out_axes=1), in_axes=(0, 0, 0))
    fp32_ref_out, mha_acausal_ref_bwd = jax.vjp(mha_acausal_ref, q.astype(jnp.float32), k.astype(jnp.float32), v.astype(jnp.float32))
    fp32_win_out, mha_windowed_ref_bwd = jax.vjp(mha_windowed_ref, q.astype(jnp.float32), k.astype(jnp.float32), v.astype(jnp.float32), window)
    mha_acausal_man_bwd = jax.vmap(jax.vmap(manual_acausal_bwd, in_axes=1, out_axes=1))
    mha_acausal_qkv_bwd = jax.vmap(jax.vmap(attn_dqkv, in_axes=1, out_axes=1))
    mha_windowed_qkv_bwd = jax.vmap(jax.vmap(partial(attn_dqkv, window=window), in_axes=1, out_axes=1))
    cudnn_out, mha_acausal_cudnn_bwd = jax.vjp(mha_acausal_cudnn, q, k, v)
    cudnn_win_out, mha_windowed_cudnn_bwd = jax.vjp(partial(mha_windowed_cudnn, window=window), q, k, v)

    @partial(jax.grad, argnums=(0, 1, 2))
    def mha_acausal_cudnn_grad(q, k, v, g_lse, g_o):
        lse, out = mha_acausal_cudnn(q, k, v)
        loss = jnp.sum(g_lse * lse) + jnp.sum(g_o * out)
        return loss

    @partial(jax.grad, argnums=(0, 1, 2))
    def mha_acausal_attn_grad(q, k, v, g_lse, g_o):
        lse, out = mha_acausal_attn(q, k, v)
        loss = jnp.sum(g_lse * lse) + jnp.sum(g_o * out)
        return loss

    @partial(jax.grad, argnums=(0, 1, 2))
    def mha_acausal_official_grad(q, k, v, g_lse, g_o):
        out = mha_acausal_official(q, k, v, segment_ids=None, sm_scale=1.0, causal=False, num_stages=1)
        loss = jnp.sum(g_o * out)
        return loss

    @partial(jax.grad, argnums=(0, 1, 2))
    def mha_windowed_cudnn_grad(q, k, v, g_lse, g_o):
        lse, out = mha_windowed_cudnn(q, k, v, window)
        loss = jnp.sum(g_lse * lse) + jnp.sum(g_o * out)
        return loss

    @partial(jax.grad, argnums=(0, 1, 2))
    def mha_windowed_attn_grad(q, k, v, g_lse, g_o):
        lse, out = mha_windowed_attn(q, k, v)
        loss = jnp.sum(g_lse * lse) + jnp.sum(g_o * out)
        return loss


    # Compute reference acausal attention
    ref_out = mha_acausal_ref(q, k, v)
    ref_bwd_out = mha_acausal_ref_bwd((g_lse.astype(jnp.float32), g_o.astype(jnp.float32)))
    ref_win_bwd_out = mha_windowed_ref_bwd((g_lse.astype(jnp.float32), g_o.astype(jnp.float32)))[:-1]
    #print(jax.tree.map(lambda x: x.shape, ref_bwd_out))
    #fwd_lse, fwd_out = mha_acausal_attn(q, k, v)
    # Compare results
    #print(f"Forward error: out {sq_err(ref_out, fwd_out):.6f}, lse {sq_err(ref_lse, fwd_lse):.6f}")

    print("Benchmarking official gradient...")
    benchmark_fn(mha_acausal_official_grad, (q, k, v, g_lse, g_o), ref_bwd_out)

    # Benchmark
    print("Benchmarking reference implementation...")
    benchmark_fn(mha_acausal_ref, (q, k, v), fp32_ref_out)
    print("Benchmarking cudnn implementation...")
    benchmark_fn(mha_acausal_cudnn, (q, k, v), fp32_ref_out)
    print("Benchmarking Triton implementation...")
    benchmark_fn(mha_acausal_attn, (q, k, v), fp32_ref_out)

    print("Benchmarking windowed reference implementation...")
    benchmark_fn(mha_windowed_ref, (q, k, v, window), fp32_win_out)
    print("Benchmarking windowed cudnn implementation...")
    benchmark_fn(partial(mha_windowed_cudnn, window=window), (q, k, v), fp32_win_out)
    print("Benchmarking windowed Triton implementation...")
    benchmark_fn(mha_windowed_attn, (q, k, v), fp32_win_out)

    print("Benchmarking manual backward pass...")
    benchmark_fn(mha_acausal_man_bwd, (q, k, v, ref_out[0], ref_out[1], g_lse, g_o), ref_bwd_out)
    print("Benchmarking cudnn backward pass...")
    benchmark_fn(mha_acausal_cudnn_bwd, ((g_lse.astype(jnp.float32), g_o),), ref_bwd_out)
    print("Benchmarking Triton dqkv pass...")
    benchmark_fn(mha_acausal_qkv_bwd, (q, k, v, fp32_ref_out[0], fp32_ref_out[1].astype(dtype), g_lse, g_o), ref_bwd_out)

    print("Benchmarking windowed cudnn backward pass...")
    benchmark_fn(mha_windowed_cudnn_bwd, ((g_lse.astype(jnp.float32), g_o),), ref_win_bwd_out)
    print("Benchmarking windowed Triton dqkv pass...")
    benchmark_fn(mha_windowed_qkv_bwd, (q, k, v, fp32_win_out[0], fp32_win_out[1].astype(dtype), g_lse, g_o), ref_win_bwd_out)

    print("Benchmarking cudnn gradient...")
    benchmark_fn(mha_acausal_cudnn_grad, (q, k, v, g_lse, g_o), ref_bwd_out)
    print("Benchmarking Triton gradient...")
    benchmark_fn(mha_acausal_attn_grad, (q, k, v, g_lse, g_o), ref_bwd_out)
    print("Benchmarking windowed cudnn gradient...")
    benchmark_fn(mha_windowed_cudnn_grad, (q, k, v, g_lse, g_o), ref_win_bwd_out)
    print("Benchmarking windowed Triton gradient...")
    benchmark_fn(mha_windowed_attn_grad, (q, k, v, g_lse, g_o), ref_win_bwd_out)


if __name__ == "__main__":
    main()

