import math
import os
from typing import Any

import torch
import triton
import triton.language as tl

from ..utils.set_device import SetDevice
from .full_forward import attn_flash_triton
from ..common import calc_dims


@triton.jit
def mm(a, b, ta: tl.constexpr = False, tb: tl.constexpr = False):
    if ta and not tb:
        return tl.dot(tl.trans(a), b)
    elif not ta and tb:
        return tl.dot(a, tl.trans(b))
    elif ta and tb:
        return tl.dot(tl.trans(a), tl.trans(b))
    else:
        return tl.dot(a, b)


@triton.jit
def bmm(a, b, ta: tl.constexpr = False, tb: tl.constexpr = False):
    A0: tl.constexpr = a.shape[0]
    A1: tl.constexpr = a.shape[1]
    B0: tl.constexpr = b.shape[0]
    B1: tl.constexpr = b.shape[1]
    BSZ: tl.constexpr = a.shape[2]
    C0: tl.constexpr = A0 if not ta else A1
    C1: tl.constexpr = B1 if not tb else B0

    if BSZ == 1:
        a = a.reshape(A0, A1)
        b = b.reshape(B0, B1)
        c = mm(a, b, ta, tb)
        c = c.reshape(C0, C1, 1)

    elif BSZ == 2:
        a0, a1 = tl.split(a)
        b0, b1 = tl.split(b)
        c0 = mm(a0, b0, ta, tb)
        c1 = mm(a1, b1, ta, tb)
        c = tl.join(c0, c1)

    elif BSZ == 4:
        a = a.reshape(A0, A1, 2, 2)
        b = b.reshape(B0, B1, 2, 2)
        a0, a1 = tl.split(a)
        b0, b1 = tl.split(b)

        a00, a01 = tl.split(a0)
        b00, b01 = tl.split(b0)
        c00 = mm(a00, b00, ta, tb)
        c01 = mm(a01, b01, ta, tb)
        c0 = tl.join(c00, c01)

        a10, a11 = tl.split(a1)
        b10, b11 = tl.split(b1)
        c10 = mm(a10, b10, ta, tb)
        c11 = mm(a11, b11, ta, tb)
        c1 = tl.join(c10, c11)

        c = tl.join(c0, c1)
        c = c.reshape(C0, C1, 4)

    else:
        tl.static_assert(False, "Only batch size 1 or 2 supported for now")

    return c.to(a.dtype)


@triton.jit
def pbmm(a, b):
    """bmm but b has no batch dimension"""
    A0: tl.constexpr = a.shape[0]
    A1: tl.constexpr = a.shape[1]
    BSZ: tl.constexpr = a.shape[2]
    B0: tl.constexpr = b.shape[0]
    B1: tl.constexpr = b.shape[1]
    a = a.permute(0, 2, 1).reshape(A0 * BSZ, A1)
    r = tl.dot(a, b)
    r = r.reshape(A0, BSZ, B1).permute(0, 2, 1)
    return r.to(a.dtype)


@triton.jit
def sparse_attn_kernel(
        q_buffer, q_stride_n, q_stride_h, q_stride_i, q_stride_d,  # (bsz, num_heads, q_len, head_dim)
        k_buffer, k_stride_n, k_stride_h, k_stride_i, k_stride_d,  # (bsz, num_heads, k_len, head_dim)
        m_buffer, m_stride_n, m_stride_h, m_stride_i, m_stride_t,  # (bsz, num_heads, q_blocks, top_k)
        v_buffer, v_stride_n, v_stride_h, v_stride_i, v_stride_v,  # (bsz, num_heads, k_len, value_dim)
        r_buffer, r_stride_n, r_stride_h, r_stride_i, r_stride_v,  # (bsz, num_heads, q_len, value_dim)
        l_buffer, l_stride_n, l_stride_h, l_stride_i,  # (bsz, num_heads, q_len)
        mult_factor: float,
        bsz, num_heads, q_len, k_len, head_dim: tl.constexpr, value_dim: tl.constexpr, top_k: tl.constexpr,
        block_size_q: tl.constexpr, block_size_k: tl.constexpr, query_offset, start_sink_tokens, end_sink_tokens,
        B_r: tl.constexpr, B_c: tl.constexpr, return_l: tl.constexpr) -> Any:

    n = tl.program_id(1)
    h = tl.program_id(2)
    i = tl.program_id(0)

    q_block_offset = query_offset // block_size_q
    q_start_padding = query_offset - q_block_offset * block_size_q
    q_blocks = tl.cdiv(q_len + query_offset, block_size_q) - q_block_offset
    q_end_padding = q_blocks * block_size_q - (q_start_padding + q_len)
    k_blocks = tl.cdiv(k_len, block_size_k)
    k_blocks = max(k_blocks, tl.cdiv((q_block_offset + q_blocks) * block_size_q, block_size_k))

    T_r = tl.cdiv(q_blocks * block_size_q, B_r)
    T_c = tl.cdiv(top_k * block_size_k, B_c)

    d = tl.arange(0, head_dim)
    v = tl.arange(0, value_dim)

    r_block_begin = i * B_r // block_size_q
    r_block_end = min(q_blocks, r_block_begin + B_r // block_size_q)
    r_indices = r_block_begin + tl.arange(0, B_r // block_size_q)  # sparse-block-level indices

    q = (
        tl.arange(0, B_r // block_size_q)[None, :] * block_size_q
        + tl.arange(0, block_size_q)[:, None]
    )  # token-level indices
    q_positions = (q_block_offset + r_block_begin) * block_size_q + q  # (block_size_q, B_r // block_size_q)
    q_indices = r_block_begin * block_size_q + q

    Q_i = tl.load(
        q_buffer
        + n * q_stride_n
        + h * q_stride_h
        + (q_indices[:, None, :] - q_start_padding) * q_stride_i
        + d[None, :, None] * q_stride_d,
        mask=(
            (query_offset <= q_positions[:, None, :])
            & (q_positions[:, None, :] < query_offset + q_len)
        ),
        other=0.0
    )  # (block_size_q, head_dim, B_r // block_size_q)
    Q_i = (Q_i * mult_factor).to(Q_i.dtype)

    sink_end_start_indices = tl.cdiv(
        (q_block_offset + r_block_begin + 1 + tl.arange(0, B_r // block_size_q)) * block_size_q,
        block_size_k
    ) * block_size_k - end_sink_tokens

    dtype = Q_i.dtype
    NEG_INF = -10000.0
    O_ij = tl.zeros((block_size_q, value_dim, B_r // block_size_q), dtype=dtype)
    l_ij = tl.full((block_size_q, B_r // block_size_q), 1.0, dtype=dtype)
    m_ij = tl.full((block_size_q, B_r // block_size_q), NEG_INF, dtype=dtype)

    # Handle sink tokens
    for j in range(tl.cdiv(start_sink_tokens, B_c)):
        c_block_begin = j * B_c // block_size_k
        c_block_end = min(start_sink_tokens // block_size_k, c_block_begin + B_c // block_size_k)
        c_indices = c_block_begin + tl.arange(0, B_c // block_size_k)  # sparse-block-level indices

        k_indices = c_indices[:, None] * block_size_k + tl.arange(0, block_size_k)[None, :]  # (B_c // block_size_k, block_size_k)
        K_j = tl.load(
            k_buffer
            + n * k_stride_n
            + h * k_stride_h
            + k_indices[None, :, :] * k_stride_i
            + d[:, None, None] * k_stride_d,
            mask=(
                (c_indices[None, :, None] < c_block_end)
                & (k_indices[None, :, :] < k_len)
            ),
            other=0
        )  # (head_dim, B_c // block_size_k, block_size_k)
        K_j = K_j.reshape(head_dim, B_c)

        V_j = tl.load(
            v_buffer
            + n * v_stride_n
            + h * v_stride_h
            + k_indices[:, :, None] * v_stride_i
            + v[None, None, :] * v_stride_v,
            mask=(
                (c_indices[:, None, None] < c_block_end)
                & (k_indices[:, :, None] < k_len)
            ),
            other=0
        )  # (B_c // block_size_k, block_size_k, value_dim)
        V_j = V_j.reshape(B_c, value_dim)

        S_ij = pbmm(Q_i, K_j)  # (block_size_q, B_c, B_r // block_size_q)

        # Apply causal mask
        k_positions = k_indices[None, :, :, None]  # (1, B_c // block_size_k, block_size_k, 1)
        valid_mask = (
            (k_positions <= q_positions[:, None, None, :])
            & (k_positions < sink_end_start_indices[None, None, None, :])
            & (c_indices[None, :, None, None] < c_block_end)
        )
        valid_mask = valid_mask.reshape(block_size_q, B_c, B_r // block_size_q)

        S_ij = tl.where(valid_mask, S_ij, tl.full((), NEG_INF, dtype=dtype))

        m_ijm1 = m_ij
        m_ij = tl.maximum(m_ij, tl.max(S_ij, axis=1)).to(dtype)  # (block_size_q, B_r // block_size_q)
        P_tilde_ij = tl.exp((S_ij - m_ij[:, None, :]).to(tl.float32)).to(dtype)  # (block_size_q, B_c, B_r // block_size_q)
        m_diff_exp = tl.exp((m_ijm1 - m_ij).to(tl.float32)).to(dtype)
        l_ij = l_ij * m_diff_exp + tl.sum(P_tilde_ij, axis=1).to(dtype)  # (block_size_q, B_r // block_size_q)
        O_ij = O_ij * m_diff_exp[:, None, :] + pbmm(P_tilde_ij, V_j)  # (block_size_q, value_dim, B_r // block_size_q)

    for j in range(tl.cdiv(end_sink_tokens, B_c)):
        c_block_begin = j * B_c // block_size_k
        c_block_end = min(end_sink_tokens // block_size_k, c_block_begin + B_c // block_size_k)
        c_indices = c_block_begin + tl.arange(0, B_c // block_size_k)

        k_block_indices = sink_end_start_indices[None, :] // block_size_k + c_indices[:, None]
        k_block_indices_valid = (k_block_indices >= 0)  # (B_c // block_size_k, B_r // block_size_q)
        k_indices = (
            k_block_indices[:, None, :] * block_size_k + tl.arange(0, block_size_k)[None, :, None]
        )  # (B_c // block_size_k, block_size_k, B_r // block_size_q)

        K_j = tl.load(
            k_buffer
            + n * k_stride_n
            + h * k_stride_h
            + k_indices[None, :, :, :] * k_stride_i
            + d[:, None, None, None] * k_stride_d,
            mask=(
                (k_indices[None, :, :, :] < k_len)
                & (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[None, :, None, None] < c_block_end)
                & k_block_indices_valid[None, :, None, :]
            ),
            other=0
        )  # (head_dim, B_c // block_size_k, block_size_k, B_r // block_size_q)
        K_j = K_j.reshape(head_dim, B_c, B_r // block_size_q)

        V_j = tl.load(
            v_buffer
            + n * v_stride_n
            + h * v_stride_h
            + k_indices[:, :, None, :] * v_stride_i
            + v[None, None, :, None] * v_stride_v,
            mask=(
                (k_indices[:, :, None, :] < k_len)
                & (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[:, None, None, None] < c_block_end)
                & k_block_indices_valid[:, None, None, :]
            ),
            other=0
        )  # (B_c // block_size_k, block_size_k, value_dim, B_r // block_size_q)
        V_j = V_j.reshape(B_c, value_dim, B_r // block_size_q)

        S_ij = bmm(Q_i, K_j)  # (block_size_q, B_c, B_r // block_size_q)

        # Apply causal mask
        k_positions = k_indices[None, :, :, :]  # (1, B_c // block_size_k, block_size_k, B_r // block_size_q)
        valid_mask = (
            (k_positions <= q_positions[:, None, None, :])
            & (c_indices[None, :, None, None] < c_block_end)
            & k_block_indices_valid[None, :, None, :]
        )

        valid_mask = valid_mask.reshape(block_size_q, B_c, B_r // block_size_q)
        S_ij = tl.where(valid_mask, S_ij, tl.full((), NEG_INF, dtype=dtype))

        m_ijm1 = m_ij
        m_ij = tl.maximum(m_ij, tl.max(S_ij, axis=1)).to(dtype)  # (block_size_q, B_r // block_size_q)
        P_tilde_ij = tl.exp((S_ij - m_ij[:, None, :]).to(tl.float32)).to(dtype)  # (block_size_q, B_c, B_r // block_size_q)
        m_diff_exp = tl.exp((m_ijm1 - m_ij).to(tl.float32)).to(dtype)
        l_ij = l_ij * m_diff_exp + tl.sum(P_tilde_ij, axis=1).to(dtype)  # (block_size_q, B_r // block_size_q)
        O_ij = O_ij * m_diff_exp[:, None, :] + bmm(P_tilde_ij, V_j)  # (block_size_q, value_dim, B_r // block_size_q)

    # Handle top-k
    for j in range(T_c):
        c_block_begin = j * B_c // block_size_k
        c_block_end = min(top_k, c_block_begin + B_c // block_size_k)
        c_indices = c_block_begin + tl.arange(0, B_c // block_size_k)  # sparse-block-level indices

        sparse_indices = tl.load(
            m_buffer
            + n * m_stride_n
            + h * m_stride_h
            + r_indices[None, :] * m_stride_i
            + c_indices[:, None] * m_stride_t,
            mask=(
                (r_indices[None, :] < r_block_end)
                & (c_indices[:, None] < c_block_end)
            ),
            other=0
        )  # sparse-block-level indices, (B_c // block_size_k, B_r // block_size_q)

        k_indices = (
            start_sink_tokens + sparse_indices[:, None, :] * block_size_k
            + tl.arange(0, block_size_k)[None, :, None],
        )  # (B_c // block_size_k, block_size_k, B_r // block_size_q)

        K_j = tl.load(
            k_buffer
            + n * k_stride_n
            + h * k_stride_h
            + k_indices[None, :, :, :] * k_stride_i
            + d[:, None, None, None] * k_stride_d,
            mask=(
                (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[None, :, None, None] < c_block_end)
                & (sparse_indices[None, :, None, :] >= 0)
            ),
            other=0
        )  # (head_dim, B_c // block_size_k, block_size_k, B_r // block_size_q)
        K_j = K_j.reshape(head_dim, B_c, B_r // block_size_q)

        V_j = tl.load(
            v_buffer
            + n * v_stride_n
            + h * v_stride_h
            + k_indices[:, :, None, :] * v_stride_i
            + v[None, None, :, None] * v_stride_v,
            mask=(
                (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[:, None, None, None] < c_block_end)
                & (sparse_indices[:, None, None, :] >= 0)
            ),
            other=0
        )  # (B_c // block_size_k, block_size_k, value_dim, B_r // block_size_q)
        V_j = V_j.reshape(B_c, value_dim, B_r // block_size_q)

        S_ij = bmm(Q_i, K_j)  # (block_size_q, B_c, B_r // block_size_q)

        # Apply causal mask
        k_positions = k_indices[None, :, :, :]  # (1, B_c // block_size_k, block_size_k, B_r // block_size_q)
        valid_mask = (
            (k_positions <= q_positions[:, None, None, :])
            & (k_positions < sink_end_start_indices[None, None, None, :])
            & (c_indices[None, :, None, None] < c_block_end)
        )

        # Skip length 0 blocks
        valid_mask &= (sparse_indices >= 0)[None, :, None, :]

        valid_mask = valid_mask.reshape(block_size_q, B_c, B_r // block_size_q)
        S_ij = tl.where(valid_mask, S_ij, tl.full((), NEG_INF, dtype=dtype))

        m_ijm1 = m_ij
        m_ij = tl.maximum(m_ij, tl.max(S_ij, axis=1)).to(dtype)  # (block_size_q, B_r // block_size_q)
        P_tilde_ij = tl.exp((S_ij - m_ij[:, None, :]).to(tl.float32)).to(dtype)  # (block_size_q, B_c, B_r // block_size_q)
        m_diff_exp = tl.exp((m_ijm1 - m_ij).to(tl.float32)).to(dtype)
        l_ij = l_ij * m_diff_exp + tl.sum(P_tilde_ij, axis=1).to(dtype)  # (block_size_q, B_r // block_size_q)
        O_ij = O_ij * m_diff_exp[:, None, :] + bmm(P_tilde_ij, V_j)  # (block_size_q, value_dim, B_r // block_size_q)

    O_i = O_ij / l_ij[:, None, :]  # (block_size_q, value_dim, B_r // block_size_q)
    L_i = m_ij + tl.log(l_ij.to(tl.float32)).to(dtype)

    tl.store(
        r_buffer
        + n * r_stride_n
        + h * r_stride_h
        + (q_indices[:, None, :] - q_start_padding) * r_stride_i
        + v[None, :, None] * r_stride_v,
        O_i,
        mask=(
            (query_offset <= q_positions[:, None, :])
            & (q_positions[:, None, :] < query_offset + q_len)
        ),
    )

    if return_l:
        tl.store(
            l_buffer
            + n * l_stride_n
            + h * l_stride_h
            + (q_indices - q_start_padding) * l_stride_i,
            L_i,
            mask=(
                (query_offset <= q_positions)
                & (q_positions < query_offset + q_len)
            ),
        )


def sparse_attn_triton_impl(
        query_states, key_states, sparse_indices, value_states,
        block_size_q: int, block_size_k: int, query_offset: int,
        start_sink_tokens: int, end_sink_tokens: int,
        B_r: int = None, B_c: int = None, out=None, return_l=True):
    """
    Sparse flashattention
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param sparse_indices: (bsz, num_heads, q_blocks, top_k)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param query_offset: offset of query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :param B_r: number of rows in the flash block
    :param B_c: number of columns in the flash block
    :param out: output tensor to store the output
    :param return_l: whether to compute and return logsumexp value for backward pass
    :return: tuple of (output, compile_info) or (output, l, compile_info).
             output is of shape (bsz, num_heads, q_len, value_dim),
             l is of shape (bsz, num_heads, q_len).
    """
    device = query_states.device
    dtype = query_states.dtype

    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()
    _, _, _, top_k = sparse_indices.size()

    # TODO: Find better default values for B_r and B_c
    if B_r is None:
        B_r = max(int(os.environ.get('QE_SA_BLOCK_BQ', 32)), block_size_q)
    if B_c is None:
        B_c = min(int(os.environ.get('QE_SA_BLOCK_BK', 64)), top_k)

    assert B_r % block_size_q == 0
    assert B_c % block_size_k == 0

    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )

    q, k, m, v = query_states, key_states, sparse_indices, value_states
    r = out
    if r is None:
        r = torch.zeros(bsz, num_heads, q_len, value_dim, dtype=dtype, device=device)
    l, l_stride_n, l_stride_h, l_stride_i = None, None, None, None
    if return_l:
        l = torch.zeros(bsz, num_heads, q_len, dtype=dtype, device=device)
        l_stride_n, l_stride_h, l_stride_i = l.stride(0), l.stride(1), l.stride(2)

    grid = (triton.cdiv(q_blocks * block_size_q, B_r), bsz, num_heads)
    with SetDevice(query_states.device):
        compile_info = sparse_attn_kernel[grid](
            q_buffer=q, q_stride_n=q.stride(0), q_stride_h=q.stride(1), q_stride_i=q.stride(2), q_stride_d=q.stride(3),
            k_buffer=k, k_stride_n=k.stride(0), k_stride_h=k.stride(1), k_stride_i=k.stride(2), k_stride_d=k.stride(3),
            m_buffer=m, m_stride_n=m.stride(0), m_stride_h=m.stride(1), m_stride_i=m.stride(2), m_stride_t=m.stride(3),
            v_buffer=v, v_stride_n=v.stride(0), v_stride_h=v.stride(1), v_stride_i=v.stride(2), v_stride_v=v.stride(3),
            r_buffer=r, r_stride_n=r.stride(0), r_stride_h=r.stride(1), r_stride_i=r.stride(2), r_stride_v=r.stride(3),
            l_buffer=l, l_stride_n=l_stride_n, l_stride_h=l_stride_h, l_stride_i=l_stride_i,
            mult_factor=1.0 / math.sqrt(head_dim),
            bsz=bsz, num_heads=num_heads, q_len=q_len, k_len=k_len, head_dim=head_dim, value_dim=value_dim, top_k=top_k,
            block_size_q=block_size_q, block_size_k=block_size_k, query_offset=query_offset,
            start_sink_tokens=start_sink_tokens, end_sink_tokens=end_sink_tokens,
            B_r=B_r, B_c=B_c, return_l=return_l,
            num_stages=1,
            num_warps=16,
        )

    if return_l:
        return r, l, compile_info
    return r, compile_info


def sparse_attn_triton(
        query_states, key_states, sparse_indices, value_states,
        block_size_q: int, block_size_k: int, query_offset: int,
        start_sink_tokens: int, end_sink_tokens: int,
        B_r: int = None, B_c: int = None):
    """
    Sparse flashattention
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param sparse_indices: (bsz, num_heads, q_blocks, top_k)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param query_offset: offset of query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :param B_r: number of rows in the flash block
    :param B_c: number of columns in the flash block
    :param out: output tensor to store the output
    :param return_l: whether to compute and return logsumexp value for backward pass
    :return: tuple of (output, compile_info) or (output, l, compile_info).
             output is of shape (bsz, num_heads, q_len, value_dim),
             l is of shape (bsz, num_heads, q_len).
    """
    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()
    _, _, _, top_k = sparse_indices.size()

    assert end_sink_tokens >= block_size_q
    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )
    total_attended_tokens = start_sink_tokens + top_k * block_size_k + end_sink_tokens
    sparse_q_block_begin = max(0, min(q_blocks, triton.cdiv(total_attended_tokens, block_size_q) - 1 - q_block_offset))
    assert sparse_indices.size(2) == q_blocks - sparse_q_block_begin

    cutoff = max(0, sparse_q_block_begin * block_size_q - q_start_padding)

    # Perform dense attention for the initial part
    attn_output_initial, L_initial, _ = attn_flash_triton(
        query_states[:, :, :cutoff].unsqueeze(3), key_states, value_states,
        B_r=32, B_c=32, query_offset=query_offset,
    )
    attn_output_initial = attn_output_initial.squeeze(3)
    L_initial = L_initial.squeeze(3)

    attn_output, L, _ = sparse_attn_triton_impl(
        query_states[:, :, cutoff:], key_states, sparse_indices, value_states,
        block_size_q, block_size_k, cutoff + query_offset,
        start_sink_tokens, end_sink_tokens,
        B_r, B_c,
    )

    return (
        torch.cat([attn_output_initial, attn_output], dim=2),
        torch.cat([L_initial, L], dim=2),
    )
