import torch

if hasattr(torch.distributed, 'ReduceOp'):
    ReduceOp = torch.distributed.ReduceOp
elif hasattr(torch.distributed, 'reduce_op'):
    ReduceOp = torch.distributed.reduce_op
else:
    ReduceOp = torch.distributed.deprecated.reduce_op

from .distributed import DistributedDataParallel, Reducer
# This is tricky because I'd like SyncBatchNorm to be exposed the same way
# for both the cuda-enabled and python-fallback versions, and I don't want
# to suppress the error information.
try:
    import syncbn
    from .optimized_sync_batchnorm import SyncBatchNorm
except ImportError as err:
    from .sync_batchnorm import SyncBatchNorm
    SyncBatchNorm.syncbn_import_error = err

def convert_syncbn_model(module, process_group=None, channel_last=False):
    '''
    Recursively traverse module and its children to replace all instances of
    ``torch.nn.modules.batchnorm._BatchNorm`` with :class:`apex.parallel.SyncBatchNorm`.

    All ``torch.nn.BatchNorm*N*d`` wrap around
    ``torch.nn.modules.batchnorm._BatchNorm``, so this function lets you easily switch
    to use sync BN.

    Args:
        module (torch.nn.Module): input module

    Example::

        >>> # model is an instance of torch.nn.Module
        >>> import apex
        >>> sync_bn_model = apex.parallel.convert_syncbn_model(model)
    '''
    from apex import deprecated_warning
    deprecated_warning("apex.parallel.convert_syncbn_model is deprecated and will be removed by the end of February 2023. Use `torch.nn.SyncBatchNorm.convert_sync_batchnorm`.")
    mod = module
    if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
        return module
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last)
        mod.running_mean = module.running_mean
        mod.running_var = module.running_var
        mod.num_batches_tracked = module.num_batches_tracked
        if module.affine:
            mod.weight.data = module.weight.data.clone().detach()
            mod.bias.data = module.bias.data.clone().detach()
    for name, child in module.named_children():
        mod.add_module(name, convert_syncbn_model(child,
                                                  process_group=process_group,
                                                  channel_last=channel_last))
    # TODO(jie) should I delete model explicitly?
    del module
    return mod

def create_syncbn_process_group(group_size):
    '''
    Creates process groups to be used for syncbn of a give ``group_size`` and returns
    process group that current GPU participates in.

    ``group_size`` must divide the total number of GPUs (world_size).

    ``group_size`` of 0 would be considered as =world_size. In this case ``None`` will be returned.

    ``group_size`` of 1 would be equivalent to using non-sync bn, but will still carry the overhead.

    Args:
        group_size (int): number of GPU's to collaborate for sync bn

    Example::

        >>> # model is an instance of torch.nn.Module
        >>> import apex
        >>> group = apex.parallel.create_syncbn_process_group(group_size)
    '''

    if group_size==0:
        return None

    world_size = torch.distributed.get_world_size()
    assert(world_size >= group_size)
    assert(world_size % group_size == 0)

    group=None
    for group_num in (range(world_size//group_size)):
        group_ids = range(group_num*group_size, (group_num+1)*group_size)
        cur_group = torch.distributed.new_group(ranks=group_ids)
        if (torch.distributed.get_rank()//group_size == group_num):
            group = cur_group
            #can not drop out and return here, every process must go through creation of all subgroups

    assert(group is not None)
    return group
