from importlib.metadata import version
import warnings
import transformers
from score_kv.fixed_mistral_hijack import pyramidkv_mistral_flash_attn2_forward, fixed_mistral_flash_attn2_forward, fixed_MistralModel_forward
from score_kv.fixed_mistral_hijack import prepare_inputs_for_generation_mistral as fixed_prepare_inputs_for_generation_mistral
from score_kv.adaptive_mistral_hijack import adaptive_mistral_flash_attn2_forward, adaptive_MistralModel_forward
from score_kv.adaptive_mistral_hijack import prepare_inputs_for_generation_mistral as ada_prepare_inputs_for_generation_mistral

from score_kv.fixed_llama_hijack import pyramidkv_llama_flash_attn2_forward, fixed_llama_flash_attn2_forward, fixed_LlamaModel_forward
from score_kv.fixed_llama_hijack import prepare_inputs_for_generation_llama as fixed_prepare_inputs_for_generation_llama
from score_kv.adaptive_llama_hijack import adaptive_llama_flash_attn2_forward,adaptive_LlamaModel_forward
from score_kv.adaptive_llama_hijack import prepare_inputs_for_generation_llama as ada_prepare_inputs_for_generation_llama

from score_kv.fixed_qwen_hijack import pyramidkv_qwen_flash_attn2_forward, fixed_qwen_flash_attn2_forward, fixed_qwenModel_forward
from score_kv.fixed_qwen_hijack import prepare_inputs_for_generation_qwen as fixed_prepare_inputs_for_generation_qwen
from score_kv.adaptive_qwen_hijack import adaptive_qwen_flash_attn2_forward, adaptive_qwenModel_forward
from score_kv.adaptive_qwen_hijack import prepare_inputs_for_generation_qwen as ada_prepare_inputs_for_generation_qwen


def check_version():
    try:
        transformers_version = version("transformers")
    except Exception as e:
        print(f"Transformers not installed: {e}")
    version_list = ['4.37']
    warning_flag = True
    for x in version_list:
        if x in transformers_version:
            warning_flag = False
            break
    if warning_flag:
        warnings.warn(f"Transformers version {transformers_version} might not be compatible with SnapKV. SnapKV is tested with Transformers version {version_list}.")


def replace_mistral_fixed():
    check_version()
    transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral
    transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = fixed_mistral_flash_attn2_forward
    transformers.models.mistral.modeling_mistral.MistralModel.forward = fixed_MistralModel_forward

def replace_mistral_adaptive():
    check_version()
    transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_mistral
    transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = adaptive_mistral_flash_attn2_forward
    transformers.models.mistral.modeling_mistral.MistralModel.forward = adaptive_MistralModel_forward

def replace_llama_fixed():
    check_version()
    transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama
    transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = fixed_llama_flash_attn2_forward
    transformers.models.llama.modeling_llama.LlamaModel.forward = fixed_LlamaModel_forward

def replace_llama_adaptive():
    check_version()
    transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_llama
    transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = adaptive_llama_flash_attn2_forward
    transformers.models.llama.modeling_llama.LlamaModel.forward = adaptive_LlamaModel_forward

def replace_qwen_fixed():
    check_version()
    transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_qwen
    transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2.forward = fixed_qwen_flash_attn2_forward
    transformers.models.qwen2.modeling_qwen2.Qwen2Model.forward = fixed_qwenModel_forward

def replace_qwen_adaptive():
    check_version()
    transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_qwen
    transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2.forward = adaptive_qwen_flash_attn2_forward
    transformers.models.qwen2.modeling_qwen2.Qwen2Model.forward = adaptive_qwenModel_forward




def replace_mistral(method):
    print(method)

    if method == "AdativeKV":
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_mistral
        transformers.models.mistral.modeling_mistral.MistralModel.forward = adaptive_MistralModel_forward
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = adaptive_mistral_flash_attn2_forward
    elif method == 'SnapKV':
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral
        transformers.models.mistral.modeling_mistral.MistralModel.forward = fixed_MistralModel_forward
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = fixed_mistral_flash_attn2_forward
    elif method == 'PyramidKV':
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral
        transformers.models.mistral.modeling_mistral.MistralModel.forward = fixed_MistralModel_forward
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = pyramidkv_mistral_flash_attn2_forward
    elif method == 'Score':
        from score_kv.adaptive_mistral_hijack_score_lwhw import prepare_inputs_for_generation_mistral_score_lwhw as ada_prepare_inputs_for_generation_mistral_score_lwhw
        from score_kv.adaptive_mistral_hijack_score_lwhw import reason_mistral_flash_attn2_forward_score_lwhw, adaptive_mistral_flash_attn2_forward_score_lwhw, adaptive_MistralModel_forward_score_lwhw
        transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_mistral_score_lwhw
        transformers.models.mistral.modeling_mistral.MistralModel.forward = adaptive_MistralModel_forward_score_lwhw
        transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = adaptive_mistral_flash_attn2_forward_score_lwhw


def replace_llama(method):
    check_version()

    if method == "AdativeKV":
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_llama
        transformers.models.llama.modeling_llama.LlamaModel.forward = adaptive_LlamaModel_forward
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = adaptive_llama_flash_attn2_forward
    elif method == 'SnapKV':
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama
        transformers.models.llama.modeling_llama.LlamaModel.forward = fixed_LlamaModel_forward
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = fixed_llama_flash_attn2_forward
    elif method == 'PyramidKV':
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama
        transformers.models.llama.modeling_llama.LlamaModel.forward = fixed_LlamaModel_forward
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = pyramidkv_llama_flash_attn2_forward
    elif method == 'Score':
        from score_kv.adaptive_llama_hijack_score_lwhw import prepare_inputs_for_generation_llama_score_lwhw as ada_prepare_inputs_for_generation_llama_score_lwhw
        from score_kv.adaptive_llama_hijack_score_lwhw import reason_llama_flash_attn2_forward_score_lwhw, adaptive_llama_flash_attn2_forward_score_lwhw, adaptive_LlamaModel_forward_score_lwhw
        transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_llama_score_lwhw
        transformers.models.llama.modeling_llama.LlamaModel.forward = adaptive_LlamaModel_forward_score_lwhw
        transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = adaptive_llama_flash_attn2_forward_score_lwhw

def replace_qwen(method):
    check_version()

    if method == "AdativeKV":
        transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_qwen
        transformers.models.qwen2.modeling_qwen2.Qwen2Model.forward = adaptive_qwenModel_forward
        transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2.forward = adaptive_qwen_flash_attn2_forward
    elif method == 'SnapKV':
        transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_qwen
        transformers.models.qwen2.modeling_qwen2.Qwen2Model.forward = fixed_qwenModel_forward
        transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2.forward = fixed_qwen_flash_attn2_forward
    elif method == 'PyramidKV':
        transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_qwen
        transformers.models.qwen2.modeling_qwen2.Qwen2Model.forward = fixed_qwenModel_forward
        transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2.forward = pyramidkv_qwen_flash_attn2_forward
    elif method == 'Score':
        from score_kv.adaptive_qwen_hijack_score_lwhw import prepare_inputs_for_generation_qwen_score_lwhw as ada_prepare_inputs_for_generation_qwen_score_lwhw
        from score_kv.adaptive_qwen_hijack_score_lwhw import adaptive_qwen_flash_attn2_forward_score_lwhw, adaptive_qwenModel_forward_score_lwhw
        transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_qwen_score_lwhw
        transformers.models.qwen2.modeling_qwen2.Qwen2Model.forward = adaptive_qwenModel_forward_score_lwhw
        transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2.forward = adaptive_qwen_flash_attn2_forward_score_lwhw