from typing import Optional, Tuple
import torch
from einops import einsum
from .gdn_prefill import gdn_prefill


def naive_rdn_prefill(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    alpha: torch.Tensor,
    beta: torch.Tensor,
    gamma: Optional[torch.Tensor],
    cu_seq_len: Optional[torch.LongTensor] = None,
    initial_S: Optional[torch.Tensor] = None,
    initial_R: Optional[torch.Tensor] = None,
    output_final_state: bool = True,
    scale: Optional[float] = None,
    rclip: float = 1.0,
    l2_qk_norm: bool = True,
    **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
    assert q.dtype == k.dtype == v.dtype
    assert q.dtype != torch.float32
    if scale is None:
        scale = k.shape[-1] ** -0.5
    if gamma is not None:
        assert gamma.shape == beta.shape
    # S_{t-1} q_{t}
    sq, final_state_S = gdn_prefill(
        torch.nn.functional.pad(q[:, 1:], (0, 0, 0, 0, 0, 1)),
        k,
        v,
        alpha,
        beta,
        initial_S=initial_S,
        scale=1,
        cu_seq_len=cu_seq_len,
        output_final_state=output_final_state,
        l2_qk_norm=l2_qk_norm,
    )
    sq = torch.nn.functional.pad(sq[:, :-1], (0, 0, 0, 0, 1, 0))
    # S_{t-1} k_{t}
    sk, _ = gdn_prefill(
        torch.nn.functional.pad(k[:, 1:], (0, 0, 0, 0, 0, 1)),
        k,
        v,
        alpha,
        beta,
        initial_S=initial_S,
        scale=1,
        cu_seq_len=cu_seq_len,
        output_final_state=False,
        l2_qk_norm=l2_qk_norm,
    )
    sk = torch.nn.functional.pad(sk[:, :-1], (0, 0, 0, 0, 1, 0))
    # v_{t} - S_{t-1} k_{t}
    vr = torch.clamp(v - sk, min=-rclip, max=rclip)
    # fit: H_{t} -> v_{t} - S_{t-1} k_{t}
    rq, final_state_R = gdn_prefill(
        q,
        k,
        vr,
        alpha,
        gamma if gamma is not None else beta,
        initial_S=initial_R,
        scale=1,
        cu_seq_len=cu_seq_len,
        output_final_state=output_final_state,
        l2_qk_norm=l2_qk_norm,
    )
    # o_{t} = alpha_{t} * S_{t-1} q_{t} + gamma_{t} * H_{t} k_{t}
    if gamma is not None:
        o = alpha.exp()[..., None] * sq + gamma[..., None] * rq
    else:
        o = alpha.exp()[..., None] * sq + beta[..., None] * rq
    o = (o * scale).to(v.dtype)
    return o, final_state_S, final_state_R


def naive_rdn_decode(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    alpha: torch.Tensor,
    beta: torch.Tensor,
    gamma: Optional[torch.Tensor],
    state_S: torch.Tensor,
    state_R: torch.Tensor,
    slot_ids: Optional[torch.Tensor] = None,
    scale: Optional[float] = None,
    l2_qk_norm: bool = True,
) -> torch.Tensor:
    # shape check
    batch_size, num_heads, head_dim_qk, head_dim_v = *q.shape, v.shape[-1]
    assert q.shape == k.shape
    assert q.shape[:-1] == v.shape[:-1]
    assert alpha.shape == (batch_size, num_heads)
    assert alpha.shape == beta.shape
    if gamma is not None:
        assert gamma.shape == alpha.shape
    assert (
        state_S.shape[1:] == state_R.shape[1:] == (num_heads, head_dim_qk, head_dim_v)
    )
    if slot_ids is not None:
        assert slot_ids.shape == (batch_size,)
    if scale is None:
        scale = head_dim_qk**-0.5
    dtype = v.dtype
    q = q.to(state_S.dtype)
    k = k.to(state_S.dtype)
    v = v.to(state_S.dtype)
    # l2 norm
    if l2_qk_norm:
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
    # kv cache
    if slot_ids is not None:
        S = state_S[slot_ids, ...]
        R = state_R[slot_ids, ...]
    else:
        S = state_S[:batch_size, ...]
        R = state_R[:batch_size, ...]
    # compute S_{t-1}q_{t}, S_{t-1}k_{t}
    sq = einsum(S, q, "b h dk dv, b h dk -> b h dv")
    sk = einsum(S, k, "b h dk dv, b h dk -> b h dv")
    # compute residual
    r = torch.clamp(v - sk, min=-1, max=1)
    # update S_{t} = exp(alpha_{t}) * S_{t-1} + beta_{t} * v_{t} @ k_{t}^{T}
    S = alpha[..., None, None].exp() * S + beta[..., None, None] * einsum(
        v, k, "b h dv, b h dk -> b h dk dv"
    )
    # update R_{t} = exp(alpha_{t}) * R_{t-1} + gamma_{t} * v_{t} @ k_{t}^{T}
    if gamma is not None:
        R = alpha[..., None, None].exp() * R + gamma[..., None, None] * einsum(
            r, k, "b h dv, b h dk -> b h dk dv"
        )
    else:
        R = alpha[..., None, None].exp() * R + beta[..., None, None] * einsum(
            r, k, "b h dv, b h dk -> b h dk dv"
        )
    # compute R_{t}q_{t}
    rq = einsum(R, q, "b h dk dv, b h dk -> b h dv")
    # merge S_{t-1}q_{t} and R_{t}q_{t}, o_{t} = exp(alpha_{t}) * S_{t-1}q_{t} + gamma_{t} * R_{t}q_{t}
    if gamma is not None:
        o = alpha[..., None].exp() * sq + gamma[..., None] * rq
    else:
        o = alpha[..., None].exp() * sq + beta[..., None] * rq
    # scale
    o = o * scale
    # update state inplace
    if slot_ids is not None:
        state_S.index_put_((slot_ids,), S)
        state_R.index_put_((slot_ids,), R)
    else:
        state_S[:batch_size, ...] = S
        state_R[:batch_size, ...] = R
    # return
    return o.to(dtype)


def torch_rdn_decode(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    alpha: torch.Tensor,
    beta: torch.Tensor,
    gamma: Optional[torch.Tensor],
    state_S: torch.Tensor,
    state_R: torch.Tensor,
    slot_ids: Optional[torch.Tensor] = None,
    scale: Optional[float] = None,
    rclip: float = 1.0,
    l2_qk_norm: bool = True,
) -> torch.Tensor:
    # shape check
    batch_size, num_heads, head_dim_qk, head_dim_v = *q.shape, v.shape[-1]
    assert q.shape == k.shape
    assert q.shape[:-1] == v.shape[:-1]
    assert alpha.shape == (batch_size, num_heads)
    assert alpha.shape == beta.shape
    assert state_S.shape[1:] == (num_heads, head_dim_qk, head_dim_v)
    if slot_ids is not None:
        assert slot_ids.shape == (batch_size,)
    if scale is None:
        scale = head_dim_qk**-0.5
    dtype = v.dtype
    q = q.to(state_S.dtype)
    k = k.to(state_S.dtype)
    v = v.to(state_S.dtype)
    # l2 norm
    if l2_qk_norm:
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
    # kv cache
    if slot_ids is not None:
        S = state_S[slot_ids, ...]
        R = state_R[slot_ids, ...]
    else:
        S = state_S[:batch_size, ...]
        R = state_R[:batch_size, ...]
    # compute S_{t-1}q_{t}, S_{t-1}k_{t}
    sq = einsum(S, q, "b h dk dv, b h dk -> b h dv")
    sk = einsum(S, k, "b h dk dv, b h dk -> b h dv")
    # compute residual
    r = torch.clamp(v - sk, min=-rclip, max=rclip)
    # update S
    # S_{t} = exp(alpha_{t}) * S_{t-1} @ (I - beta_{t} * k_{t} @ k_{t}^{T}) + beta_{t} * v_{t} @ k_{t}^{T}
    #       = exp(alpha_{t}) * S_{t-1} - exp(alpha_{t}) * beta_{t} * S_{t-1} @ k_{t} @ k_{t}^{T} + beta_{t} * v_{t} @ k_{t}^{T}
    #       = exp(alpha_{t}) * S_{t-1} - beta_{t} * k_{t} @ (v_{t} - exp(alpha_{t}) * S_{t-1} @ k_{t})
    # where us_{t} = v_{t} - exp(alpha_{t}) * S_{t-1} @ k_{t}
    us = v - alpha.exp()[..., None] * sk
    # update S_{t}
    S = alpha[..., None, None].exp() * S + beta[..., None, None] * einsum(
        us, k, "b h dv, b h dk -> b h dk dv"
    )
    # update R
    ur = r - einsum(alpha.exp(), R, k, "b h, b h dk dv, b h dk -> b h dv")
    if gamma is not None:
        R = alpha[..., None, None].exp() * R + gamma[..., None, None] * einsum(
            ur, k, "b h dv, b h dk -> b h dk dv"
        )
    else:
        R = alpha[..., None, None].exp() * R + beta[..., None, None] * einsum(
            ur, k, "b h dv, b h dk -> b h dk dv"
        )
    # compute R_{t}q_{t}
    rq = einsum(R, q, "b h dk dv, b h dk -> b h dv")
    # merge S_{t-1}q_{t} and R_{t}q_{t}, o_{t} = exp(alpha_{t}) * S_{t-1}q_{t} + gamma_{t} * R_{t}q_{t}
    if gamma is not None:
        o = alpha[..., None].exp() * sq + gamma[..., None] * rq
    else:
        o = alpha[..., None].exp() * sq + beta[..., None] * rq
    # scale
    o = o * scale
    # update state inplace
    if slot_ids is not None:
        state_S.index_put_((slot_ids,), S)
        state_R.index_put_((slot_ids,), R)
    else:
        state_S[:batch_size, ...] = S
        state_R[:batch_size, ...] = R
    # return
    return o.to(dtype)
