import torch
import torch.autograd

from .full_forward import attn_flash_triton
from .full_backward import attn_flash_bwd_triton


class FlashAttention(torch.autograd.Function):
    @staticmethod
    def forward(  # noqa
            ctx, query_states, key_states, value_states,
            query_offset: int,
            key_states_extra=None, value_states_extra=None):
        output, L, _ = attn_flash_triton(
            query_states, key_states, value_states,
            query_offset=query_offset,
            key_states_extra=key_states_extra,
            value_states_extra=value_states_extra,
            return_l=True
        )
        if key_states_extra is not None:
            ctx.save_for_backward(query_states, key_states, value_states,
                                  key_states_extra, value_states_extra, output, L)
        else:
            ctx.save_for_backward(query_states, key_states, value_states, output, L)
        ctx.query_offset = query_offset
        ctx.using_extra = key_states_extra is not None
        return output

    @staticmethod
    def backward(ctx, grad_output):  # noqa
        if ctx.using_extra:
            query_states, key_states, value_states, key_states_extra, value_states_extra, output, L = ctx.saved_tensors
        else:
            query_states, key_states, value_states, output, L = ctx.saved_tensors
            key_states_extra = value_states_extra = None
        query_offset = ctx.query_offset

        (grad_query, grad_key, grad_value, grad_key_extra, grad_value_extra), _ = attn_flash_bwd_triton(
            query_states, key_states, value_states,
            output, grad_output, L,
            query_offset=query_offset,
            key_states_extra=key_states_extra,
            value_states_extra=value_states_extra,
        )

        return grad_query, grad_key, grad_value, None, grad_key_extra, grad_value_extra


def flash_attn(query_states, key_states, value_states, query_offset: int,
               key_states_extra=None, value_states_extra=None):
    """
    :param query_states: (bsz, num_heads, q_len, head_dim)
                         or (bsz, num_heads, q_len, num_extra_tokens, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param query_offset: offset of the query
    :param key_states_extra: (bsz, num_heads, q_len, num_extra_tokens, head_dim)
    :param value_states_extra: (bsz, num_heads, q_len, num_extra_tokens, value_dim)
    :return: output (bsz, num_heads, q_len, num_extra_tokens, value_dim)
    """

    is_vanilla = query_states.ndim == 4

    if is_vanilla:
        query_states = query_states.unsqueeze(3)
    else:
        assert key_states_extra is not None
        assert value_states_extra is not None

    output = FlashAttention.apply(
        query_states, key_states, value_states, query_offset,
        key_states_extra, value_states_extra)

    if is_vanilla:
        output = output.squeeze(3)

    return output
