# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import functools
import tempfile

import torch


def spawn_and_init(fn, world_size, args=None):
    if args is None:
        args = ()
    with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
        torch.multiprocessing.spawn(
            fn=functools.partial(init_and_run, fn, args),
            args=(
                world_size,
                tmp_file.name,
            ),
            nprocs=world_size,
            join=True,
        )


def distributed_init(rank, world_size, tmp_file):
    torch.distributed.init_process_group(
        backend="nccl",
        init_method="file://{}".format(tmp_file),
        world_size=world_size,
        rank=rank,
    )
    torch.cuda.set_device(rank)


def init_and_run(fn, args, rank, world_size, tmp_file):
    distributed_init(rank, world_size, tmp_file)
    group = torch.distributed.new_group()
    fn(rank, group, *args)


def objects_are_equal(a, b) -> bool:
    if type(a) is not type(b):
        return False
    if isinstance(a, dict):
        if set(a.keys()) != set(b.keys()):
            return False
        for k in a.keys():
            if not objects_are_equal(a[k], b[k]):
                return False
        return True
    elif isinstance(a, (list, tuple, set)):
        if len(a) != len(b):
            return False
        return all(objects_are_equal(x, y) for x, y in zip(a, b))
    elif torch.is_tensor(a):
        return (
            a.size() == b.size()
            and a.dtype == b.dtype
            and a.device == b.device
            and torch.all(a == b)
        )
    else:
        return a == b
