from importlib.metadata import version
import warnings
import transformers



def replace_llama(method):

    
    elif method == "PyramidKV":
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = llama_flash_attn2_forward_pyramid
        
    if method not in ["FullKV"]:
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = prepare_inputs_for_generation_llama

    transformers.models.llama.modeling_llama.LlamaModel.forward= llama_model_forward
    

def replace_mistral(method):
    if method not in ["FullKV"]:
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = prepare_inputs_for_generation_mistral

    elif method == "PyramidKV":
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = mistral_flash_attn2_forward_pyramid


    transformers.models.mistral.modeling_mistral.MistralModel.forward= mistral_model_forward
