# Copyright (c) OpenMMLab. All rights reserved.
from collections import defaultdict
from typing import Callable, List, Optional

from torch import nn
from torch.nn import GroupNorm, LayerNorm
from mmengine.logging import MMLogger
from mmengine.optim import DefaultOptimWrapperConstructor
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS


@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class CustomOptimWrapperConstructor(DefaultOptimWrapperConstructor):
    """Different learning rates are set for different layers of backbone.

    By default, each parameter share the same optimizer settings, and we
    provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
    It is a dict and may contain the following fields:

    - ``layer_decay_rate`` (float): The learning rate of a parameter will
      multiply it by multiple times according to the layer depth of the
      parameter. Usually, it's less than 1, so that the earlier layers will
      have a lower learning rate. Defaults to 1.
    - ``bias_decay_mult`` (float): It will be multiplied to the weight
      decay for all bias parameters (except for those in normalization layers).
    - ``norm_decay_mult`` (float): It will be multiplied to the weight
      decay for all weight and bias parameters of normalization layers.
    - ``flat_decay_mult`` (float): It will be multiplied to the weight
      decay for all one-dimensional parameters
    - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
      one of the keys in ``custom_keys`` is a substring of the name of one
      parameter, then the setting of the parameter will be specified by
      ``custom_keys[key]`` and other setting like ``bias_decay_mult`` will be
      ignored. It should be a dict and may contain fields ``decay_mult``.
      (The ``lr_mult`` is disabled in this constructor).
    - ``bypass_duplicate`` (bool): If true, the duplicate parameters
      would not be added into optimizer. Defaults to False.
    - ``custom_layer_decay_keys`` (dict): Specified parameters-wise layer
      decay settings by keys. If one of the keys in ``custom_layer_decay_keys``
      is a substring of the name of one parameter, then the layer-wise decay
      rate of the parameter will be specified by ``custom_layer_decay_keys[key]``.

    Example:

    In the config file, you can use this constructor as below:

    .. code:: python

        optim_wrapper = dict(
            optimizer=dict(
                type='AdamW',
                lr=4e-3,
                weight_decay=0.05,
                eps=1e-8,
                betas=(0.9, 0.999)),
            constructor='CustomOptimWrapperConstructor',
            paramwise_cfg=dict(
                layer_decay_rate=0.75,  # layer-wise lr decay factor
                norm_decay_mult=0.,
                flat_decay_mult=0.,
                custom_keys={
                    '.cls_token': dict(decay_mult=0.0),
                    '.pos_embed': dict(decay_mult=0.0)
                }))
    """

    def add_params(self,
                   params: List[dict],
                   module: nn.Module,
                   prefix: str = '',
                   get_layer_depth: Optional[Callable] = None,
                   **kwargs) -> None:
        """Add all parameters of module to the params list.

        The parameters of the given module will be added to the list of param
        groups, with specific rules defined by paramwise_cfg.

        Args:
            params (List[dict]): A list of param groups, it will be modified
                in place.
            module (nn.Module): The module to be added.
            optimizer_cfg (dict): The configuration of optimizer.
            prefix (str): The prefix of the module.
        """
        # get param-wise options
        custom_keys = self.paramwise_cfg.get('custom_keys', {})
        # first sort with alphabet order and then sort with reversed len of str
        sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
        logger = MMLogger.get_current_instance()

        # The model should have `get_layer_depth` method
        if get_layer_depth is None and not hasattr(module, 'get_layer_depth'):
            if not hasattr(module, 'backbone') or not hasattr(module.backbone, 'get_layer_depth'):
                raise NotImplementedError('The layer-wise learning rate decay needs'
                                          f' the model {type(module)} to have'
                                          ' either `get_layer_depth` method'
                                          ' or `backbone.get_layer_depth` method.')
            else:
                get_layer_depth = lambda param_name: module.backbone.get_layer_depth(param_name, 'backbone.')
        else:
            get_layer_depth = get_layer_depth or module.get_layer_depth

        bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None)
        norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None)
        flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None)
        decay_rate = self.paramwise_cfg.get('layer_decay_rate', 1.0)
        bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)

        custom_layer_decay_keys = self.paramwise_cfg.get('custom_layer_decay_keys', {})
        sorted_layer_decay_keys = sorted(sorted(custom_layer_decay_keys.keys()), key=len, reverse=True)

        # special rules for norm layers and depth-wise conv layers
        is_norm = isinstance(module,
                             (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))

        for name, param in module.named_parameters(recurse=False):
            param_group = {'params': [param]}
            param_name = prefix + name
            if bypass_duplicate and self._is_in(param_group, params):
                logger.warning(
                    f'{prefix} is duplicate. It is skipped since '
                    f'bypass_duplicate={bypass_duplicate}')
                continue
            if not param.requires_grad:
                continue

            if self.base_wd is not None:
                base_wd = self.base_wd
                custom_key = next(
                    filter(lambda k: k in param_name, sorted_keys), None)
                # custom parameters decay
                if custom_key is not None:
                    custom_cfg = custom_keys[custom_key].copy()
                    decay_mult = custom_cfg.pop('decay_mult', 1.)

                    param_group['weight_decay'] = base_wd * decay_mult
                    # add custom settings to param_group
                    param_group.update(custom_cfg)
                # norm decay
                elif is_norm and norm_decay_mult is not None:
                    param_group['weight_decay'] = base_wd * norm_decay_mult
                # bias decay
                elif name == 'bias' and bias_decay_mult is not None:
                    param_group['weight_decay'] = base_wd * bias_decay_mult
                # flatten parameters decay
                elif param.ndim == 1 and flat_decay_mult is not None:
                    param_group['weight_decay'] = base_wd * flat_decay_mult
                else:
                    param_group['weight_decay'] = base_wd

            custom_layer_decay_key = next(
                filter(lambda k: k in param_name, sorted_layer_decay_keys), None)

            layer_id, max_id = get_layer_depth(param_name)
            if custom_layer_decay_key is not None:
                custom_decay_rate = custom_layer_decay_keys[custom_layer_decay_key]
                scale = custom_decay_rate**(max_id - layer_id - 1)
            else:
                scale = decay_rate**(max_id - layer_id - 1)
            param_group['lr'] = self.base_lr * scale
            param_group['lr_scale'] = scale
            param_group['layer_id'] = layer_id
            param_group['param_name'] = param_name

            params.append(param_group)

        for child_name, child_mod in module.named_children():
            child_prefix = f'{prefix}{child_name}.'
            self.add_params(
                params,
                child_mod,
                prefix=child_prefix,
                get_layer_depth=get_layer_depth,
            )

        if prefix == '':
            layer_params = defaultdict(list)
            for param in params:
                layer_params[param['layer_id']].append(param)
            for layer_id, layer_params in layer_params.items():
                lr_scale = layer_params[0]['lr_scale']
                lr = layer_params[0]['lr']
                msg = [
                    f'layer {layer_id} params '
                    f'(lr={lr:.3g}, lr_scale={lr_scale:.3g}):'
                ]
                for param in layer_params:
                    msg.append(f'\t{param["param_name"]}: '
                               f'weight_decay={param["weight_decay"]:.3g}')
                logger.debug('\n'.join(msg))
