# Copyright (c) Facebook, Inc. and its affiliates.
import logging
from datetime import timedelta

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from openood.utils import comm

__all__ = ['DEFAULT_TIMEOUT', 'launch']

DEFAULT_TIMEOUT = timedelta(minutes=30)


def _find_free_port():
    import socket

    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    # Binding to port 0 will cause the OS to find an available port for us
    sock.bind(('', 0))
    port = sock.getsockname()[1]
    sock.close()
    # NOTE: there is still a chance the port could be taken by other processes.
    return port


def launch(
        main_func,
        num_gpus_per_machine,
        num_machines=1,
        machine_rank=0,
        dist_url=None,
        args=(),
        timeout=DEFAULT_TIMEOUT,
):
    """Launch multi-gpu or distributed training. This function must be called
    on all machines involved in the training. It will spawn child processes
    (defined by ``num_gpus_per_machine``) on each machine.

    Args:
        main_func: a function that will be called by `main_func(*args)`
        num_gpus_per_machine (int): number of GPUs per machine
        num_machines (int): the total number of machines
        machine_rank (int): the rank of this machine
        dist_url (str): url to connect to for distributed jobs,
        including protocol e.g. "tcp://127.0.0.1:8686".
        Can be set to "auto" to automatically select a free port on localhost
        timeout (timedelta): timeout of the distributed workers
        args (tuple): arguments passed to main_func
    """
    world_size = num_machines * num_gpus_per_machine
    if world_size > 1:
        # https://github.com/pytorch/pytorch/pull/14391
        # TODO prctl in spawned processes

        if dist_url == 'auto':
            assert num_machines == 1, \
            'dist_url=auto not supported in multi-machine jobs.'
            port = _find_free_port()
            dist_url = f'tcp://127.0.0.1:{port}'
        if num_machines > 1 and dist_url.startswith('file://'):
            logger = logging.getLogger(__name__)
            logger.warning(
                'file:// is not a reliable init_method in multi-machine jobs.'
                'Prefer tcp://'
            )

        mp.spawn(
            _distributed_worker,
            nprocs=num_gpus_per_machine,
            args=(
                main_func,
                world_size,
                num_gpus_per_machine,
                machine_rank,
                dist_url,
                args,
                timeout,
            ),
            daemon=False,
        )
    else:
        main_func(*args)


def _distributed_worker(
    local_rank,
    main_func,
    world_size,
    num_gpus_per_machine,
    machine_rank,
    dist_url,
    args,
    timeout=DEFAULT_TIMEOUT,
):
    assert torch.cuda.is_available(
    ), 'cuda is not available. Please check your installation.'
    global_rank = machine_rank * num_gpus_per_machine + local_rank
    try:
        dist.init_process_group(
            backend='NCCL',
            init_method=dist_url,
            world_size=world_size,
            rank=global_rank,
            timeout=timeout,
        )
    except Exception as e:
        logger = logging.getLogger(__name__)
        logger.error('Process group URL: {}'.format(dist_url))
        raise e

    # Setup the local process group
    # (which contains ranks within the same machine)
    assert comm._LOCAL_PROCESS_GROUP is None
    num_machines = world_size // num_gpus_per_machine
    for i in range(num_machines):
        ranks_on_i = list(
            range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
        pg = dist.new_group(ranks_on_i)
        if i == machine_rank:
            comm._LOCAL_PROCESS_GROUP = pg

    assert num_gpus_per_machine <= torch.cuda.device_count()
    torch.cuda.set_device(local_rank)

    # synchronize is needed here to prevent a possible timeout
    # after calling init_process_group
    # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
    comm.synchronize()

    main_func(*args)
