import math
from typing import Any

import torch
import triton
import triton.language as tl

from ..utils.set_device import SetDevice


@triton.jit
def flash_attn_bwd_kernel(
        q_buffer, q_stride_n, q_stride_h, q_stride_i, q_stride_e, q_stride_d,
        # (bsz, num_heads, q_len, num_extra_tokens, head_dim)
        k_buffer, k_stride_n, k_stride_h, k_stride_i, k_stride_d,  # (bsz, num_heads, k_len, head_dim)
        ka_buffer, ka_stride_n, ka_stride_h, ka_stride_i, ka_stride_e, ka_stride_d,
        # (bsz, num_heads, q_len, num_extra_tokens, head_dim)
        v_buffer, v_stride_n, v_stride_h, v_stride_i, v_stride_v,  # (bsz, num_heads, k_len, value_dim)
        va_buffer, va_stride_n, va_stride_h, va_stride_i, va_stride_e, va_stride_v,
        # (bsz, num_heads, q_len, num_extra_tokens, value_dim)
        do_buffer, do_stride_n, do_stride_h, do_stride_i, do_stride_e, do_stride_v,
        # (bsz, num_heads, q_len, num_extra_tokens, value_dim)
        l_buffer, l_stride_n, l_stride_h, l_stride_i, l_stride_e,  # (bsz, num_heads, q_len, num_extra_tokens)
        d_buffer, d_stride_n, d_stride_h, d_stride_i, d_stride_e,  # (bsz, num_heads, q_len, num_extra_tokens)
        dq_buffer, dq_stride_n, dq_stride_h, dq_stride_i, dq_stride_e, dq_stride_d,
        # (bsz, num_heads, q_len, num_extra_tokens, head_dim)
        dk_buffer, dk_stride_n, dk_stride_h, dk_stride_i, dk_stride_d,  # (bsz, num_heads, k_len, head_dim)
        dka_buffer, dka_stride_n, dka_stride_h, dka_stride_i, dka_stride_e, dka_stride_d,
        # (bsz, num_heads, q_len, num_extra_tokens, head_dim)
        dv_buffer, dv_stride_n, dv_stride_h, dv_stride_i, dv_stride_v,  # (bsz, num_heads, k_len, value_dim)
        dva_buffer, dva_stride_n, dva_stride_h, dva_stride_i, dva_stride_e, dva_stride_v,
        # (bsz, num_heads, q_len, num_extra_tokens, value_dim)
        mult_factor: float,
        q_len, k_len, head_dim: tl.constexpr, value_dim: tl.constexpr, num_extra_tokens: tl.constexpr,
        B_r: tl.constexpr, B_c: tl.constexpr, query_offset, begin_offset) -> Any:

    n = tl.program_id(0)
    h = tl.program_id(1)
    j = tl.program_id(2)

    T_r = tl.cdiv(q_len, B_r)
    T_c = tl.cdiv(k_len, B_c)

    d = tl.arange(0, head_dim)
    v = tl.arange(0, value_dim)

    c_block_begin = j * B_c
    c_block_end = min(k_len, c_block_begin + B_c)
    c_indices = c_block_begin + tl.arange(0, B_c)

    dtype = tl.bfloat16

    K_j = tl.load(
        k_buffer
        + n * k_stride_n
        + h * k_stride_h
        + c_indices[:, None] * k_stride_i
        + d[None, :] * k_stride_d,
        mask=(c_indices[:, None] < c_block_end),
        other=0.0
    ).to(dtype)

    V_j = tl.load(
        v_buffer
        + n * v_stride_n
        + h * v_stride_h
        + c_indices[:, None] * v_stride_i
        + v[None, :] * v_stride_v,
        mask=(c_indices[:, None] < c_block_end),
        other=0.0
    ).to(dtype)

    dK_j = tl.zeros((B_c, head_dim), dtype=dtype)
    dV_j = tl.zeros((B_c, value_dim), dtype=dtype)

    for e in range(num_extra_tokens):
        for i in range(T_r):
            r_block_begin = i * B_r
            r_block_end = min(q_len, r_block_begin + B_r)
            r_indices = r_block_begin + tl.arange(0, B_r)

            Q_i = tl.load(
                q_buffer
                + n * q_stride_n
                + h * q_stride_h
                + r_indices[:, None] * q_stride_i
                + e * q_stride_e
                + d[None, :] * q_stride_d,
                mask=(r_indices[:, None] < r_block_end),
                other=0.0
            ).to(dtype)
            Q_i *= mult_factor.to(dtype)

            dO_i = tl.load(
                do_buffer
                + n * do_stride_n
                + h * do_stride_h
                + r_indices[:, None] * do_stride_i
                + e * do_stride_e
                + v[None, :] * do_stride_v,
                mask=(r_indices[:, None] < r_block_end),
                other=0.0
            ).to(dtype)

            L_i = tl.load(
                l_buffer
                + n * l_stride_n
                + h * l_stride_h
                + r_indices * l_stride_i
                + e * l_stride_e,
                mask=(r_indices < r_block_end),
                other=0.0
            ).to(dtype)

            D_i = tl.load(
                d_buffer
                + n * d_stride_n
                + h * d_stride_h
                + r_indices * d_stride_i
                + e * d_stride_e,
                mask=(r_indices < r_block_end),
                other=0.0
            ).to(dtype)

            S_ij = tl.dot(Q_i, tl.trans(K_j)).to(dtype)  # r x d

            k_positions = begin_offset + c_indices
            q_positions = begin_offset + query_offset + r_indices
            causal_mask = k_positions[None, :] <= q_positions[:, None]
            S_ij = tl.where(causal_mask, S_ij, tl.full((), float('-inf'), dtype=dtype))

            P_ij = tl.exp((S_ij - L_i[:, None]).to(tl.float32)).to(dtype)  # r x c
            dV_j += tl.dot(tl.trans(P_ij), dO_i).to(dtype)  # c x v
            dP_ij = tl.dot(dO_i, tl.trans(V_j)).to(dtype)  # r x c
            dS_ij = P_ij * (dP_ij - D_i[:, None])  # r x c
            dQ_i_update = tl.dot(dS_ij, K_j).to(dtype)  # r x d
            tl.atomic_add(
                dq_buffer
                + n * dq_stride_n
                + h * dq_stride_h
                + r_indices[:, None] * dq_stride_i
                + e * dq_stride_e
                + d[None, :] * dq_stride_d,
                dQ_i_update * mult_factor.to(dtype),
                mask=(r_indices[:, None] < r_block_end),
            )
            dK_j += tl.dot(tl.trans(dS_ij), Q_i).to(dtype)  # c x d

    tl.store(
        dk_buffer
        + n * dk_stride_n
        + h * dk_stride_h
        + c_indices[:, None] * dk_stride_i
        + d[None, :] * dk_stride_d,
        dK_j,
        mask=(c_indices[:, None] < c_block_end),
    )
    tl.store(
        dv_buffer
        + n * dv_stride_n
        + h * dv_stride_h
        + c_indices[:, None] * dv_stride_i
        + v[None, :] * dv_stride_v,
        dV_j,
        mask=(c_indices[:, None] < c_block_end),
    )

    # Extra tokens
    if ka_buffer is not None:
        for e in range(num_extra_tokens):
            K_j = tl.load(
                ka_buffer
                + n * ka_stride_n
                + h * ka_stride_h
                + c_indices[:, None] * ka_stride_i
                + e * ka_stride_e
                + d[None, :] * ka_stride_d,
                mask=(c_indices[:, None] < c_block_end),
                other=0.0
            ).to(dtype)

            V_j = tl.load(
                va_buffer
                + n * va_stride_n
                + h * va_stride_h
                + c_indices[:, None] * va_stride_i
                + e * va_stride_e
                + v[None, :] * va_stride_v,
                mask=(c_indices[:, None] < c_block_end),
                other=0.0
            ).to(dtype)

            dK_j = tl.zeros((B_c, head_dim), dtype=dtype)
            dV_j = tl.zeros((B_c, value_dim), dtype=dtype)

            for i in range(T_r):
                r_block_begin = i * B_r
                r_block_end = min(q_len, r_block_begin + B_r)
                r_indices = r_block_begin + tl.arange(0, B_r)

                Q_i = tl.load(
                    q_buffer
                    + n * q_stride_n
                    + h * q_stride_h
                    + r_indices[:, None] * q_stride_i
                    + e * q_stride_e
                    + d[None, :] * q_stride_d,
                    mask=(r_indices[:, None] < r_block_end),
                    other=0.0
                ).to(dtype)
                Q_i *= mult_factor.to(dtype)

                dO_i = tl.load(
                    do_buffer
                    + n * do_stride_n
                    + h * do_stride_h
                    + r_indices[:, None] * do_stride_i
                    + e * do_stride_e
                    + v[None, :] * do_stride_v,
                    mask=(r_indices[:, None] < r_block_end),
                    other=0.0
                ).to(dtype)

                L_i = tl.load(
                    l_buffer
                    + n * l_stride_n
                    + h * l_stride_h
                    + r_indices * l_stride_i
                    + e * l_stride_e,
                    mask=(r_indices < r_block_end),
                    other=0.0
                ).to(dtype)

                D_i = tl.load(
                    d_buffer
                    + n * d_stride_n
                    + h * d_stride_h
                    + r_indices * d_stride_i
                    + e * d_stride_e,
                    mask=(r_indices < r_block_end),
                    other=0.0
                ).to(dtype)

                S_ij = tl.dot(Q_i, tl.trans(K_j)).to(dtype)  # r x d

                k_positions = begin_offset + c_indices + 1
                q_positions = begin_offset + query_offset + r_indices
                causal_mask = k_positions[None, :] == q_positions[:, None]  # NOTE: 'equal to' is correct here
                S_ij = tl.where(causal_mask, S_ij, tl.full((), float('-inf'), dtype=dtype))

                P_ij = tl.exp((S_ij - L_i[:, None]).to(tl.float32)).to(dtype)  # r x c
                dV_j += tl.dot(tl.trans(P_ij), dO_i).to(dtype)  # c x v
                dP_ij = tl.dot(dO_i, tl.trans(V_j)).to(dtype)  # r x c
                dS_ij = P_ij * (dP_ij - D_i[:, None])  # r x c
                dQ_i_update = tl.dot(dS_ij, K_j).to(dtype)  # r x d
                tl.atomic_add(
                    dq_buffer
                    + n * dq_stride_n
                    + h * dq_stride_h
                    + r_indices[:, None] * dq_stride_i
                    + e * dq_stride_e
                    + d[None, :] * dq_stride_d,
                    dQ_i_update * mult_factor.to(dtype),
                    mask=(r_indices[:, None] < r_block_end),
                )
                dK_j += tl.dot(tl.trans(dS_ij), Q_i).to(dtype)  # c x d

            tl.store(
                dka_buffer
                + n * dka_stride_n
                + h * dka_stride_h
                + c_indices[:, None] * dka_stride_i
                + e * dka_stride_e
                + d[None, :] * dka_stride_d,
                dK_j,
                mask=(c_indices[:, None] < c_block_end),
            )
            tl.store(
                dva_buffer
                + n * dva_stride_n
                + h * dva_stride_h
                + c_indices[:, None] * dva_stride_i
                + e * dva_stride_e
                + v[None, :] * dva_stride_v,
                dV_j,
                mask=(c_indices[:, None] < c_block_end),
            )


def attn_flash_bwd_triton(
        query_states, key_states, value_states, output, grad_output, L,
        query_offset: int,
        key_states_extra=None, value_states_extra=None,
        begin_offset: int = 0,
        B_r: int = None, B_c: int = None):
    """
    Flashattention backward pass using Triton
    :param query_states: (bsz, num_heads, q_len, num_extra_tokens, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param output: (bsz, num_heads, q_len, num_extra_tokens, value_dim)
    :param grad_output: (bsz, num_heads, q_len, num_extra_tokens, value_dim)
    :param L: (bsz, num_heads, q_len, num_extra_tokens)
    :param query_offset: offset of the query
    :param key_states_extra: (bsz, num_heads, q_len, num_extra_tokens, head_dim)
    :param value_states_extra: (bsz, num_heads, q_len, num_extra_tokens, value_dim)
    :param begin_offset: offset of both the query and key
    :param B_r: number of rows in the flash block
    :param B_c: number of columns in the flash block
    :return: (dQ, dK, dV)
    """
    device = query_states.device
    dtype = query_states.dtype

    if B_r is None:
        B_r = 64
    if B_c is None:
        B_c = 64

    assert query_states.ndim == 5
    assert key_states.ndim == 4
    assert value_states.ndim == 4
    assert output.ndim == 5
    assert grad_output.ndim == 5
    assert L.ndim == 4

    bsz, num_heads, q_len, num_extra_tokens, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()

    T_r = triton.cdiv(q_len, B_r)
    T_c = triton.cdiv(k_len, B_c)

    D = (grad_output * output).sum(dim=-1)  # (bsz, num_heads, q_len, num_extra_tokens, value_dim)

    q, k, v, o, do = query_states, key_states, value_states, output, grad_output
    ka, va = None, None
    ka_stride_n, ka_stride_h, ka_stride_i, ka_stride_e, ka_stride_d = None, None, None, None, None
    va_stride_n, va_stride_h, va_stride_i, va_stride_e, va_stride_v = None, None, None, None, None
    if key_states_extra is not None:
        ka, va = key_states_extra, value_states_extra
        ka_stride_n, ka_stride_h, ka_stride_i, ka_stride_e, ka_stride_d = (
            ka.stride(0), ka.stride(1), ka.stride(2), ka.stride(3), ka.stride(4))
        va_stride_n, va_stride_h, va_stride_i, va_stride_e, va_stride_v = (
            va.stride(0), va.stride(1), va.stride(2), va.stride(3), va.stride(4))

    acc_dtype = dtype
    if dtype == torch.bfloat16:  # workaround for bf16 atomic_add not supported
        acc_dtype = torch.float32
    dq = torch.zeros(bsz, num_heads, q_len, num_extra_tokens, head_dim, device=device, dtype=acc_dtype)
    dk = torch.zeros(bsz, num_heads, k_len, head_dim, device=device, dtype=dtype)
    dv = torch.zeros(bsz, num_heads, k_len, value_dim, device=device, dtype=dtype)
    dka, dva = None, None
    dka_stride_n, dka_stride_h, dka_stride_i, dka_stride_e, dka_stride_d = None, None, None, None, None
    dva_stride_n, dva_stride_h, dva_stride_i, dva_stride_e, dva_stride_v = None, None, None, None, None
    if key_states_extra is not None:
        dka = torch.zeros(bsz, num_heads, q_len, num_extra_tokens, head_dim, device=device, dtype=dtype)
        dva = torch.zeros(bsz, num_heads, q_len, num_extra_tokens, value_dim, device=device, dtype=dtype)
        dka_stride_n, dka_stride_h, dka_stride_i, dka_stride_e, dka_stride_d = (
            dka.stride(0), dka.stride(1), dka.stride(2), dka.stride(3), dka.stride(4))
        dva_stride_n, dva_stride_h, dva_stride_i, dva_stride_e, dva_stride_v = (
            dva.stride(0), dva.stride(1), dva.stride(2), dva.stride(3), dva.stride(4))

    grid = (bsz, num_heads, T_c)
    with SetDevice(query_states.device):
        compile_info = flash_attn_bwd_kernel[grid](
            q_buffer=q, q_stride_n=q.stride(0), q_stride_h=q.stride(1), q_stride_i=q.stride(2), q_stride_e=q.stride(3),
            q_stride_d=q.stride(4),

            k_buffer=k, k_stride_n=k.stride(0), k_stride_h=k.stride(1), k_stride_i=k.stride(2), k_stride_d=k.stride(3),
            ka_buffer=ka, ka_stride_n=ka_stride_n, ka_stride_h=ka_stride_h, ka_stride_i=ka_stride_i,
            ka_stride_e=ka_stride_e, ka_stride_d=ka_stride_d,

            v_buffer=v, v_stride_n=v.stride(0), v_stride_h=v.stride(1), v_stride_i=v.stride(2), v_stride_v=v.stride(3),
            va_buffer=va, va_stride_n=va_stride_n, va_stride_h=va_stride_h, va_stride_i=va_stride_i,
            va_stride_e=va_stride_e, va_stride_v=va_stride_v,

            do_buffer=do, do_stride_n=do.stride(0), do_stride_h=do.stride(1), do_stride_i=do.stride(2),
            do_stride_e=do.stride(3), do_stride_v=do.stride(4),

            l_buffer=L, l_stride_n=L.stride(0), l_stride_h=L.stride(1), l_stride_i=L.stride(2), l_stride_e=L.stride(3),
            d_buffer=D, d_stride_n=D.stride(0), d_stride_h=D.stride(1), d_stride_i=D.stride(2), d_stride_e=D.stride(3),

            dq_buffer=dq, dq_stride_n=dq.stride(0), dq_stride_h=dq.stride(1), dq_stride_i=dq.stride(2),
            dq_stride_e=dq.stride(3), dq_stride_d=dq.stride(4),

            dk_buffer=dk, dk_stride_n=dk.stride(0), dk_stride_h=dk.stride(1), dk_stride_i=dk.stride(2),
            dk_stride_d=dk.stride(3),
            dka_buffer=dka, dka_stride_n=dka_stride_n, dka_stride_h=dka_stride_h, dka_stride_i=dka_stride_i,
            dka_stride_e=dka_stride_e, dka_stride_d=dka_stride_d,

            dv_buffer=dv, dv_stride_n=dv.stride(0), dv_stride_h=dv.stride(1), dv_stride_i=dv.stride(2),
            dv_stride_v=dv.stride(3),
            dva_buffer=dva, dva_stride_n=dva_stride_n, dva_stride_h=dva_stride_h, dva_stride_i=dva_stride_i,
            dva_stride_e=dva_stride_e, dva_stride_v=dva_stride_v,

            mult_factor=1.0 / math.sqrt(head_dim),
            q_len=q_len, k_len=k_len, head_dim=head_dim, value_dim=value_dim, num_extra_tokens=num_extra_tokens,
            B_r=B_r, B_c=B_c, query_offset=query_offset, begin_offset=begin_offset,
            num_stages=1,
            num_warps=16,
        )

    dq = dq.to(dtype)

    return (dq, dk, dv, dka, dva), compile_info
