# -*- coding: utf-8 -*-
# This code is based on https://github.com/fla-org/flash-linear-attention
import torch
import triton
import triton.language as tl
from fla.ops.utils import contiguous
import math 
from fla.modules.l2norm import l2norm_bwd, l2norm_fwd

# ----------------------------------------------------------------------------
# Forward kernel 
# ----------------------------------------------------------------------------
# In kenerl for mamba backbone we had BK = D_K, where d_state = 16 was preferable. 
# Here D_K dimension will be bigger so we have to iterate over BK
# Not having to iterate over BK, saved 2*num_seq_blocks writting operation.
# Sadly we cannot afford that here.
# ---------------------------------------------------------------------------
@triton.autotune(
    configs=[
        triton.Config({'BT': BT, 'BK': BK, 'BV': BV}, num_stages=st, num_warps=wp)
        for BT in [8]
        for BK in [8]
        for BV in [16]
        for st in [4,5,6]
        for wp in [2,4,8]
    ],
    key=['L', 'D_K', 'D_V']
)
@triton.jit
def bma_heads_dt_fwd_kernel(
    q,                                 # query [B, L, H, D_K]
    k,                                 # key   [B, L, H, D_K]
    v,                                 # value [B, L, H, D_V]
    b,                                 # beta  [B, L, H, D_V]
    dt,                                # related to alpha, and beta_eff = dt * beta  input dep [B, L]
    log_N,                             # related to alpha, not input dep [H] scalar float per head
    log_I0,                            # scalar float per head
    log_LRM,                           # LRI scalar per head learning rate of the mean
    log_LRI,                           # LRM scalar per head learning rate of the importances
    N_init,                            # N_init scalar initial value of N
    o,                                 # output [B, L, H, D_V] (direct output, no extra NK dim)
    intermediate_state,                # intermediate state [NL, B, H, D_V, D_K] to be safe, will be deprecated soon 
    intermediate_state_importance,     # intermediate importance state [NL, B, D_V, D_K] to be safe, will be deprecated soon 
    final_state,                       # final state scratchpad [B, H, D_V, D_K] (also used for writing final state)
    final_state_importance,            # final importance state scratchpad [B, H, D_V, D_K]
    s_int_l,                           # stride for intermediate state along NL 
    s_int_b,                           # stride for intermediate state along batch 
    s_int_h,                           # stride for intermediate state along head
    s_int_v,                           # stride for intermediate state along d_inner dimension 
    s_int_k,                           # stride for intermediate state along d_state dimension
    s_state_b,                         # stride for state along batch 
    s_state_h,                         # stride for state along head 
    s_state_v,                         # stride for state along d_inner dimension 
    s_state_k,                         # stride for state along d_state dimension 
    s_qk_b,                            # stride for q,k along batch
    s_qk_l,                            # stride for q,k along sequence dimension
    s_qk_h,                            # stride for q,k along head
    s_qk_d,                            # stride for q,k along d dimension (usually 1)
    s_vo_b,                            # stride for v,b,o along batch
    s_vo_l,                            # stride for v,b,o along sequence dimension
    s_vo_h,                            # stride for v,b,o along head
    s_vo_d,                            # stride for v,b,o along d dimension (usually 1)
    s_dt_b,                            # stride for dt along batch
    s_dt_l,                            # stride for dt along sequence dimension
    s_dt_h,                            # stride for dt along head dimension
    L,                                 # sequence length
    H,                                 # number of Heads
    beta_exp,                          # beta_exp is a boolean, if True, beta is exp(beta), else sigmoid(beta)
    CHECKPOINTS_LENGHT,                # number of iterations between the saving of two intermediate states.
    BT: tl.constexpr,                  # BLOCK SIZE along sequence (chunk size)
    BK: tl.constexpr,                  # BLOCK SIZE along key dimension
    BV: tl.constexpr,                  # BLOCK SIZE along value dimension
    D_K: tl.constexpr,                 # query/key dimension
    D_V: tl.constexpr,                 # value/beta/output dimension
):
    # ---------------------------------------------------------------------
    # Each kernel instance is launched with grid (NV, batch)
    # i_v: index for the V (value) block; i_b: batch index.
    # ---------------------------------------------------------------------
    i_v,  i_bh = tl.program_id(0), tl.program_id(1)
    i_b, i_h = i_bh // H, i_bh % H
    # ---------------------------------------------------------------------
    # Determine how many sequence (time) blocks there are.
    # Determine how many BK blocks there are.
    # ---------------------------------------------------------------------
    num_seq_blocks = tl.cdiv(L, BT)
    NK = tl.cdiv(D_K, BK)
    # ---------------------------------------------------------------------
    # Add the base to all tensor pointers 
    # ---------------------------------------------------------------------
    # I/O -----------------------------------------------------------------
    q = q + s_qk_h * i_h + s_qk_b * i_b
    k = k + s_qk_h * i_h + s_qk_b * i_b
    v = v + s_vo_h * i_h + s_vo_b * i_b
    o = o + s_vo_h * i_h + s_vo_b * i_b
    b = b + s_vo_h * i_h + s_vo_b * i_b
    dt = dt + s_dt_h * i_h + s_dt_b * i_b
    # States --------------------------------------------------------------
    final_state = final_state + i_h * s_state_h + i_b * s_state_b
    final_state_importance = final_state_importance + i_h * s_state_h + i_b * s_state_b
    intermediate_state = intermediate_state + i_h * s_int_h + i_b * s_int_b
    intermediate_state_importance = intermediate_state_importance + i_h * s_int_h + i_b * s_int_b
    # Per-head scalars ----------------------------------------------------------------------------------
    # N_init = N_init + 3 * N_init * i_h / H #between N_init and 4 times N_init 
    # Iprior = 0.25 * (N_init / D_K)
    # Iinf = Iprior + (N_init / D_K)
    # LRI_BOOST = D_K / 8 
    N     = tl.exp(tl.load(log_N     + i_h).to(tl.float32)) * N_init     # N   [H] → scalar
    I0    = tl.exp(tl.load(log_I0    + i_h).to(tl.float32))         # I0  [H] → scalar
    LRM   = tl.exp(tl.load(log_LRM   + i_h).to(tl.float32))              # LRM [H] → scalar
    LRI   = tl.exp(tl.load(log_LRI   + i_h).to(tl.float32))           # LRI [H] → scalar times D_K / 2
    CHECKPOINT_LENGHT = tl.load(CHECKPOINTS_LENGHT   + i_h).to(tl.int16) # CHECKPOINT_LENGHT [H] → scalar
    # ----------------------------------------------------------------------------------------------------
    # Initialize the scratchpad for state.
    # For each BK block, we write the initial state into the final_state buffers,
    # which will be used as scratchpad for the sequential state updates.
    # ---------------------------------------------------------------------
    for nk in range(NK):
        p_state = tl.make_block_ptr(
            final_state,
            shape=(D_V, D_K), strides=(s_state_v, s_state_k),
            offsets=(i_v * BV, nk * BK),
            block_shape=(BV, BK), order=(1, 0)
        )
        p_state_imp = tl.make_block_ptr(
            final_state_importance,
            shape=(D_V, D_K), strides=(s_state_v, s_state_k),
            offsets=(i_v * BV, nk * BK),
            block_shape=(BV, BK), order=(1, 0)
        )
        
        b_s = tl.zeros([BV, BK], dtype=tl.float32)
        b_si = tl.full([BV, BK], 1 , dtype=tl.float32) #Expected Iinf = Ip + 0.5 * N_init LRI / D_K = 1 + 0.5 * 16 * D_K / D_K /8
        # ---------------------------------------------------------------------
        # Write initial state into scratchpad.
        # b_s stands for block state 
        # b_si stands for block state importance
        # ---------------------------------------------------------------------
        tl.store(p_state, b_s.to(p_state.dtype.element_ty), boundary_check=(0, 1))
        tl.store(p_state_imp, b_si.to(p_state_imp.dtype.element_ty), boundary_check=(0, 1))
    # ---------------------------------------------------------------------
    # Process one sequence (time) block at a time.
    # ---------------------------------------------------------------------
    for seq_blk in range(num_seq_blocks):
        # -------------------------------------------------------------------------
        # Pointer for output/value/beta for the current time block, and value block.
        # -------------------------------------------------------------------------
        p_o = tl.make_block_ptr(
            o,
            shape=(L, D_V), strides=(s_vo_l, s_vo_d),
            offsets=(seq_blk * BT, i_v * BV),
            block_shape=(BT, BV), order=(1, 0)
        )
        p_v = tl.make_block_ptr(
            v,
            shape=(L, D_V), strides=(s_vo_l, s_vo_d),
            offsets=(seq_blk * BT, i_v * BV),
            block_shape=(BT, BV), order=(1, 0)
        )
        p_b = tl.make_block_ptr(
            b,
            shape=(L, D_V), strides=(s_vo_l, s_vo_d),
            offsets=(seq_blk * BT, i_v * BV),
            block_shape=(BT, BV), order=(1, 0)
        )
        p_dt = tl.make_block_ptr(
            dt,
            shape=(L + 1,),        
            strides=(s_dt_l,),          
            offsets=(seq_blk * BT,),  
            block_shape=(BT,),     
            order=(0,) #1D            
        )
        # ---------------------------------------------------------------------
        # load output/value/beta for the current time block, and value block.
        # Initialize an accumulator for output contributions for each BK block
        # ---------------------------------------------------------------------
        b_v = tl.load(p_v, boundary_check=(0, 1))
        b_b = tl.load(p_b, boundary_check=(0, 1))
        b_dt = tl.load(p_dt, boundary_check=(0,))
        b_dt = tl.sigmoid(b_dt.to(tl.float32))
        if beta_exp:
            b_b_eff  = (b_dt[:,None] * tl.exp(b_b.to(tl.float32))).to(b_v.dtype)  
        else:  
            b_b_eff  = (b_dt[:,None] * tl.sigmoid(b_b.to(tl.float32))).to(b_v.dtype)    
        col   = tl.arange(0, BT)         
        valid = col < BT - 1 
        p_shift = tl.advance(p_dt, (1,))                                      
        b_dt_shifted = tl.load(p_shift, boundary_check=(0,))
        b_dt_shifted = tl.sigmoid(b_dt_shifted.to(tl.float32))
        b_dt_shifted = tl.where(valid, b_dt_shifted, 0.0)
        out_accum = tl.zeros([BT, BV], dtype=tl.float32)
        # ---------------------------------------------------------------------
        # Compute the relevant tensor for forgetting.
        # Sigmoid has to be done in float 32, same for exponential.
        # Adavance one step in time to get dt_shifted
        # The very last dt might be anything, it's not an allocated memory.
        # It must be set to 0.
        # ---------------------------------------------------------------------
        b_a_shifted = tl.cumsum(b_dt_shifted, axis=0, reverse=True)
        b_a_shifted = tl.exp(-b_a_shifted/N)
        b_a_reversed = tl.cumsum(b_dt, axis=0)  
        sum_log_F = tl.sum(b_dt/N, axis=0)
        b_a_reversed = tl.exp(-b_a_reversed/N)
        prod_F = tl.exp(-sum_log_F)
        # -----------------------------------------------------------------
        # Loop over each key (NK) block.
        # For each BK block, load its current state from scratchpad, process the
        # current time block, update the state, and accumulate the output.
        # -----------------------------------------------------------------
        for nk in range(NK):
            # -----------------------------------------------------------------
            # Reload state for the current BK block from scratchpad.
            # -----------------------------------------------------------------
            p_state = tl.make_block_ptr(
                final_state,
                shape=(D_V, D_K), strides=(s_state_v, s_state_k),
                offsets=(i_v * BV, nk * BK),
                block_shape=(BV, BK), order=(1, 0)
            )
            p_state_imp = tl.make_block_ptr(
                final_state_importance,
                shape=(D_V, D_K), strides=(s_state_v, s_state_k),
                offsets=(i_v * BV, nk * BK),
                block_shape=(BV, BK), order=(1, 0)
            )
            b_s = tl.load(p_state, boundary_check=(0, 1)).to(tl.float32)
            b_si = tl.load(p_state_imp, boundary_check=(0, 1)).to(tl.float32)
            # ---------------------------------------------------------------------
            # Strong forgetting deteriorate the re-computation of previous state.
            # To have precise gradients we need to save some intermediate states.
            # CHECKPOINT_LENGHT should be at most 4N.
            # The base depends also on time indexes for the intermediate states.
            # ---------------------------------------------------------------------
            step = tl.maximum(1, CHECKPOINT_LENGHT // BT)          # ← guard
            offset_int = seq_blk // step
            if seq_blk % step == 0:
                p_int_state = tl.make_block_ptr(
                    intermediate_state + offset_int*s_int_l,
                    shape=(D_V, D_K), strides=(s_int_v, s_int_k),
                    offsets=(i_v * BV, nk * BK),
                    block_shape=(BV, BK), order=(1, 0)
                )
                p_int_state_imp = tl.make_block_ptr(
                    intermediate_state_importance + offset_int*s_int_l,
                    shape=(D_V, D_K), strides=(s_int_v, s_int_k),
                    offsets=(i_v * BV, nk * BK),
                    block_shape=(BV, BK), order=(1, 0)
                )
                tl.store(p_int_state, b_s.to(p_int_state.dtype.element_ty), boundary_check=(0, 1))
                tl.store(p_int_state_imp, b_si.to(p_int_state_imp.dtype.element_ty), boundary_check=(0, 1))
            # ---------------------------------------------------------------------
            # Pointer for key/query for the current time block, and BK block.
            # ---------------------------------------------------------------------
            p_q = tl.make_block_ptr(
                q,
                shape=(L, D_K), strides=(s_qk_l, s_qk_d),
                offsets=(seq_blk * BT, nk * BK),
                block_shape=(BT, BK), order=(1, 0)
            )
            p_k = tl.make_block_ptr(
                k,
                shape=(L, D_K), strides=(s_qk_l, s_qk_d),
                offsets=(seq_blk * BT, nk * BK),
                block_shape=(BT, BK), order=(1, 0)
            )
            # ---------------------------------------------------------------------
            # Load key/query for the current time block, and BK block.
            # ---------------------------------------------------------------------
            b_q = tl.load(p_q, boundary_check=(0, 1))
            b_k = tl.load(p_k, boundary_check=(0, 1))
            # ---------------------------------------------------------------------
            # Do maximum operation possible before the outer product.
            # ---------------------------------------------------------------------
            b_kk = b_k * b_k
            b_ba = b_b_eff * b_a_shifted[:,None]
            b_vba = b_v * b_ba
            # ---------------------------------------------------------------------
            # Do the outer product.
            # ---------------------------------------------------------------------
            input_M  = LRM * b_k[:, None, :]  * b_vba[:, :, None]
            input_I  = LRI * b_kk[:, None, :] * b_ba[:, :, None]
            # ---------------------------------------------------------------------
            # Compute cumulative sums over the time block.
            # Taking into account the forgetting.
            # M/I --> BT,BV,D_K
            # ---------------------------------------------------------------------
            M  = tl.cumsum(input_M, axis=0) / b_a_shifted[:, None, None] + b_si[None] * b_s[None] * b_a_reversed[:,None,None]
            I  = tl.cumsum(input_I, axis=0) / b_a_shifted[:, None, None] + b_si[None] * b_a_reversed[:,None,None] + (1 - b_a_reversed[:,None,None]) * I0
            Mu = M / I
            # ---------------------------------------------------------------------
            # Compute output contribution: 
            # Multiply Mu by q and sum over the key (D_K) dimension.
            # ---------------------------------------------------------------------
            contribution = tl.sum(Mu * b_q[:, None, :], axis=2)
            # Accumulate the contribution from this BK block.
            out_accum = out_accum + contribution
            # ---------------------------------------------------------------------
            # Compute the new state in float32 
            # Carefull input_M/I has already been multiply by alphas
            # ---------------------------------------------------------------------
            new_b_s = prod_F * b_si * b_s + tl.sum(input_M.to(b_si.dtype), axis=0)
            new_b_si = tl.sum(input_I.to(b_si.dtype), axis=0) + (1-prod_F)*I0 + prod_F * b_si
            new_b_s = new_b_s / new_b_si
            # ---------------------------------------------------------------------
            # Write the updated state back to the scratchpad.
            # ---------------------------------------------------------------------
            tl.store(p_state, new_b_s.to(p_state.dtype.element_ty), boundary_check=(0, 1))
            tl.store(p_state_imp, new_b_si.to(p_state_imp.dtype.element_ty), boundary_check=(0, 1))
        # ---------------------------------------------------------------------    
        # End loop over NK
        # Store the accumulated output for the current sequence block.
        # ---------------------------------------------------------------------
        tl.store(p_o, out_accum.to(p_o.dtype.element_ty), boundary_check=(0, 1))
    # ---------------------------------------------------------------------
    # End loop over sequence blocks
    # Final states were already saved so we are done. 
    # ---------------------------------------------------------------------


# ----------------------------------------------------------------------------
# Backward kernel 
# ----------------------------------------------------------------------------
# BE CAREFUL MAKE SURE BV MIN HERE IS ABOVE BV_MIN IN THE WRAPPER
# ----------------------------------------------------------------------------
@triton.autotune(
    configs=[
        triton.Config({'BT': BT, 'BK': BK, 'BV': BV}, num_stages=st, num_warps=wp)
        for BT in [8]
        for BK in [8]
        for BV in [16]
        for st in [4,5,6]
        for wp in [2,4,8]
    ],
    key=['L', 'D_K', 'D_V']
)
@triton.jit
def bma_heads_dt_bwd_kernel(
    q,                                    # query [B, L, H, D_K]
    k,                                    # key   [B, L, H, D_K]
    v,                                    # value [B, L, H, D_V]
    b,                                    # beta  [B, L, H, D_V] For happy synapses importances --> have non scalar beta ;)
    dt,                                   # related to alpha, and beta_eff = dt * beta  input dep [B, L]
    log_N,                                # N [H] scalar float per head, related to alpha
    log_I0,                               # I0 [H] scalar float per head
    log_LRM,                              # LRI [H] scalar per head learning rate of the mean
    log_LRI,                              # LRI [H] scalar per head learning rate of the importances
    N_init,                               # # N_init scalar initial value of N
    do,                                   # gradient of output [B, L, H, D_V]
    dq,                                   # gradient of query [NV, B, L, H, D_K]
    dk,                                   # gradient of key   [NV, B, L, H, D_K]
    dv,                                   # gradient of value [B, L, H, D_V]
    db,                                   # gradient for b [B, L, H, D_V]
    ddt,                                  # gradient for dt [NV, B, L]
    dN ,                                  # gradient for A [NV, B, H]
    dM,                                   # used to stored intermediate step [B, H, D_V, D_K] 
    dI,                                   # used to stored intermediate step [B, H, D_V, D_K] 
    dI0,                                  # gradient of I0 [NV, B, H]
    dLRM,                                 # gradient of learning rate over state mean [NV, B, H]
    dLRI,                                 # gradient of learning rate over state importance [NV, B, H]
    intermediate_state,                   # safegard [NL, B, H, D_V, D_K]
    intermediate_state_importance,        # safeguard [NL, B, H, D_V, D_K]
    final_state,                          # final state [B, H, D_V, D_K]
    final_state_importance,               # final state importance [B, H, D_V, D_K]
    s_scalar_v,                           # stride for state along NV dimension
    s_scalar_b,                           # stride for state along batch dimension
    s_scalar_h,                           # stride for state along head dimension
    s_int_l,                              # stride for state along NL dimension
    s_int_b,                              # stride for state along batch 
    s_int_h,                              # stride for state along heads
    s_int_v,                              # stride for intermediate state along d_inner dimension 
    s_int_k,                              # stride for intermediate state along d_state dimension
    s_state_b,                            # stride for state along batch 
    s_state_h,                            # stride for state along heads
    s_state_v,                            # stride for state along d_inner dimension 
    s_state_k,                            # stride for state along d_state dimension 
    s_dqk_v,                              # stride for dq, dk along along NV dimension
    s_dqk_b,                              # stride for dq, dk along along batch
    s_dqk_l,                              # stride for dq, dk along sequence dimension
    s_dqk_h,                              # stride for dq, dk along along heads 
    s_dqk_d,                              # stride for dq, dk along along d_state dimension 
    s_ddt_v,                              # stride for dt along NV dimension
    s_ddt_b,                              # stride for dt along batch
    s_ddt_l,                              # stride for dt along sequence dimension
    s_ddt_h,                              # stride for dt along heads dimension
    s_qk_b,                               # stride for q, k along batch
    s_qk_l,                               # stride for q, k along sequence dimension
    s_qk_h,                               # stride for q, k along heads
    s_qk_d,                               # stride for q, k along d dimension
    s_vo_b,                               # stride for v, b, do, dv, db along batch
    s_vo_l,                               # stride for v, b, do, dv, db along sequence dimension
    s_vo_h,                               # stride for v, b, do, dv, db along heads
    s_vo_d,                               # stride for v, b, do, dv, db along d dimensiom
    s_dt_b,                               # stride for dt along batch
    s_dt_l,                               # stride for dt along sequence dimension
    s_dt_h,                               # stride for dt along heads dimension
    L,                                    # sequence length
    H,                                    # Head size
    beta_exp,                             # beta_exp is a boolean, if True, beta is exp(beta), else sigmoid(beta)
    CHECKPOINTS_LENGHT,                   # number of iterations between the saving of two intermediate states.
    BT: tl.constexpr,                     # block size along sequence dimension
    BK: tl.constexpr,                     # block size  along the K dimension
    BV: tl.constexpr,                     # block size  along the V dimension
    D_K: tl.constexpr,                    # query/key dimension
    D_V: tl.constexpr,                    # value dimension
):
    # ---------------------------------------------------------------------
    # Each kernel instance is launched with grid (NV, batch)
    # i_v: index for the V (value) block; i_b: batch index.
    # ---------------------------------------------------------------------
    i_v,  i_bh = tl.program_id(0), tl.program_id(1)
    i_b, i_h = i_bh // H, i_bh % H
    # ---------------------------------------------------------------------
    # Determine how many sequence (time) blocks there are.
    # Determine how many BK blocks there are.
    # ---------------------------------------------------------------------
    num_seq_blocks = tl.cdiv(L, BT)
    NK = tl.cdiv(D_K, BK)
    # -----------------------------------------------------------------------
    # Add the base to all tensor pointers 
    # -----------------------------------------------------------------------
    # I/O -------------------------------------------------------------------
    q  = q + s_qk_h * i_h + s_qk_b * i_b
    k  = k + s_qk_h * i_h + s_qk_b * i_b
    v  = v + s_vo_h * i_h + s_vo_b * i_b
    b  = b + s_vo_h * i_h + s_vo_b * i_b
    dt = dt + s_dt_h * i_h + s_dt_b * i_b
    # I/O gadients ----------------------------------------------------------
    dq = dq + s_dqk_h * i_h + s_dqk_b * i_b + s_dqk_v * i_v #(carefull: dq not same shape as q)
    dk = dk + s_dqk_h * i_h + s_dqk_b * i_b + s_dqk_v * i_v #(carefull: dk not same shape as k)
    dv = dv + s_vo_h * i_h + s_vo_b * i_b
    db = db + s_vo_h * i_h + s_vo_b * i_b
    do = do + s_vo_h * i_h + s_vo_b * i_b
    ddt = ddt+ s_ddt_h * i_h + s_ddt_b * i_b + s_ddt_v * i_v #(carefull: ddt not same shape as dt)
    # States ----------------------------------------------------------------
    final_state = final_state + i_h * s_state_h + i_b * s_state_b
    final_state_importance = final_state_importance + i_h * s_state_h + i_b * s_state_b
    dM = dM + i_h * s_state_h + i_b * s_state_b
    dI = dI + i_h * s_state_h + i_b * s_state_b
    intermediate_state = intermediate_state + i_h * s_int_h + i_b * s_int_b
    intermediate_state_importance = intermediate_state_importance + i_h * s_int_h + i_b * s_int_b
    # Per-head scalars -----------------------------------------------------------------------------------
    # N_init = N_init + 3 * N_init * i_h / H #between N_init and 4 times N_init 
    # Iprior = 0.25 * (N_init / D_K)
    # LRI_BOOST = D_K / 8
    N     = tl.exp(tl.load(log_N     + i_h).to(tl.float32)) * N_init     # N   [H] → scalar
    I0    = tl.exp(tl.load(log_I0    + i_h).to(tl.float32))            # I0  [H] → scalar
    LRM   = tl.exp(tl.load(log_LRM   + i_h).to(tl.float32))              # LRM [H] → scalar
    LRI   = tl.exp(tl.load(log_LRI   + i_h).to(tl.float32))            # LRI [H] → scalar
    CHECKPOINT_LENGHT = tl.load(CHECKPOINTS_LENGHT   + i_h).to(tl.int16) # CHECKPOINT_LENGHT [H] → scalar
    # ----------------------------------------------------------------------------------------------------
    # Per-head scalars -----------------------------------------------------
    dN     = dN     + i_h * s_scalar_h + i_v * s_scalar_v + i_b * s_scalar_b      
    dI0    = dI0    + i_h * s_scalar_h + i_v * s_scalar_v + i_b * s_scalar_b  
    dLRM   = dLRM   + i_h * s_scalar_h + i_v * s_scalar_v + i_b * s_scalar_b  
    dLRI   = dLRI   + i_h * s_scalar_h + i_v * s_scalar_v + i_b * s_scalar_b  
    # ---------------------------------------------------------------------
    # Prepare accumulator for the grad.
    # Accumulators stay in float32
    # ---------------------------------------------------------------------
    grad_N   = tl.zeros([1], dtype=tl.float32)
    grad_LRI = tl.zeros([1], dtype=tl.float32)
    grad_LRM = tl.zeros([1], dtype=tl.float32)
    grad_I0  = tl.zeros([1], dtype=tl.float32)
    # Process blocks in reverse order.
    for seq_blk in range(num_seq_blocks - 1, -1, -1):
        # -------------------------------------------------------------------------------
        # Pointer for grad output/value/beta for the current time block, and value block.
        # ------------------------------------------------------------------------------- 
        offset = seq_blk * BT
        p_v = tl.make_block_ptr(v, (L, D_V), (s_vo_l, s_vo_d),
                                  (offset, i_v * BV), (BT, BV), (1, 0))
        p_b = tl.make_block_ptr(b, (L, D_V), (s_vo_l, s_vo_d),
                                (offset, i_v * BV), (BT, BV), (1, 0))
        p_do = tl.make_block_ptr(do, (L, D_V), (s_vo_l, s_vo_d),
                                (offset, i_v * BV), (BT, BV), (1, 0))
        p_dt = tl.make_block_ptr(dt, (L + 1,), (s_dt_l,), 
                                 (seq_blk * BT,), (BT,), order=(0,))
        # ---------------------------------------------------------------------------
        # load grad output/value/beta/dt for the current time block, and value block.
        # Integration time influences input strenght.
        # ---------------------------------------------------------------------------
        b_v = tl.load(p_v, boundary_check=(0, 1))
        b_b = tl.load(p_b, boundary_check=(0, 1))
        b_do = tl.load(p_do, boundary_check=(0, 1))
        b_dt = tl.load(p_dt, boundary_check=(0,))
        b_dt = tl.sigmoid(b_dt.to(tl.float32))
        # Options for beta activate --------------------------------------------
        if beta_exp:
            b_act = tl.exp(b_b.to(tl.float32)).to(b_v.dtype)
            b_grad_r = tl.full([BT, BV], 1, dtype=b_v.dtype)
        else:
            b_act = tl.sigmoid(b_b.to(tl.float32)).to(b_v.dtype)
            b_grad_r = (1-b_act)
        b_b_eff  = (b_dt[:,None] * b_act).to(b_v.dtype)  
        col   = tl.arange(0, BT)         
        valid = col < BT - 1 
        p_shift = tl.advance(p_dt, (1,))                                      
        b_dt_shifted = tl.load(p_shift, boundary_check=(0,))
        b_dt_shifted = tl.sigmoid(b_dt_shifted.to(tl.float32))
        b_dt_shifted = tl.where(valid, b_dt_shifted, 0.0)
        # ---------------------------------------------------------------------
        # Compute the relevant tensor for forgetting.
        # Sigmoid has to be done in float 32, same for exponential.
        # Adavance one step in time to get dt_shifted
        # The very last dt might be anything, it's not an allocated memory.
        # It must be set to 0.
        # ---------------------------------------------------------------------
        b_a_shifted = tl.cumsum(b_dt_shifted, axis=0, reverse=True)
        b_a_shifted = tl.exp(-b_a_shifted/N)
        b_a_chunk = tl.exp(-b_dt/N)
        b_a_reversed = tl.cumsum(b_dt, axis=0)
        b_a_reversed = tl.exp(-b_a_reversed/N)
        sum_log_F = tl.sum(b_dt, axis=0)
        prod_F = tl.exp(-sum_log_F/N)
        # ---------------------------------------------------------------------
        # Prepare accumulator for value and beta grads.
        # Accumulators stay in float32
        # ---------------------------------------------------------------------
        grad_dt = tl.zeros([BT], dtype=tl.float32)
        grad_v = tl.zeros([BT, BV], dtype=tl.float32)
        grad_b = tl.zeros([BT, BV], dtype=tl.float32)       
        for nk in range(NK):
            # ---------------------------------------------------------------------
            # Load state for the current BK block from scratchpad.
            # Load grad accumulators for the current BK block from scratchpad.
            # ---------------------------------------------------------------------
            p_state = tl.make_block_ptr(
                final_state,
                shape=(D_V, D_K), strides=(s_state_v, s_state_k),
                offsets=(i_v * BV, nk * BK),
                block_shape=(BV, BK), order=(1, 0)
            )
            p_state_imp = tl.make_block_ptr(
                final_state_importance,
                shape=(D_V, D_K), strides=(s_state_v, s_state_k),
                offsets=(i_v * BV, nk * BK),
                block_shape=(BV, BK), order=(1, 0)
            )
            b_s = tl.load(p_state, boundary_check=(0, 1)).to(tl.float32)
            b_si = tl.load(p_state_imp, boundary_check=(0, 1)).to(tl.float32)
            p_dM = tl.make_block_ptr(
                dM,
                shape=(D_V, D_K), strides=(s_state_v, s_state_k),
                offsets=(i_v * BV, nk * BK),
                block_shape=(BV, BK), order=(1, 0)
            )
            p_dI = tl.make_block_ptr(
                dI,
                shape=(D_V, D_K), strides=(s_state_v, s_state_k),
                offsets=(i_v * BV, nk * BK),
                block_shape=(BV, BK), order=(1, 0)
            )
            current_dM = tl.load(p_dM, boundary_check=(0, 1)).to(tl.float32)
            current_dI = tl.load(p_dI, boundary_check=(0, 1)).to(tl.float32)
            # ---------------------------------------------------------------------
            # Load query/key for the current time and BK block 
            # ---------------------------------------------------------------------
            p_q = tl.make_block_ptr(q, (L, D_K), (s_qk_l, s_qk_d),
                                  (offset, nk * BK), (BT, BK), (1, 0))
            p_k = tl.make_block_ptr(k, (L, D_K), (s_qk_l, s_qk_d),
                                    (offset, nk * BK), (BT, BK), (1, 0))
            # ---------------------------------------------------------------------
            # Load key/query for the current time and BK block.
            # ---------------------------------------------------------------------
            b_k = tl.load(p_k, boundary_check=(0, 1))
            b_q = tl.load(p_q, boundary_check=(0, 1))
            # ---------------------------------------------------------------------
            # Do maximum operation possible before the outer product.
            # Here we do not multiply by alpha, as it would imply to divide by it for A,LRM, and LRI grad.
            # ---------------------------------------------------------------------
            b_kk = b_k * b_k
            b_vb = b_v * b_b_eff
            # ---------------------------------------------------------------------
            # Do the outer product.
            # ---------------------------------------------------------------------
            input_M = LRM * b_k[:, None, :] * b_vb[:, :, None]
            input_I  = LRI * b_kk[:, None, :] * b_b_eff[:, :, None]
            # ---------------------------------------------------------------------
            # LOAD INTERMEDIATE STATE OR RE-COMPUTE
            # ---------------------------------------------------------------------
            # Strong forgetting deteriorate the re-computation of previous state.
            # To have precise gradients we need to save some intermediate states.
            # CHECKPOINT_LENGHT should be at most 4N.
            # The base depends also on time indexes for the intermediate states.
            # ---------------------------------------------------------------------
            step = tl.maximum(1, CHECKPOINT_LENGHT // BT)          # ← guard
            offset_int = seq_blk // step
            if seq_blk % step == 0:
                p_int_state = tl.make_block_ptr(
                    intermediate_state + offset_int * s_int_l,            
                    shape=(D_V, D_K), strides=(s_int_v, s_int_k),
                    offsets=(i_v * BV, nk * BK),
                    block_shape=(BV, BK), order=(1, 0)
                )
                p_int_state_imp = tl.make_block_ptr(
                    intermediate_state_importance + offset_int * s_int_l, 
                    shape=(D_V, D_K), strides=(s_int_v, s_int_k),
                    offsets=(i_v * BV, nk * BK),
                    block_shape=(BV, BK), order=(1, 0)
                )     
                b_s_pre = tl.load(p_int_state, boundary_check=(0, 1)).to(tl.float32)
                b_si_pre = tl.load(p_int_state_imp, boundary_check=(0, 1)).to(tl.float32)
            else:
                # ---------------------------------------------------------------------
                # Make sure previous states are recomputed in float32
                # ---------------------------------------------------------------------
                b_si_pre = (b_si - tl.sum((input_I* b_a_shifted[:, None, None]).to(b_si.dtype), axis=0) -  (1-prod_F) * I0) /  prod_F
                b_si_pre = tl.maximum(b_si_pre, I0)
                b_s_pre = (b_s * b_si - tl.sum((input_M* b_a_shifted[:, None, None]).to(b_si.dtype), axis=0) ) / (b_si_pre * prod_F)
            # ---------------------------------------------------------------------
            # Compute cumulative sums over the time block.
            # Taking into account the forgetting. 
            # Be carefull compared to the forward pass input_M/I have not been multiply by alphas yet.
            # M/I --> BT,BV,D_K
            # ---------------------------------------------------------------------
            M = tl.cumsum(input_M * b_a_shifted[:, None, None], axis=0) / b_a_shifted[:, None, None] + b_si_pre[None] * b_s_pre[None] * b_a_reversed[:,None,None]
            I = tl.cumsum(input_I * b_a_shifted[:, None, None], axis=0) / b_a_shifted[:, None, None] + (1-b_a_reversed[:,None,None]) * I0  + b_a_reversed[:,None,None] * b_si_pre[None]
            Mu = M / I
            # ---------------------------------------------------------------------
            # Compute the state one step erlier for grad A.
            # M_shift should be divide by I_shift...
            # However the grad would be multiply after by I_shift so might has well do nothing.
            # ---------------------------------------------------------------------
            M_shift = (M - input_M)/b_a_chunk[:,None,None]
            I_shift = (I - input_I - (1-b_a_chunk[:,None,None])*I0) / b_a_chunk[:,None,None] 
            # ---------------------------------------------------------------------
            # Compute the grad over q (could have been done in forward mode)
            # ---------------------------------------------------------------------
            grad_q = tl.sum(b_do[:, :, None] * Mu, axis=1)
            # ---------------------------------------------------------------------
            # Compute the local grad over input_M and input_I.
            # ---------------------------------------------------------------------
            dmu = b_do[:, :, None] * b_q[:, None, :]
            dM_local = dmu / I
            dI_local = -dmu * M / (I * I)
            # ---------------------------------------------------------------------
            # We don't forget to add the grad contribution coming from future state
            # They are accumulated in current_dM and current_dI
            # ---------------------------------------------------------------------
            dinput_M =  tl.cumsum(dM_local * b_a_reversed[:,None,None], axis=0, reverse=True)/b_a_reversed[:,None,None] + current_dM * b_a_shifted[:,None,None]
            dinput_I =  tl.cumsum(dI_local * b_a_reversed[:,None,None], axis=0, reverse=True)/b_a_reversed[:,None,None] + current_dI * b_a_shifted[:,None,None]
            # ---------------------------------------------------------------------
            # Compute the local grad over inputs and LRs
            # Carefull that accumulators gets tl.float32
            # ---------------------------------------------------------------------
            grad_v   += LRM * tl.sum(dinput_M * b_k[:, None, :], axis=2) * b_b_eff
            grad_b_a  = LRM * tl.sum(dinput_M * b_k[:, None, :], axis=2) * b_v  #actually grad_b_eff
            grad_b_a += LRI * tl.sum(dinput_I * b_k[:, None, :] * b_k[:, None, :], axis=2) #actually grad_b_eff
            grad_dt  += tl.sum(grad_b_a * b_act, axis = 1) * b_dt * (1-b_dt)  
            grad_b   += grad_b_a * b_b_eff * b_grad_r 
            grad_k    =  LRM * tl.sum(dinput_M * b_v[:, :, None] * b_b_eff[:, :, None], axis=1)
            grad_k   += LRI * 2 * tl.sum(dinput_I * b_b_eff[:, :, None], axis=1) * b_k
            grad_LRI += tl.sum((dinput_I * input_I).to(b_si.dtype))
            grad_LRM += tl.sum((dinput_M * input_M).to(b_si.dtype))
            # ---------------------------------------------------------------------
            # Compute the local grad for the forgetting
            # grad_a_M would be multiply by I_shift if we did divide before by I_shift
            # Carefull that accumulators gets tl.float32
            # ---------------------------------------------------------------------
            grad_a_M = M_shift * dinput_M
            grad_a_I = (I_shift - I0) * dinput_I
            grad_a_M = tl.sum(tl.sum(grad_a_M, axis=2),axis=1)
            grad_a_I = tl.sum(tl.sum(grad_a_I, axis=2),axis=1)
            grad_a = grad_a_M + grad_a_I
            grad_N += tl.sum((grad_a * b_a_chunk * b_dt / N).to(b_si.dtype), axis=0)
            grad_I0 +=  tl.sum(dI_local.to(b_si.dtype)) * I0   
            grad_dt += - grad_a * b_a_chunk * b_dt * (1-b_dt) / N   
            # ---------------------------------------------------------------------
            # All grad are computed we can update futur grad accumulator and states
            # ---------------------------------------------------------------------
            current_dM = tl.sum((dM_local*b_a_reversed[:,None,None]).to(b_si.dtype), axis=0) + current_dM * prod_F 
            current_dI = tl.sum((dI_local*b_a_reversed[:,None,None]).to(b_si.dtype), axis=0) + current_dI * prod_F
            b_s = b_s_pre
            b_si = b_si_pre
            # ---------------------------------------------------------------------
            # We save gradients that have a time and D_K dimension 
            # ---------------------------------------------------------------------
            p_dq = tl.make_block_ptr(dq, (L, D_K), (s_dqk_l, s_dqk_d),
                                   (offset, nk * BK), (BT, BK), (1, 0))
            p_dk = tl.make_block_ptr(dk, (L, D_K), (s_dqk_l, s_dqk_d),
                                    (offset, nk * BK), (BT, BK), (1, 0))
            tl.store(p_dq, grad_q.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
            tl.store(p_dk, grad_k.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
            tl.store(p_state, b_s_pre.to(p_state.dtype.element_ty), boundary_check=(0, 1))
            tl.store(p_state_imp, b_si_pre.to(p_state_imp.dtype.element_ty), boundary_check=(0, 1))
            tl.store(p_dM, current_dM.to(p_dM.dtype.element_ty), boundary_check=(0, 1))
            tl.store(p_dI, current_dI.to(p_dI.dtype.element_ty), boundary_check=(0, 1))
        # ---------------------------------------------------------------------
        # We save gradients that have a time dimension 
        # ---------------------------------------------------------------------
        p_dv = tl.make_block_ptr(dv, (L, D_V), (s_vo_l, s_vo_d),
                                    (offset, i_v * BV), (BT, BV), (1, 0))
        p_db = tl.make_block_ptr(db, (L, D_V), (s_vo_l, s_vo_d),
                                (offset, i_v * BV), (BT, BV), (1, 0))
        tl.store(p_dv, grad_v.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
        tl.store(p_db, grad_b.to(p_db.dtype.element_ty), boundary_check=(0, 1))
        # dt grads -----------------------------------------------------------
        p_ddt = tl.make_block_ptr(ddt, (L,), (s_ddt_l,), 
                                (seq_blk * BT,), (BT,), order=(0,))
        tl.store(p_ddt, grad_dt.to(p_ddt.dtype.element_ty), boundary_check=(0,))
    # ---------------------------------------------------------------------
    # The time lopp is over.
    # We save gradients that doesnt have a time dimension
    # ---------------------------------------------------------------------
    p_dN= tl.make_block_ptr(dN,
                            shape=(1,), strides=(1,),
                            offsets=(0,), block_shape=(1,), order=(0,))
    tl.store(p_dN, grad_N.to(p_dN.dtype.element_ty), boundary_check=(0,)) 
    p_dLRM = tl.make_block_ptr(dLRM,
                            shape=(1,), strides=(1,),
                            offsets=(0,), block_shape=(1,), order=(0,))
    tl.store(p_dLRM, grad_LRM.to(p_dLRM.dtype.element_ty), boundary_check=(0,)) 

    p_dLRI = tl.make_block_ptr(dLRI,
                            shape=(1,), strides=(1,),
                            offsets=(0,), block_shape=(1,), order=(0,))
    tl.store(p_dLRI, grad_LRI.to(p_dLRI.dtype.element_ty), boundary_check=(0,)) 
    p_dI0 = tl.make_block_ptr(dI0,
                            shape=(1,), strides=(1,),
                            offsets=(0,), block_shape=(1,), order=(0,))
    tl.store(p_dI0, grad_I0.to(p_dI0.dtype.element_ty), boundary_check=(0,)) 
    

class BayesianMetaplasticAttentionHeadsDtFunction(torch.autograd.Function):
    @staticmethod
    @contiguous
    @torch.autocast(device_type="cuda")
    def forward(ctx, q, k, v, b, dt, log_N, log_I0, log_LRM, log_LRI, N_init, beta_exp, use_qk_l2norm_in_kernel, output_final_state):
        # q, k: [B, L, H, D_K]; v, b: [B, L, H, D_V]
        batch_size, L, H, D_K = q.shape
        D_V = v.shape[-1]
        # -----------------------------------------------------------------
        # Gated delta net style.
        # -----------------------------------------------------------------
        if use_qk_l2norm_in_kernel:
            q, q_rstd = l2norm_fwd(q)
            k, k_rstd = l2norm_fwd(k)
            
        else:
            q_rstd, k_rstd = None, None
        # -----------------------------------------------------------------
        # Allocate output tensor directly with the right shape.
        # -----------------------------------------------------------------
        o = q.new_empty(batch_size, L, H, D_V)
        # -----------------------------------------------------------------
        # Add any element to dt, because we need to load dt_shifted.
        # -----------------------------------------------------------------
        dt = torch.cat([dt, torch.zeros(batch_size, 1, H, device=dt.device)], dim=1) 
        # -----------------------------------------------------------------
        # Clip everything that goes inside an exponential
        # -----------------------------------------------------------------
        bound = math.log(10)
        with torch.no_grad():
            log_N.clamp_(-bound, bound)
            # log_I0.clamp_(-bound, bound)
            # log_LRM.clamp_(-bound, bound)
            # log_LRI.clamp_(-bound, bound)
        # -----------------------------------------------------------------
        # final_state and final_state_importance will be used for bwd pass
        # States stay in float32
        # -----------------------------------------------------------------
        final_state = q.new_empty(batch_size, H, D_V, D_K, dtype=torch.float32, requires_grad=False)
        final_state_importance = q.new_empty(batch_size, H, D_V, D_K, dtype=torch.float32, requires_grad=False)
        # -----------------------------------------------------------------
        # Allocate intermediate tensor directly with the right shape.
        # States stay in float32
        # If N is small use more intermediate state
        # -----------------------------------------------------------------
        N = N_init * torch.exp(log_N)
        BT_MIN = 8
        CHECKPOINTS_LENGHT = torch.minimum(BT_MIN * torch.ceil(4 * N / BT_MIN) , 128*torch.ones_like(N))
        NC = int(torch.max(triton.cdiv(L, CHECKPOINTS_LENGHT)))
        intermediate_state = q.new_empty(NC, batch_size, H, D_V, D_K, dtype=torch.float32, requires_grad=False)
        intermediate_state_importance = q.new_empty(NC, batch_size, H, D_V, D_K, dtype=torch.float32, requires_grad=False)
        # -----------------------------------------------------------------
        # Grid: (NV, batch_size) + META stuff for NV (for the autotune...)
        # -----------------------------------------------------------------
        grid = lambda META: (triton.cdiv(D_V, META['BV']), batch_size*H)
        bma_heads_dt_fwd_kernel[grid](
            q, k, v, b, dt, log_N, log_I0, log_LRM, log_LRI, N_init, 
            o, intermediate_state, intermediate_state_importance, final_state, final_state_importance,
            intermediate_state.stride(0), intermediate_state.stride(1), intermediate_state.stride(2), intermediate_state.stride(3), intermediate_state.stride(4),
            final_state.stride(0), final_state.stride(1), final_state.stride(2), final_state.stride(3),
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),
            dt.stride(0), dt.stride(1), dt.stride(2),
            L, H, beta_exp, CHECKPOINTS_LENGHT, D_K=D_K, D_V=D_V, 
        )
        # -----------------------------------------------------------------
        # Save tensor for bwd
        # -----------------------------------------------------------------
        ctx.save_for_backward(q, q_rstd, k, k_rstd, v, b, dt, log_N, log_I0, log_LRM, log_LRI, intermediate_state, intermediate_state_importance, final_state, final_state_importance)
        # -----------------------------------------------------------------
        # Save scalars for bwd
        # -----------------------------------------------------------------
        ctx.N_init = N_init 
        ctx.beta_exp = beta_exp
        ctx.CHECKPOINTS_LENGHT = CHECKPOINTS_LENGHT
        ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
        return (o.to(q.dtype), final_state, final_state_importance) if output_final_state else o.to(q.dtype)

    @staticmethod
    @contiguous
    @torch.autocast(device_type="cuda")
    def backward(ctx, do, d_final_state=None, d_final_state_importance=None):
        # ----------------------------------------------------------------------------
        # Take the relevant tensor saved in the forward pass
        # ----------------------------------------------------------------------------
        q, q_rstd, k, k_rstd, v, b, dt, log_N, log_I0, log_LRM, log_LRI, intermediate_state, intermediate_state_importance, final_state, final_state_importance = ctx.saved_tensors
        # ----------------------------------------------------------------------------
        # Be carefull BV_min here is bellow or equal to BV_MIN in the autotune
        # ----------------------------------------------------------------------------
        BV_min = 16
        batch_size, L, H, D_K = q.shape
        D_V = v.shape[-1]
        NV_max = triton.cdiv(D_V, BV_min)
        # ----------------------------------------------------------------------------
        # Allocate the tensor for the grad.
        # ----------------------------------------------------------------------------
        dq   = torch.zeros(NV_max, batch_size, L, H, D_K, dtype=q.dtype, requires_grad=False, device=q.device)
        dk   = torch.zeros(NV_max, batch_size, L, H, D_K, dtype=q.dtype, requires_grad=False, device=q.device)
        dN   = torch.zeros(NV_max, batch_size, H, dtype=q.dtype, requires_grad=False, device=q.device) 
        dLRM = torch.zeros(NV_max, batch_size, H, dtype=q.dtype, requires_grad=False, device=q.device) 
        dLRI = torch.zeros(NV_max, batch_size, H, dtype=q.dtype, requires_grad=False, device=q.device) 
        dI0  = torch.zeros(NV_max, batch_size, H, dtype=q.dtype, requires_grad=False, device=q.device) 
        ddt  = torch.zeros(NV_max, batch_size, L, H, dtype=q.dtype, requires_grad=False, device=q.device) 
        dM   = torch.zeros(batch_size, H, D_V, D_K, dtype=q.dtype, requires_grad=False, device=q.device)
        dI   = torch.zeros(batch_size, H, D_V, D_K, dtype=q.dtype, requires_grad=False, device=q.device)
        dv   = q.new_empty(batch_size, L, H, D_V, dtype=q.dtype)
        db   = q.new_empty(batch_size, L, H, D_V, dtype=q.dtype)
        # -----------------------------------------------------------------
        # Grid: (NV, batch_size) + META stuff for NV (for the autotune...)
        # -----------------------------------------------------------------
        grid = lambda META: (triton.cdiv(D_V, META['BV']), batch_size*H)
        bma_heads_dt_bwd_kernel[grid](
            q, k, v, b, dt, log_N, log_I0, log_LRM, log_LRI, ctx.N_init,
            do, dq, dk, dv, db, ddt, dN, dM, dI, dI0, dLRM, dLRI, intermediate_state, intermediate_state_importance, final_state, final_state_importance,
            dN.stride(0), dN.stride(1), dN.stride(2),
            intermediate_state.stride(0), intermediate_state.stride(1), intermediate_state.stride(2), intermediate_state.stride(3), intermediate_state.stride(4),
            final_state.stride(0), final_state.stride(1), final_state.stride(2), final_state.stride(3),
            dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), dq.stride(4), 
            ddt.stride(0), ddt.stride(1), ddt.stride(2), ddt.stride(3),
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),
            dt.stride(0), dt.stride(1), dt.stride(2), 
            L, H, ctx.beta_exp, ctx.CHECKPOINTS_LENGHT,
            D_K=D_K, D_V=D_V, 
        )
        # -----------------------------------------------------------------
        # Sum over NV dimension.
        # -----------------------------------------------------------------
        dk = dk.sum(0)
        dq = dq.sum(0)
        ddt = ddt.sum(0)
        dN = dN.sum(0)
        dLRM = dLRM.sum(0)
        dLRI = dLRI.sum(0)
        dI0 = dI0.sum(0)
        if ctx.use_qk_l2norm_in_kernel:
            dq = l2norm_bwd(q, q_rstd, dq)
            dk = l2norm_bwd(k, k_rstd, dk)
        return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), db.to(v.dtype), ddt.to(v.dtype), dN.to(v.dtype), dI0.to(v.dtype), dLRM.to(v.dtype), dLRI.to(v.dtype), None, None, None, None
    

def bayesian_metaplastic_attention_heads_dt(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    b: torch.Tensor,
    dt: torch.Tensor,
    log_N: torch.Tensor,
    log_I0: torch.Tensor,
    log_LRM: torch.Tensor,
    log_LRI: torch.Tensor,
    N_init: float,
    beta_exp: bool,
    use_qk_l2norm_in_kernel: bool,
    output_final_state: bool = False,
):
    outputs = BayesianMetaplasticAttentionHeadsDtFunction.apply(q, k, v, b, dt, log_N, log_I0, log_LRM, log_LRI, N_init, beta_exp, use_qk_l2norm_in_kernel, output_final_state)
    if output_final_state:
        o, final_state, final_state_importance = outputs
        return o, final_state, final_state_importance
    else:
        return outputs
