import cv2
import os
import warnings

import torch
import torch.distributed as dist
import torch.multiprocessing as mp


def setup_multi_processes(cfg):
    """Setup multi-processing environment variables."""
    # set multi-process start method as `fork` to speed up the training
    # if platform.system() != 'Windows':
    #     mp_start_method = cfg.get('mp_start_method', 'fork')
    #     current_method = mp.get_start_method(allow_none=True)
    #     if current_method is not None and current_method != mp_start_method:
    #         warnings.warn(
    #             f'Multi-processing start method `{mp_start_method}` is '
    #             f'different from the previous setting `{current_method}`.'
    #             f'It will be force set to `{mp_start_method}`. You can change '
    #             f'this behavior by changing `mp_start_method` in your config.')
    #     mp.set_start_method(mp_start_method, force=True)

    # disable opencv multithreading to avoid system being overloaded
    # opencv_num_threads = cfg.get('opencv_num_threads', 0)
    # cv2.setNumThreads(opencv_num_threads)

    # # setup OMP threads
    # # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py  # noqa
    # if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
    #     omp_num_threads = 1
    #     warnings.warn(
    #         f'Setting OMP_NUM_THREADS environment variable for each process '
    #         f'to be {omp_num_threads} in default, to avoid your system being '
    #         f'overloaded, please further tune the variable for optimal '
    #         f'performance in your application as needed.')
    #     os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)

    # # setup MKL threads
    # if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
    #     mkl_num_threads = 1
    #     warnings.warn(
    #         f'Setting MKL_NUM_THREADS environment variable for each process '
    #         f'to be {mkl_num_threads} in default, to avoid your system being '
    #         f'overloaded, please further tune the variable for optimal '
    #         f'performance in your application as needed.')
    #     os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
    pass
    
def get_dist_info():
    if dist.is_available() and dist.is_initialized():
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1
    return rank, world_size

def auto_select_device():
    if torch.cuda.is_available():
        return 'cuda'
    else:
        return 'cpu'

def sync_random_seed(seed=None, device=None):
    """Make sure different ranks share the same seed.

    All workers must call this function, otherwise it will deadlock.
    This method is generally used in `DistributedSampler`,
    because the seed should be identical across all processes
    in the distributed group.

    In distributed sampling, different ranks should sample non-overlapped
    data in the dataset. Therefore, this function is used to make sure that
    each rank shuffles the data indices in the same order based
    on the same seed. Then different ranks could use different indices
    to select non-overlapped data from the same data list.

    Args:
        seed (int, Optional): The seed. Default to None.
        device (str): The device where the seed will be put on.
            Default to 'cuda'.

    Returns:
        int: Seed to be used.
    """
    if device is None:
        device = auto_select_device()
    if seed is None:
        seed = np.random.randint(2**31)
    assert isinstance(seed, int)

    rank, world_size = get_dist_info()

    if world_size == 1:
        return seed

    if rank == 0:
        random_num = torch.tensor(seed, dtype=torch.int32, device=device)
    else:
        random_num = torch.tensor(0, dtype=torch.int32, device=device)
    dist.broadcast(random_num, src=0)
    return random_num.item()


def init_dist(backend: str, **kwargs) -> None:
    # TODO: use local_rank instead of rank % num_gpus
    # if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
    #     rank = int(os.environ['RANK'])
    #     world_size = int(os.environ['WORLD_SIZE'])
    # else:
    #     rank = -1
    #     world_size = -1
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(kwargs['rank'] % num_gpus)
    # kwargs.update({
    #     'rank':rank,
    #     'world_size':world_size
    #     })
    dist.init_process_group(backend=backend, **kwargs)

