import os
import torch
import torch.distributed as dist
import builtins as __builtin__

def ddp_setup():
    os.environ['MASTER_ADDR'] = 'localhost'
    # os.environ['MASTER_PORT'] = '29500'
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ['WORLD_SIZE'])
        local_gpu = int(os.environ['LOCAL_RANK'])

        torch.cuda.set_device(local_gpu)

        dist.init_process_group("nccl", rank=rank, world_size=world_size)
        torch.distributed.barrier()

        builtin_print = __builtin__.print
        def print(*args, **kwargs):
            force = kwargs.pop('force', False)
            if rank==0 or force:
                builtin_print(*args, **kwargs)
        __builtin__.print = print
    return rank, world_size, local_gpu


def ddp_cleanup():
    dist.destroy_process_group()


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)

