from importlib.metadata import version
import transformers

from restkv.llama_model import llama_flash_attn2_forward_restkv
from restkv.llama_model import llama_attn_forward_restkv
from restkv.llama_model import llama_sdpa_attn_forward_restkv

from restkv.mistral_model import mistral_flash_attn2_forward_restkv
from restkv.mistral_model import mistral_attn_forward_restkv
from restkv.mistral_model import mistral_sdpa_attn_forward_restkv

from restkv.qwen_model import qwen_flash_attn2_forward_restkv
from restkv.gemma_model import gemma_flash_attn2_forward_restkv

from restkv.llama_model import prepare_inputs_for_generation_llama_new
from restkv.mistral_model import prepare_inputs_for_generation_mistral_new
from restkv.qwen_model import prepare_inputs_for_generation_qwen_new
from restkv.gemma_model import prepare_inputs_for_generation_gemma_new

def replace_llama(method, model_name=None):
    if method == "restkv":
        print("Using restkv!")
        transformers.models.llama.modeling_llama.LlamaAttention.forward = llama_attn_forward_restkv
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = llama_flash_attn2_forward_restkv
        transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward = llama_sdpa_attn_forward_restkv


    if method not in ["fullkv"]:
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = prepare_inputs_for_generation_llama_new


def replace_qwen(method, model_name=None):
    if method == "restkv":
        print("Using restkv!")
        transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2.forward = qwen_flash_attn2_forward_restkv

    if method not in ["fullkv"]:
        transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM.prepare_inputs_for_generation = prepare_inputs_for_generation_qwen_new
        
def replace_gemma(method, model_name=None):
    if method == "restkv":
        print("Using restkv!")
        transformers.models.gemma.modeling_gemma.GemmaFlashAttention2.forward = gemma_flash_attn2_forward_restkv

    if method not in ["fullkv"]:
        transformers.models.gemma.modeling_gemma.GemmaForCausalLM.prepare_inputs_for_generation = prepare_inputs_for_generation_gemma_new

def replace_mistral(method):
    if method == "restkv":
        print("Using restkv!")
        transformers.models.mistral.modeling_mistral.MistralAttention.forward = mistral_attn_forward_restkv
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = mistral_flash_attn2_forward_restkv
        transformers.models.mistral.modeling_mistral.MistralSdpaAttention.forward = mistral_sdpa_attn_forward_restkv

    if method not in ["fullkv"]:
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = prepare_inputs_for_generation_mistral_new
