import os

import torch


def setup_ddp(rank, world_size, port=12357):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)

    # initialize the process group
    torch.distributed.init_process_group(
        "nccl",
        rank=rank,
        world_size=world_size,
    )
    torch.cuda.set_device(rank)
    torch.distributed.barrier()


def cleanup_ddp():
    torch.distributed.destroy_process_group()


def is_main_process():
    return torch.distributed.get_rank() == 0


def distribute_loader(loader):
    return torch.utils.data.DataLoader(
        loader.dataset,
        batch_size=loader.batch_size // torch.distributed.get_world_size(),
        sampler=torch.utils.data.distributed.DistributedSampler(
            loader.dataset,
            num_replicas=torch.distributed.get_world_size(),
            rank=torch.distributed.get_rank(),
        ),
        num_workers=loader.num_workers,
        pin_memory=loader.pin_memory,
    )