import torch
import torch.distributed as dist


def setup_dist():
    """
    Setup a distributed process group.
    """
    if not dist.is_available() or dist.is_initialized():
        return
    use_gpu = torch.cuda.is_available()
    backend = "nccl" if use_gpu else "gloo"
    dist.init_process_group(backend=backend)
    if use_gpu:
        torch.cuda.set_device(dist.get_rank())
        return torch.device(f"cuda:{dist.get_rank()}")
    return torch.device("cpu")


def dev():
    """
    Get the device to use for torch.distributed.
    """
    if torch.cuda.is_available():
        return torch.device(f"cuda")
    return torch.device("cpu")


def sync_params(params):
    """
    Synchronize a sequence of Tensors across ranks from rank 0.
    """
    with torch.no_grad():
        for p in params:
            dist.broadcast(p, 0)
