from typing import Optional
import torch
import triton
import triton.language as tl
from einops import einsum


def torch_sgla_decode(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    alpha: torch.Tensor,
    beta: torch.Tensor,
    state_S: torch.Tensor,
    slot_ids: Optional[torch.Tensor] = None,
    scale: Optional[float] = None,
    l2_qk_norm: bool = True,
) -> torch.Tensor:
    # shape check
    batch_size, num_heads, head_dim_qk, head_dim_v = *q.shape, v.shape[-1]
    assert q.shape == k.shape
    assert q.shape[:-1] == v.shape[:-1]
    assert alpha.shape == (batch_size, num_heads)
    assert alpha.shape == beta.shape
    assert state_S.shape[1:] == (num_heads, head_dim_qk, head_dim_v)
    if slot_ids is not None:
        assert slot_ids.shape == (batch_size,)
    if scale is None:
        scale = head_dim_qk**-0.5
    dtype = v.dtype
    q = q.to(state_S.dtype)
    k = k.to(state_S.dtype)
    v = v.to(state_S.dtype)
    # l2 norm
    if l2_qk_norm:
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
    # kv cache
    if slot_ids is not None:
        S = state_S[slot_ids, ...]
    else:
        S = state_S[:batch_size, ...]
    # update S_{t} = exp(alpha_{t}) * S_{t-1} + beta_{t} * v_{t} @ k_{t}^{T}
    S = alpha[..., None, None].exp() * S + beta[..., None, None] * einsum(
        v, k, "b h dv, b h dk -> b h dk dv"
    )
    # compute o_{t} = S_{t} @ q_{t}
    o = einsum(S, q, "b h dk dv, b h dk -> b h dv")
    # scale
    o = o * scale
    # update state inplace
    if slot_ids is not None:
        state_S.index_put_((slot_ids,), S)
    else:
        state_S[:batch_size, ...] = S
    # return
    return o.to(dtype)


@triton.heuristics(
    {
        "BLOCK_SIZE_DK": lambda args: triton.next_power_of_2(args["head_dim_qk"]),
        "BLOCK_SIZE_DV": lambda args: triton.next_power_of_2(args["head_dim_v"]),
        "USE_SLOT_IDS": lambda args: args["slot_ids"] is not None,
    }
)
@triton.jit
def _sgla_decode_kernel(
    q_ptr,
    k_ptr,
    v_ptr,
    o_ptr,
    alpha_ptr,
    beta_ptr,
    S_ptr,
    slot_ids,
    scale,
    l2_qk_norm: tl.constexpr,
    max_batch_size,
    head_dim_qk,
    head_dim_v,
    stride_q_b,
    stride_q_h,
    stride_q_d,
    stride_k_b,
    stride_k_h,
    stride_k_d,
    stride_v_b,
    stride_v_h,
    stride_v_d,
    stride_o_b,
    stride_o_h,
    stride_o_d,
    stride_a_b,
    stride_a_h,
    stride_b_b,
    stride_b_h,
    stride_S_b,
    stride_S_h,
    stride_S_dk,
    stride_S_dv,
    BLOCK_SIZE_DK: tl.constexpr,
    BLOCK_SIZE_DV: tl.constexpr,
    USE_SLOT_IDS: tl.constexpr,
):
    pid_b, pid_h = tl.program_id(0), tl.program_id(1)
    off_dk = tl.arange(0, BLOCK_SIZE_DK)
    off_dv = tl.arange(0, BLOCK_SIZE_DV)
    q_ptrs = q_ptr + pid_b * stride_q_b + pid_h * stride_q_h + off_dk * stride_q_d
    k_ptrs = k_ptr + pid_b * stride_k_b + pid_h * stride_k_h + off_dk * stride_k_d
    v_ptrs = v_ptr + pid_b * stride_v_b + pid_h * stride_v_h + off_dv * stride_v_d
    o_ptrs = o_ptr + pid_b * stride_o_b + pid_h * stride_o_h + off_dv * stride_o_d
    alpha_ptr = alpha_ptr + pid_b * stride_a_b + pid_h * stride_a_h
    beta_ptr = beta_ptr + pid_b * stride_b_b + pid_h * stride_b_h
    if USE_SLOT_IDS:
        slot_id = tl.load(slot_ids + pid_b).to(tl.int32)
        slot_id = tl.where(slot_id < 0, max_batch_size + slot_id, slot_id)
        S_ptrs = (
            S_ptr
            + slot_id * stride_S_b
            + pid_h * stride_S_h
            + off_dk[:, None] * stride_S_dk
            + off_dv[None, :] * stride_S_dv
        )
    else:
        S_ptrs = (
            S_ptr
            + pid_b * stride_S_b
            + pid_h * stride_S_h
            + off_dk[:, None] * stride_S_dk
            + off_dv[None, :] * stride_S_dv
        )
    # load
    mask_dk = off_dk < head_dim_qk
    mask_dv = off_dv < head_dim_v
    q = tl.load(q_ptrs, mask=mask_dk, other=0).to(tl.float32)
    k = tl.load(k_ptrs, mask=mask_dk, other=0).to(tl.float32)
    v = tl.load(v_ptrs, mask=mask_dv, other=0).to(tl.float32)
    alpha = tl.load(alpha_ptr)
    alpha = tl.exp(alpha)
    beta = tl.load(beta_ptr)
    S = tl.load(S_ptrs, mask=mask_dk[:, None] & mask_dv[None, :], other=0)
    # qk norm
    if l2_qk_norm:
        q_rstd = 1 / (tl.sqrt(tl.sum(q * q, axis=0)) + 1e-5)
        k_rstd = 1 / (tl.sqrt(tl.sum(k * k, axis=0)) + 1e-5)
        q = q * q_rstd
        k = k * k_rstd
    # update S_{t} = exp(alpha_{t}) * S_{t-1} + beta_{t} * v_{t} @ k_{t}^{T}
    S = alpha * S + beta * k[:, None] * v[None, :]
    # compute S_{t}q_{t}
    o = tl.sum(S * q[:, None], axis=0)
    # scale
    o = o * scale
    # store
    tl.store(o_ptrs, o.to(o_ptr.dtype.element_ty), mask=mask_dv)
    tl.store(
        S_ptrs, S.to(S_ptr.dtype.element_ty), mask=mask_dk[:, None] & mask_dv[None, :]
    )


def sgla_decode(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    alpha: torch.Tensor,
    beta: torch.Tensor,
    state_S: torch.Tensor,
    slot_ids: Optional[torch.Tensor] = None,
    scale: Optional[float] = None,
    l2_qk_norm: bool = True,
) -> torch.Tensor:
    """
    Scalar Gate Linear Attention Decoding
    S_t = exp(alpha_t) * S_{t-1} + beta_t * v_t @ k_t.T
    o_t = S_{t} @ q_t

    Args:
        q (torch.Tensor): Query tensor of shape (batch_size, num_heads, head_dim_qk)
        k (torch.Tensor): Key tensor of shape (batch_size, num_heads, head_dim_qk)
        v (torch.Tensor): Value tensor of shape (batch_size, num_heads, head_dim_v)
        alpha (torch.Tensor): Log of gate tensor of shape (batch_size, num_heads), where each element is in [-inf, 0]
        beta (torch.Tensor): Learning rate tensor of shape (batch_size, num_heads), where each element is in [0, 1].
        state_S (torch.Tensor): Cached state tensor of shape (max_batch_size, num_heads, head_dim_qk, head_dim_v).
        slot_ids (torch.Tensor): Slot ids tensor of shape (batch_size,). Default to None, which means use the first batch_size elements of state_S and state_R.
        scale (float): Scale factor. Default to None, which means scale is head_dim_qk**-0.5
        l2_qk_norm (bool): Whether to normalize the query and key tensors. Default to True.

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, num_heads, head_dim_v)
    """
    # shape check
    batch_size, num_heads, head_dim_qk, head_dim_v = *q.shape, v.shape[-1]
    assert q.shape == k.shape
    assert q.shape[:-1] == v.shape[:-1]
    assert alpha.shape == (batch_size, num_heads)
    assert alpha.shape == beta.shape
    assert state_S.shape[1:] == (num_heads, head_dim_qk, head_dim_v)
    if slot_ids is not None:
        assert slot_ids.shape == (batch_size,)
    if scale is None:
        scale = head_dim_qk**-0.5
    max_batch_size = state_S.shape[0]
    o = torch.empty_like(v)

    def grid(meta):
        return (batch_size, num_heads)

    _sgla_decode_kernel[grid](
        q,
        k,
        v,
        o,
        alpha,
        beta,
        state_S,
        slot_ids,
        scale,
        l2_qk_norm,
        max_batch_size,
        head_dim_qk,
        head_dim_v,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        v.stride(0),
        v.stride(1),
        v.stride(2),
        o.stride(0),
        o.stride(1),
        o.stride(2),
        alpha.stride(0),
        alpha.stride(1),
        beta.stride(0),
        beta.stride(1),
        state_S.stride(0),
        state_S.stride(1),
        state_S.stride(2),
        state_S.stride(3),
        num_warps=4,
        num_stages=3,
    )
    return o
