import math
from typing import Any

import triton
import torch
import triton.language as tl

from ..utils.set_device import SetDevice


@triton.jit
def flash_attn_kernel(
        q_buffer, q_stride_n, q_stride_h, q_stride_i, q_stride_e, q_stride_d,
        # (bsz, num_heads, q_len, num_extra_tokens, head_dim)
        k_buffer, k_stride_n, k_stride_h, k_stride_i, k_stride_d,  # (bsz, num_heads, k_len, head_dim)
        ka_buffer, ka_stride_n, ka_stride_h, ka_stride_i, ka_stride_e, ka_stride_d,
        # (bsz, num_heads, q_len, num_extra_tokens, head_dim)
        v_buffer, v_stride_n, v_stride_h, v_stride_i, v_stride_v,  # (bsz, num_heads, k_len, value_dim)
        va_buffer, va_stride_n, va_stride_h, va_stride_i, va_stride_e, va_stride_v,
        # (bsz, num_heads, q_len, num_extra_tokens, value_dim)
        r_buffer, r_stride_n, r_stride_h, r_stride_i, r_stride_e, r_stride_v,  # (bsz, num_heads, q_len, value_dim)
        l_buffer, l_stride_n, l_stride_h, l_stride_i, l_stride_e,  # (bsz, num_heads, q_len)
        mult_factor: float,
        q_len, k_len, head_dim: tl.constexpr, value_dim: tl.constexpr, num_extra_tokens: tl.constexpr,
        query_offset, B_r: tl.constexpr, B_c: tl.constexpr, return_l: tl.constexpr) -> Any:

    n = tl.program_id(1)
    h = tl.program_id(2) // num_extra_tokens
    e = tl.program_id(2) % num_extra_tokens
    i = tl.program_id(0)

    T_r = tl.cdiv(q_len, B_r)
    T_c = tl.cdiv(k_len, B_c)

    d = tl.arange(0, head_dim)
    v = tl.arange(0, value_dim)

    r_block_begin = i * B_r
    r_block_end = min(q_len, r_block_begin + B_r)
    r_indices = r_block_begin + tl.arange(0, B_r)

    dtype = tl.bfloat16

    Q_i = tl.load(
        q_buffer
        + n * q_stride_n
        + h * q_stride_h
        + r_indices[:, None] * q_stride_i
        + e * q_stride_e
        + d[None, :] * q_stride_d,
        mask=(r_indices[:, None] < r_block_end),
        other=0.0
    ).to(dtype)
    Q_i *= mult_factor.to(dtype)  # (B_r, head_dim)

    O_ij = tl.zeros((B_r, value_dim), dtype=dtype)
    l_ij = tl.full((B_r,), 1.0, dtype=dtype)
    m_ij = tl.full((B_r,), float('-inf'), dtype=dtype)

    for j in range(T_c):
        c_block_begin = j * B_c
        c_block_end = min(k_len, c_block_begin + B_c)
        c_indices = c_block_begin + tl.arange(0, B_c)

        K_j = tl.load(
            k_buffer
            + n * k_stride_n
            + h * k_stride_h
            + c_indices[:, None] * k_stride_i
            + d[None, :] * k_stride_d,
            mask=(c_indices[:, None] < c_block_end),
            other=0.0,
        ).to(dtype)  # (B_c, head_dim)

        V_j = tl.load(
            v_buffer
            + n * v_stride_n
            + h * v_stride_h
            + c_indices[:, None] * v_stride_i
            + v[None, :] * v_stride_v,
            mask=(c_indices[:, None] < c_block_end),
            other=0.0,
        ).to(dtype)  # (B_c, value_dim)

        S_ij = tl.dot(Q_i, tl.trans(K_j)).to(dtype)  # (B_r, B_c)

        # Apply causal mask
        k_positions = c_indices
        q_positions = query_offset + r_indices
        causal_mask = k_positions[None, :] <= q_positions[:, None]
        S_ij = tl.where(causal_mask, S_ij, tl.full((), float('-inf'), dtype=dtype))

        m_ijm1 = m_ij
        m_ij = tl.maximum(m_ij, tl.max(S_ij, axis=1)).to(dtype)  # r
        P_tilde_ij = tl.exp((S_ij - m_ij[:, None]).to(tl.float32)).to(dtype)  # r x c
        m_diff_exp = tl.exp((m_ijm1 - m_ij).to(tl.float32)).to(dtype)  # r
        l_ij = l_ij * m_diff_exp + tl.sum(P_tilde_ij, axis=1).to(dtype)  # r
        O_ij = O_ij * m_diff_exp[:, None] + tl.dot(P_tilde_ij, V_j).to(dtype)

        # Extra tokens
        if ka_buffer is not None:
            K_j = tl.load(
                ka_buffer
                + n * ka_stride_n
                + h * ka_stride_h
                + c_indices[:, None] * ka_stride_i
                + e * ka_stride_e
                + d[None, :] * ka_stride_d,
                mask=(c_indices[:, None] < c_block_end),
                other=0.0,
            ).to(dtype)  # (B_c, head_dim)

            V_j = tl.load(
                va_buffer
                + n * va_stride_n
                + h * va_stride_h
                + c_indices[:, None] * va_stride_i
                + e * va_stride_e
                + v[None, :] * va_stride_v,
                mask=(c_indices[:, None] < c_block_end),
                other=0.0,
            ).to(dtype)  # (B_c, value_dim)

            S_ij = tl.dot(Q_i, tl.trans(K_j)).to(dtype)  # (B_r, B_c)

            # Apply causal mask
            k_positions = c_indices + 1
            q_positions = query_offset + r_indices
            causal_mask = k_positions[None, :] == q_positions[:, None]  # NOTE: 'equal to' is correct here
            S_ij = tl.where(causal_mask, S_ij, tl.full((), float('-inf'), dtype=dtype))

            m_ijm1 = m_ij
            m_ij = tl.maximum(m_ij, tl.max(S_ij, axis=1)).to(dtype)  # r
            P_tilde_ij = tl.exp((S_ij - m_ij[:, None]).to(tl.float32)).to(dtype)  # r x c
            m_diff_exp = tl.exp((m_ijm1 - m_ij).to(tl.float32)).to(dtype)  # r
            l_ij = l_ij * m_diff_exp + tl.sum(P_tilde_ij, axis=1).to(dtype)  # r
            O_ij = O_ij * m_diff_exp[:, None] + tl.dot(P_tilde_ij, V_j).to(dtype)

    O_ij *= (1 / l_ij)[:, None]
    L_i = m_ij + tl.log(l_ij.to(tl.float32)).to(dtype)

    tl.store(
        r_buffer
        + n * r_stride_n
        + h * r_stride_h
        + r_indices[:, None] * r_stride_i
        + e * r_stride_e
        + v[None, :] * r_stride_v,
        O_ij,
        mask=(r_indices[:, None] < r_block_end),
    )

    if return_l:
        tl.store(
            l_buffer
            + n * l_stride_n
            + h * l_stride_h
            + r_indices * l_stride_i
            + e * l_stride_e,
            L_i,
            mask=(r_indices < r_block_end),
        )


def attn_flash_triton(
        query_states, key_states, value_states,
        query_offset: int,
        key_states_extra=None, value_states_extra=None,
        B_r: int = None, B_c: int = None,
        out=None, return_l=True):
    """
    FlashAttention in triton
    :param query_states: (bsz, num_heads, q_len, num_extra_tokens, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param query_offset: offset of the query
    :param key_states_extra: (bsz, num_heads, q_len, num_extra_tokens, head_dim)
    :param value_states_extra: (bsz, num_heads, q_len, num_extra_tokens, value_dim)
    :param B_r: number of rows in the flash block
    :param B_c: number of columns in the flash block
    :param out: output tensor to store the output
    :param return_l: whether to compute and return logsumexp value for backward pass
    :return: output (bsz, num_heads, q_len, num_extra_tokens, value_dim)
             l (bsz, num_heads, q_len, num_extra_tokens)
    """

    device = query_states.device
    dtype = query_states.dtype

    # TODO: Find better default values for B_r and B_c
    if B_r is None:
        B_r = 32
    if B_c is None:
        B_c = 64

    bsz, num_heads, q_len, num_extra_tokens, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()

    T_r = triton.cdiv(q_len, B_r)
    T_c = triton.cdiv(k_len, B_c)

    q, k, v = query_states, key_states, value_states
    ka, va = None, None
    ka_stride_n, ka_stride_h, ka_stride_i, ka_stride_e, ka_stride_d = None, None, None, None, None
    va_stride_n, va_stride_h, va_stride_i, va_stride_e, va_stride_v = None, None, None, None, None
    if key_states_extra is not None:
        ka, va = key_states_extra, value_states_extra
        ka_stride_n, ka_stride_h, ka_stride_i, ka_stride_e, ka_stride_d = (
            ka.stride(0), ka.stride(1), ka.stride(2), ka.stride(3), ka.stride(4))
        va_stride_n, va_stride_h, va_stride_i, va_stride_e, va_stride_v = (
            va.stride(0), va.stride(1), va.stride(2), va.stride(3), va.stride(4))

    r = out
    if r is None:
        r = torch.zeros(bsz, num_heads, q_len, num_extra_tokens, value_dim, dtype=dtype, device=device)
    l, l_stride_n, l_stride_h, l_stride_i, l_stride_e = None, None, None, None, None
    if return_l:
        l = torch.zeros(bsz, num_heads, q_len, num_extra_tokens, dtype=dtype, device=device)
        l_stride_n, l_stride_h, l_stride_i, l_stride_e = l.stride(0), l.stride(1), l.stride(2), l.stride(3)

    grid = (T_r, bsz, num_heads * num_extra_tokens)
    with SetDevice(query_states.device):
        compile_info = flash_attn_kernel[grid](
            q_buffer=q, q_stride_n=q.stride(0), q_stride_h=q.stride(1), q_stride_i=q.stride(2), q_stride_e=q.stride(3),
            q_stride_d=q.stride(4),
            k_buffer=k, k_stride_n=k.stride(0), k_stride_h=k.stride(1), k_stride_i=k.stride(2), k_stride_d=k.stride(3),
            ka_buffer=ka, ka_stride_n=ka_stride_n, ka_stride_h=ka_stride_h, ka_stride_i=ka_stride_i,
            ka_stride_e=ka_stride_e, ka_stride_d=ka_stride_d,
            v_buffer=v, v_stride_n=v.stride(0), v_stride_h=v.stride(1), v_stride_i=v.stride(2), v_stride_v=v.stride(3),
            va_buffer=va, va_stride_n=va_stride_n, va_stride_h=va_stride_h, va_stride_i=va_stride_i,
            va_stride_e=va_stride_e, va_stride_v=va_stride_v,
            r_buffer=r, r_stride_n=r.stride(0), r_stride_h=r.stride(1), r_stride_i=r.stride(2), r_stride_e=r.stride(3),
            r_stride_v=r.stride(4),
            l_buffer=l, l_stride_n=l_stride_n, l_stride_h=l_stride_h, l_stride_i=l_stride_i, l_stride_e=l_stride_e,
            mult_factor=1.0 / math.sqrt(head_dim),
            q_len=q_len, k_len=k_len, head_dim=head_dim, value_dim=value_dim, num_extra_tokens=num_extra_tokens,
            query_offset=query_offset, B_r=B_r, B_c=B_c, return_l=return_l,
            num_stages=1,
            num_warps=16,
        )

    if return_l:
        return r, l, compile_info
    return r, compile_info
