# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union

import torch
from torch import nn

from mmpretrain.registry import MODELS
from ..backbones.vision_transformer import TransformerEncoderLayer
from ..utils import PromptMultiheadAttention
from .mae_neck import MAEPretrainDecoder


class PromptTransformerEncoderLayer(TransformerEncoderLayer):
    """Prompt Transformer Encoder Layer for MILAN.

    This module is specific for the prompt encoder in MILAN. It will not update
    the visible tokens from the encoder.

    Args:
        embed_dims (int): The feature dimension.
        num_heads (int): Parallel attention heads.
        feedforward_channels (int): The hidden dimension for FFNs.
        drop_rate (float): Probability of an element to be zeroed
            after the feed forward layer. Defaults to 0.0.
        attn_drop_rate (float): The drop out rate for attention layer.
            Defaults to 0.0.
        drop_path_rate (float): Stochastic depth rate. Defaults to 0.0.
        num_fcs (int): The number of fully-connected layers for FFNs.
            Defaults to 2.
        qkv_bias (bool): Enable bias for qkv if True. Defaults to True.
        act_cfg (dict): The activation config for FFNs.
            Defaults to ``dict(type='GELU')``.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        batch_first (bool): Key, Query and Value are shape of
            (batch, n, embed_dim)
            or (n, batch, embed_dim). Defaults to False.
        init_cfg (dict, optional): The Config for initialization.
            Defaults to None.
    """

    def __init__(self,
                 embed_dims: int,
                 num_heads: int,
                 feedforward_channels=int,
                 drop_rate: float = 0.,
                 attn_drop_rate: float = 0.,
                 drop_path_rate: float = 0.,
                 num_fcs: int = 2,
                 qkv_bias: bool = True,
                 act_cfg: dict = dict(type='GELU'),
                 norm_cfg: dict = dict(type='LN'),
                 init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
        super().__init__(
            embed_dims=embed_dims,
            num_heads=num_heads,
            feedforward_channels=feedforward_channels,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            num_fcs=num_fcs,
            qkv_bias=qkv_bias,
            act_cfg=act_cfg,
            norm_cfg=norm_cfg,
            init_cfg=init_cfg)
        self.attn = PromptMultiheadAttention(
            embed_dims=embed_dims,
            num_heads=num_heads,
            attn_drop=attn_drop_rate,
            proj_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            qkv_bias=qkv_bias)

    def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor,
                ids_restore: torch.Tensor) -> torch.Tensor:
        """Forward function for `PromptMultiheadAttention`.

        Args:
            x (torch.Tensor): Mask token features with shape N x L_m x C.
            visible_tokens (torch.Tensor): The visible tokens features from
                encoder with shape N x L_v x C.
            ids_restore (torch.Tensor): The ids of all tokens in the original
                image with shape N x L.

        Returns:
            torch Tensor: Output features with shape N x L x C.
        """
        x = x + self.attn(self.norm1(x), visible_tokens, ids_restore)
        x = self.ffn(self.norm2(x), identity=x)
        return x


@MODELS.register_module()
class MILANPretrainDecoder(MAEPretrainDecoder):
    """Prompt decoder for MILAN.

    This decoder is used in MILAN pretraining, which will not update these
    visible tokens from the encoder.

    Args:
        num_patches (int): The number of total patches. Defaults to 196.
        patch_size (int): Image patch size. Defaults to 16.
        in_chans (int): The channel of input image. Defaults to 3.
        embed_dim (int): Encoder's embedding dimension. Defaults to 1024.
        decoder_embed_dim (int): Decoder's embedding dimension.
            Defaults to 512.
        decoder_depth (int): The depth of decoder. Defaults to 8.
        decoder_num_heads (int): Number of attention heads of decoder.
            Defaults to 16.
        predict_feature_dim (int): The dimension of the feature to be
            predicted. Defaults to 512.
        mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
            Defaults to 4.
        norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
        init_cfg (Union[List[dict], dict], optional): Initialization config
            dict. Defaults to None.
    """

    def __init__(self,
                 num_patches: int = 196,
                 patch_size: int = 16,
                 in_chans: int = 3,
                 embed_dim: int = 1024,
                 decoder_embed_dim: int = 512,
                 decoder_depth: int = 8,
                 decoder_num_heads: int = 16,
                 predict_feature_dim: int = 512,
                 mlp_ratio: int = 4,
                 norm_cfg: dict = dict(type='LN', eps=1e-6),
                 init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
        super().__init__(
            num_patches=num_patches,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
            decoder_embed_dim=decoder_embed_dim,
            decoder_depth=decoder_depth,
            decoder_num_heads=decoder_num_heads,
            mlp_ratio=mlp_ratio,
            norm_cfg=norm_cfg,
            init_cfg=init_cfg)

        # map the dim of features from decoder to the dim compatible with
        # that of CLIP
        self.decoder_pred = nn.Linear(
            decoder_embed_dim, predict_feature_dim, bias=True)

        # use prompt transformer encoder layer, instead of the conventional
        # transformer encoder layer
        self.decoder_blocks = nn.ModuleList([
            PromptTransformerEncoderLayer(
                decoder_embed_dim,
                decoder_num_heads,
                int(mlp_ratio * decoder_embed_dim),
                qkv_bias=True,
                norm_cfg=norm_cfg) for _ in range(decoder_depth)
        ])

    def forward(self, x: torch.Tensor, ids_restore: torch.Tensor,
                ids_keep: torch.Tensor,
                ids_dump: torch.Tensor) -> torch.Tensor:
        """Forward function.

        Args:
            x (torch.Tensor): The input features, which is of shape (N, L, C).
            ids_restore (torch.Tensor): The indices to restore these tokens
                to the original image.
            ids_keep (torch.Tensor): The indices of tokens to be kept.
            ids_dump (torch.Tensor): The indices of tokens to be masked.

        Returns:
            torch.Tensor: The reconstructed features, which is of shape
            (N, L, C).
        """
        # embed tokens
        x = self.decoder_embed(x)

        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(
            x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
        x_ = torch.gather(
            x_,
            dim=1,
            index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
        x = torch.cat([x[:, :1, :], x_], dim=1)

        # add pos embed
        x = x + self.decoder_pos_embed

        # split mask tokens and visible tokens
        visible_tokens = torch.cat([
            x[:, :1, :],
            torch.gather(
                x[:, 1:, :],
                dim=1,
                index=ids_keep.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
        ],
                                   dim=1)
        x = torch.gather(
            x[:, 1:, :],
            dim=1,
            index=ids_dump.unsqueeze(-1).repeat(1, 1, x.shape[-1]))

        for blk in self.decoder_blocks:
            x = blk(x, visible_tokens, ids_restore)

        # full sequence recovery
        x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1)
        x_ = torch.gather(
            x_,
            dim=1,
            index=ids_restore.unsqueeze(-1).repeat(1, 1,
                                                   x.shape[-1]))  # unshuffle
        x = torch.cat([visible_tokens[:, :1, :], x_], dim=1)

        x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x)

        return x
