import types

from .llama import llama_attention_forward as llama_attention_optimal_brain_cache 
from .llama import llama_flashattention2_forward as llama_flashattention2_optimal_brain_cache
from .llama import llama_attention_streaming_forward as llama_attention_optimal_brain_cache_streaming
from .llama import llama_attention_forward_autodiff as llama_attention_optimal_brain_cache_autodiff

# from .mistral import mistral_flashattention2_forward as mistral_flashattention2_optimal_brain_cache


from transformers.models.llama.modeling_llama import LlamaAttention, LlamaFlashAttention2
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralFlashAttention2
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2FlashAttention2



def name2hfattention(model_name):
    if 'llama' in model_name.lower():
        return LlamaAttention
    elif 'mistral' in model_name.lower():
        return MistralAttention
    elif 'qwen2' in model_name.lower():
        return Qwen2Attention
    else:
        raise ValueError(f"Unsupported model name: {model_name}")


def name2modified_attention(model_name):
    if 'llama' in model_name.lower():
        return llama_attention_optimal_brain_cache if 'autodiff' not in model_name.lower() else llama_attention_optimal_brain_cache_autodiff
    elif 'qwen2' in model_name.lower():
        return llama_attention_optimal_brain_cache if 'autodiff' not in model_name.lower() else llama_attention_optimal_brain_cache_autodiff
    elif 'mistral' in model_name.lower():
        raise NotImplementedError
    else:
        raise ValueError(f"Unsupported model name: {model_name}")


def name2hfflashattention(model_name):
    if 'llama' in model_name.lower():
        return LlamaFlashAttention2
    elif 'mistral' in model_name.lower():
        return MistralFlashAttention2
    elif 'qwen2' in model_name.lower():
        return Qwen2FlashAttention2
    else:
        raise ValueError(f"Unsupported model name: {model_name}")


def name2modified_flashattention(model_name):
    if 'llama' in model_name.lower():
        return llama_flashattention2_optimal_brain_cache
    elif 'mistral' in model_name.lower():
        return mistral_flashattention2_optimal_brain_cache
    elif 'qwen2' in model_name.lower():
        return llama_flashattention2_optimal_brain_cache
    else:
        raise ValueError(f"Unsupported model name: {model_name}")


def enable_optimal_brain_kv(model, model_name):
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            enable_optimal_brain_kv(module, model_name)

        if isinstance(module, name2hfattention(model_name)):
            model._modules[name].forward = types.MethodType(
                name2modified_attention(model_name), model._modules[name]
            )


def enable_optimal_brain_kv_flashattn2(model, model_name):
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            enable_optimal_brain_kv_flashattn2(module, model_name)

        if isinstance(module, name2hfflashattention(model_name)):
            model._modules[name].forward = types.MethodType(
                name2modified_flashattention(model_name), model._modules[name]
            )


def enable_optimal_brain_kv_streamingattn(model, model_name):
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            enable_optimal_brain_kv_streamingattn(module, model_name)

        if isinstance(module, name2hfattention(model_name)):
            model._modules[name].forward = types.MethodType(
                llama_attention_optimal_brain_cache_streaming, model._modules[name]
            )