import torch

from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

from . import eager_attn
from . import memeffi_attn
from . import flex_attn

RNSA_ATTENTION_IMPLEMENTATIONS = {
    "rnsa_eager": eager_attn.rnsa_attention_forward,
    "rnsa_memeffi": memeffi_attn.rnsa_attention_forward,
    "rnsa_flex": flex_attn.rnsa_attention_forward,
    "eager": eager_attn.eager_attention_forward,
    "memeffi": memeffi_attn.attention_forward,
    "flex": flex_attn.attention_forward,
}


def get_rnsa_wrapper(attn_impl: str):
    attn_fn = ALL_ATTENTION_FUNCTIONS[attn_impl]

    def attn_wrapper(*args, **kwargs):
        kwargs.pop("forget_weights", None)
        attn_output, attn_weights = attn_fn(*args, **kwargs)
        return attn_output, attn_weights, None
    
    return attn_wrapper

def get_attention_interface(attn_impl: str, compile=False):
    if attn_impl not in RNSA_ATTENTION_IMPLEMENTATIONS:
        attention_inference = get_rnsa_wrapper(attn_impl)
    else:
        attention_inference = RNSA_ATTENTION_IMPLEMENTATIONS.get(attn_impl, None)

    if compile:
        attention_inference = torch.compile(attention_inference)

    return attention_inference
