import transformers
from src.llama_model import llama_flash_attn2_forward_SnapKV, llama_model_forward, prepare_inputs_for_generation_llama

def replace_llama():
    transformers.models.llama.modeling_llama.LlamaModel.forward = llama_model_forward
    transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = llama_flash_attn2_forward_SnapKV 
    transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = prepare_inputs_for_generation_llama