# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import Tuple, Union

import torch.nn as nn
from mmengine import print_log
from mmengine.optim import (
    DefaultOptimWrapperConstructor,
    OptimWrapper,
    OptimWrapperDict,
    OPTIM_WRAPPER_CONSTRUCTORS,
    OPTIMIZERS,
)

from mmengine import OPTIM_WRAPPERS


@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class MultiOptimWrapperConstructor:
    """OptimizerConstructor for GAN models. This class construct optimizer for
    the submodules of the model separately, and return a
    :class:`mmengine.optim.OptimWrapperDict` or
    :class:`mmengine.optim.OptimWrapper`.

    Example 1: Build multi optimizers (e.g., GANs):
        >>> # build GAN model
        >>> model = dict(
        >>>     type='GANModel',
        >>>     num_classes=10,
        >>>     generator=dict(type='Generator'),
        >>>     discriminator=dict(type='Discriminator'))
        >>> gan_model = MODELS.build(model)
        >>> # build constructor
        >>> optim_wrapper = dict(
        >>>     generator=dict(
        >>>         type='OptimWrapper',
        >>>         accumulative_counts=1,
        >>>         optimizer=dict(type='Adam', lr=0.0002,
        >>>                        betas=(0.5, 0.999))),
        >>>     discriminator=dict(
        >>>         type='OptimWrapper',
        >>>         accumulative_counts=1,
        >>>         optimizer=dict(type='Adam', lr=0.0002,
        >>>                            betas=(0.5, 0.999))))
        >>> optim_dict_builder = MultiOptimWrapperConstructor(optim_wrapper)
        >>> # build optim wrapper dict
        >>> optim_wrapper_dict = optim_dict_builder(gan_model)

    Example 2: Build multi optimizers for specific submodules:
        >>> # build model
        >>> class GAN(nn.Module):
        >>>     def __init__(self) -> None:
        >>>         super().__init__()
        >>>         self.generator = nn.Conv2d(3, 3, 1)
        >>>         self.discriminator = nn.Conv2d(3, 3, 1)
        >>> class TextEncoder(nn.Module):
        >>>     def __init__(self):
        >>>         super().__init__()
        >>>         self.embedding = nn.Embedding(100, 100)
        >>> class ToyModel(nn.Module):
        >>>     def __init__(self) -> None:
        >>>         super().__init__()
        >>>         self.m1 = GAN()
        >>>         self.m2 = nn.Conv2d(3, 3, 1)
        >>>         self.m3 = nn.Linear(2, 2)
        >>>         self.text_encoder = TextEncoder()
        >>> model = ToyModel()
        >>> # build constructor
        >>> optim_wrapper = {
        >>>     '.*embedding': {
        >>>         'type': 'OptimWrapper',
        >>>         'optimizer': {
        >>>             'type': 'Adam',
        >>>             'lr': 1e-4,
        >>>             'betas': (0.9, 0.99)
        >>>         }
        >>>     },
        >>>     'm1.generator': {
        >>>         'type': 'OptimWrapper',
        >>>         'optimizer': {
        >>>             'type': 'Adam',
        >>>             'lr': 1e-5,
        >>>             'betas': (0.9, 0.99)
        >>>         }
        >>>     },
        >>>     'm2': {
        >>>         'type': 'OptimWrapper',
        >>>         'optimizer': {
        >>>             'type': 'Adam',
        >>>             'lr': 1e-5,
        >>>         }
        >>>     }
        >>> }
        >>> optim_dict_builder = MultiOptimWrapperConstructor(optim_wrapper)
        >>> # build optim wrapper dict
        >>> optim_wrapper_dict = optim_dict_builder(model)

    Example 3: Build a single optimizer for multi modules (e.g., DreamBooth):
        >>> # build StableDiffusion model
        >>> model = dict(
        >>>     type='StableDiffusion',
        >>>     unet=dict(type='unet'),
        >>>     vae=dict(type='vae'),
                text_encoder=dict(type='text_encoder'))
        >>> diffusion_model = MODELS.build(model)
        >>> # build constructor
        >>> optim_wrapper = dict(
        >>>     modules=['unet', 'text_encoder']
        >>>     optimizer=dict(type='Adam', lr=0.0002),
        >>>     accumulative_counts=1)
        >>> optim_dict_builder = MultiOptimWrapperConstructor(optim_wrapper)
        >>> # build optim wrapper dict
        >>> optim_wrapper_dict = optim_dict_builder(diffusion_model)

    Args:
        optim_wrapper_cfg_dict (dict): Config of the optimizer wrapper.
        paramwise_cfg (dict): Config of parameter-wise settings. Default: None.
    """

    def __init__(self, optim_wrapper_cfg: dict, paramwise_cfg=None):

        if not isinstance(optim_wrapper_cfg, dict):
            raise TypeError(
                "optimizer_cfg should be a dict", f"but got {type(optim_wrapper_cfg)}"
            )
        assert (
            paramwise_cfg is None
        ), "paramwise_cfg should be set in each optimizer separately"
        self.optim_cfg = optim_wrapper_cfg

        if "modules" in optim_wrapper_cfg:
            # single optimizer with multi param groups
            cfg_ = optim_wrapper_cfg.copy()
            self.modules = cfg_.pop("modules")
            paramwise_cfg_ = cfg_.pop("paramwise_cfg", None)
            self.constructors = DefaultOptimWrapperConstructor(cfg_, paramwise_cfg_)
        else:
            self.constructors = {}
            self.modules = {}
            for key, cfg in self.optim_cfg.items():
                cfg_ = cfg.copy()
                if "modules" in cfg_:
                    self.modules[key] = cfg_.pop("modules")
                paramwise_cfg_ = cfg_.pop("paramwise_cfg", None)
                self.constructors[key] = DefaultOptimWrapperConstructor(
                    cfg_, paramwise_cfg_
                )

    def __call__(self, module: nn.Module) -> Union[OptimWrapperDict, OptimWrapper]:
        """Build optimizer and return a optimizer_wrapper_dict."""

        optimizers = {}
        if hasattr(module, "module"):
            module = module.module
        if isinstance(self.constructors, dict):
            for key, constructor in self.constructors.items():
                module_names = self.modules[key] if self.modules else key
                if isinstance(module_names, str) and module_names in module._modules:
                    optimizers[key] = constructor(module._modules[module_names])
                    optim_wrapper_cfg = constructor.optimizer_cfg
                    print_log(
                        f"Add to optimizer '{key}' " f"({optim_wrapper_cfg}): '{key}'.",
                        "current",
                    )
                else:

                    assert not constructor.paramwise_cfg, (
                        "Do not support paramwise_cfg for multi module " "optimizer."
                    )

                    params, found_names = get_params_by_names(module, module_names)
                    # build optimizer
                    optimizer_cfg = constructor.optimizer_cfg.copy()
                    optimizer_cfg["params"] = params
                    optimizer = OPTIMIZERS.build(optimizer_cfg)

                    # build optimizer wrapper
                    optim_wrapper_cfg = constructor.optim_wrapper_cfg.copy()
                    optim_wrapper_cfg.setdefault("type", "OptimWrapper")
                    optim_wrapper = OPTIM_WRAPPERS.build(
                        optim_wrapper_cfg, default_args=dict(optimizer=optimizer)
                    )

                    for name in found_names:
                        print_log(
                            f"Add to optimizer '{key}' "
                            f"({constructor.optimizer_cfg}): '{name}'.",
                            "current",
                        )

                    optimizers[key] = optim_wrapper

            return OptimWrapperDict(**optimizers)

        else:
            params, found_names = get_params_by_names(module, self.modules)

            constructor = self.constructors
            assert (
                not constructor.paramwise_cfg
            ), "Do not support paramwise_cfg for multi parameters"

            optimizer_cfg = constructor.optimizer_cfg.copy()
            optimizer_cfg["params"] = params
            optimizer = OPTIMIZERS.build(optimizer_cfg)
            for name in found_names:
                print_log(
                    f"Add to optimizer ({constructor.optimizer_cfg}): " f"'{name}'.",
                    "current",
                )

            # build optimizer wrapper
            optim_wrapper_cfg = constructor.optim_wrapper_cfg.copy()
            optim_wrapper_cfg.setdefault("type", "OptimWrapper")
            optim_wrapper = OPTIM_WRAPPERS.build(
                optim_wrapper_cfg, default_args=dict(optimizer=optimizer)
            )

            return optim_wrapper


def get_params_by_names(
    module: nn.Module, names: Union[str, list]
) -> Tuple[list, list]:
    """Support two kinds of name matching:
        1. matching name from **first-level** submodule.
        2. matching name by `re.fullmatch`.

    Args:
        module (nn.Module): The module to get parameters.
        names (Union[str, list]): The name or a list of names of the
            submodule parameters.

    Returns:
        Tuple[list]: A list of parameters and corresponding name for logging.
    """

    if not isinstance(names, list):
        names = [names]

    params = []
    found_names = []
    for name in names:
        if name in module._modules:
            params.extend(module._modules[name].parameters())
            found_names.append(name)
        else:
            for n, m in module.named_modules():
                if re.fullmatch(name, n):
                    params.extend(m.parameters())
                    found_names.append(n)
    return params, found_names
