import torch.distributed as dist

# global group
_DATA_PARALLEL_GROUP = None
_CPU_DATA_PARALLEL_GROUP = None
_SEQUENCE_PARALLEL_GROUP = None
_CPU_SEQUENCE_PARALLEL_GROUP = None
DistGroups = dict()


def init_seq_dp_group(sp_size, is_all_dp=False):
    global _DATA_PARALLEL_GROUP
    global _CPU_DATA_PARALLEL_GROUP
    global _SEQUENCE_PARALLEL_GROUP
    global _CPU_SEQUENCE_PARALLEL_GROUP
    global DistGroups
    whole_world_size = dist.get_world_size()
    assert whole_world_size % sp_size == 0
    dp_size = whole_world_size // sp_size
    global_rank = dist.get_rank()
    if is_all_dp:
        _DATA_PARALLEL_GROUP = dist.group.WORLD
        _CPU_DATA_PARALLEL_GROUP = dist.new_group([i for i in range(whole_world_size)], backend="gloo")
    else:
        for i in range(sp_size):
            dp_ranks = [i + j * sp_size for j in range(dp_size)]
            group = dist.new_group(dp_ranks)
            cpu_group = dist.new_group(dp_ranks, backend="gloo")
            if global_rank in dp_ranks:
                _DATA_PARALLEL_GROUP = group
                _CPU_DATA_PARALLEL_GROUP = cpu_group
    DistGroups["dp"] = _DATA_PARALLEL_GROUP
    DistGroups["dp_cpu"] = _CPU_DATA_PARALLEL_GROUP
    for i in range(dp_size):
        sp_ranks = [j + i * sp_size for j in range(sp_size)]
        group = dist.new_group(sp_ranks)
        cpu_group = dist.new_group(sp_ranks, backend="gloo")
        if global_rank in sp_ranks:
            _SEQUENCE_PARALLEL_GROUP = group
            _CPU_SEQUENCE_PARALLEL_GROUP = cpu_group
    DistGroups["sp"] = _SEQUENCE_PARALLEL_GROUP
    DistGroups["sp_cpu"] = _CPU_SEQUENCE_PARALLEL_GROUP
    # DEBUG
    for k, group in DistGroups.items():
        print(f"On Rank{global_rank}-{k}: {get_all_ranks(group)}")


def get_all_ranks(group):
    rank = 0
    results = []
    try:
        while True:
            results.append(dist.distributed_c10d._get_global_rank(group, rank))
            rank += 1
    except RuntimeError:
        pass
    return results
