from importlib.metadata import version
import warnings
import transformers

from score_kv.fixed_mistral_hijack_score_lwhw import pyramidkv_mistral_flash_attn2_forward_score_lwhw, fixed_mistral_flash_attn2_forward_score_lwhw, fixed_MistralModel_forward_score_lwhw
from score_kv.fixed_mistral_hijack_score_lwhw import prepare_inputs_for_generation_mistral_score_lwhw as fixed_prepare_inputs_for_generation_mistral_score_lwhw
from score_kv.adaptive_mistral_hijack_score_lwhw import reason_mistral_flash_attn2_forward_score_lwhw, adaptive_mistral_flash_attn2_forward_score_lwhw, adaptive_MistralModel_forward_score_lwhw
from score_kv.adaptive_mistral_hijack_score_lwhw import prepare_inputs_for_generation_mistral_score_lwhw as ada_prepare_inputs_for_generation_mistral_score_lwhw

from score_kv.fixed_llama_hijack_score_lwhw import pyramidkv_llama_flash_attn2_forward_score_lwhw, fixed_llama_flash_attn2_forward_score_lwhw, fixed_LlamaModel_forward_score_lwhw
from score_kv.fixed_llama_hijack_score_lwhw import prepare_inputs_for_generation_llama_score_lwhw as fixed_prepare_inputs_for_generation_llama_score_lwhw
from score_kv.adaptive_llama_hijack_score_lwhw import adaptive_llama_flash_attn2_forward_score_lwhw, adaptive_LlamaModel_forward_score_lwhw
from score_kv.adaptive_llama_hijack_score_lwhw import prepare_inputs_for_generation_llama_score_lwhw as ada_prepare_inputs_for_generation_llama_score_lwhw

from score_kv.fixed_llama_hijack import pyramidkv_llama_flash_attn2_forward, fixed_llama_flash_attn2_forward, fixed_LlamaModel_forward
from score_kv.fixed_llama_hijack import prepare_inputs_for_generation_llama as fixed_prepare_inputs_for_generation_llama
from score_kv.adaptive_llama_hijack import adaptive_llama_flash_attn2_forward, adaptive_LlamaModel_forward
from score_kv.adaptive_llama_hijack import prepare_inputs_for_generation_llama as ada_prepare_inputs_for_generation_llama


def replace_mistral_fixed():
    transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral
    transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = fixed_mistral_flash_attn2_forward
    transformers.models.mistral.modeling_mistral.MistralModel.forward = fixed_MistralModel_forward

def replace_mistral_adaptive():
    transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_mistral
    transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = adaptive_mistral_flash_attn2_forward
    transformers.models.mistral.modeling_mistral.MistralModel.forward = adaptive_MistralModel_forward

def replace_llama_fixed():
    transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama
    transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = fixed_llama_flash_attn2_forward
    transformers.models.llama.modeling_llama.LlamaModel.forward = fixed_LlamaModel_forward

def replace_llama_adaptive():
    transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_llama
    transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = adaptive_llama_flash_attn2_forward
    transformers.models.llama.modeling_llama.LlamaModel.forward = adaptive_LlamaModel_forward



def replace_mistral_score_lwhw(method):

    if method == "AdativeKV":
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_mistral_score_lwhw
        transformers.models.mistral.modeling_mistral.MistralModel.forward = adaptive_MistralModel_forward_score_lwhw
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = adaptive_mistral_flash_attn2_forward_score_lwhw
    elif method == 'SnapKV':
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral_score_lwhw
        transformers.models.mistral.modeling_mistral.MistralModel.forward = fixed_MistralModel_forward_score_lwhw
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = fixed_mistral_flash_attn2_forward_score_lwhw
    elif method == 'PyramidKV':
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral_score_lwhw
        transformers.models.mistral.modeling_mistral.MistralModel.forward = fixed_MistralModel_forward_score_lwhw
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = pyramidkv_mistral_flash_attn2_forward_score_lwhw


def replace_llama_score_lwhw(method):

    if method == "AdativeKV":    
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_llama_score_lwhw
        transformers.models.llama.modeling_llama.LlamaModel.forward = adaptive_LlamaModel_forward_score_lwhw
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = adaptive_llama_flash_attn2_forward_score_lwhw
    elif method == 'SnapKV':
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama_score_lwhw
        transformers.models.llama.modeling_llama.LlamaModel.forward = fixed_LlamaModel_forward_score_lwhw
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = fixed_llama_flash_attn2_forward_score_lwhw
    elif method == 'PyramidKV':
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama_score_lwhw
        transformers.models.llama.modeling_llama.LlamaModel.forward = fixed_LlamaModel_forward_score_lwhw
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = pyramidkv_llama_flash_attn2_forward_score_lwhw