# Trying to load Tri Dao's implementation.
import torch


def attention_computation_flash(q, k, v, mask=None, center=False, debias=False):
    try:
        from flash_attn import flash_attn_func  # type: ignore
    except ImportError:
        raise ImportError(
            "This backend can only be used if you managed to compile the flash_attn repo on your machine."
        )
    B, S, nh, hs = q.shape
    # Transpose to have batch dim and n_head in the leading spots of the tensor
    q = q.transpose(1, 2)  # (B, nh, S, hs)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)

    # Self-attend: (B, nh, S, hs) x (B, nh, hs, S) -> (B, nh, S, S)
    # with sdpa_kernel(SDPBackend.FLASH_ATTENTION): # cannot compile without graph break
    y = flash_attn_func(
        q,
        k,
        v,
        dropout_p=0.0,
        softmax_scale=None,
        causal=True,
        window_size=(-1, -1),
        alibi_slopes=None,
        deterministic=False,
    )
    # returns (B, nh, S, hs)
    if center:
        y = y + v
    if debias:
        # y = y - 1 / S # cheapo version
        y = y - y.cumsum(dim=2) / torch.arange(0, S, device=q.device, dtype=torch.float)
    return y
