import torch.distributed as dist
import os

def world_info_from_env():
    # local_rank = 0
    # for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
    #     if v in os.environ:
    #         local_rank = int(os.environ[v])
    #         break
    # global_rank = 0
    # for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
    #     if v in os.environ:
    #         global_rank = int(os.environ[v])
    #         break
    world_size = 1
    for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
        if v in os.environ:
            world_size = int(os.environ[v])
            break

    return world_size


def get_rank() -> int:
    if not (dist.is_available() and dist.is_initialized()):
        return 0
    else:
        return dist.get_rank()

# def dist_init(config):
#     dist.init_process_group(backend=config.train_dist_backend,
#                             init_method=config.train_dist_init_method,
#                             rank=config.train_dist_node_rank,
#                             world_size=config.train_dist_world_size)
#     torch.cuda.set_device(config.train_dist_local_rank)