# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import Dict, Optional

import torch.distributed as dist

from mmcv.runner.optimizer.builder import OPTIMIZER_BUILDERS, OPTIMIZERS
from mmcv.utils import build_from_cfg, print_log


@OPTIMIZER_BUILDERS.register_module(force=True)
class DefaultOptimizerConstructor:
    """Rewrote default constructor for optimizers. By default each parameter
    share the same optimizer settings, and we provide an argument
    ``paramwise_cfg`` to specify parameter-wise settings. It is a dict and may
    contain the following fields:

    Args:
        model (:obj:`nn.Module`): The model with parameters to be optimized.
        optimizer_cfg (dict): The config dict of the optimizer.
            Positional fields are
                - `type`: class name of the optimizer.
            Optional fields are
                - any arguments of the corresponding optimizer type, e.g.,
                  lr, weight_decay, momentum, etc.
        paramwise_cfg (dict, optional): Parameter-wise options.
            Defaults to None.

    Example 1:
        >>> model = torch.nn.modules.Conv1d(1, 1, 1)
        >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
        >>>                      weight_decay=0.0001)
        >>> paramwise_cfg = dict('bias': dict(weight_decay=0., \
                                 lars_exclude=True))
        >>> optim_builder = DefaultOptimizerConstructor(
        >>>     optimizer_cfg, paramwise_cfg)
        >>> optimizer = optim_builder(model)

    Example 2:
        >>> # assume model have attribute model.backbone and model.cls_head
        >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
        >>> paramwise_cfg = dict(custom_keys={
                'backbone': dict(lr_mult=0.1, decay_mult=0.9)})
        >>> optim_builder = DefaultOptimizerConstructor(
        >>>     optimizer_cfg, paramwise_cfg)
        >>> optimizer = optim_builder(model)
        >>> # Then the `lr` and `weight_decay` for model.backbone is
        >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
        >>> # model.cls_head is (0.01, 0.95).
    """

    def __init__(self,
                 optimizer_cfg: Dict,
                 paramwise_cfg: Optional[Dict] = None):
        if not isinstance(optimizer_cfg, dict):
            raise TypeError('optimizer_cfg should be a dict',
                            f'but got {type(optimizer_cfg)}')
        self.optimizer_cfg = optimizer_cfg
        self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
        self.base_lr = optimizer_cfg.get('lr', None)
        self.base_wd = optimizer_cfg.get('weight_decay', None)
        self._validate_cfg()

    def _validate_cfg(self):
        if not isinstance(self.paramwise_cfg, dict):
            raise TypeError('paramwise_cfg should be None or a dict, '
                            f'but got {type(self.paramwise_cfg)}')

    def __call__(self, model):
        """ add `paramwise_cfg` option for `DefaultOptimizerConstructor` """
        if hasattr(model, 'module'):
            model = model.module
        optimizer_cfg = self.optimizer_cfg.copy()
        paramwise_options = self.paramwise_cfg

        # if no paramwise option is specified, just use the global setting
        if paramwise_options is None:
            optimizer_cfg['params'] = model.parameters()
            return build_from_cfg(optimizer_cfg, OPTIMIZERS)
        else:
            assert isinstance(paramwise_options, dict)
            params = []
            for name, param in model.named_parameters():
                param_group = {'params': [param]}
                if not param.requires_grad:
                    params.append(param_group)
                    continue

                for regexp, options in paramwise_options.items():
                    if re.search(regexp, name):
                        for key, value in options.items():
                            if key.endswith('_mult'):  # is a multiplier
                                key = key[:-5]
                                assert key in optimizer_cfg, \
                                    f'{key} not in optimizer_cfg'
                                value = optimizer_cfg[key] * value
                            param_group[key] = value
                            if not dist.is_initialized() or \
                                    dist.get_rank() == 0:
                                print_log(f'paramwise_options -- {name}: {key}={value}')

                # otherwise use the global settings
                params.append(param_group)

            optimizer_cfg['params'] = params
            return build_from_cfg(optimizer_cfg, OPTIMIZERS)


@OPTIMIZER_BUILDERS.register_module()
class TransformerFinetuneConstructor:
    """Rewrote default constructor for optimizers.

    By default each parameter share the same optimizer settings, and we
    provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
    In addition, we provide two optional parameters, ``model_type`` and
    ``layer_decay`` to set the commonly used layer-wise learning rate decay
    schedule. Currently, we only support layer-wise learning rate schedule
    for swin and vit.

    Args:
        optimizer_cfg (dict): The config dict of the optimizer.
            Positional fields are
                - `type`: class name of the optimizer.
            Optional fields are
                - any arguments of the corresponding optimizer type, e.g.,
                  lr, weight_decay, momentum, model_type, layer_decay, etc.
        paramwise_cfg (dict, optional): Parameter-wise options.
            Defaults to None.


    Example 1:
        >>> model = torch.nn.modules.Conv1d(1, 1, 1)
        >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
        >>>                      weight_decay=0.0001, model_type='vit')
        >>> paramwise_cfg = dict('bias': dict(weight_decay=0., \
                                 lars_exclude=True))
        >>> optim_builder = TransformerFinetuneConstructor(
        >>>     optimizer_cfg, paramwise_cfg)
        >>> optimizer = optim_builder(model)
    """

    def __init__(self, optimizer_cfg, paramwise_cfg=None):
        if not isinstance(optimizer_cfg, dict):
            raise TypeError('optimizer_cfg should be a dict',
                            f'but got {type(optimizer_cfg)}')
        self.optimizer_cfg = optimizer_cfg
        self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
        self.layer_decay = self.optimizer_cfg.pop('layer_decay', 0.0)
        # Choose which type of layer-wise lr decay to use. Currently, we only
        # support ViT and Swin.
        self.model_type = self.optimizer_cfg.pop('model_type', None)

    def __call__(self, model):
        if hasattr(model, 'module'):
            model = model.module
        optimizer_cfg = self.optimizer_cfg.copy()
        paramwise_options = self.paramwise_cfg

        # generate layer-wise lr decay
        if self.layer_decay > 0:
            if self.model_type == 'convnext':
                self._generate_convnext_layer_wise_lr_decay(model, paramwise_options)
            elif self.model_type == 'swin':
                self._generate_swin_layer_wise_lr_decay(model, paramwise_options)
            elif self.model_type == 'vit':
                self._generate_vit_layer_wise_lr_decay(model, paramwise_options)
            else:
                raise NotImplementedError(
                    f'Currently, we do not support layer-wise lr decay for {self.model_type}')

        # if no paramwise option is specified, just use the global setting
        if paramwise_options is None:
            optimizer_cfg['params'] = model.parameters()
            return build_from_cfg(optimizer_cfg, OPTIMIZERS)
        else:
            assert isinstance(paramwise_options, dict)
            params = []
            for name, param in model.named_parameters():
                param_group = {'params': [param]}
                if not param.requires_grad:
                    params.append(param_group)
                    continue

                for regexp, options in paramwise_options.items():
                    if re.search(regexp, name):
                        for key, value in options.items():
                            if key.endswith('_mult'):  # is a multiplier
                                key = key[:-5]
                                assert key in optimizer_cfg, \
                                    f'{key} not in optimizer_cfg'
                                value = optimizer_cfg[key] * value
                            param_group[key] = value
                            if not dist.is_initialized() or \
                                    dist.get_rank() == 0:
                                print_log(f'paramwise_options -- {name}: {key}={value}')

                # otherwise use the global settings
                params.append(param_group)

            optimizer_cfg['params'] = params
            return build_from_cfg(optimizer_cfg, OPTIMIZERS)

    def _generate_convnext_layer_wise_lr_decay(self, model, paramwise_options):
        """Generate layer-wise learning rate decay for ConvNeXt."""
        num_layers = 12 if len(model.backbone.stages[-2]) > 9 else 6
        layer_scales = list(self.layer_decay ** (num_layers + 1 - i)
                            for i in range(num_layers + 2))

        for name, _ in model.named_parameters():
            layer_id = self._get_convnext_layer(name, num_layers=num_layers)
            paramwise_options[name] = dict(lr_mult=layer_scales[layer_id])

    def _get_convnext_layer(self, name, num_layers=12):
        """
        Divide [3, 3, 27, 3] layers into 12 groups; each group is three 
        consecutive blocks, including possible neighboring downsample layers;
        adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
        """
        if name.startswith("backbone.downsample_layers"):
            stage_id = int(name.split('.')[2])
            if stage_id == 0:
                layer_id = 0
            elif stage_id == 1 or stage_id == 2:
                layer_id = stage_id + 1
            elif stage_id == 3:
                layer_id = num_layers
            return layer_id
        elif name.startswith("backbone.stages"):
            stage_id = int(name.split('.')[2])
            block_id = int(name.split('.')[3])
            if stage_id == 0 or stage_id == 1:
                layer_id = stage_id + 1
            elif stage_id == 2:
                layer_id = 3 + block_id // 3 
            elif stage_id == 3:
                layer_id = num_layers
            return layer_id
        else:
            return num_layers + 1  # after backbone

    def _generate_swin_layer_wise_lr_decay(self, model, paramwise_options):
        """Generate layer-wise learning rate decay for Swin Transformer."""
        num_layers = sum(model.backbone.depths) + 2
        layer_scales = list(self.layer_decay**i
                            for i in reversed(range(num_layers)))

        for name, _ in model.named_parameters():

            layer_id = self._get_swin_layer(name, num_layers,
                                            model.backbone.depths)
            paramwise_options[name] = dict(lr_mult=layer_scales[layer_id])

    def _get_swin_layer(self, name, num_layers, depths):
        if 'mask_token' in name:
            return 0
        elif 'patch_embed' in name:
            return 0
        elif name.startswith('backbone.stages'):
            layer_id = int(name.split('.')[2])
            block_id = name.split('.')[4]
            if block_id == 'reduction' or block_id == 'norm':
                return sum(depths[:layer_id + 1])
            layer_id = sum(depths[:layer_id]) + int(block_id)
            return layer_id + 1
        else:
            return num_layers - 1

    def _generate_vit_layer_wise_lr_decay(self, model, paramwise_options):
        """Generate layer-wise learning rate decay for Vision Transformer."""
        num_layers = len(model.backbone.layers) + 1
        layer_scales = list(self.layer_decay**(num_layers - i)
                            for i in range(num_layers + 1))

        if 'pos_embed' in paramwise_options:
            paramwise_options['pos_embed'].update(
                dict(lr_mult=layer_scales[0]))
        else:
            paramwise_options['pos_embed'] = dict(lr_mult=layer_scales[0])

        if 'cls_token' in paramwise_options:
            paramwise_options['cls_token'].update(
                dict(lr_mult=layer_scales[0]))
        else:
            paramwise_options['cls_token'] = dict(lr_mult=layer_scales[0])

        if 'patch_embed' in paramwise_options:
            paramwise_options['patch_embed'].update(
                dict(lr_mult=layer_scales[0]))
        else:
            paramwise_options['patch_embed'] = dict(lr_mult=layer_scales[0])

        for i in range(num_layers - 1):
            paramwise_options[f'backbone\\.layers\\.{i}\\.'] = dict(
                lr_mult=layer_scales[i + 1])
