# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Apply monkey-patch function to models
"""

#### Open Source Models
#### transformers version < 4.48

import importlib.metadata
from functools import lru_cache

from packaging import version
from transformers import PretrainedConfig


def apply_monkey_patch_to_llama():
    from transformers.models.llama.modeling_llama import LlamaFlashAttention2

    from verl.models.transformers.llama import llama_flash_attn_forward

    LlamaFlashAttention2.forward = llama_flash_attn_forward


def apply_monkey_patch_to_qwen2():
    from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2

    from verl.models.transformers.qwen2 import qwen2_flash_attn_forward

    Qwen2FlashAttention2.forward = qwen2_flash_attn_forward


_PATCH_NAME_TO_FUNC = {
    "llama": apply_monkey_patch_to_llama,
    "qwen2": apply_monkey_patch_to_qwen2,
}


def apply_monkey_patch(config: PretrainedConfig, verbose=True):
    if not is_transformers_version_in_range("4.45.0", "4.47.1"):
        raise AssertionError(
            "The installed `transformers` version doesn't support ulysses patch. "
            "Please install a version between 4.45.0 and 4.47.1 to use this ulysses feature."
        )
    success_apply_monkey_patch = False
    if config.model_type in _PATCH_NAME_TO_FUNC:
        _PATCH_NAME_TO_FUNC[config.model_type]()
        success_apply_monkey_patch = True

    if success_apply_monkey_patch and verbose:
        print(f"Applying monkey patch to model {config.model_type}")
    elif not success_apply_monkey_patch:
        raise NotImplementedError(
            f"Ulysses for model {config.model_type} is not implemented, \
                                   please set `ulysses_sequence_parallel_size=1`"
        )

    return success_apply_monkey_patch


@lru_cache
def is_transformers_version_in_range(min_version: str, max_version: str) -> bool:
    try:
        # Get the installed version of the transformers library
        transformers_version = importlib.metadata.version("transformers")
    except importlib.metadata.PackageNotFoundError:
        raise ModuleNotFoundError("The `transformers` package is not installed.")

    # Check if the version is within the specified range
    return version.parse(min_version) <= version.parse(transformers_version) <= version.parse(max_version)
