# Adapted from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/precise_bn.py  # noqa: E501
# Original licence: Copyright (c) 2019 Facebook, Inc under the Apache License 2.0  # noqa: E501

import logging
import time

import mmcv
import torch
from mmcv.parallel import MMDistributedDataParallel
from mmcv.runner import Hook
from mmcv.utils import print_log
from torch.nn import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import DataLoader


def is_parallel_module(module):
    """Check if a module is a parallel module.

    The following 3 modules (and their subclasses) are regarded as parallel
    modules: DataParallel, DistributedDataParallel,
    MMDistributedDataParallel (the deprecated version).

    Args:
        module (nn.Module): The module to be checked.
    Returns:
        bool: True if the input module is a parallel module.
    """
    parallels = (DataParallel, DistributedDataParallel,
                 MMDistributedDataParallel)
    return bool(isinstance(module, parallels))


@torch.no_grad()
def update_bn_stats(model, data_loader, num_iters=200, logger=None):
    """Recompute and update the batch norm stats to make them more precise.

    During
    training both BN stats and the weight are changing after every iteration,
    so the running average can not precisely reflect the actual stats of the
    current model.
    In this function, the BN stats are recomputed with fixed weights, to make
    the running average more precise. Specifically, it computes the true
    average of per-batch mean/variance instead of the running average.

    Args:
        model (nn.Module): The model whose bn stats will be recomputed.
        data_loader (iterator): The DataLoader iterator.
        num_iters (int): number of iterations to compute the stats.
        logger (:obj:`logging.Logger` | None): Logger for logging.
            Default: None.
    """

    model.train()

    assert len(data_loader) >= num_iters, (
        f'length of dataloader {len(data_loader)} must be greater than '
        f'iteration number {num_iters}')

    if is_parallel_module(model):
        parallel_module = model
        model = model.module
    else:
        parallel_module = model
    # Finds all the bn layers with training=True.
    bn_layers = [
        m for m in model.modules() if m.training and isinstance(m, _BatchNorm)
    ]

    if len(bn_layers) == 0:
        print_log('No BN found in model', logger=logger, level=logging.WARNING)
        return
    print_log(f'{len(bn_layers)} BN found', logger=logger)

    # Finds all the other norm layers with training=True.
    for m in model.modules():
        if m.training and isinstance(m, (_InstanceNorm, GroupNorm)):
            print_log(
                'IN/GN stats will be updated like training.',
                logger=logger,
                level=logging.WARNING)

    # In order to make the running stats only reflect the current batch, the
    # momentum is disabled.
    # bn.running_mean = (1 - momentum) * bn.running_mean + momentum *
    # batch_mean
    # Setting the momentum to 1.0 to compute the stats without momentum.
    momentum_actual = [bn.momentum for bn in bn_layers]  # pyre-ignore
    for bn in bn_layers:
        bn.momentum = 1.0

    # Note that running_var actually means "running average of variance"
    running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
    running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]

    finish_before_loader = False
    prog_bar = mmcv.ProgressBar(len(data_loader))
    for ind, data in enumerate(data_loader):
        with torch.no_grad():
            parallel_module(**data, return_loss=False)
        prog_bar.update()
        for i, bn in enumerate(bn_layers):
            # Accumulates the bn stats.
            running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
            # running var is actually
            running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)

        if (ind + 1) >= num_iters:
            finish_before_loader = True
            break
    assert finish_before_loader, 'Dataloader stopped before ' \
                                 f'iteration {num_iters}'

    for i, bn in enumerate(bn_layers):
        # Sets the precise bn stats.
        bn.running_mean = running_mean[i]
        bn.running_var = running_var[i]
        bn.momentum = momentum_actual[i]


class PreciseBNHook(Hook):
    """Precise BN hook.

    Attributes:
        dataloader (DataLoader): A PyTorch dataloader.
        num_iters (int): Number of iterations to update the bn stats.
            Default: 200.
        interval (int): Perform precise bn interval (by epochs). Default: 1.
    """

    def __init__(self, dataloader, num_iters=200, interval=1):
        if not isinstance(dataloader, DataLoader):
            raise TypeError('dataloader must be a pytorch DataLoader, but got'
                            f' {type(dataloader)}')
        self.dataloader = dataloader
        self.interval = interval
        self.num_iters = num_iters

    def after_train_epoch(self, runner):
        if self.every_n_epochs(runner, self.interval):
            # sleep to avoid possible deadlock
            time.sleep(2.)
            print_log(
                f'Running Precise BN for {self.num_iters} iterations',
                logger=runner.logger)
            update_bn_stats(
                runner.model,
                self.dataloader,
                self.num_iters,
                logger=runner.logger)
            print_log('BN stats updated', logger=runner.logger)
            # sleep to avoid possible deadlock
            time.sleep(2.)
