from xtuner.parallel.sequence import sequence_parallel_wrapper
from .utils import upad_qkv

SUPPORT_FLASH2 = False

try:
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import pad_input
    SUPPORT_FLASH2 = True
except ImportError:
    pass


@sequence_parallel_wrapper
def flash_attn_wo_mask(
        query_states,
        key_states,
        value_states,
        dropout_p=0.0,
        softmax_scale=None,
        causal=True,
        window_size=(-1, -1),  # -1 means infinite context window
):
    attn_output = flash_attn_func(
        query_states,
        key_states,
        value_states,
        dropout_p=dropout_p,
        softmax_scale=softmax_scale,
        causal=causal,
        window_size=window_size)
    return attn_output


@sequence_parallel_wrapper
def flash_attn_w_mask(
        query_states,  # bs, q_len, nhead, h_dim
        key_states,
        value_states,
        attention_mask,
        causal=True,
        dropout_p=0.0,
        window_size=(-1, -1),  # -1 means infinite context window
):
    batch_size, q_len = query_states.shape[:2]
    query_states, key_states, value_states, indices_q, \
        cu_seq_lens, max_seq_lens = upad_qkv(
            query_states, key_states, value_states, attention_mask, q_len)

    cu_seqlens_q, cu_seqlens_k = cu_seq_lens
    max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
    attn_output_unpad = flash_attn_varlen_func(
        query_states,
        key_states,
        value_states,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_k=cu_seqlens_k,
        max_seqlen_q=max_seqlen_in_batch_q,
        max_seqlen_k=max_seqlen_in_batch_k,
        dropout_p=dropout_p,
        causal=causal,
        window_size=window_size)
    attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)
    return attn_output


@sequence_parallel_wrapper
def varlen_flash_attn(
        query_states,
        key_states,
        value_states,
        cumulative_len,
        max_seqlen,
        dropout_p=0.,
        causal=True,
        window_size=(-1, -1),  # -1 means infinite context window
):
    q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten(
        0, 1), value_states.flatten(0, 1)
    attn_output = flash_attn_varlen_func(
        q_unpad,
        k_unpad,
        v_unpad,
        cumulative_len,
        cumulative_len,
        max_seqlen,
        max_seqlen,
        dropout_p=dropout_p,
        return_attn_probs=False,
        causal=causal,
        window_size=window_size)
    attn_output = attn_output.unsqueeze(0)
    return attn_output
