from torch import Tensor
from packaging import version

from megatron.core import __version__
from megatron.core import tensor_parallel


def language_model_embedding_forward(self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: int = None) -> Tensor:
    """Forward pass of the embedding module.
    Args:
        input_ids (Tensor): The input tokens
        position_ids (Tensor): The position id's used to calculate position embeddings
        tokentype_ids (int): The token type ids. Used when args.bert_binary_head is
            set to True. Defaults to None
    Returns:
        Tensor: The output embeddings
    """
    word_embeddings = self.word_embeddings(input_ids)
                   
    if hasattr(self.config, 'embed_scale'):
        word_embeddings = word_embeddings * self.config.embed_scale
                 
    if self.add_position_embedding:
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = word_embeddings + position_embeddings
    else:
        embeddings = word_embeddings
    if not self.reduce_scatter_embeddings:
                                                                               
        embeddings = embeddings.transpose(0, 1).contiguous()
    if tokentype_ids is not None:
        assert self.tokentype_embeddings is not None
                                                                      
        tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2)
        embeddings = embeddings + tokentype_embedding
    else:
        assert self.tokentype_embeddings is None

                                                                               
    if self.config.fp32_residual_connection:
        embeddings = embeddings.float()

              
    if self.config.sequence_parallel:
        if not self.reduce_scatter_embeddings and self.scatter_to_sequence_parallel:
                           
            extra_args = {}
            if version.parse(__version__) >= version.parse('0.13.0'):
                extra_args["group"] = self.tp_group
            embeddings = tensor_parallel.scatter_to_sequence_parallel_region(
                embeddings, **extra_args
            )
                         
                                                                              
                                                                                   
                                           
        if self.config.clone_scatter_output_in_embedding and self.scatter_to_sequence_parallel:
            embeddings = embeddings.clone()
        with tensor_parallel.get_cuda_rng_tracker().fork():
            embeddings = self.embedding_dropout(embeddings)
    else:
        embeddings = self.embedding_dropout(embeddings)
    return embeddings
