

import importlib.metadata
import sys
from functools import lru_cache
from typing import Optional

import torch
from packaging import version
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.modeling_utils import PreTrainedModel

from verl.utils.import_utils import is_trl_available
from verl.utils.ulysses import (
    gather_heads_scatter_seq,
    gather_seq_scatter_heads,
    get_ulysses_sequence_parallel_group,
    get_ulysses_sequence_parallel_world_size,
    slice_input_tensor,
)

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch, slen, num_key_value_heads, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim)
    return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim)

def _ulysses_flash_attention_forward(
    query_states: torch.Tensor,
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    *args,
    position_ids: Optional[torch.Tensor] = None,
    **kwargs,
):
    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()

    if ulysses_sp_size > 1:
        assert position_ids is not None, "position_ids is required for Ulysses sequence parallelism"

        repeats = max(ulysses_sp_size // key_states.size(2), 1)
        key_states = repeat_kv(key_states, repeats)
        value_states = repeat_kv(value_states, repeats)

        query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
        key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
        value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)

        position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)]
        torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group())
        position_ids = torch.concat(position_ids_list, dim=-1)

    attn_output = _flash_attention_forward(
        query_states, key_states, value_states, *args, position_ids=position_ids, **kwargs
    )

    if ulysses_sp_size > 1:

        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)

    return attn_output

def patch_vlm_for_ulysses_input_slicing(model_class: type):

    def _create_ulysses_wrapped_decoder_forward(original_forward):
        def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
            inputs_embeds = kwargs.get("inputs_embeds")
            call_kwargs = kwargs.copy()

            current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()

            slice_now = (
                inputs_embeds is not None
                and current_ulysses_sp_size > 1
                and getattr(self, "_needs_initial_slice", True)
            )
            if slice_now:
                call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)
                self._needs_initial_slice = False
            try:
                return original_forward(self, *args, **call_kwargs)
            finally:
                if slice_now:
                    self._needs_initial_slice = True

        return ulysses_wrapped_decoder_forward

    original_forward = model_class.forward
    wrapped_forward = _create_ulysses_wrapped_decoder_forward(original_forward)
    model_class.forward = wrapped_forward
    print(f"Monkey patch {model_class.__name__}.forward for Ulysses SP input slicing.")

def patch_forward_with_backends(
    model: PreTrainedModel,
    use_fused_kernels: bool = False,
    fused_kernels_backend: str = None,
):
    if not use_fused_kernels or fused_kernels_backend not in ["triton", "torch"]:
        print(
            f"Skipping monkey patch for {model.__class__.__name__} as use_fused_kernels is "
            f"{use_fused_kernels} or fused_kernels_backend is {fused_kernels_backend}"
        )
        return

    forward_with_torch_backend_function = model.__class__.forward
    forward_with_triton_backend_function = model.__class__.forward
    if model.config.model_type == "qwen2_5_vl":
        from verl.models.transformers.qwen2_5_vl import forward_with_torch_backend, forward_with_triton_backend

        forward_with_torch_backend_function = forward_with_torch_backend
        forward_with_triton_backend_function = forward_with_triton_backend
    elif model.config.model_type == "qwen2_vl":
        from verl.models.transformers.qwen2_vl import forward_with_torch_backend, forward_with_triton_backend

        forward_with_torch_backend_function = forward_with_torch_backend
        forward_with_triton_backend_function = forward_with_triton_backend
    else:
        from verl.models.transformers.dense_common import forward_with_torch_backend, forward_with_triton_backend

        forward_with_torch_backend_function = forward_with_torch_backend
        forward_with_triton_backend_function = forward_with_triton_backend

    if fused_kernels_backend == "triton":
        model.__class__.forward = forward_with_triton_backend_function
        print(f"Using Triton backend for fused kernels in {model.__class__.__name__}")
    elif fused_kernels_backend == "torch":
        model.__class__.forward = forward_with_torch_backend_function
        print(f"Using Torch backend for fused kernels in {model.__class__.__name__}")
    else:
        raise ValueError(f"Unsupported fused_kernels_backend: {fused_kernels_backend}. Choose 'triton' or 'torch'.")

def apply_monkey_patch(
    model: PreTrainedModel,
    ulysses_sp_size: int = 1,
    use_remove_padding: bool = True,
    use_fused_kernels: bool = False,
    fused_kernels_backend: str = None,
):

    """Replace _flash_attention_forward to _ulysses_flash_attention_forward"""
    module = sys.modules[model.__module__]

    try:
        num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads
    except AttributeError:
        num_attention_heads, num_key_value_heads = (
            model.config.text_config.num_attention_heads,
            model.config.text_config.num_key_value_heads,
        )

    assert num_attention_heads % ulysses_sp_size == 0, (
        f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}"
    )
    assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, (
        f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size "
        f"{ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0,"
        f"kv heads are repeated to ensure correctness."
    )

    if is_trl_available():
        from trl import AutoModelForCausalLMWithValueHead

        def state_dict(self, *args, **kwargs):
            return torch.nn.Module.state_dict(self, *args, **kwargs)

        AutoModelForCausalLMWithValueHead.state_dict = state_dict
        print("Monkey patch state_dict in AutoModelForCausalLMWithValueHead. ")

    if model.config.model_type == "qwen2_5_vl":
        if is_transformers_version_in_range(min_version="4.53.0"):
            from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention

            raise ValueError("Transformers 4.53 is not supported")
        else:
            from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
                Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,
            )

        if use_remove_padding or ulysses_sp_size > 1:
            from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward

            Qwen2_5_VLAttention.forward = ulysses_flash_attn_forward
            print("Monkey patch FlashAttention2.forward in Qwen2.5VL")

        if ulysses_sp_size > 1:
            if is_transformers_version_in_range(min_version="4.52.0"):
                from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel

                patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)
            else:
                from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel

                patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel)

    elif model.config.model_type == "qwen2_vl":
        if is_transformers_version_in_range(min_version="4.53.0"):
            from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention

            raise ValueError("Transformers 4.53 is not supported")
        else:
            from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention

        if use_remove_padding or ulysses_sp_size > 1:
            from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward

            Qwen2VLAttention.forward = ulysses_flash_attn_forward
            print("Monkey patch FlashAttention2.forward in Qwen2VL")

        if ulysses_sp_size > 1:
            if is_transformers_version_in_range(min_version="4.52.0"):
                from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel

                patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel)
            else:
                from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel

                patch_vlm_for_ulysses_input_slicing(Qwen2VLModel)

    elif model.config.model_type == "kimi_vl":
        if use_remove_padding or ulysses_sp_size > 1:

            from verl.models.transformers.kimi_vl import _ulysses_flash_attn_forward

            module.DeepseekV3FlashAttention2.forward = _ulysses_flash_attn_forward
            print("Monkey patch FlashAttention2.forward in KimiVL")

        if ulysses_sp_size > 1:
            patch_vlm_for_ulysses_input_slicing(module.DeepseekV3ForCausalLM)

        if use_fused_kernels:
            print("Not support fused kernels for KimiVL")

        return

    if use_remove_padding or ulysses_sp_size > 1:
        if hasattr(module, "_flash_attention_forward"):
            module._flash_attention_forward = _ulysses_flash_attention_forward
            print(f"Monkey patch _flash_attention_forward in {model.__module__}")
        else:

            from transformers.integrations import flash_attention

            flash_attention._flash_attention_forward = _ulysses_flash_attention_forward
            print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}")

    patch_forward_with_backends(model, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend)

@lru_cache
def is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool:
    try:

        transformers_version_str = importlib.metadata.version("transformers")
    except importlib.metadata.PackageNotFoundError as e:
        raise ModuleNotFoundError("The `transformers` package is not installed.") from e

    transformers_version = version.parse(transformers_version_str)

    lower_bound_check = True
    if min_version is not None:
        lower_bound_check = version.parse(min_version) <= transformers_version

    upper_bound_check = True
    if max_version is not None:
        upper_bound_check = transformers_version <= version.parse(max_version)

    return lower_bound_check and upper_bound_check
