from typing import List, Optional, Union

import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
try:
    from torch.distributed._tensor import DTensor, distribute_tensor
    HAVE_DTENSOR = True
except ImportError:
    HAVE_DTENSOR = False

from megatron.core import parallel_state
from megatron.core.distributed.finalize_model_grads import _unshard_if_dtensor, _reshard_if_dtensor
from megatron.core.transformer.moe.moe_utils import get_updated_expert_bias
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import get_attr_wrapped_model, get_model_config


def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
    """
    All-reduce word embedding grads.
    Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
    sync.
    """
    if (
        parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
        and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1
    ):
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
            model_module = model[0]
        elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
            model_module = model[-1]
        else:                                                                           
            model_module = model[0]
        model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
        if model_module.share_embeddings_and_output_weights:
            weight = model_module.shared_embedding_or_output_weight()
            grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
            orig_grad = getattr(weight, grad_attr)
            grad = _unshard_if_dtensor(orig_grad)
                           
            if grad is None:
                return
                         
            torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
            setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))

def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerConfig):
    """
    All-reduce layernorm grads (for sequence parallelism).
    """
                                                                 
                                       
    if parallel_state.get_tensor_model_parallel_world_size() > 1 and (
        config.sequence_parallel or config.qk_layernorm
    ):
        params = []
        grads = []
        for model_chunk in model:
            for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
                               
                if (
                    param.requires_grad
                    and (getattr(param, 'sequence_parallel', False)
                    or 'q_layernorm' in name
                    or 'k_layernorm' in name)
                ):
                    params.append(param)
                    grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad"
                    grad = getattr(param, grad_attr)
                    grad = _unshard_if_dtensor(grad)
                    grads.append(grad.data)
                            
        if grads:
            coalesced = _flatten_dense_tensors(grads)
            torch.distributed.all_reduce(
                coalesced, group=parallel_state.get_tensor_model_parallel_group()
            )
            for param, buf, synced in zip(
                params, grads, _unflatten_dense_tensors(coalesced, grads)
            ):
                buf.copy_(synced)
                grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad"
                orig_grad = getattr(param, grad_attr)
                setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad))