"""Interface to various attention backends"""

import torch
from functools import partial
from typing import Callable, Optional

from .pytorch import attention_computation_sdpa
from .amd import attention as aotriton_attention
from .openai import attention as triton_tutorial_attention
from .mosaic import flash_attn_func as mpt_7b_attention
from .cuda_flash_attention import attention_computation_flash

"""Notes:

All interfaces standardize on the following shape
 q: (batch_size, seqlen_q, nheads, headdim)

but

openai, sdpa and amd will convert to BATCH, N_HEAD, N_CTX, HEAD_DIM internally
"""


def _skip_attention(q, k, v, mask=None):
    """For debugging/benchmarking without attention computation"""
    return v.clone()


def select_attention_implementation(
    provider="sdpa", center=False, debias=False
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:

    if provider != "sdpa" and center:
        raise ValueError("Centering not implemented for this provider.")
    if provider != "sdpa" and debias:
        raise ValueError("debias not implemented for this provider.")

    if provider == "sdpa":
        return partial(attention_computation_sdpa, center=center, debias=debias)
    elif provider == "amd":
        return partial(aotriton_attention, causal=True, b=None, sm_scale=None, dropout_p=0.0, autotune=True)
    elif provider == "openai":
        return partial(triton_tutorial_attention, causal=True, sm_scale=None)
    elif provider == "mosaic":
        return partial(mpt_7b_attention, bias=None, causal=True, softmax_scale=None)
    elif provider == "tridao":
        return partial(attention_computation_flash, center=center, debias=debias)
    elif provider == "debug-skip":
        return _skip_attention
    else:
        raise ValueError(f"Attention implementation provider {provider} not registered.")
