import os, torch, random
import torch.distributed as dist
import torch.nn as nn
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

ddp_rank = 0
ddp_world_size = 1
ddp_disabled = False


def setup(rank, world_size, port):
    global ddp_rank, ddp_world_size

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = str(port)

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    ddp_rank = rank
    ddp_world_size = world_size


def cleanup():
    dist.destroy_process_group()


def spawn(entry, args=(), n_gpus=None, join=True):
    """
    Entry function should be `entry(rank, world_size, ddp_port, *args)`
    """
    global ddp_disabled

    if n_gpus is None: n_gpus = 1024
    n_gpus = min(n_gpus, torch.cuda.device_count())
    port = random.randint(32000, 37000)
    if n_gpus == 1:
        print(f'DDP: No need to DDP, using single process')
        ddp_disabled = True
        entry(0, 1, port, *args)
    elif n_gpus > 1:
        print(f'DDP: Setup DDP with {n_gpus} devices')
        mp.spawn(
            entry,
            args=(
                n_gpus,
                port,
            ) + tuple(args),
            nprocs=n_gpus,
            daemon=False,
            join=join,
        )
    else:
        raise Exception('no gpu')


def barrier():
    return dist.barrier()


def printable():
    global ddp_disabled, ddp_rank, ddp_world_size
    return ddp_world_size == 1 or (ddp_rank == 0) or ddp_disabled


def wrap_model(model, find_unused_paramters=False):
    global ddp_rank, ddp_world_size
    if ddp_world_size == 1:
        return MimicDDP(model)
    else:
        print('DDP: Model wrapped', ddp_rank, find_unused_paramters)
        return DDP(model,
                   device_ids=[ddp_rank],
                   find_unused_parameters=find_unused_paramters)


class MimicDDP(nn.Module):

    def __init__(self, module) -> None:
        super().__init__()
        self.module = module

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)
