# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule

from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer
from mmpretrain.registry import MODELS


@MODELS.register_module()
class BEiTV2Neck(BaseModule):
    """Neck for BEiTV2 Pre-training.

    This module construct the decoder for the final prediction.

    Args:
        num_layers (int): Number of encoder layers of neck. Defaults to 2.
        early_layers (int): The layer index of the early output from the
            backbone. Defaults to 9.
        backbone_arch (str): Vision Transformer architecture. Defaults to base.
        drop_rate (float): Probability of an element to be zeroed.
            Defaults to 0.
        drop_path_rate (float): stochastic depth rate. Defaults to 0.
        layer_scale_init_value (float): The initialization value for the
            learnable scaling of attention and FFN. Defaults to 0.1.
        use_rel_pos_bias (bool): Whether to use unique relative position bias,
            if False, use shared relative position bias defined in backbone.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.
    """
    arch_zoo = {
        **dict.fromkeys(
            ['b', 'base'], {
                'embed_dims': 768,
                'depth': 12,
                'num_heads': 12,
                'feedforward_channels': 3072,
            }),
        **dict.fromkeys(
            ['l', 'large'], {
                'embed_dims': 1024,
                'depth': 24,
                'num_heads': 16,
                'feedforward_channels': 4096,
            }),
    }

    def __init__(
        self,
        num_layers: int = 2,
        early_layers: int = 9,
        backbone_arch: str = 'base',
        drop_rate: float = 0.,
        drop_path_rate: float = 0.,
        layer_scale_init_value: float = 0.1,
        use_rel_pos_bias: bool = False,
        norm_cfg: dict = dict(type='LN', eps=1e-6),
        init_cfg: Optional[Union[dict, List[dict]]] = dict(
            type='TruncNormal', layer='Linear', std=0.02, bias=0)
    ) -> None:
        super().__init__(init_cfg=init_cfg)

        if isinstance(backbone_arch, str):
            backbone_arch = backbone_arch.lower()
            assert backbone_arch in set(self.arch_zoo), \
                (f'Arch {backbone_arch} is not in default archs '
                 f'{set(self.arch_zoo)}')
            self.arch_settings = self.arch_zoo[backbone_arch]
        else:
            essential_keys = {
                'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
            }
            assert isinstance(backbone_arch, dict) and essential_keys <= set(
                backbone_arch
            ), f'Custom arch needs a dict with keys {essential_keys}'
            self.arch_settings = backbone_arch

        # stochastic depth decay rule
        self.early_layers = early_layers
        depth = self.arch_settings['depth']
        dpr = np.linspace(0, drop_path_rate,
                          max(depth, early_layers + num_layers))

        self.patch_aggregation = nn.ModuleList()
        for i in range(early_layers, early_layers + num_layers):
            _layer_cfg = dict(
                embed_dims=self.arch_settings['embed_dims'],
                num_heads=self.arch_settings['num_heads'],
                feedforward_channels=self.
                arch_settings['feedforward_channels'],
                drop_rate=drop_rate,
                drop_path_rate=dpr[i],
                norm_cfg=norm_cfg,
                layer_scale_init_value=layer_scale_init_value,
                window_size=None,
                use_rel_pos_bias=use_rel_pos_bias)
            self.patch_aggregation.append(
                BEiTTransformerEncoderLayer(**_layer_cfg))

        self.rescale_patch_aggregation_init_weight()

        embed_dims = self.arch_settings['embed_dims']
        _, norm = build_norm_layer(norm_cfg, embed_dims)
        self.add_module('norm', norm)

    def rescale_patch_aggregation_init_weight(self):
        """Rescale the initialized weights."""

        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.patch_aggregation):
            rescale(layer.attn.proj.weight.data,
                    self.early_layers + layer_id + 1)
            rescale(layer.ffn.layers[1].weight.data,
                    self.early_layers + layer_id + 1)

    def forward(self, inputs: Tuple[torch.Tensor], rel_pos_bias: torch.Tensor,
                **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get the latent prediction and final prediction.

        Args:
            x (Tuple[torch.Tensor]): Features of tokens.
            rel_pos_bias (torch.Tensor): Shared relative position bias table.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
              - ``x``: The final layer features from backbone, which are normed
                in ``BEiTV2Neck``.
              - ``x_cls_pt``: The early state features from backbone, which are
                consist of final layer cls_token and early state patch_tokens
                from backbone and sent to PatchAggregation layers in the neck.
        """

        early_states, x = inputs[0], inputs[1]
        x_cls_pt = torch.cat([x[:, [0]], early_states[:, 1:]], dim=1)
        for layer in self.patch_aggregation:
            x_cls_pt = layer(x_cls_pt, rel_pos_bias=rel_pos_bias)

        # shared norm
        x, x_cls_pt = self.norm(x), self.norm(x_cls_pt)

        # remove cls_token
        x = x[:, 1:]
        x_cls_pt = x_cls_pt[:, 1:]
        return x, x_cls_pt
