import torch
import triton
from typing import Optional, Tuple
from .chunk_h import chunk_fwd_h, chunk_bwd_h, chunk_sq_sk_bwd_h
from .chunk_o import chunk_fwd_o, chunk_bwd_dqka, chunk_bwd_dvb
from .chunk_sq_sk import (
    chunk_fwd_sq_sk,
    chunk_sq_sk_bwd_dqka,
    chunk_sq_sk_bwd_dvb,
)
from ..common.l2norm import l2norm_bwd, l2norm_fwd
from ..common.clip import clip_fwd, clip_bwd
from ..common.merge import merge_output_fwd, merge_output_bwd
from ..common.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard


class ChunkParallelRLA(torch.autograd.Function):

    @staticmethod
    @input_guard
    @autocast_custom_fwd
    def forward(
        ctx,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        alpha: torch.Tensor,
        beta: Optional[torch.Tensor],
        gamma: Optional[torch.Tensor],
        cu_seq_len: Optional[torch.Tensor] = None,
        initial_S: Optional[torch.Tensor] = None,
        initial_R: Optional[torch.Tensor] = None,
        output_final_state: bool = False,
        scale: Optional[float] = None,
        rclip: float = 1.0,
        l2_qk_norm: bool = True,
        chunk_size: Optional[int] = None,
        state_in_fp32: bool = True,
    ):
        # shape check
        batch_size, seq_len, num_heads, head_dim_qk, head_dim_v = *q.shape, v.shape[-1]
        assert head_dim_qk == k.shape[-1]
        assert alpha.shape == (batch_size, seq_len, num_heads)
        if beta is not None:
            assert beta.shape == (batch_size, seq_len, num_heads)
        # default scale and chunk_size
        if scale is None:
            scale = head_dim_qk**-0.5
        if chunk_size is None:
            chunk_size = min(64, max(16, triton.next_power_of_2(seq_len)))
        else:
            assert chunk_size in {32, 64, 128}, "Chunk size must be 32, 64, or 128"
        # set cu_seq_len
        if cu_seq_len is not None:
            assert batch_size == 1, "Only support batch size 1 when using `cu_seq_len`."
            cu_seq_len = cu_seq_len.to(torch.int32)
            batch_size = cu_seq_len.shape[0] - 1
            cu_chunk_len = torch.zeros_like(cu_seq_len)
            cu_chunk_len[1:] = torch.ceil(
                (cu_seq_len[1:] - cu_seq_len[:-1]) / chunk_size
            ).to(torch.int32)
            cu_chunk_len = torch.cumsum(cu_chunk_len, dim=0)
        else:
            cu_seq_len = None
            cu_chunk_len = None
        # norm
        if l2_qk_norm:
            q, q_rstd = l2norm_fwd(q)
            k, k_rstd = l2norm_fwd(k)
        else:
            q_rstd, k_rstd = None, None
        # set initial state
        if initial_S is not None or initial_R is not None:
            assert (
                initial_S.shape
                == initial_R.shape
                == (
                    batch_size,
                    num_heads,
                    head_dim_qk,
                    head_dim_v,
                )
            )
        # compute chunk state
        chunk_S, last_S = chunk_fwd_h(
            k,
            v,
            alpha,
            beta,
            initial_S,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
            output_final_state,
            state_in_fp32,
        )
        # compute output, S_{t-1}q_{t} and S_{t-1}k_{t}
        sq, sk = chunk_fwd_sq_sk(
            q,
            k,
            v,
            alpha,
            beta,
            chunk_S,
            1,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
        )
        # compute clipped residual error
        r, _ = clip_fwd(v, sk, min_val=-rclip, max_val=rclip, return_mask=False)
        # fit residual
        chunk_R, last_R = chunk_fwd_h(
            k,
            r,
            alpha,
            gamma if gamma is not None else beta,
            initial_R,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
            output_final_state,
            state_in_fp32,
        )
        # compute output, H_{t}q_{t}
        rq = chunk_fwd_o(
            q,
            k,
            r,
            alpha,
            gamma if gamma is not None else beta,
            chunk_R,
            1,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
        )
        # merge output, o_{t} = alpha_{t} * S_{t-1}q_{t} + gamma_{t} * H_{t}q_{t}
        if gamma is not None:
            o = merge_output_fwd(sq, rq, alpha, gamma, scale)
        else:
            o = merge_output_fwd(sq, rq, alpha, beta, scale)
        # save for backward
        ctx.save_for_backward(
            q,
            q_rstd,
            k,
            k_rstd,
            v,
            sk,
            sq,
            rq,
            alpha,
            beta,
            gamma,
            initial_S,
            initial_R,
            cu_seq_len,
            cu_chunk_len,
        )
        ctx.chunk_size = chunk_size
        ctx.scale = scale
        ctx.rclip = rclip
        ctx.state_in_fp32 = state_in_fp32
        ctx.l2_qk_norm = l2_qk_norm
        return o, last_S, last_R

    @staticmethod
    @input_guard
    @autocast_custom_bwd
    def backward(ctx, do, dls, dlr):
        (
            q,
            q_rstd,
            k,
            k_rstd,
            v,
            sk,
            sq,
            rq,
            alpha,
            beta,
            gamma,
            initial_S,
            initial_R,
            cu_seq_len,
            cu_chunk_len,
        ) = ctx.saved_tensors
        chunk_size, scale, rclip, state_in_fp32, l2_qk_norm = (
            ctx.chunk_size,
            ctx.scale,
            ctx.rclip,
            ctx.state_in_fp32,
            ctx.l2_qk_norm,
        )
        # backward for merge output
        dsq, drq, dalpha, dgamma = merge_output_bwd(
            do, sq, rq, alpha, gamma if gamma is not None else beta, scale
        )
        del do
        if gamma is None:
            dbeta = dgamma
            dgamma = None
        # recompute residual error
        r, clip_mask = clip_fwd(v, sk, min_val=-rclip, max_val=rclip, return_mask=True)
        # recompute chunk state for R
        chunk_R, _ = chunk_fwd_h(
            k,
            r,
            alpha,
            gamma if gamma is not None else beta,
            initial_R,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
            False,
            state_in_fp32,
        )
        # compute dR
        d_chunk_R = chunk_bwd_h(
            q,
            drq,
            alpha,
            dlr,
            1,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
            state_in_fp32,
        )
        # compute gradient of q, k, alpha
        dq, dk, dalpha_tmp = chunk_bwd_dqka(
            q,
            k,
            r,
            alpha,
            gamma if gamma is not None else beta,
            chunk_R,
            drq,
            d_chunk_R,
            1,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
        )
        dalpha.add_(dalpha_tmp)
        # compute gradient of v
        dr, dgamma_tmp = chunk_bwd_dvb(
            q,
            k,
            r,
            alpha,
            gamma if gamma is not None else beta,
            drq,
            d_chunk_R,
            1,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
        )
        if gamma is not None:
            dgamma.add_(dgamma_tmp)
        else:
            dbeta.add_(dgamma_tmp)
        # bwd for r=clip(v-sk)
        dv, dsk = clip_bwd(dr, clip_mask)
        del r, dr, clip_mask, drq, chunk_R, d_chunk_R, dalpha_tmp, dgamma_tmp
        # recompute chunk state for S
        chunk_S, _ = chunk_fwd_h(
            k,
            v,
            alpha,
            beta,
            initial_S,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
            False,
            state_in_fp32,
        )
        # compute dS
        d_chunk_S = chunk_sq_sk_bwd_h(
            q,
            k,
            dsq,
            dsk,
            alpha,
            dls,
            1,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
            state_in_fp32,
        )
        # compute gradient of q, k, alpha
        dq_tmp, dk_tmp, dk2_tmp, dalpha_tmp = chunk_sq_sk_bwd_dqka(
            q,
            k,
            v,
            alpha,
            beta,
            chunk_S,
            dsq,
            dsk,
            d_chunk_S,
            1,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
        )
        dq.add_(dq_tmp)
        dk.add_(dk_tmp)
        dk.add_(dk2_tmp)
        dalpha.add_(dalpha_tmp)
        # compute gradient of v and beta
        dv_tmp, dbeta_tmp = chunk_sq_sk_bwd_dvb(
            q,
            k,
            v,
            alpha,
            beta,
            dsq,
            dsk,
            d_chunk_S,
            1,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
        )
        dv.add_(dv_tmp)
        if gamma is not None:
            dbeta = dbeta_tmp
        else:
            dbeta.add_(dbeta_tmp)
        # bwd for l2 norm
        if l2_qk_norm:
            dq = l2norm_bwd(q, q_rstd, dq)
            dk = l2norm_bwd(k, k_rstd, dk)
        return (
            dq.to(q.dtype),
            dk.to(k.dtype),
            dv.to(v.dtype),
            dalpha.to(alpha.dtype),
            dbeta.to(beta.dtype),
            dgamma.to(gamma.dtype),
            None,  # cu_seq_len
            None,  # initial_S
            None,  # initial_R
            None,  # output_final_state
            None,  # scale
            None,  # rclip
            None,  # l2_qk_norm
            None,  # chunk_size
            None,  # state_in_fp32
        )


@torch.compiler.disable
def rla_prefill(
    q: torch.Tensor,  # shape: batch_size, total_seqlen, num_heads, head_dim
    k: torch.Tensor,  # shape: batch_size, total_seqlen, num_heads, head_dim
    v: torch.Tensor,  # shape: batch_size, total_seqlen, num_heads, head_dim
    alpha: torch.Tensor,  # shape: batch_size, total_seqlen, num_heads
    beta: torch.Tensor,  # shape: batch_size, total_seqlen, num_heads
    gamma: Optional[torch.Tensor],  # shape: batch_size, total_seqlen, num_heads
    cu_seqlens: Optional[torch.LongTensor] = None,
    initial_S: Optional[torch.Tensor] = None,
    initial_R: Optional[torch.Tensor] = None,
    output_final_state: bool = True,
    scale: Optional[float] = None,
    rclip: float = 1.0,
    l2_qk_norm: bool = True,
    **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
    """
    Residual Linear Attention
    S_t = exp(alpha_t) * S_{t-1} + beta_t * v_t @ k_t.T
    r_t = clip(v_t - S_{t-1} @ k_t, -1, 1)
    R_t = exp(alpha_t) * R_{t-1} + gamma_t * r_t @ k_t.T
    o_t = exp(alpha_t) * S_{t-1} @ q_t + gamma_t * R_t @ q_t

    Args:
        q (torch.Tensor): Query tensor of shape (batch_size, total_seqlen, num_heads, head_dim)
        k (torch.Tensor): Key tensor of shape (batch_size, total_seqlen, num_heads, head_dim)
        v (torch.Tensor): Value tensor of shape (batch_size, total_seqlen, num_heads, head_dim)
        alpha (torch.Tensor): Log of gate tensor of shape (batch_size, total_seqlen, num_heads), where each element is in [-inf, 0]
        beta (torch.Tensor): Learning rate tensor of shape (batch_size, total_seqlen, num_heads), where each element is in [0, 1].
        gamma (torch.Tensor): Correction strength tensor of shape (batch_size, total_seqlen, num_heads), where each element is in [0, 1]. Default to None, which means gamma is equals to beta.
        cu_seqlens (torch.LongTensor): Cumulative sequence length tensor of shape (batch_size,)
        initial_S (torch.Tensor): Initial base state tensor of shape (batch_size, num_heads, head_dim). Default to None, which means initial state is 0.
        initial_R (torch.Tensor): Initial residual state tensor of shape (batch_size, num_heads, head_dim). Default to None, which means initial state is 0.
        output_final_state (bool): Whether to output the final state. Default to True.
        scale (float): Scale factor. Default to None, which means scale is head_dim_qk**-0.5
        rclip (float): Clip residual error to [-rclip, rclip]. Default to 1.0. Set to inf to disable clipping.
        l2_qk_norm (bool): Whether to normalize the query and key tensors. Default to True.

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, total_seqlen, num_heads, head_dim)
        torch.Tensor: Final base state tensor of shape (batch_size, num_heads, head_dim)
        torch.Tensor: Final residual state tensor of shape (batch_size, num_heads, head_dim)
    """
    o, final_state_S, final_state_R = ChunkParallelRLA.apply(
        q,
        k,
        v,
        alpha,
        beta,
        gamma,
        cu_seqlens,
        initial_S,
        initial_R,
        output_final_state,
        scale,
        rclip,
        l2_qk_norm,
        64,
        True,
    )
    return o, final_state_S, final_state_R
