import os
import torch
import torch.distributed as dist
import datetime


def setup_multi_gpu(opt):
    if opt.world_size is not None and opt.local_rank > -1:
        os.environ["TOKENIZERS_PARALLELISM"] = "true"
        os.environ["NCCL_BLOCKING_WAIT"] = '1'
        opt.world_size = int(os.environ['WORLD_SIZE'])
        opt.local_rank = int(os.environ['LOCAL_RANK'])
        set_up_distributed_training_multi_gpu(opt)


def set_up_distributed_training_multi_gpu(opt, backend='nccl'):
    opt.device_id = opt.local_rank
    torch.cuda.set_device(opt.device_id)
    opt.distributed_rank = opt.device_id
    torch.distributed.init_process_group(backend=backend,
                                         init_method='env://',
                                         world_size=opt.world_size,
                                         timeout=datetime.timedelta(seconds=7200))


def gather(data):
    assert data.dim() < 3
    size_list = [None for _ in range(torch.distributed.get_world_size())]
    data_len = data.shape[0]
    dist.all_gather_object(size_list, data_len)
    max_len = max(size_list)
    data_shape = (max_len,)
    pad = (0, max_len - data_len)
    if data.dim() == 2:
        data_shape += (data.shape[-1],)
        pad = (0, 0, 0, max_len - data_len)

    data = torch.nn.functional.pad(data, pad, value=-1)
    tot_list = [torch.zeros(data_shape, dtype=data.dtype, device=data.device) for _ in size_list]
    dist.all_gather(tot_list, data)
    return tot_list