from functools import partial

import torch
from torch import Tensor
from torch.nn.attention.flex_attention import create_block_mask, flex_attention


@torch.compile
def dot_product_attention(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    mask: Tensor,
    ssmax: Tensor | None = None,
    bias: Tensor | None = None,
) -> Tensor:
    """Attention with optional scalable-softmax and ALiBi bias.

    ---
    Args:
        query: Query tensor.
            Shape of [batch_size, n_heads, n_target, embed_dim].
        key: Key tensor.
            Shape of [batch_size, n_heads, n_source, embed_dim].
        value: Value tensor.
            Shape of [batch_size, n_heads, n_source, embed_dim].
        mask: Attention mask.
            Shape of [batch_size, n_target, n_source].
        ssmax: Scalable-softmax scalings.
            Shape of [batch_size, n_target].
        bias: ALiBi bias to apply to the QK^T scores.
            Shape of [batch_size, n_heads, n_target, n_source].

    ---
    Returns:
        The dot-product attention.
            Shape of [batch_size, n_heads, n_target, embed_dim].

    ---
    Note:
        `embed_dim` must be a power of 2.
        `n_target` and `n_source` must be a multiple of 128.
    """
    batch_size, _, n_target, _ = query.shape
    _, _, n_source, _ = key.shape

    def alibi_mod(bias: Tensor, s: Tensor, b: Tensor, h: Tensor, q: Tensor, k: Tensor) -> Tensor:
        return s + bias[b, h, q, k]

    def mask_mod(mask: Tensor, b: Tensor, h: Tensor, q: Tensor, k: Tensor) -> Tensor:
        return mask[b, q, k]

    # It seems like torch compiler don't like if more than one tensor is closed in the score_mod
    # function. The scalable-softmax can be applied outside the attention so we profit from this and
    # do the rescaling there.
    if ssmax is not None:
        query = torch.einsum("bhle,bl->bhle", query, ssmax)

    if ssmax is not None and bias is not None:
        bias = torch.einsum("bhls,bl->bhls", bias, ssmax)

    score_mod = partial(alibi_mod, bias) if bias is not None else None
    block_mask = create_block_mask(
        partial(mask_mod, mask),
        B=batch_size,
        H=None,
        Q_LEN=n_target,
        KV_LEN=n_source,
        device=query.device,
    )
    return flex_attention(query, key, value, score_mod, block_mask)


def chunked_attention(
    chunk_size: int,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    mask: Tensor,
    ssmax: Tensor | None = None,
    bias: Tensor | None = None,
):
    """Split the attention computation into multiple chunks to save intermediate memory.

    ---
    Args:
        chunk_size: Size of the chunks. Must be positive.
        query: Query tensor.
            Shape of [batch_size, n_heads, n_target, embed_dim].
        key: Key tensor.
            Shape of [batch_size, n_heads, n_source, embed_dim].
        value: Value tensor.
            Shape of [batch_size, n_heads, n_source, embed_dim].
        mask: Attention mask.
            Shape of [batch_size, n_target, n_source].
        ssmax: Scalable-softmax scalings.
            Shape of [batch_size, n_target].
        bias: ALiBi bias to apply to the QK^T scores.
            Shape of [batch_size, n_heads, n_target, n_source].

    ---
    Returns:
        The dot-product attention.
            Shape of [batch_size, n_heads, n_target, embed_dim].

    ---
    Note:
        `embed_dim` must be a power of 2.
        `n_target` and `n_source` must be a multiple of 128.
    """
    assert chunk_size > 0

    _, _, seq_len, _ = query.shape
    out = torch.empty(query.shape, device=query.device, dtype=query.dtype)

    for start in range(0, seq_len, chunk_size):
        end = min(start + chunk_size, seq_len)
        query_chunk = query[:, :, start:end, :]
        mask_chunk = mask[:, start:end, :]
        ssmax_chunk = ssmax[:, start:end] if ssmax is not None else None
        bias_chunk = bias[:, :, start:end, :] if bias is not None else None
        out[:, :, start:end, :] = dot_product_attention(
            query_chunk, key, value, mask_chunk, ssmax_chunk, bias_chunk
        )

    return out
