import torch
import numpy as np
from datetime import timedelta
from functools import partial
from itertools import cycle
from datetime import timedelta
from functools import partial
from itertools import cycle

from megatron.core import mpu
from megatron.core.parallel_state import (RankGenerator, create_group, default_embedding_ranks,
                                          default_position_embedding_ranks)
from megatron.training import get_args

                                                 
                                                
_MODEL_PARALLEL_GROUP_GLOO = None
_MODEL_PARALLEL_GLOBAL_RANKS_GLOO = None
_GROUP_GLOO = None

                        
_MODEL_AND_CONTEXT_PARALLEL_GROUP = None
_MODEL_EXPERT_AND_CONTEXT_PARALLEL_GROUP = None
_MODEL_AND_CONTEXT_PARALLEL_GLOBAL_RANKS = None
_MODEL_EXPERT_AND_CONTEXT_PARALLEL_GLOBAL_RANKS = None


def init_pg(distributed_timeout_minutes: int = 30):
    timeout = timedelta(minutes=distributed_timeout_minutes)

    global _GROUP_GLOO
    world_size = torch.distributed.get_world_size()
    ranks = np.arange(world_size)
    _GROUP_GLOO = torch.distributed.new_group(ranks=ranks, timeout=timeout, backend='gloo')

    global _MODEL_PARALLEL_GROUP_GLOO
    global _MODEL_PARALLEL_GLOBAL_RANKS_GLOO
    assert _MODEL_PARALLEL_GROUP_GLOO is None, 'model parallel group is already initialized'

                                                      
    args = get_args()
    rank = torch.distributed.get_rank()
    encoder_pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size
    encoder_tensor_model_parallel_size = args.encoder_tensor_model_parallel_size
    get_embedding_ranks = None
    get_position_embedding_ranks = None

    if encoder_pipeline_model_parallel_size is None:
        encoder_pipeline_model_parallel_size = 0

    if encoder_tensor_model_parallel_size == 0 and encoder_pipeline_model_parallel_size > 0:
        encoder_tensor_model_parallel_size = args.tensor_model_parallel_size

    if get_embedding_ranks is None:
        get_embedding_ranks = partial(default_embedding_ranks,
                                      split_rank=args.pipeline_model_parallel_split_rank)

    if get_position_embedding_ranks is None:
        get_position_embedding_ranks = partial(default_position_embedding_ranks,
                                               split_rank=args.pipeline_model_parallel_split_rank)

                                                         
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()

    encoder_model_size = (encoder_tensor_model_parallel_size *
                          encoder_pipeline_model_parallel_size * args.context_parallel_size)
    decoder_model_size = (args.tensor_model_parallel_size * args.pipeline_model_parallel_size *
                          args.context_parallel_size)
    total_model_size = encoder_model_size + decoder_model_size
    data_parallel_size: int = world_size // total_model_size
    encoder_world_size = encoder_model_size * data_parallel_size

    def generator_wrapper(group_type, is_expert=False, **kwargs):
        order = 'tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-cp-ep-pp-dp'
        if is_expert:
            d_ranks = RankGenerator(
                tp=args.expert_tensor_parallel_size,
                ep=args.expert_model_parallel_size,
                dp=args.expert_data_parallel_size,
                pp=args.pipeline_model_parallel_size,
                cp=1,
                order=order,
                rank_offset=encoder_world_size,
            ).get_ranks(group_type, **kwargs)
        else:
            d_ranks = RankGenerator(
                tp=args.tensor_model_parallel_size,
                ep=1,
                dp=data_parallel_size,
                pp=args.pipeline_model_parallel_size,
                cp=args.context_parallel_size,
                order=order,
                rank_offset=encoder_world_size,
            ).get_ranks(group_type, **kwargs)

        if encoder_world_size > 0:
            encoder_rank_generator = RankGenerator(
                tp=encoder_tensor_model_parallel_size,
                ep=1,
                dp=data_parallel_size,
                pp=encoder_pipeline_model_parallel_size,
                cp=args.context_parallel_size,
                order=order,
                rank_offset=0,
            )
        else:
            encoder_rank_generator = None

        if encoder_rank_generator is None:
            for x in d_ranks:
                yield x
            return
        e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs)
        if group_type == 'tp-pp':
                                                                 
                                                                
            assert len(e_ranks) == len(d_ranks)
            for x, y in zip(e_ranks, d_ranks):
                yield x + y

    for ranks in generator_wrapper('tp-pp'):
        group = create_group(
            ranks,
            timeout=timedelta(minutes=args.distributed_timeout_minutes),
            backend="gloo",
            group_desc='_MODEL_PARALLEL_GROUP_GLOO',
        )
        if rank in ranks:
            _MODEL_PARALLEL_GROUP_GLOO = group
            _MODEL_PARALLEL_GLOBAL_RANKS_GLOO = ranks


def get_model_parallel_src_rank():
    world_size = torch.distributed.get_world_size()
    all_ranks = np.arange(world_size)
    tp_size = mpu.get_tensor_model_parallel_world_size()
    pp_size = mpu.get_pipeline_model_parallel_world_size()
    all_ranks = all_ranks.reshape(pp_size, -1, tp_size)
    dp_rank = mpu.get_data_parallel_rank()
    return all_ranks[:, dp_rank, :].min()


def get_tensor_and_data_parallel_src_rank():
    tp_size = mpu.get_tensor_model_parallel_world_size()
    pp_size = mpu.get_pipeline_model_parallel_world_size()
    cp_size = mpu.get_context_parallel_world_size()

    world_size = torch.distributed.get_world_size()
    all_ranks = np.arange(world_size)
    pp_rank = mpu.get_pipeline_model_parallel_rank()
    all_ranks = all_ranks.reshape(pp_size, -1, cp_size, tp_size)
    return all_ranks[pp_rank, :, :, :].min()


def update_weights_gather_dst_rank():
    ep_size = mpu.get_expert_model_parallel_world_size()
    tp_size = mpu.get_tensor_model_parallel_world_size()
    pp_size = mpu.get_pipeline_model_parallel_world_size()
    cp_size = mpu.get_context_parallel_world_size()

    ep_rank = mpu.get_expert_model_parallel_rank()
    if ep_size > 1:
                                                  
                                           
                                          
        src_rank = torch.distributed.get_rank()
        dst_src_rank = src_rank // tp_size * tp_size
        return dst_src_rank

    world_size = torch.distributed.get_world_size()
    all_ranks = np.arange(world_size)
    pp_rank = mpu.get_pipeline_model_parallel_rank()
    all_ranks = all_ranks.reshape(pp_size, -1, cp_size, tp_size)
    return all_ranks[pp_rank, :, :, :].min()


def get_model_and_context_parallel_group(with_expert_parallel=False):
    """Get the model parallel group the caller rank belongs to."""
    if with_expert_parallel:
        assert (_MODEL_EXPERT_AND_CONTEXT_PARALLEL_GROUP
                is not None), 'model, exeprt and context parallel group is not initialized'
        return _MODEL_EXPERT_AND_CONTEXT_PARALLEL_GROUP
    assert _MODEL_AND_CONTEXT_PARALLEL_GROUP is not None, \
        'model and context parallel group is not initialized'
    return _MODEL_AND_CONTEXT_PARALLEL_GROUP


def get_model_and_context_parallel_src_rank(with_expert_parallel=False):
    """Get the model parallel group the caller rank belongs to."""
    if with_expert_parallel:
        assert (_MODEL_EXPERT_AND_CONTEXT_PARALLEL_GLOBAL_RANKS
                is not None), 'model, expert and context parallel src rank is not initialized'
        return _MODEL_EXPERT_AND_CONTEXT_PARALLEL_GLOBAL_RANKS[0]
    assert _MODEL_AND_CONTEXT_PARALLEL_GLOBAL_RANKS is not None, \
        'model and context parallel src rank is not initialized'
    return _MODEL_AND_CONTEXT_PARALLEL_GLOBAL_RANKS[0]


def get_model_parallel_with_cp_src_rank():
                         
                                    
                                                                                               
    world_size = torch.distributed.get_world_size()
    all_ranks = np.arange(world_size)
    tp_size = mpu.get_tensor_model_parallel_world_size()
    pp_size = mpu.get_pipeline_model_parallel_world_size()
    cp_size = mpu.get_context_parallel_world_size()
    all_ranks = all_ranks.reshape(pp_size, -1, cp_size, tp_size)
    dp_rank = mpu.get_data_parallel_rank()
    return all_ranks[:, dp_rank, :, :].min()


def is_mp_head():
    return mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0


def is_mp_and_cp_head():
    return mpu.is_pipeline_first_stage() \
            and mpu.get_tensor_model_parallel_rank() == 0 \
            and mpu.get_context_parallel_rank() == 0


def get_mp_and_cp_size():
    tp_size = mpu.get_tensor_model_parallel_world_size()
    pp_size = mpu.get_pipeline_model_parallel_world_size()
    cp_size = mpu.get_context_parallel_world_size()
    return tp_size * pp_size * cp_size


def is_tp_dp_and_cp_head(enable_expert_parallel=False):
    if enable_expert_parallel:
        dp_with_cp_rank = mpu.get_data_parallel_rank(with_context_parallel=True)
        ep_size = mpu.get_expert_model_parallel_world_size()
        return mpu.get_tensor_model_parallel_rank() == 0 \
            and dp_with_cp_rank < ep_size

    return mpu.get_tensor_model_parallel_rank() == 0 \
                and mpu.get_data_parallel_rank() == 0 \
                and mpu.get_context_parallel_rank() == 0


def is_update_weight_head(enable_expert_parallel=False):
    if enable_expert_parallel:
        if mpu.get_expert_tensor_parallel_world_size() != mpu.get_tensor_model_parallel_world_size(
        ):
                             
            return mpu.get_expert_data_parallel_rank() == 0

        dp_with_cp_rank = mpu.get_data_parallel_rank(with_context_parallel=True)
        ep_size = mpu.get_expert_model_parallel_world_size()
        return mpu.get_tensor_model_parallel_rank() == 0 \
            and dp_with_cp_rank < ep_size

    return mpu.get_tensor_model_parallel_rank() == 0 \
                and mpu.get_data_parallel_rank() == 0 \
                and mpu.get_context_parallel_rank() == 0


def is_mp_cp_and_ep_head():
    raise NotImplementedError('not implemented')


def get_mp_cp_and_ep_size():
    raise NotImplementedError('not implemented')


def cpu_barrier(pg=None):
    if pg is None:
        pg = _GROUP_GLOO
    torch.distributed.barrier(group=pg)


def get_model_parallel_group_gloo():
    """Get the model-parallel group the caller rank belongs to."""
    assert _MODEL_PARALLEL_GROUP_GLOO is not None, 'model parallel gloo group is not initialized'
    return _MODEL_PARALLEL_GROUP_GLOO


def get_model_parallel_src_rank_gloo():
    """Calculate the global rank corresponding to the first local rank
    in the model parallel group."""
    assert _MODEL_PARALLEL_GLOBAL_RANKS_GLOO is not None, "Model parallel group is not initialized"
    return _MODEL_PARALLEL_GLOBAL_RANKS_GLOO[0]
