import torch
import triton
from typing import Optional, Tuple
from .chunk_h import chunk_fwd_h, chunk_bwd_h
from .chunk_o import chunk_fwd_o, chunk_bwd_dqka, chunk_bwd_dvb
from ..common.l2norm import l2norm_bwd, l2norm_fwd
from ..common.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard


class ChunkParallelSGLA(torch.autograd.Function):
    @staticmethod
    @input_guard
    @autocast_custom_fwd
    def forward(
        ctx,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        alpha: torch.Tensor,
        beta: Optional[torch.Tensor],
        cu_seq_len: Optional[torch.Tensor] = None,
        initial_S: Optional[torch.Tensor] = None,
        output_final_state: bool = False,
        scale: Optional[float] = None,
        l2_qk_norm: bool = True,
        chunk_size: Optional[int] = None,
        state_in_fp32: bool = True,
    ):
        # shape check
        batch_size, seq_len, num_heads, head_dim_qk, head_dim_v = *q.shape, v.shape[-1]
        assert head_dim_qk == k.shape[-1]
        assert alpha.shape == (batch_size, seq_len, num_heads)
        if beta is not None:
            assert beta.shape == (batch_size, seq_len, num_heads)
        # default scale and chunk_size
        if scale is None:
            scale = head_dim_qk**-0.5
        if chunk_size is None:
            chunk_size = min(64, max(16, triton.next_power_of_2(seq_len)))
        else:
            assert chunk_size in {32, 64, 128}, "Chunk size must be 32, 64, or 128"
        # set cu_seq_len
        if cu_seq_len is not None:
            assert batch_size == 1, "Only support batch size 1 when using `cu_seq_len`."
            cu_seq_len = cu_seq_len.to(torch.int32)
            batch_size = cu_seq_len.shape[0] - 1
            cu_chunk_len = torch.zeros_like(cu_seq_len)
            cu_chunk_len[1:] = torch.ceil(
                (cu_seq_len[1:] - cu_seq_len[:-1]) / chunk_size
            ).to(torch.int32)
            cu_chunk_len = torch.cumsum(cu_chunk_len, dim=0)
        else:
            cu_seq_len = None
            cu_chunk_len = None
        # norm
        if l2_qk_norm:
            q, q_rstd = l2norm_fwd(q)
            k, k_rstd = l2norm_fwd(k)
        else:
            q_rstd, k_rstd = None, None
        # set initial state
        if initial_S is not None:
            assert initial_S.shape == (
                batch_size,
                num_heads,
                head_dim_qk,
                head_dim_v,
            )
        # compute chunk state
        chunk_S, last_S = chunk_fwd_h(
            k,
            v,
            alpha,
            beta,
            initial_S,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
            output_final_state,
            state_in_fp32,
        )
        # compute output, S_{t-1}q_{t} and S_{t-1}k_{t}
        o = chunk_fwd_o(
            q,
            k,
            v,
            alpha,
            beta,
            chunk_S,
            scale,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
        )
        # save for backward
        ctx.save_for_backward(
            q,
            q_rstd,
            k,
            k_rstd,
            v,
            alpha,
            beta,
            initial_S,
            cu_seq_len,
            cu_chunk_len,
        )
        ctx.chunk_size = chunk_size
        ctx.scale = scale
        ctx.state_in_fp32 = state_in_fp32
        ctx.l2_qk_norm = l2_qk_norm
        return o, last_S

    @staticmethod
    @input_guard
    @autocast_custom_bwd
    def backward(ctx, do, dls):
        (
            q,
            q_rstd,
            k,
            k_rstd,
            v,
            alpha,
            beta,
            initial_S,
            cu_seq_len,
            cu_chunk_len,
        ) = ctx.saved_tensors
        chunk_size, scale, state_in_fp32, l2_qk_norm = (
            ctx.chunk_size,
            ctx.scale,
            ctx.state_in_fp32,
            ctx.l2_qk_norm,
        )
        # recompute chunk state for S
        chunk_S, _ = chunk_fwd_h(
            k,
            v,
            alpha,
            beta,
            initial_S,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
            False,
            state_in_fp32,
        )
        # compute dS
        d_chunk_S = chunk_bwd_h(
            q,
            do,
            alpha,
            dls,
            scale,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
            state_in_fp32,
        )
        # compute gradient of q, k, alpha
        dq, dk, dalpha = chunk_bwd_dqka(
            q,
            k,
            v,
            alpha,
            beta,
            chunk_S,
            do,
            d_chunk_S,
            scale,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
        )
        # compute gradient of v
        dv, dbeta = chunk_bwd_dvb(
            q,
            k,
            v,
            alpha,
            beta,
            do,
            d_chunk_S,
            scale,
            cu_seq_len,
            cu_chunk_len,
            chunk_size,
        )
        # bwd for l2 norm
        if l2_qk_norm:
            dq = l2norm_bwd(q, q_rstd, dq)
            dk = l2norm_bwd(k, k_rstd, dk)
        return (
            dq.to(q.dtype),
            dk.to(k.dtype),
            dv.to(v.dtype),
            dalpha.to(alpha.dtype),
            dbeta.to(beta.dtype),
            None,  # cu_seq_len
            None,  # initial_S
            None,  # output_final_state
            None,  # scale
            None,  # l2_qk_norm
            None,  # chunk_size
            None,  # state_in_fp32
        )


@torch.compiler.disable
def sgla_prefill(
    q: torch.Tensor,  # shape: batch_size, total_seqlen, num_heads, head_dim
    k: torch.Tensor,  # shape: batch_size, total_seqlen, num_heads, head_dim
    v: torch.Tensor,  # shape: batch_size, total_seqlen, num_heads, head_dim
    alpha: torch.Tensor,  # shape: batch_size, total_seqlen, num_heads
    beta: Optional[torch.Tensor] = None,  # shape: batch_size, total_seqlen, num_heads
    cu_seqlens: Optional[torch.LongTensor] = None,
    initial_S: Optional[torch.Tensor] = None,
    output_final_state: bool = True,
    scale: Optional[float] = None,
    l2_qk_norm: bool = True,
    **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    """
    Scalar Gate Linear Attention, S_t = exp(alpha_t) * S_{t-1} + beta_t * v_t.T * k_t, alpha_t and beta_t are head-wise scalars.

    Args:
        q (torch.Tensor): Query tensor of shape (batch_size, total_seqlen, num_heads, head_dim)
        k (torch.Tensor): Key tensor of shape (batch_size, total_seqlen, num_heads, head_dim)
        v (torch.Tensor): Value tensor of shape (batch_size, total_seqlen, num_heads, head_dim)
        alpha (torch.Tensor): Log of gate tensor of shape (batch_size, total_seqlen, num_heads), where each element is in [-inf, 0]
        beta (torch.Tensor): Learning rate tensor of shape (batch_size, total_seqlen, num_heads), where each element is in [0, 1]. Default to None, which means beta is 1.
        initial_state (torch.Tensor): Initial state tensor of shape (batch_size, num_heads, head_dim). Default to None, which means initial state is 0.
        scale (float): Scale factor. Default to None, which means scale is head_dim_qk**-0.5
        output_final_state (bool): Whether to output the final state. Default to True.
        cu_seqlens (torch.LongTensor): Cumulative sequence length tensor of shape (batch_size,)
        l2_qk_norm (bool): Whether to normalize the query and key tensors. Default to False.

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, total_seqlen, num_heads, head_dim)
        torch.Tensor: Final state tensor of shape (batch_size, num_heads, head_dim)
    """
    o, final_state_S = ChunkParallelSGLA.apply(
        q,
        k,
        v,
        alpha,
        beta,
        cu_seqlens,
        initial_S,
        output_final_state,
        scale,
        l2_qk_norm,
        64,
        True,
    )
    return o, final_state_S
