import torch
import triton
import triton.language as tl
import torch.nn.functional as F

import time
from typing import List, Optional, Tuple, Union
from einops import rearrange, repeat

from fla.utils import contiguous
from fla.ops.utils.softmax import softmax_bwd
from fla.ops.gla import chunk_gla

import pdb

def sizeof_fmt(num, suffix='B'):
    for unit in ('', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi'):
        if abs(num) < 1024.0:
            return f'{num:3.1f}{unit}{suffix}'
        num /= 1024.0
    return f'{num:.1f}Yi{suffix}'


def torch_impl(q2, k2, v2, gk2, eta):
    # q2: [b, l, n, h, d]
    # k2: [b, l, n, h, d]
    # v2: [b, l, n, h, d]
    # gk2: [b, l, n, h, d]
    # eta: [b, l, n]

    q2, k2, v2, gk2 = [repeat(x.unsqueeze(2), "b l 1 h d -> b l n h d", n=nsp) for x in [q2, k2, v2, gk2]]

    eta = F.softmax(eta, dim=-1)
    topk_values_r, topk_indices_r = torch.topk(eta, num_reader, dim=-1)
    topk_values_w, topk_indices_w = torch.topk(eta, num_writer, dim=-1)
    # topk_values_w, topk_indices_w = topk_values_r[:,:,:num_writer], topk_indices_r[:,:,:num_writer]
    eta_topk_r = torch.zeros_like(eta)
    eta_topk_w = torch.zeros_like(eta)
    eta_topk_r.scatter_(-1, topk_indices_r, topk_values_r)
    eta_topk_w.scatter_(-1, topk_indices_w, topk_values_w)
    mask_r = eta_topk_r.bool().int()
    mask_w = eta_topk_w.bool().int()

    mask_w = mask_w[..., None, None]
    mask_r = mask_r[..., None, None]
    k2 = k2 * mask_w
    v2 = v2 * mask_w
    gk2 = gk2 * mask_w
    q2 = q2 * mask_r
    
    q2 = q2 * eta_topk_r[..., None, None]  # choice1: e mul q
    k2 = k2 * eta_topk_w[..., None, None]  # choice2: e mul k

    return q2, k2, v2, gk2, eta, mask_w, mask_r


@triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps, num_stages=num_stages)
        for num_warps in [2, 4, 8]
        for num_stages in [2, 3]
    ],
    key=['BN'],
)
@triton.jit
def _fused_softmax_topk_fwd_kernel(
    e,
    e_o,
    mw,
    mr,
    stride_e_b,
    stride_e_l,
    B,
    T,
    N,
    NUM_WRITER: tl.constexpr,
    NUM_READER: tl.constexpr,
    BN: tl.constexpr,
):
    i_b, i_t = tl.program_id(0), tl.program_id(1)

    offsets_n = tl.arange(0, BN)
    mask_n = offsets_n < N
    p_e = e + i_b * stride_e_b + i_t * stride_e_l + offsets_n
    p_e_o = e_o + i_b * stride_e_b + i_t * stride_e_l + offsets_n
    p_mw = mw + i_b * stride_e_b + i_t * stride_e_l + offsets_n
    p_mr = mr + i_b * stride_e_b + i_t * stride_e_l + offsets_n

    ### stable softmax and topk ###
    b_e = tl.load(p_e, mask=mask_n, other=-float('inf')).to(tl.float32)
    b_m = tl.max(b_e, axis=0)
    b_e = tl.exp(b_e - b_m)
    b_p = b_e / tl.sum(b_e, axis=0)
    b_p = tl.where(mask_n, b_p.to(p_e.dtype.element_ty), -float('inf'))
    b_ps = tl.sort(b_p, descending=True)
    tl.store(p_e_o, b_p.to(p_e_o.dtype.element_ty), mask=mask_n)

    mask_w = tl.full((BN,), 1, dtype=b_p.dtype)
    if NUM_WRITER < N:
        threshold_w = tl.sum(b_ps * (offsets_n == NUM_WRITER - 1))        
        mask_w_gr = b_p > threshold_w
        need = NUM_WRITER - tl.sum(mask_w_gr.to(tl.int32))
        mask_w_eq = b_p == threshold_w
        mask_w_eq_need = mask_w_eq & (tl.cumsum(mask_w_eq.to(tl.int32), axis=0) <= need)
        mask_w = mask_w_gr | mask_w_eq_need
        mask_w = mask_w.to(b_p.dtype)
    tl.store(p_mw, mask_w.to(p_mw.dtype.element_ty), mask=mask_n)

    mask_r = tl.full((BN,), 1, dtype=b_p.dtype)
    if NUM_READER < N:
        threshold_r = tl.sum(b_ps * (offsets_n == NUM_READER - 1))
        mask_r_gr = b_p > threshold_r
        need = NUM_READER - tl.sum(mask_r_gr.to(tl.int32))
        mask_r_eq = b_p == threshold_r
        mask_r_eq_need = mask_r_eq & (tl.cumsum(mask_r_eq.to(tl.int32), axis=0) <= need)
        mask_r = mask_r_gr | mask_r_eq_need
        mask_r = mask_r.to(b_p.dtype)
    tl.store(p_mr, mask_r.to(p_mr.dtype.element_ty), mask=mask_n)


@triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps, num_stages=num_stages)
        for num_warps in [2, 4, 8]
        for num_stages in [2, 3]
    ],
    key=['BN', 'BK', 'BV'],
)
@triton.jit
def _fused_mask_fwd_kernel(
    q, k, v, g, e, mw, mr,
    q_o, k_o, v_o, g_o,
    stride_k_b, stride_k_l, stride_k_h,
    stride_v_b, stride_v_l, stride_v_h,
    stride_e_b, stride_e_l,
    B, T, N, H, K, V,
    BN: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
):
    i_b, i_t, i_h = tl.program_id(0), tl.program_id(1), tl.program_id(2)

    offsets_n = tl.arange(0, BN)
    offsets_k = tl.arange(0, BK)
    offsets_v = tl.arange(0, BV)
    mask_n = offsets_n < N
    mask_k = offsets_k < K
    mask_v = offsets_v < V

    p_e = e + i_b * stride_e_b + i_t * stride_e_l + offsets_n
    p_mw = mw + i_b * stride_e_b + i_t * stride_e_l + offsets_n
    p_mr = mr + i_b * stride_e_b + i_t * stride_e_l + offsets_n

    p_q = q + i_b * stride_k_b + i_t * stride_k_l + i_h * stride_k_h + offsets_k
    p_k = k + i_b * stride_k_b + i_t * stride_k_l + i_h * stride_k_h + offsets_k
    p_g = g + i_b * stride_k_b + i_t * stride_k_l + i_h * stride_k_h + offsets_k
    p_v = v + i_b * stride_v_b + i_t * stride_v_l + i_h * stride_v_h + offsets_v
    p_q_o = q_o + i_b * stride_k_b * N + i_t * stride_k_l * N + i_h * stride_k_h \
        + offsets_n[:, None] * H * K + offsets_k[None, :]
    p_k_o = k_o + i_b * stride_k_b * N + i_t * stride_k_l * N + i_h * stride_k_h \
        + offsets_n[:, None] * H * K + offsets_k[None, :]
    p_g_o = g_o + i_b * stride_k_b * N + i_t * stride_k_l * N + i_h * stride_k_h \
        + offsets_n[:, None] * H * K + offsets_k[None, :]
    p_v_o = v_o + i_b * stride_v_b * N + i_t * stride_v_l * N + i_h * stride_v_h \
        + offsets_n[:, None] * H * V + offsets_v[None, :]

    b_e = tl.load(p_e, mask=mask_n, other=0.)
    mask_w = tl.load(p_mw, mask=mask_n, other=0.).to(b_e.dtype)
    mask_r = tl.load(p_mr, mask=mask_n, other=0.).to(b_e.dtype)
    b_e_topk_w = b_e * mask_w
    b_e_topk_r = b_e * mask_r

    ### mask qkvg ###
    b_q = tl.load(p_q, mask=mask_k, other=0.)
    b_q = b_q[None, :] * b_e_topk_r[:, None]

    b_k = tl.load(p_k, mask=mask_k, other=0.)
    b_k = b_k[None, :] * b_e_topk_w[:, None]

    b_g = tl.load(p_g, mask=mask_k, other=0.)
    b_g = b_g[None, :] * mask_w[:, None]
    
    b_v = tl.load(p_v, mask=mask_v, other=0.)
    b_v = b_v[None, :] * mask_w[:, None]

    mask_nk = mask_n[:, None] & mask_k[None, :]
    mask_nv = mask_n[:, None] & mask_v[None, :]
    tl.store(p_q_o, b_q.to(p_q_o.dtype.element_ty), mask=mask_nk)
    tl.store(p_k_o, b_k.to(p_k_o.dtype.element_ty), mask=mask_nk)
    tl.store(p_g_o, b_g.to(p_g_o.dtype.element_ty), mask=mask_nk)
    tl.store(p_v_o, b_v.to(p_v_o.dtype.element_ty), mask=mask_nv)


@triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps, num_stages=num_stages)
        for num_warps in [2, 4, 8]
        for num_stages in [2, 3]
    ],
    key=['BN', 'BK', 'BV'],
)
@triton.jit
def _fused_mask_bwd_kernel(
    q, k, e, mw, mr,
    dq_o, dk_o, dv_o, dg_o,
    dq, dk, dv, dg, de,
    stride_k_b, stride_k_l, stride_k_h,
    stride_v_b, stride_v_l, stride_v_h,
    stride_e_b, stride_e_l,
    B, T, N, H, K, V,
    BN: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
):
    i_b, i_t, i_h = tl.program_id(0), tl.program_id(1), tl.program_id(2)

    offsets_n = tl.arange(0, BN)
    offsets_k = tl.arange(0, BK)
    offsets_v = tl.arange(0, BV)
    mask_n = offsets_n < N
    mask_k = offsets_k < K
    mask_v = offsets_v < V

    p_e = e + i_b * stride_e_b + i_t * stride_e_l + offsets_n
    p_de = de + (i_b * stride_e_b + i_t * stride_e_l + offsets_n) * H + i_h
    p_mw = mw + i_b * stride_e_b + i_t * stride_e_l + offsets_n
    p_mr = mr + i_b * stride_e_b + i_t * stride_e_l + offsets_n

    p_q = q + i_b * stride_k_b + i_t * stride_k_l + i_h * stride_k_h + offsets_k
    p_k = k + i_b * stride_k_b + i_t * stride_k_l + i_h * stride_k_h + offsets_k
    p_dq = dq + i_b * stride_k_b + i_t * stride_k_l + i_h * stride_k_h + offsets_k
    p_dk = dk + i_b * stride_k_b + i_t * stride_k_l + i_h * stride_k_h + offsets_k
    p_dg = dg + i_b * stride_k_b + i_t * stride_k_l + i_h * stride_k_h + offsets_k
    p_dv = dv + i_b * stride_v_b + i_t * stride_v_l + i_h * stride_v_h + offsets_v
    p_dq_o = dq_o + i_b * stride_k_b * N + i_t * stride_k_l * N + i_h * stride_k_h \
        + offsets_n[:, None] * H * K + offsets_k[None, :]
    p_dk_o = dk_o + i_b * stride_k_b * N + i_t * stride_k_l * N + i_h * stride_k_h \
        + offsets_n[:, None] * H * K + offsets_k[None, :]
    p_dg_o = dg_o + i_b * stride_k_b * N + i_t * stride_k_l * N + i_h * stride_k_h \
        + offsets_n[:, None] * H * K + offsets_k[None, :]
    p_dv_o = dv_o + i_b * stride_v_b * N + i_t * stride_v_l * N + i_h * stride_v_h \
        + offsets_n[:, None] * H * V + offsets_v[None, :]

    b_e = tl.load(p_e, mask=mask_n, other=0.)
    mask_w = tl.load(p_mw, mask=mask_n, other=0.).to(b_e.dtype)
    mask_r = tl.load(p_mr, mask=mask_n, other=0.).to(b_e.dtype)
    b_e_topk_w = b_e * mask_w
    b_e_topk_r = b_e * mask_r

    mask_nk = mask_n[:, None] & mask_k[None, :]
    mask_nv = mask_n[:, None] & mask_v[None, :]
    b_dq_o = tl.load(p_dq_o, mask=mask_nk, other=0.)
    b_dk_o = tl.load(p_dk_o, mask=mask_nk, other=0.)
    b_dg_o = tl.load(p_dg_o, mask=mask_nk, other=0.)
    b_dv_o = tl.load(p_dv_o, mask=mask_nv, other=0.)
    b_dq = tl.sum((b_dq_o * b_e_topk_r[:, None]).to(tl.float32), axis=0).to(b_dq_o.dtype)
    b_dk = tl.sum((b_dk_o * b_e_topk_w[:, None]).to(tl.float32), axis=0).to(b_dk_o.dtype)
    b_dg = tl.sum((b_dg_o * mask_w[:, None]).to(tl.float32), axis=0).to(b_dg_o.dtype)
    b_dv = tl.sum((b_dv_o * mask_w[:, None]).to(tl.float32), axis=0).to(b_dv_o.dtype)

    b_q = tl.load(p_q, mask=mask_k, other=0.)
    b_k = tl.load(p_k, mask=mask_k, other=0.)
    b_de = b_dq_o * b_q[None, :] * mask_r[:, None] + b_dk_o * b_k[None, :] * mask_w[:, None]
    b_de = tl.sum(b_de.to(tl.float32), axis=1).to(b_de.dtype)  

    tl.store(p_de, b_de.to(p_de.dtype.element_ty), mask=mask_n)
    tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k)
    tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
    tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask_k)
    tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)
        

class SoftmaxAndMask(torch.autograd.Function):
    r"""
    Applies softmax to router weights, repeats and masks inputs,
    scales queries and keys with the router weights, and generates reader/writer masks.

    Notation:
        B: batch size
        T: sequence length
        H: number of attention heads
        K: key/query head dimension
        V: value head dimension
        N: number of state partitions

    Args:
        q (torch.Tensor):
            Queries of shape `(B, T, H, K)`.
        k (torch.Tensor):
            Keys of shape `(B, T, H, K)`.
        v (torch.Tensor):
            Values of shape `(B, T, H, V)`.
        g (torch.Tensor):
            Gates of shape `(B, T, H, V)`.
        e (torch.Tensor):
            Router weights before softmax of shape `(B, T, N)`.
        num_writer (int):
            Number of state partitions to write.
        num_reader (int):
            Number of state partitions to read.

    Returns:
        q_out (torch.Tensor):
            Repeated and masked queries of shape `(B, T, N * H, K)`.
        k_out (torch.Tensor):
            Repeated and masked keys of shape `(B, T, N * H, K)`.
        v_out (torch.Tensor):
            Repeated and masked values of shape `(B, T, N * H, V)`.
        g_out (torch.Tensor):
            Repeated and masked gates of shape `(B, T, N * H, V)`.
        e_out (torch.Tensor):
            Router weights after softmax of shape `(B, T, N)`.
        mask_w (torch.Tensor):
            Writer mask of shape `(B, T, N)`.
        mask_r (torch.Tensor):
            Reader mask of shape `(B, T, N)`.
    """

    @staticmethod
    @contiguous
    def forward(ctx, q, k, v, g, e, num_writer, num_reader):
        B, T, H, K, V, N = *k.shape, v.shape[-1], e.shape[-1]
        BN = triton.next_power_of_2(N)
        BK = triton.next_power_of_2(K)
        BV = triton.next_power_of_2(V)

        q_out = q.new_empty(B, T, N * H, K)
        k_out = k.new_empty(B, T, N * H, K)
        v_out = v.new_empty(B, T, N * H, V)
        g_out = g.new_empty(B, T, N * H, K)
        e_out = torch.empty_like(e)
        mask_w = torch.empty_like(e, dtype=torch.int32)
        mask_r = torch.empty_like(e, dtype=torch.int32)

        _fused_softmax_topk_fwd_kernel[(B, T)](
            e,
            e_out,
            mask_w,
            mask_r,
            e.stride(0),
            e.stride(1),
            B,
            T,
            N,
            NUM_WRITER=num_writer,
            NUM_READER=num_reader,
            BN=BN,
        )

        _fused_mask_fwd_kernel[(B, T, H)](
            q, k, v, g, e_out, mask_w, mask_r,
            q_out, k_out, v_out, g_out,
            k.stride(0), k.stride(1), k.stride(2),
            v.stride(0), v.stride(1), v.stride(2),
            e.stride(0), e.stride(1),
            B, T, N, H, K, V,
            BN=BN,
            BK=BK,
            BV=BV,
        )
        
        ctx.save_for_backward(q, k, v, g, e_out, mask_w, mask_r)
        ctx.num_writer = num_writer
        ctx.num_reader = num_reader
        return q_out, k_out, v_out, g_out, e_out, mask_w, mask_r

    @staticmethod
    @contiguous
    def backward(ctx, dq_out, dk_out, dv_out, dg_out, de_out, dmask_w, dmask_r):
        q, k, v, g, e_out, mask_w, mask_r = ctx.saved_tensors

        B, T, H, K, V, N = *k.shape, v.shape[-1], e_out.shape[-1]
        BN = triton.next_power_of_2(N)
        BK = triton.next_power_of_2(K)
        BV = triton.next_power_of_2(V)

        dq = torch.empty_like(q)
        dk = torch.empty_like(k)
        dv = torch.empty_like(v)
        dg = torch.empty_like(g)
        de = g.new_empty(B, T, N, H)

        grid = (B, T, H)
        
        _fused_mask_bwd_kernel[grid](
            q, k, e_out, mask_w, mask_r,
            dq_out, dk_out, dv_out, dg_out,
            dq, dk, dv, dg, de,
            k.stride(0), k.stride(1), k.stride(2),
            v.stride(0), v.stride(1), v.stride(2),
            e_out.stride(0), e_out.stride(1),
            B, T, N, H, K, V,
            BN=BN,
            BK=BK,
            BV=BV,
        )

        de = de.sum(dim=-1).add_(de_out)
        de = softmax_bwd(e_out, de, dtype=de.dtype)
        
        return dq.to(q), dk.to(k), dv.to(v), dg.to(g), de.to(e_out), None, None

softmax_and_mask = SoftmaxAndMask.apply


def get_abs_err(x, y):
    return (x.detach()-y.detach()).flatten().abs().max().item()


def get_err_ratio(x, y):
    err = (x.detach()-y.detach()).flatten().square().mean().sqrt().item()
    base = (x.detach()).flatten().square().mean().sqrt().item()
    return err / (base + 1e-8)


def assert_close(prefix, ref, tri, ratio, err_atol=1e-6):
    abs_atol = get_abs_err(ref, tri)
    msg = f"{prefix} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
    error_rate = get_err_ratio(ref, tri)
    if abs_atol <= err_atol:
        return msg
    assert error_rate < ratio, msg
    return msg



if __name__ == "__main__":
    device = "cuda"
    # dtype = torch.bfloat16
    # dtype = torch.float16
    dtype = torch.float32

    B, H, T, D = 1, 18, 8192, 128
    nsp, num_reader, num_writer = 8, 3, 2
    q1 = torch.randn(B, H, T, D, dtype=dtype)
    q2 = torch.randn(B, H, T, D, dtype=dtype)
    k1 = torch.randn(B, H, T, D, dtype=dtype)
    k2 = torch.randn(B, H, T, D, dtype=dtype)
    v = torch.randn(B, H, T, D, dtype=dtype)
    gk1 = (F.logsigmoid(torch.rand((B, H, T, D), dtype=dtype, device=device)) / 16)
    gk2 = (F.logsigmoid(torch.rand((B, H, T, D), dtype=dtype, device=device)) / 16)
    eta = torch.randn(B, T, nsp, dtype=dtype)

    q1, q2, k1, k2, v, gk1, gk2, eta = map(lambda x: x.to(device).requires_grad_(True), (q1, q2, k1, k2, v, gk1, gk2, eta))

    do = torch.rand_like(v).transpose(1, 2)


    def gla_process(q1, k1, gk1, v1, q2, k2, v2, gk2):
        q, k, gk, v = [torch.cat(pair, dim=-2) for pair in zip((q1, k1, gk1, v1), (q2, k2, gk2, v2))]

        ## GLA start ##
        warmup_iters = 5
        repeat_iters = 20
        # warm-up
        torch.cuda.empty_cache()
        for _ in range(warmup_iters):
            o, recurrent_state = chunk_gla(
                q,
                k,
                v,
                gk,
                initial_state=None,
                output_final_state=True,
                cu_seqlens=None,
                head_first=False,
            )
        torch.cuda.reset_peak_memory_stats()
        
        speed, mem = [], []
        for _ in range(repeat_iters):
            torch.cuda.synchronize()
            start = time.time()
            o, recurrent_state = chunk_gla(
                q,
                k,
                v,
                gk,
                initial_state=None,
                output_final_state=True,
                cu_seqlens=None,
                head_first=False,
            )
            torch.cuda.synchronize()
            elapsed = time.time() - start
            speed.append(elapsed * 1000 * 1000)
            mem.append(torch.cuda.max_memory_allocated())
        speed_mean = sum(speed) / len(speed)
        mem_mean = sum(mem) / len(mem)
        print(f"[GLA] Total fwd time: {speed_mean:.0f}us")
        print(f"[GLA] Max memory used: {sizeof_fmt(mem_mean)}")
        torch.cuda.reset_peak_memory_stats()
        ## GLA end ##

        o, recurrent_state = chunk_gla(
            q,
            k,
            v,
            gk,
            initial_state=None,
            output_final_state=True,
            cu_seqlens=None,
            head_first=False,
        )

        o = rearrange(o, "b l (n h) d -> b l n h d", n=nsp+1)
        o = o.sum(2)
        print("o shape:", o.shape)
        return o


    def fwd_bwd_test(q1, q2, k1, k2, v, gk1, gk2, eta, do, num_writer, num_reader):
        print("\n=== fwd test ===\n")
        warmup_iters = 5
        repeat_iters = 20

        # ===== sse-mask start  =====
        q1, k1, gk1, v1 = [rearrange(src, 'b h l d -> b l h d') for src in [q1, k1, gk1, v]]
        q2, k2, gk2, v2 = [rearrange(src, 'b h l d -> b l h d') for src in [q2, k2, gk2, v]]
        # q2, k2, v2, gk2 = [
        #     rearrange(
        #         repeat(x.unsqueeze(1), "b 1 h l d -> b n h l d", n=nsp),
        #         "b n h l d -> b l n h d"
        #     ) for x in [q2, k2, v, gk2]
        # ]

        k1 = F.softmax(k1.float(), dim=-1)
        k2 = F.softmax(k2.float(), dim=-1)
        k1 = k1.to(v)
        k2 = k2.to(v)

        # warm-up
        torch.cuda.empty_cache()
        for _ in range(warmup_iters):
            q2_ref, k2_ref, v2_ref, gk2_ref, eta_ref, mask_w_ref, mask_r_ref = torch_impl(q2, k2, v2, gk2, eta)
            # q2_tri, k2_tri, v2_tri, gk2_tri, eta_tri, mask_w_tri, mask_r_tri = softmax_and_mask(q2[:,:,0], k2[:,:,0], v2[:,:,0], gk2[:,:,0], eta, num_writer, num_reader)
            q2_tri, k2_tri, v2_tri, gk2_tri, eta_tri, mask_w_tri, mask_r_tri = softmax_and_mask(q2, k2, v2, gk2, eta, num_writer, num_reader)
        torch.cuda.reset_peak_memory_stats()
        
        speed, mem = [], []
        for _ in range(repeat_iters):
            torch.cuda.synchronize()
            start = time.time()
            q2_ref, k2_ref, v2_ref, gk2_ref, eta_ref, mask_w_ref, mask_r_ref = torch_impl(q2, k2, v2, gk2, eta)
            torch.cuda.synchronize()
            elapsed = time.time() - start
            speed.append(elapsed * 1000 * 1000)
            mem.append(torch.cuda.max_memory_allocated())
        speed_mean = sum(speed) / len(speed)
        mem_mean = sum(mem) / len(mem)
        print(f"[Torch] Total fwd time: {speed_mean:.0f}us")
        print(f"[Torch] Max memory used: {sizeof_fmt(mem_mean)}")
        torch.cuda.reset_peak_memory_stats()

        speed, mem = [], []
        for _ in range(repeat_iters):
            torch.cuda.synchronize()
            start = time.time()
            # q2_tri, k2_tri, v2_tri, gk2_tri, eta_tri, mask_w_tri, mask_r_tri = softmax_and_mask(q2[:,:,0], k2[:,:,0], v2[:,:,0], gk2[:,:,0], eta, num_writer, num_reader)
            q2_tri, k2_tri, v2_tri, gk2_tri, eta_tri, mask_w_tri, mask_r_tri = softmax_and_mask(q2, k2, v2, gk2, eta, num_writer, num_reader)
            torch.cuda.synchronize()
            elapsed = time.time() - start
            speed.append(elapsed * 1000 * 1000)
            mem.append(torch.cuda.max_memory_allocated())
        speed_mean = sum(speed) / len(speed)
        mem_mean = sum(mem) / len(mem)
        print(f"[Triton] Total fwd time: {speed_mean:.0f}us")
        print(f"[Triton] Max memory used: {sizeof_fmt(mem_mean)}")

        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        q2_ref, k2_ref, v2_ref, gk2_ref = [
            rearrange(x, "b l n h d -> b l (n h) d") for x in [q2_ref, k2_ref, v2_ref, gk2_ref]
        ]

        print("=== fwd diff test ===")
        print(assert_close('e', eta_ref, eta_tri, 0.005))
        print(assert_close('q', q2_ref, q2_tri, 0.005))
        print(assert_close('k', k2_ref, k2_tri, 0.005))
        print("v diff num:", (v2_ref != v2_tri).sum())
        print("g diff num:", (gk2_ref != gk2_tri).sum())
        print("mw diff num:", (mask_w_ref.squeeze() != mask_w_tri.squeeze()).sum())
        print("mr diff num:", (mask_r_ref.squeeze() != mask_r_tri.squeeze()).sum())
        print("=== ✅ fwd pass ===")

        pdb.set_trace()

        q1.retain_grad()
        k1.retain_grad()
        gk1.retain_grad()
        v.retain_grad()
        q2.retain_grad()
        k2.retain_grad()
        gk2.retain_grad()
        eta.retain_grad()

        print("\n=== bwd test ===\n")

        w1, w2 = 1.0, 0.01

        o_ref = gla_process(q1, k1, gk1, v1, q2_ref, k2_ref, v2_ref, gk2_ref)
        aux_loss_ref = torch.zeros(()).to(eta_ref)
        p_ref = torch.mean(eta_ref, dim=(0, 1))
        f_ref = torch.mean(mask_w_ref.squeeze(-1).squeeze(-1).float(), dim=(0, 1))
        aux_loss_ref = torch.sum(p_ref * f_ref) * nsp / num_writer
        torch.cuda.empty_cache()
        # warm-up
        for _ in range(warmup_iters):
            q1.grad = k1.grad = gk1.grad = v.grad = q2.grad = k2.grad = gk2.grad = eta.grad = None
            (o_ref * do * w1 + aux_loss_ref * w2).sum().backward(retain_graph=True)
            # (q2_ref + k2_ref).sum().backward(retain_graph=True)
        torch.cuda.reset_peak_memory_stats()
        # benchmark bwd
        speed, mem = [], []
        for _ in range(repeat_iters):
            torch.cuda.synchronize()
            q1.grad = k1.grad = gk1.grad = v.grad = q2.grad = k2.grad = gk2.grad = eta.grad = None
            start = time.time()
            (o_ref * do * w1 + aux_loss_ref * w2).sum().backward(retain_graph=True)
            # (q2_ref + k2_ref).sum().backward(retain_graph=True)
            torch.cuda.synchronize()
            elapsed = time.time() - start
            speed.append(elapsed * 1000 * 1000)
            mem.append(torch.cuda.max_memory_allocated())
        speed_mean = sum(speed) / len(speed)
        mem_mean = sum(mem) / len(mem)
        print(f"[Torch] Total bwd time: {speed_mean:.0f}us")
        print(f"[Torch] Max memory used: {sizeof_fmt(mem_mean)}")
        torch.cuda.reset_peak_memory_stats()
        ref_dq1, ref_dk1, ref_dg1, ref_dv, ref_dq2, ref_dk2, ref_dg2, ref_de = q1.grad, k1.grad, gk1.grad, v.grad, q2.grad, k2.grad, gk2.grad, eta.grad
        q1.grad = k1.grad = gk1.grad = v.grad = q2.grad = k2.grad = gk2.grad = eta.grad = None
        pdb.set_trace()

        o_tri = gla_process(q1, k1, gk1, v1, q2_tri, k2_tri, v2_tri, gk2_tri)
        aux_loss_tri = torch.zeros(()).to(eta_tri)
        p_tri = torch.mean(eta_tri, dim=(0, 1))
        f_tri = torch.mean(mask_w_tri.float(), dim=(0, 1))
        aux_loss_tri = torch.sum(p_tri * f_tri) * nsp / num_writer
        # warm-up
        torch.cuda.empty_cache()
        for _ in range(warmup_iters):
            q1.grad = k1.grad = gk1.grad = v.grad = q2.grad = k2.grad = gk2.grad = eta.grad = None
            (o_tri * do * w1 + aux_loss_tri * w2).sum().backward(retain_graph=True)
            # (q2_tri + k2_tri).sum().backward(retain_graph=True)
        torch.cuda.reset_peak_memory_stats()
        # benchmark bwd
        speed, mem = [], []
        for _ in range(repeat_iters):
            torch.cuda.synchronize()
            q1.grad = k1.grad = gk1.grad = v.grad = q2.grad = k2.grad = gk2.grad = eta.grad = None
            start = time.time()
            (o_tri * do * w1 + aux_loss_tri * w2).sum().backward(retain_graph=True)
            # (q2_tri + k2_tri).sum().backward(retain_graph=True)
            torch.cuda.synchronize()
            elapsed = time.time() - start
            speed.append(elapsed * 1000 * 1000)
            mem.append(torch.cuda.max_memory_allocated())
        speed_mean = sum(speed) / len(speed)
        mem_mean = sum(mem) / len(mem)
        print(f"[Triton] Total bwd time: {speed_mean:.0f}us")
        print(f"[Triton] Max memory used: {sizeof_fmt(mem_mean)}")
        torch.cuda.reset_peak_memory_stats()
        tri_dq1, tri_dk1, tri_dg1, tri_dv, tri_dq2, tri_dk2, tri_dg2, tri_de = q1.grad, k1.grad, gk1.grad, v.grad, q2.grad, k2.grad, gk2.grad, eta.grad
        pdb.set_trace()

        print("=== bwd diff test ===")
        print(assert_close('aux', aux_loss_ref, aux_loss_tri, 0.005))
        print(assert_close('  o', o_ref, o_tri, 0.005))
        print(assert_close('dq1', ref_dq1, tri_dq1, 0.005))
        print(assert_close('dk1', ref_dk1, tri_dk1, 0.005))
        print(assert_close('dg1', ref_dg1, tri_dg1, 0.005))
        print(assert_close(' dv', ref_dv, tri_dv, 0.005))
        print(assert_close('dq2', ref_dq2, tri_dq2, 0.005))
        print(assert_close('dk2', ref_dk2, tri_dk2, 0.005))
        print(assert_close('dg2', ref_dg2, tri_dg2, 0.005))
        print(assert_close(' de', ref_de, tri_de, 0.1))  # because of grad accumulation (bf16)
        print("=== ✅ bwd pass ===")

        pdb.set_trace()


    fwd_bwd_test(q1, q2, k1, k2, v, gk1, gk2, eta, do, num_writer, num_reader)
