import torch
from mmcv.utils import Registry, build_from_cfg

MODULE_HOOKS = Registry('module_hooks')


def register_module_hooks(Module, module_hooks_list):
    handles = []
    for module_hook_cfg in module_hooks_list:
        hooked_module_name = module_hook_cfg.pop('hooked_module', 'backbone')
        if not hasattr(Module, hooked_module_name):
            raise ValueError(
                f'{Module.__class__} has no {hooked_module_name}!')
        hooked_module = getattr(Module, hooked_module_name)
        hook_pos = module_hook_cfg.pop('hook_pos', 'forward_pre')

        if hook_pos == 'forward_pre':
            handle = hooked_module.register_forward_pre_hook(
                build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func())
        elif hook_pos == 'forward':
            handle = hooked_module.register_forward_hook(
                build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func())
        elif hook_pos == 'backward':
            handle = hooked_module.register_backward_hook(
                build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func())
        else:
            raise ValueError(
                f'hook_pos must be `forward_pre`, `forward` or `backward`, '
                f'but get {hook_pos}')
        handles.append(handle)
    return handles


@MODULE_HOOKS.register_module()
class GPUNormalize:
    """Normalize images with the given mean and std value on GPUs.

    Call the member function ``hook_func`` will return the forward pre-hook
    function for module registration.

    GPU normalization, rather than CPU normalization, is more recommended in
    the case of a model running on GPUs with strong compute capacity such as
    Tesla V100.

    Args:
        mean (Sequence[float]): Mean values of different channels.
        std (Sequence[float]): Std values of different channels.
    """

    def __init__(self, input_format, mean, std):
        if input_format not in ['NCTHW', 'NCHW', 'NCHW_Flow', 'NPTCHW']:
            raise ValueError(f'The input format {input_format} is invalid.')
        self.input_format = input_format
        _mean = torch.tensor(mean)
        _std = torch.tensor(std)
        if input_format == 'NCTHW':
            self._mean = _mean[None, :, None, None, None]
            self._std = _std[None, :, None, None, None]
        elif input_format == 'NCHW':
            self._mean = _mean[None, :, None, None]
            self._std = _std[None, :, None, None]
        elif input_format == 'NCHW_Flow':
            self._mean = _mean[None, :, None, None]
            self._std = _std[None, :, None, None]
        elif input_format == 'NPTCHW':
            self._mean = _mean[None, None, None, :, None, None]
            self._std = _std[None, None, None, :, None, None]
        else:
            raise ValueError(f'The input format {input_format} is invalid.')

    def hook_func(self):

        def normalize_hook(Module, input):
            x = input[0]
            assert x.dtype == torch.uint8, (
                f'The previous augmentation should use uint8 data type to '
                f'speed up computation, but get {x.dtype}')

            mean = self._mean.to(x.device)
            std = self._std.to(x.device)

            with torch.no_grad():
                x = x.float().sub_(mean).div_(std)

            return (x, *input[1:])

        return normalize_hook
