              
                                                      
                       

from typing import Optional

import torch
from torch import Tensor

from megatron.core import parallel_state
from megatron.core.utils import is_te_min_version
from megatron.core.models.common.embeddings.rope_utils import (
    _rotate_half,
    _apply_rotary_pos_emb_bshd,
    _apply_rotary_pos_emb_thd,
)

                                                                                                   
                                                                              
try:
    from apex.transformer.functional import fused_apply_rotary_pos_emb
except ImportError:
    try:
        from megatron.core.extensions.transformer_engine import fused_apply_rotary_pos_emb
    except:
        fused_apply_rotary_pos_emb = None

try:
    from megatron.core.extensions.transformer_engine import fused_apply_rotary_pos_emb_thd
except ImportError:
    try:
        from apex.transformer.functional import fused_apply_rotary_pos_emb_thd
    except ImportError:
        fused_apply_rotary_pos_emb_thd = None

try:
    from flash_attn.layers.rotary import apply_rotary_emb as apply_rotary_emb_flash
except ImportError:
    apply_rotary_emb_flash = None

from gpatch.core.transformer.transformer_config import GpatchTransformerConfig


def apply_rotary_pos_emb_bshd(
    t: Tensor,
    freqs: Tensor,
    rotary_interleaved: bool = False,
    mscale=1.0,
    px_rope_variant: str = "rope",
) -> Tensor:
    """Apply rotary positional embedding to input tensor T.

    check https://kexue.fm/archives/8265 for detailed formulas

    Args:
        t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
        freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim]

    Returns:
        Tensor: The input tensor after applying RoPE
    """
    rot_dim = freqs.shape[-1]

                                                                                
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    if px_rope_variant == "yarn_with_flip":
                                                                                                 
                         
        sk, b, np, hn = t.shape
        t = t.view(sk, b, np, hn // 2, 2).transpose(4, 3).reshape(sk, b, np, hn).contiguous()

                                    
                                                                                  
    cos_ = (torch.cos(freqs) * mscale).to(t.dtype)
    sin_ = (torch.sin(freqs) * mscale).to(t.dtype)

    t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_)
    return torch.cat((t, t_pass), dim=-1)


def apply_rotary_pos_emb_thd_packed_freqs(
    t: Tensor,
    cu_seqlens: Tensor,
    freqs: Tensor,
    rotary_interleaved: bool = False,
    mscale=1.0,
    px_rope_variant: str = "rope",
) -> Tensor:
    """A baseline implementation of applying RoPE for `thd` format.

    Args:
        t (Tensor): Input tensor T is of shape [t, h, d]
        cu_seqlens(Tensor):  Cumulative sum of sequence lengths in a batch for `t`,
        with shape [b + 1] and dtype torch.int32.
        freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]

    Returns:
        Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
    """
    orig_dtype = t.dtype
    t = t.float()
    output = apply_rotary_pos_emb_bshd(
        t.unsqueeze(1),
        freqs,
        mscale=mscale,
        px_rope_variant=px_rope_variant,
    ).squeeze(1)
    output = output.to(orig_dtype)
    return output


def apply_rotary_pos_emb(
    t: Tensor,
    freqs: Tensor,
    config: GpatchTransformerConfig,
    cu_seqlens: Optional[Tensor] = None,
    mscale: float = 1.0,
):
    """
    Reroute to the appropriate apply_rotary_pos_emb function depending on
    fused/unfused kernels, or bshd (conventional) / thd (packed seq) format
    """

    if config.apply_rope_fusion:
        if cu_seqlens is None:
            assert fused_apply_rotary_pos_emb is not None, "apply_rope_fusion is not available."
            return fused_apply_rotary_pos_emb(t, freqs, transpose_output_memory=True)
        else:
            assert fused_apply_rotary_pos_emb_thd is not None, "apply_rope_fusion is not available."
            cp_size = parallel_state.get_context_parallel_world_size()
            if cp_size > 1:
                if not is_te_min_version("1.11.0", check_equality=False):
                    raise ValueError("Only TE >= 1.12 supports RoPE fusion for THD format with CP.")
                return fused_apply_rotary_pos_emb_thd(
                    t,
                    cu_seqlens,
                    freqs,
                    cp_size=cp_size,
                    cp_rank=parallel_state.get_context_parallel_rank(),
                )
            else:
                return fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs)
    else:
        if cu_seqlens is None:
            return _apply_rotary_pos_emb_bshd(
                t,
                freqs,
                rotary_interleaved=config.rotary_interleaved,
                multi_latent_attention=config.multi_latent_attention,
                mscale=mscale,
            )
        else:
            if config.packed_freqs:
                return apply_rotary_pos_emb_thd_packed_freqs(
                    t,
                    cu_seqlens,
                    freqs,
                    rotary_interleaved=config.rotary_interleaved,
                )
            return _apply_rotary_pos_emb_thd(
                t,
                cu_seqlens,
                freqs,
                rotary_interleaved=config.rotary_interleaved,
                multi_latent_attention=config.multi_latent_attention,
                mscale=mscale,
            )
