from importlib.metadata import version
import warnings
import transformers
from score_kv.fixed_mistral_hijack import pyramidkv_mistral_flash_attn2_forward, fixed_mistral_flash_attn2_forward, fixed_MistralModel_forward
from score_kv.fixed_mistral_hijack import prepare_inputs_for_generation_mistral as fixed_prepare_inputs_for_generation_mistral
from score_kv.adaptive_mistral_hijack import reason_mistral_flash_attn2_forward, adaptive_mistral_flash_attn2_forward, adaptive_MistralModel_forward
from score_kv.adaptive_mistral_hijack import prepare_inputs_for_generation_mistral as ada_prepare_inputs_for_generation_mistral

from score_kv.fixed_llama_hijack_score import pyramidkv_llama_flash_attn2_forward_score, fixed_llama_flash_attn2_forward_score, fixed_LlamaModel_forward_score
from score_kv.fixed_llama_hijack_score import prepare_inputs_for_generation_llama_score as fixed_prepare_inputs_for_generation_llama_score
from score_kv.adaptive_llama_hijack_score import reason_llama_flash_attn2_forward_score, adaptive_llama_flash_attn2_forward_score, adaptive_LlamaModel_forward_score
from score_kv.adaptive_llama_hijack_score import prepare_inputs_for_generation_llama_score as ada_prepare_inputs_for_generation_llama_score

from score_kv.fixed_llama_hijack import pyramidkv_llama_flash_attn2_forward, fixed_llama_flash_attn2_forward, fixed_LlamaModel_forward
from score_kv.fixed_llama_hijack import prepare_inputs_for_generation_llama as fixed_prepare_inputs_for_generation_llama
from score_kv.adaptive_llama_hijack import reason_llama_flash_attn2_forward, adaptive_llama_flash_attn2_forward, adaptive_LlamaModel_forward
from score_kv.adaptive_llama_hijack import prepare_inputs_for_generation_llama as ada_prepare_inputs_for_generation_llama


def replace_mistral_fixed():
    transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral
    transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = fixed_mistral_flash_attn2_forward
    transformers.models.mistral.modeling_mistral.MistralModel.forward = fixed_MistralModel_forward

def replace_mistral_adaptive():
    transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_mistral
    transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = adaptive_mistral_flash_attn2_forward
    transformers.models.mistral.modeling_mistral.MistralModel.forward = adaptive_MistralModel_forward

def replace_llama_fixed():
    transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama
    transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = fixed_llama_flash_attn2_forward
    transformers.models.llama.modeling_llama.LlamaModel.forward = fixed_LlamaModel_forward

def replace_llama_adaptive():
    transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_llama
    transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = adaptive_llama_flash_attn2_forward
    transformers.models.llama.modeling_llama.LlamaModel.forward = adaptive_LlamaModel_forward



def replace_mistral(method):

    if method == "AdativeKV":
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_mistral
        transformers.models.mistral.modeling_mistral.MistralModel.forward = adaptive_MistralModel_forward
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = adaptive_mistral_flash_attn2_forward
    elif method == "ReasonKV":
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_mistral
        transformers.models.mistral.modeling_mistral.MistralModel.forward = adaptive_MistralModel_forward
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = reason_mistral_flash_attn2_forward
    elif method == 'SnapKV':
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral
        transformers.models.mistral.modeling_mistral.MistralModel.forward = fixed_MistralModel_forward
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = fixed_mistral_flash_attn2_forward
    elif method == 'PyramidKV':
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral
        transformers.models.mistral.modeling_mistral.MistralModel.forward = fixed_MistralModel_forward
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = pyramidkv_mistral_flash_attn2_forward


def replace_llama_score(method):

    if method == "AdativeKV":    
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_llama_score
        transformers.models.llama.modeling_llama.LlamaModel.forward = adaptive_LlamaModel_forward_score
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = adaptive_llama_flash_attn2_forward_score
    elif method == "ReasonKV":
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_llama_score
        transformers.models.llama.modeling_llama.LlamaModel.forward = adaptive_LlamaModel_forward_score
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = reason_llama_flash_attn2_forward_score
    elif method == 'SnapKV':
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama_score
        transformers.models.llama.modeling_llama.LlamaModel.forward = fixed_LlamaModel_forward_score
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = fixed_llama_flash_attn2_forward_score
    elif method == 'PyramidKV':
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama_score
        transformers.models.llama.modeling_llama.LlamaModel.forward = fixed_LlamaModel_forward_score
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = pyramidkv_llama_flash_attn2_forward_score