

import torch
import torch.distributed

import vllm.model_executor.parallel_utils.parallel_state as ps


_TENSOR_MODEL_PARALLEL_GROUP = None


_MICRO_DATA_PARALLEL_GROUP = None


def initialize_model_parallel_from_megatron(
        tensor_model_parallel_size=None
) -> None:
    from megatron.core import parallel_state as mpu
    from megatron.distributed import new_group

    assert torch.distributed.is_initialized()

    if tensor_model_parallel_size is None:
        tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
    else:
        assert isinstance(tensor_model_parallel_size, int)


    assert ps._TENSOR_MODEL_PARALLEL_GROUP is None, ("tensor model parallel group is already initialized")

    assert tensor_model_parallel_size <= mpu.get_tensor_model_parallel_world_size(
    ), 'Not implemented for infer_tp > train_tp'

    global _TENSOR_MODEL_PARALLEL_GROUP
    global _MICRO_DATA_PARALLEL_GROUP

    assert mpu.get_tensor_model_parallel_world_size() % tensor_model_parallel_size == 0

    micro_dp_size = mpu.get_tensor_model_parallel_world_size() // tensor_model_parallel_size

    world_size: int = torch.distributed.get_world_size()

    num_micro_dp_groups = world_size // micro_dp_size

    rank = torch.distributed.get_rank()


    assert _MICRO_DATA_PARALLEL_GROUP is None, ("micro data parallel group is already initialized")
    for i in range(num_micro_dp_groups):
        ranks = range(i * micro_dp_size, (i + 1) * micro_dp_size)
        group = new_group(rank=rank, ranks=ranks, group_type='micro_dp')
        if rank in ranks:
            _MICRO_DATA_PARALLEL_GROUP = group

    if tensor_model_parallel_size == mpu.get_tensor_model_parallel_world_size():

        ps._TENSOR_MODEL_PARALLEL_GROUP = mpu.get_tensor_model_parallel_group()

        _TENSOR_MODEL_PARALLEL_GROUP = mpu.get_tensor_model_parallel_group()

    else:

        train_tp = mpu.get_tensor_model_parallel_world_size()
        num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size
        num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
        assert _TENSOR_MODEL_PARALLEL_GROUP is None, ("tensor model parallel group is already initialized")
        for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp):
            start = train_tp * i
            end = train_tp * (i + 1)
            for j in range(num_tensor_model_parallel_groups_per_train_tp):
                ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp))
                for i in range(len(ranks)):
                    ranks[i] += j

                group = new_group(rank=rank, ranks=ranks, group_type='infer_tp')
                if rank in ranks:
                    _TENSOR_MODEL_PARALLEL_GROUP = group
                    ps._TENSOR_MODEL_PARALLEL_GROUP = _TENSOR_MODEL_PARALLEL_GROUP




def get_tensor_model_parallel_group():

    assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ("tensor model parallel group is not initialized")
    return _TENSOR_MODEL_PARALLEL_GROUP


def get_tensor_model_parallel_world_size():

    return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())


def get_tensor_model_parallel_rank():

    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())


def get_tensor_model_parallel_src_rank():

    global_rank = torch.distributed.get_rank()
    local_world_size = get_tensor_model_parallel_world_size()
    return (global_rank // local_world_size) * local_world_size





def get_micro_data_parallel_group():
    assert _MICRO_DATA_PARALLEL_GROUP is not None
    return _MICRO_DATA_PARALLEL_GROUP


def get_micro_data_parallel_world_size():
    return torch.distributed.get_world_size(group=get_micro_data_parallel_group())


def get_micro_data_parallel_rank():
    return torch.distributed.get_rank(group=get_micro_data_parallel_group())
