import math
from typing import Union, Optional
from packaging import version

import torch.nn.functional as F

from megatron.core import __version__
from megatron.core.models.common.embeddings import (
    RotaryEmbedding,
    YarnRotaryEmbedding,
    _yarn_get_mscale,
)
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_config import MLATransformerConfig
from megatron.core.transformer.multi_latent_attention import MLASelfAttentionSubmodules
from megatron.core.transformer.multi_latent_attention import MultiLatentAttention


def multi_latent_attention_init_lt_0_12(
    self,
    config: MLATransformerConfig,
    submodules: Union[MLASelfAttentionSubmodules],
    layer_number: int,
    attn_mask_type: AttnMaskType,
    attention_type: str,
    cp_comm_type: str = None,
) -> None:
    super(MultiLatentAttention, self).__init__(
        config=config,
        submodules=submodules,
        layer_number=layer_number,
        attention_type=attention_type,
        attn_mask_type=attn_mask_type,
    )

    self.query_projection_size = self.config.v_head_dim * self.config.num_attention_heads

    self.q_head_dim = self.config.qk_head_dim + self.config.qk_pos_emb_head_dim

                                                                
    self.key_hidden_size = self.q_head_dim
    self.val_hidden_size = self.config.v_head_dim
    mcore_version_ge_0_12 = version.parse(__version__) >= version.parse('0.12.0')
    if mcore_version_ge_0_12:
        self.recompute_up_proj = (
   self.config.recompute_granularity == 'selective'
   and "mla_up_proj" in self.config.recompute_modules
  )
        self.qkv_up_checkpoint = None
    
    mscale = _yarn_get_mscale(self.config.rotary_scaling_factor, self.config.mscale)
    self.softmax_scale = mscale * mscale / math.sqrt(self.q_head_dim)

    if self.config.rope_type == "rope":
        self.rotary_pos_emb = RotaryEmbedding(
            self.config.qk_pos_emb_head_dim,
            rotary_percent=self.config.rotary_percent,
            rotary_base=self.config.rotary_base,
        )
    elif self.config.rope_type == "yarn":
        if mcore_version_ge_0_12:
            assert not self.config.apply_rope_fusion, "MLA Yarn RoPE does not support RoPE fusion"
        self.rotary_pos_emb = YarnRotaryEmbedding(
            self.config.qk_pos_emb_head_dim,
            rotary_base=self.config.rotary_base,
            scaling_factor=self.config.rotary_scaling_factor,
            original_max_position_embeddings=self.config.max_position_embeddings,
            beta_fast=self.config.beta_fast,
            beta_slow=self.config.beta_slow,
            mscale=self.config.mscale,
            mscale_all_dim=self.config.mscale_all_dim,
        )
    else:
        raise ValueError(
            f"Unsupported RoPE type: {self.config.rope_type}, supported types are "
            "'rope' and 'yarn'"
        )

    v_head_dim = self.config.v_head_dim if mcore_version_ge_0_12 else self.q_head_dim

    self.core_attention = build_module(
        submodules.core_attention,
        config=self.config,
        layer_number=self.layer_number,
        attn_mask_type=self.attn_mask_type,
        attention_type=self.attention_type,
        softmax_scale=self.softmax_scale,
        k_channels=self.q_head_dim,
        v_channels=v_head_dim,
        cp_comm_type=cp_comm_type,
    )

             
    self.linear_proj = build_module(
        submodules.linear_proj,
        self.query_projection_size,
        self.config.hidden_size,
        config=self.config,
        init_method=self.config.output_layer_init_method,
        bias=self.config.add_bias_linear,
        input_is_parallel=True,
        skip_bias_add=True,
        is_expert=False,
        tp_comm_buffer_name='proj',
    )


def multi_latent_attention_forward_lt_0_12(
    self,
    hidden_states,
    attention_mask,
    key_value_states=None,
    inference_params=None,
    rotary_pos_emb=None,
    rotary_pos_cos=None,
    rotary_pos_sin=None,
    attention_bias=None,
    packed_seq_params=None,
    position_ids=None,
    sequence_len_offset=None,
):
    """Forward pass for multi-latent attention"""
    assert rotary_pos_emb is None, "Rotary position embeddings should not be passed into MLA."
    assert attention_bias is None, "Attention bias should not be passed into MLA."
    assert (
        rotary_pos_cos is None and rotary_pos_sin is None
    ), "MLA does not support Flash Decoding"

                               
    q_len, bsz, _ = hidden_states.shape

                           
                           
                           
                                                                           
                         
                                                                           
    query, key, value = self.get_query_key_value_tensors(
        hidden_states,
        key_value_states,
        position_ids,
        packed_seq_params,
        inference_params=inference_params,
    )
                                                         
                                     
                                                         
                           
    query, key, value, _, attn_mask_type = self._adjust_key_value_for_inference(
        inference_params, query, key, value, rotary_pos_emb=None
    )

    if self.q_head_dim != self.config.v_head_dim:
        value = F.pad(value, [0, self.q_head_dim - self.config.v_head_dim])

                                        
                                
                                        
                                  
    if self.checkpoint_core_attention and self.training:
        core_attn_out = self._checkpointed_attention_forward(
            query, key, value, attention_mask, packed_seq_params=packed_seq_params
        )
    else:
        core_attn_out = self.core_attention(
            query,
            key,
            value,
            attention_mask,
            packed_seq_params=packed_seq_params,
            attn_mask_type=attn_mask_type,
        )

    if self.q_head_dim != self.config.v_head_dim:
        x = core_attn_out.view(q_len, bsz, self.config.num_attention_heads, -1)
        x = x[:, :, :, :self.config.v_head_dim]
        core_attn_out = x.reshape(q_len, bsz,
                                    self.config.num_attention_heads * self.config.v_head_dim)

    if packed_seq_params is not None:
                                                       
                                          
                                         
                                                                 
        core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
                       
                        
                       
    output, bias = self.linear_proj(core_attn_out)
    return output, bias
