# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import Dict, List, Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import functional as F

from mmpretrain.models.backbones import MixMIMTransformer
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils import build_2d_sincos_position_embedding
from .base import BaseSelfSupervisor


@MODELS.register_module()
class MixMIMPretrainTransformer(MixMIMTransformer):
    """MixMIM backbone for MixMIM pre-training.

    A PyTorch implement of : ` MixMIM: Mixed and Masked Image
    Modeling for Efficient Visual Representation Learning
    <https://arxiv.org/abs/2205.13137>`_

    Args:
        arch (str | dict): MixMIM architecture. If use string,
            choose from 'base','large' and 'huge'.
            If use dict, it should have below keys:

            - **embed_dims** (int): The dimensions of embedding.
            - **depths** (int): The number of transformer encoder layers.
            - **num_heads** (int): The number of heads in attention modules.

            Defaults to 'base'.
        mlp_ratio (int): The mlp ratio in FFN.  Defaults to 4.
        img_size (int | tuple): The expected input image shape. Because we
            support dynamic input shape, just set the argument to mlp_ratio
            the most common input image shape. Defaults to 224.
        patch_size (int | tuple): The patch size in patch embedding.
            Defaults to 16.
        in_channels (int): The num of input channels. Defaults to 3.
        window_size (list): The height and width of the window.
        qkv_bias (bool): Whether to add bias for qkv in attention modules.
            Defaults to True.
        patch_cfg (dict): Extra config dict for patch embedding.
            Defaults to an empty dict.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        drop_rate (float): Probability of an element to be zeroed.
            Defaults to 0.
        drop_path_rate (float): Stochastic depth rate. Defaults to 0.
        attn_drop_rate (float): Attention drop rate. Defaults to 0.
        use_checkpoint (bool): Whether use the checkpoint to reduce GPU memory
            cost. Defaults to False.
        mask_ratio (bool): The base ratio of total number of patches to be
            masked. Defaults to 0.5.
        range_mask_ratio (float): The range of mask ratio.
            Defaults to 0.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(self,
                 arch: Union[str, dict] = 'base',
                 mlp_ratio: float = 4,
                 img_size: int = 224,
                 patch_size: int = 4,
                 in_channels: int = 3,
                 window_size: List = [14, 14, 14, 7],
                 qkv_bias: bool = True,
                 patch_cfg: dict = dict(),
                 norm_cfg: dict = dict(type='LN'),
                 drop_rate: float = 0.0,
                 drop_path_rate: float = 0.0,
                 attn_drop_rate: float = 0.0,
                 use_checkpoint: bool = False,
                 mask_ratio: float = 0.5,
                 range_mask_ratio: float = 0.0,
                 init_cfg: Optional[dict] = None) -> None:

        super().__init__(
            arch=arch,
            mlp_ratio=mlp_ratio,
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            window_size=window_size,
            qkv_bias=qkv_bias,
            patch_cfg=patch_cfg,
            norm_cfg=norm_cfg,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            attn_drop_rate=attn_drop_rate,
            use_checkpoint=use_checkpoint,
            init_cfg=init_cfg)

        self.mask_ratio = mask_ratio
        self.range_mask_ratio = range_mask_ratio

    def init_weights(self):
        """Initialize position embedding, patch embedding."""
        super(MixMIMTransformer, self).init_weights()

        pos_embed = build_2d_sincos_position_embedding(
            int(self.num_patches**.5),
            self.absolute_pos_embed.shape[-1],
            cls_token=False)
        self.absolute_pos_embed.data.copy_(pos_embed.float())

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def random_masking(self,
                       x: torch.Tensor,
                       mask_ratio: float = 0.5) -> Tuple[torch.Tensor]:
        """Generate the mask for MixMIM Pretraining.

        Args:
            x (torch.Tensor): Image with data augmentation applied, which is
                of shape B x L x C.
            mask_ratio (float): The mask ratio of total patches.
                Defaults to 0.5.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
                - mask_s1 (torch.Tensor): mask with stride of
                  self.encoder_stride // 8.
                - mask_s2 (torch.Tensor): mask with stride of
                  self.encoder_stride // 4.
                - mask_s3 (torch.Tensor): mask with stride of
                  self.encoder_stride // 2.
                - mask (torch.Tensor): mask with stride of
                  self.encoder_stride.
        """

        B, C, H, W = x.shape
        out_H = H // self.encoder_stride
        out_W = W // self.encoder_stride
        s3_H, s3_W = out_H * 2, out_W * 2
        s2_H, s2_W = out_H * 4, out_W * 4
        s1_H, s1_W = out_H * 8, out_W * 8

        seq_l = out_H * out_W
        # use a shared mask for a batch images
        mask = torch.zeros([1, 1, seq_l], device=x.device)

        mask_ratio = mask_ratio + random.uniform(0.0, self.range_mask_ratio)
        noise = torch.rand(1, 1, seq_l, device=x.device)  # noise in [0, 1]
        # ascend: small is keep, large is removed
        mask_idx = torch.argsort(noise, dim=2)[:, :, :int(seq_l * mask_ratio)]
        mask.scatter_(2, mask_idx, 1)
        mask = mask.reshape(1, 1, out_H, out_W)
        mask_s1 = F.interpolate(mask, size=(s1_H, s1_W), mode='nearest')
        mask_s2 = F.interpolate(mask, size=(s2_H, s2_W), mode='nearest')
        mask_s3 = F.interpolate(mask, size=(s3_H, s3_W), mode='nearest')

        mask = mask.reshape(1, out_H * out_W, 1).contiguous()
        mask_s1 = mask_s1.reshape(1, s1_H * s1_W, 1).contiguous()
        mask_s2 = mask_s2.reshape(1, s2_H * s2_W, 1).contiguous()
        mask_s3 = mask_s3.reshape(1, s3_H * s3_W, 1).contiguous()

        return mask_s1, mask_s2, mask_s3, mask

    def forward(self,
                x: torch.Tensor,
                mask: Optional[bool] = True) -> Tuple[torch.Tensor]:
        """Generate features for masked images.

        This function generates mask and masks some patches randomly and get
        the hidden features for visible patches.

        Args:
            x (torch.Tensor): Input images, which is of shape B x C x H x W.
            mask (bool, optional): To indicate whether the forward containing
                ``mask`` or not.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
              - x (torch.Tensor): hidden features, which is of shape
                B x L x C.
              - mask_s4 (torch.Tensor): the mask tensor for the last layer.
        """
        if mask is None or False:
            return super().forward(x)

        else:
            mask_s1, mask_s2, mask_s3, mask_s4 = self.random_masking(
                x, self.mask_ratio)

            x, _ = self.patch_embed(x)

            x = x * (1. - mask_s1) + x.flip(0) * mask_s1
            x = x + self.absolute_pos_embed
            x = self.drop_after_pos(x)

            for idx, layer in enumerate(self.layers):
                if idx == 0:
                    x = layer(x, attn_mask=mask_s1)
                elif idx == 1:
                    x = layer(x, attn_mask=mask_s2)
                elif idx == 2:
                    x = layer(x, attn_mask=mask_s3)
                elif idx == 3:
                    x = layer(x, attn_mask=mask_s4)

            x = self.norm(x)

            return x, mask_s4


@MODELS.register_module()
class MixMIM(BaseSelfSupervisor):
    """MixMIM.

    Implementation of `MixMIM: Mixed and Masked Image Modeling for Efficient
    Visual Representation Learning. <https://arxiv.org/abs/2205.13137>`_.
    """

    def __init__(self,
                 backbone: dict,
                 neck: Optional[dict] = None,
                 head: Optional[dict] = None,
                 pretrained: Optional[str] = None,
                 data_preprocessor: Optional[Union[dict, nn.Module]] = None,
                 init_cfg: Optional[dict] = None):

        head.update(dict(patch_size=neck['encoder_stride']))
        super().__init__(
            backbone=backbone,
            neck=neck,
            head=head,
            pretrained=pretrained,
            data_preprocessor=data_preprocessor,
            init_cfg=init_cfg)

    def extract_feat(self, inputs: torch.Tensor):
        return self.backbone(inputs, mask=None)

    def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
             **kwargs) -> Dict[str, torch.Tensor]:
        """The forward function in training.

        Args:
            inputs (torch.Tensor): The input images.
            data_samples (List[DataSample]): All elements required
                during the forward function.

        Returns:
            Dict[str, torch.Tensor]: A dictionary of loss components.
        """
        latent, mask = self.backbone(inputs)
        x_rec = self.neck(latent, mask)
        loss = self.head.loss(x_rec, inputs, mask)
        losses = dict(loss=loss)
        return losses
