from importlib.metadata import version
import warnings
import transformers
from score_kv.fixed_mistral_hijack import pyramidkv_mistral_flash_attn2_forward, fixed_mistral_flash_attn2_forward, fixed_MistralModel_forward
from score_kv.fixed_mistral_hijack import prepare_inputs_for_generation_mistral as fixed_prepare_inputs_for_generation_mistral
from score_kv.adaptive_mistral_hijack import reason_mistral_flash_attn2_forward, adaptive_mistral_flash_attn2_forward, adaptive_MistralModel_forward
from score_kv.adaptive_mistral_hijack import prepare_inputs_for_generation_mistral as ada_prepare_inputs_for_generation_mistral

from score_kv.fixed_llama_hijack_score_lwhw_2 import pyramidkv_llama_flash_attn2_forward_score_lwhw_2, fixed_llama_flash_attn2_forward_score_lwhw_2, fixed_LlamaModel_forward_score_lwhw_2
from score_kv.fixed_llama_hijack_score_lwhw_2 import prepare_inputs_for_generation_llama_score_lwhw_2 as fixed_prepare_inputs_for_generation_llama_score_lwhw_2
from score_kv.adaptive_llama_hijack_score_lwhw_2 import adaptive_llama_flash_attn2_forward_score_lwhw_2, adaptive_LlamaModel_forward_score_lwhw_2
from score_kv.adaptive_llama_hijack_score_lwhw_2 import prepare_inputs_for_generation_llama_score_lwhw_2 as ada_prepare_inputs_for_generation_llama_score_lwhw_2

from score_kv.fixed_llama_hijack import pyramidkv_llama_flash_attn2_forward, fixed_llama_flash_attn2_forward, fixed_LlamaModel_forward
from score_kv.fixed_llama_hijack import prepare_inputs_for_generation_llama as fixed_prepare_inputs_for_generation_llama
from score_kv.adaptive_llama_hijack import adaptive_llama_flash_attn2_forward, adaptive_LlamaModel_forward
from score_kv.adaptive_llama_hijack import prepare_inputs_for_generation_llama as ada_prepare_inputs_for_generation_llama


def replace_mistral_fixed():
    transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral
    transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = fixed_mistral_flash_attn2_forward
    transformers.models.mistral.modeling_mistral.MistralModel.forward = fixed_MistralModel_forward

def replace_mistral_adaptive():
    transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_mistral
    transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = adaptive_mistral_flash_attn2_forward
    transformers.models.mistral.modeling_mistral.MistralModel.forward = adaptive_MistralModel_forward

def replace_llama_fixed():
    transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama
    transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = fixed_llama_flash_attn2_forward
    transformers.models.llama.modeling_llama.LlamaModel.forward = fixed_LlamaModel_forward

def replace_llama_adaptive():
    transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_llama
    transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = adaptive_llama_flash_attn2_forward
    transformers.models.llama.modeling_llama.LlamaModel.forward = adaptive_LlamaModel_forward




def replace_mistral(method):

    if method == "AdativeKV":
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_mistral
        transformers.models.mistral.modeling_mistral.MistralModel.forward = adaptive_MistralModel_forward
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = adaptive_mistral_flash_attn2_forward
    elif method == "ReasonKV":
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_mistral
        transformers.models.mistral.modeling_mistral.MistralModel.forward = adaptive_MistralModel_forward
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = reason_mistral_flash_attn2_forward
    elif method == 'SnapKV':
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral
        transformers.models.mistral.modeling_mistral.MistralModel.forward = fixed_MistralModel_forward
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = fixed_mistral_flash_attn2_forward
    elif method == 'PyramidKV':
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral
        transformers.models.mistral.modeling_mistral.MistralModel.forward = fixed_MistralModel_forward
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = pyramidkv_mistral_flash_attn2_forward


def replace_llama_score_lwhw_2(method):

    if method == "AdativeKV":    
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_llama_score_lwhw_2
        transformers.models.llama.modeling_llama.LlamaModel.forward = adaptive_LlamaModel_forward_score_lwhw_2
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = adaptive_llama_flash_attn2_forward_score_lwhw_2
    elif method == 'SnapKV':
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama_score_lwhw_2
        transformers.models.llama.modeling_llama.LlamaModel.forward = fixed_LlamaModel_forward_score_lwhw_2
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = fixed_llama_flash_attn2_forward_score_lwhw_2
    elif method == 'PyramidKV':
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama_score_lwhw_2
        transformers.models.llama.modeling_llama.LlamaModel.forward = fixed_LlamaModel_forward_score_lwhw_2
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = pyramidkv_llama_flash_attn2_forward_score_lwhw_2