from typing import Optional
import torch
import triton
import triton.language as tl


@triton.heuristics(
    {
        "BLOCK_SIZE_DK": lambda args: triton.next_power_of_2(args["head_dim_qk"]),
        "BLOCK_SIZE_DV": lambda args: triton.next_power_of_2(args["head_dim_v"]),
        "USE_GAMMA": lambda args: args["gamma_ptr"] is not None,
        "USE_SLOT_IDS": lambda args: args["slot_ids"] is not None,
    }
)
@triton.jit
def _rla_decode_kernel(
    q_ptr,
    k_ptr,
    v_ptr,
    o_ptr,
    alpha_ptr,
    beta_ptr,
    gamma_ptr,
    S_ptr,
    R_ptr,
    slot_ids,
    scale,
    rclip,
    l2_qk_norm: tl.constexpr,
    max_batch_size,
    head_dim_qk,
    head_dim_v,
    stride_q_b,
    stride_q_h,
    stride_q_d,
    stride_k_b,
    stride_k_h,
    stride_k_d,
    stride_v_b,
    stride_v_h,
    stride_v_d,
    stride_o_b,
    stride_o_h,
    stride_o_d,
    stride_a_b,
    stride_a_h,
    stride_b_b,
    stride_b_h,
    stride_g_b,
    stride_g_h,
    stride_S_b,
    stride_S_h,
    stride_S_dk,
    stride_S_dv,
    stride_R_b,
    stride_R_h,
    stride_R_dk,
    stride_R_dv,
    BLOCK_SIZE_DK: tl.constexpr,
    BLOCK_SIZE_DV: tl.constexpr,
    USE_SLOT_IDS: tl.constexpr,
    USE_GAMMA: tl.constexpr,
):
    pid_b, pid_h = tl.program_id(0), tl.program_id(1)
    off_dk = tl.arange(0, BLOCK_SIZE_DK)
    off_dv = tl.arange(0, BLOCK_SIZE_DV)
    q_ptrs = q_ptr + pid_b * stride_q_b + pid_h * stride_q_h + off_dk * stride_q_d
    k_ptrs = k_ptr + pid_b * stride_k_b + pid_h * stride_k_h + off_dk * stride_k_d
    v_ptrs = v_ptr + pid_b * stride_v_b + pid_h * stride_v_h + off_dv * stride_v_d
    o_ptrs = o_ptr + pid_b * stride_o_b + pid_h * stride_o_h + off_dv * stride_o_d
    alpha_ptr = alpha_ptr + pid_b * stride_a_b + pid_h * stride_a_h
    beta_ptr = beta_ptr + pid_b * stride_b_b + pid_h * stride_b_h
    if USE_GAMMA:
        gamma_ptr = gamma_ptr + pid_b * stride_g_b + pid_h * stride_g_h
    if USE_SLOT_IDS:
        slot_id = tl.load(slot_ids + pid_b).to(tl.int32)
        slot_id = tl.where(slot_id < 0, max_batch_size + slot_id, slot_id)
        S_ptrs = (
            S_ptr
            + slot_id * stride_S_b
            + pid_h * stride_S_h
            + off_dk[:, None] * stride_S_dk
            + off_dv[None, :] * stride_S_dv
        )
        R_ptrs = (
            R_ptr
            + slot_id * stride_R_b
            + pid_h * stride_R_h
            + off_dk[:, None] * stride_R_dk
            + off_dv[None, :] * stride_R_dv
        )
    else:
        S_ptrs = (
            S_ptr
            + pid_b * stride_S_b
            + pid_h * stride_S_h
            + off_dk[:, None] * stride_S_dk
            + off_dv[None, :] * stride_S_dv
        )
        R_ptrs = (
            R_ptr
            + pid_b * stride_R_b
            + pid_h * stride_R_h
            + off_dk[:, None] * stride_R_dk
            + off_dv[None, :] * stride_R_dv
        )
    # load
    mask_dk = off_dk < head_dim_qk
    mask_dv = off_dv < head_dim_v
    q = tl.load(q_ptrs, mask=mask_dk, other=0).to(tl.float32)
    k = tl.load(k_ptrs, mask=mask_dk, other=0).to(tl.float32)
    if l2_qk_norm:
        q_rstd = 1 / (tl.sqrt(tl.sum(q * q, axis=0)) + 1e-5)
        k_rstd = 1 / (tl.sqrt(tl.sum(k * k, axis=0)) + 1e-5)
        q = q * q_rstd
        k = k * k_rstd
    v = tl.load(v_ptrs, mask=mask_dv, other=0).to(tl.float32)
    alpha = tl.load(alpha_ptr)
    alpha = tl.exp(alpha)
    beta = tl.load(beta_ptr)
    if USE_GAMMA:
        gamma = tl.load(gamma_ptr)
    S = tl.load(S_ptrs, mask=mask_dk[:, None] & mask_dv[None, :], other=0)
    R = tl.load(R_ptrs, mask=mask_dk[:, None] & mask_dv[None, :], other=0)
    # compute S_{t-1}q_{t}, S_{t-1}k_{t}
    sq = tl.sum(S * q[:, None], axis=0)
    sk = tl.sum(S * k[:, None], axis=0)
    # get residual
    r = tl.clamp(v - sk, min=-rclip, max=rclip)
    # update S_{t} = exp(alpha_{t}) * S_{t-1} + beta_{t} * v_{t} @ k_{t}^{T}
    S = alpha * S + beta * k[:, None] * v[None, :]
    # update R_{t} = exp(alpha_{t}) * R_{t-1} + gamma_{t} * r_{t} @ k_{t}^{T}
    if USE_GAMMA:
        R = alpha * R + gamma * k[:, None] * r[None, :]
    else:
        R = alpha * R + beta * k[:, None] * r[None, :]
    # compute R_{t}q_{t}
    rq = tl.sum(R * q[:, None], axis=0)
    # merge S_{t-1}q_{t} and R_{t}q_{t}, o_{t} = exp(alpha_{t}) * S_{t-1}q_{t} + gamma_{t} * R_{t}q_{t}
    if USE_GAMMA:
        o = alpha * sq + gamma * rq
    else:
        o = alpha * sq + beta * rq
    # scale
    o = o * scale
    # store
    tl.store(o_ptrs, o.to(o_ptr.dtype.element_ty), mask=mask_dv)
    tl.store(
        S_ptrs, S.to(S_ptr.dtype.element_ty), mask=mask_dk[:, None] & mask_dv[None, :]
    )
    tl.store(
        R_ptrs, R.to(R_ptr.dtype.element_ty), mask=mask_dk[:, None] & mask_dv[None, :]
    )


def rla_decode(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    alpha: torch.Tensor,
    beta: torch.Tensor,
    gamma: Optional[torch.Tensor],
    state_S: torch.Tensor,
    state_R: torch.Tensor,
    slot_ids: Optional[torch.Tensor] = None,
    scale: Optional[float] = None,
    rclip: float = 1.0,
    l2_qk_norm: bool = True,
) -> torch.Tensor:
    """
    Residual Linear Attention Decoding
    S_t = exp(alpha_t) * S_{t-1} + beta_t * v_t @ k_t.T
    r_t = clip(v_t - S_{t-1} @ k_t)
    R_t = exp(alpha_t) * R_{t-1} + gamma_t * r_t @ k_t.T
    o_t = exp(alpha_t) * S_{t-1} @ q_t + gamma_t * R_t @ q_t

    Args:
        q (torch.Tensor): Query tensor of shape (batch_size, num_heads, head_dim_qk)
        k (torch.Tensor): Key tensor of shape (batch_size, num_heads, head_dim_qk)
        v (torch.Tensor): Value tensor of shape (batch_size, num_heads, head_dim_v)
        alpha (torch.Tensor): Log of gate tensor of shape (batch_size, num_heads), where each element is in [-inf, 0]
        beta (torch.Tensor): Learning rate tensor of shape (batch_size, num_heads), where each element is in [0, 1].
        gamma (torch.Tensor): Correction strength tensor of shape (batch_size, num_heads), where each element is in [0, 1]. Default to None, which means gamma is equals to beta.
        state_S (torch.Tensor): Cached base state tensor of shape (max_batch_size, num_heads, head_dim_qk, head_dim_v).
        state_R (torch.Tensor): Cached residual state tensor of shape (max_batch_size, num_heads, head_dim_qk, head_dim_v).
        slot_ids (torch.Tensor): Slot ids tensor of shape (batch_size,). Default to None, which means use the first batch_size elements of state_S and state_R.
        scale (float): Scale factor. Default to None, which means scale is head_dim_qk**-0.5
        rclip (float): Clip value. Default to 1.0.
        l2_qk_norm (bool): Whether to normalize the query and key tensors. Default to True.

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, num_heads, head_dim_v)
    """
    # shape check
    batch_size, num_heads, head_dim_qk, head_dim_v = *q.shape, v.shape[-1]
    assert q.shape == k.shape
    assert q.shape[:-1] == v.shape[:-1]
    assert alpha.shape == (batch_size, num_heads)
    assert alpha.shape == beta.shape
    if gamma is not None:
        assert gamma.shape == alpha.shape
    assert (
        state_S.shape[1:] == state_R.shape[1:] == (num_heads, head_dim_qk, head_dim_v)
    )
    if slot_ids is not None:
        assert slot_ids.shape == (batch_size,)
    if scale is None:
        scale = head_dim_qk**-0.5

    o = torch.empty_like(v)
    max_batch_size = state_S.shape[0]

    def grid(meta):
        return (batch_size, num_heads)

    _rla_decode_kernel[grid](
        q,
        k,
        v,
        o,
        alpha,
        beta,
        gamma,
        state_S,
        state_R,
        slot_ids,
        scale,
        rclip,
        l2_qk_norm,
        max_batch_size,
        head_dim_qk,
        head_dim_v,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        v.stride(0),
        v.stride(1),
        v.stride(2),
        o.stride(0),
        o.stride(1),
        o.stride(2),
        alpha.stride(0),
        alpha.stride(1),
        beta.stride(0),
        beta.stride(1),
        gamma.stride(0) if gamma is not None else 0,
        gamma.stride(1) if gamma is not None else 0,
        state_S.stride(0),
        state_S.stride(1),
        state_S.stride(2),
        state_S.stride(3),
        state_R.stride(0),
        state_R.stride(1),
        state_R.stride(2),
        state_R.stride(3),
        num_warps=8,
        num_stages=1,
    )
    return o
