# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
from mmengine.model.weight_init import trunc_normal_

from mmpretrain.models.backbones.hivit import BlockWithRPE, HiViT, PatchMerge
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils import build_2d_sincos_position_embedding
from .base import BaseSelfSupervisor


@MODELS.register_module()
class iTPNHiViT(HiViT):
    """HiViT for iTPN pre-training.

    Args:
        img_size (int | tuple): Input image size. Defaults to 224.
        patch_size (int | tuple): The patch size. Defaults to 16.
        inner_patches (int): Inner patch. Defaults to 4.
        stem_mlp_ratio (int): Ratio of MLP hidden dim to embedding dim
            in the first two stages. Defaults to 3.
        mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in
            the last stage. Defaults to 4.
        qkv_bias (bool): Enable bias for qkv projections if True.
        qk_scale (float): The number of divider after q@k. Default to None.
        drop_rate (float): Probability of an element to be zeroed.
            Defaults to 0.
        attn_drop_rate (float): The drop out rate for attention output weights.
            Defaults to 0.
        drop_path_rate (float): stochastic depth rate. Defaults to 0.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        ape (bool): If True, add absolute position embedding to
            the patch embedding.
        rpe (bool): If True, add relative position embedding to
            the patch embedding.
        layer_scale_init_value (float): Layer-scale init values. Defaults to 0.
        mask_ratio (bool): The ratio of total number of patches to be masked.
            Defaults to 0.75.
        reconstruction_type (str): The reconstruction of self-supervised
            learning. Defaults to 'pixel'.
    """

    def __init__(
        self,
        arch='base',
        img_size: int = 224,
        patch_size: int = 16,
        inner_patches: int = 4,
        stem_mlp_ratio: int = 3.,
        mlp_ratio: int = 4.,
        qkv_bias: bool = True,
        qk_scale: Optional[bool] = None,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        norm_cfg: dict = dict(type='LN', eps=1e-6),
        ape: bool = True,
        rpe: bool = False,
        layer_scale_init_value: float = 0.0,
        mask_ratio: float = 0.75,
        reconstruction_type: str = 'pixel',
        **kwargs,
    ):
        super().__init__(
            arch=arch,
            img_size=img_size,
            patch_size=patch_size,
            inner_patches=inner_patches,
            stem_mlp_ratio=stem_mlp_ratio,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            norm_cfg=norm_cfg,
            ape=ape,
            rpe=rpe,
            layer_scale_init_value=layer_scale_init_value,
            **kwargs,
        )

        self.pos_embed.requires_grad = False
        self.mask_ratio = mask_ratio

        assert reconstruction_type in ['pixel', 'clip'], \
            'iTPN method only support `pixel` and `clip`, ' \
            f'but got `{reconstruction_type}`.'
        self.reconstruction_type = reconstruction_type
        self.num_patches = self.patch_embed.num_patches

        if reconstruction_type == 'clip':
            self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))

    def init_weights(self) -> None:
        """Initialize position embedding, patch embedding and cls token."""
        super().apply(self._init_weights)

        if self.reconstruction_type == 'clip':
            trunc_normal_(self.mask_token, std=0.02)
            self.rescale_init_weight()
        else:
            pos_embed = build_2d_sincos_position_embedding(
                int(self.num_patches**.5),
                self.pos_embed.shape[-1],
                cls_token=False)
            self.pos_embed.data.copy_(pos_embed.float())

            w = self.patch_embed.proj.weight.data
            torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

    def rescale_init_weight(self) -> None:
        """Rescale the initialized weights."""

        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            if isinstance(layer, BlockWithRPE):
                if layer.attn is not None:
                    rescale(layer.attn.proj.weight.data, layer_id + 1)
                rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def masking_id(self, batch_size, mask_ratio):
        N, L = batch_size, self.pos_embed.size(1)
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(
            N, L, device=self.pos_embed.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(
            noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=self.pos_embed.device)
        mask[:, :ids_keep.size(1)] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return ids_keep, ids_restore, mask

    def forward_pixel(
        self,
        x: torch.Tensor,
        mask: Optional[bool] = True
    ) -> Tuple[Tuple, torch.Tensor, torch.Tensor]:
        """Generate features for masked images.

        The function supports two kind of forward behaviors. If the ``mask`` is
        ``True``, the function will generate mask to masking some patches
        randomly and get the hidden features for visible patches, which means
        the function will be executed as masked imagemodeling pre-training;
        if the ``mask`` is ``None`` or ``False``, the forward function will
        call ``super().forward()``, which extract features from images without
        mask.


        Args:
            x (torch.Tensor): Input images, which is of shape B x C x H x W.
            mask (bool, optional): To indicate whether the forward function
                generating ``mask`` or not.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
            mask and the ids to restore original image.

            - ``x`` (torch.Tensor): hidden features, which is of shape
              B x (L * mask_ratio) x C.
            - ``mask`` (torch.Tensor): mask used to mask image.
            - ``ids_restore`` (torch.Tensor): ids to restore original image.
        """
        if mask is None or False:
            return super().forward(x)

        else:
            B, C, H, W = x.shape
            ids_keep, ids_restore, mask = self.masking_id(B, self.mask_ratio)

            x = self.patch_embed(x)

            x = torch.gather(
                x,
                dim=1,
                index=ids_keep[:, :, None, None,
                               None].expand(-1, -1, *x.shape[2:]))

            outs = []
            for blk in self.blocks[:-self.num_main_blocks]:
                if isinstance(blk, PatchMerge):
                    outs.append(x)
                x = blk(x)

            x = x[..., 0, 0, :]
            if self.ape:
                pos_embed = self.interpolate_pos_encoding(x, H, W)
                pos_embed = torch.gather(
                    pos_embed.expand(B, -1, -1),
                    dim=1,
                    index=ids_keep[:, :, None].expand(-1, -1,
                                                      pos_embed.shape[2]),
                )
                x = x + pos_embed
            x = self.pos_drop(x)

            for blk in self.blocks[-self.num_main_blocks:]:
                x = blk(x)

            outs.append(x)

            return (tuple(outs), mask, ids_restore)

    def forward_clip(self,
                     x: torch.Tensor,
                     mask: Optional[bool] = True) -> Tuple:
        """Generate features for masked images.

        The function supports two kind of forward behaviors. If the ``mask`` is
        ``True``, the function will generate mask to masking some patches
        randomly and get the hidden features for visible patches, which means
        the function will be executed as masked imagemodeling pre-training;
        if the ``mask`` is ``None`` or ``False``, the forward function will
        call ``super().forward()``, which extract features from images without
        mask.


        Args:
            x (torch.Tensor): Input images, which is of shape B x C x H x W.
            mask (bool, optional): To indicate whether the forward function
                generating ``mask`` or not.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
            mask and the ids to restore original image.

            - ``x`` (torch.Tensor): hidden features, which is of shape
              B x (L * mask_ratio) x C.
            - ``mask`` (torch.Tensor): mask used to mask image.
            - ``ids_restore`` (torch.Tensor): ids to restore original image.
        """
        if mask is None or False:
            return super().forward(x)

        else:
            B, C, H, W = x.shape
            x = self.patch_embed(x)

            outs = []
            for blk in self.blocks[:-self.num_main_blocks]:
                if isinstance(blk, PatchMerge):
                    outs.append(x)
                x = blk(x)

            x = x[..., 0, 0, :]
            B, L, _ = x.shape
            mask_token = self.mask_token.expand(B, L, -1)
            w = mask.flatten(1).unsqueeze(-1).type_as(mask_token)
            x = x * (1. - w) + mask_token * w

            if self.ape:
                pos_embed = self.interpolate_pos_encoding(x, H, W)
                x = x + pos_embed
            x = self.pos_drop(x)

            rpe_index = True if self.rpe else None

            for blk in self.blocks[-self.num_main_blocks:]:
                x = blk(x, rpe_index)

            outs.append(x)

            return tuple(outs)

    def forward(self, x: torch.Tensor, mask: Optional[bool] = True) -> Tuple:
        """Generate features for masked images.

        The function supports two kind of forward behaviors. If the ``mask`` is
        ``True``, the function will generate mask to masking some patches
        randomly and get the hidden features for visible patches, which means
        the function will be executed as masked imagemodeling pre-training;
        if the ``mask`` is ``None`` or ``False``, the forward function will
        call ``super().forward()``, which extract features from images without
        mask.


        Args:
            x (torch.Tensor): Input images, which is of shape B x C x H x W.
            mask (bool, optional): To indicate whether the forward function
                generating ``mask`` or not.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
            mask and the ids to restore original image.

            - ``x`` (torch.Tensor): hidden features, which is of shape
              B x (L * mask_ratio) x C.
            - ``mask`` (torch.Tensor): mask used to mask image.
            - ``ids_restore`` (torch.Tensor): ids to restore original image.
        """

        if self.reconstruction_type == 'pixel':
            return self.forward_pixel(x, mask)
        return self.forward_clip(x, mask)


@MODELS.register_module()
class iTPN(BaseSelfSupervisor):
    """iTPN.

    Implementation of `iTPN: Integrally Pre-Trained Transformer Pyramid
    Networks <https://arxiv.org/abs/2211.12735>`_.
    """

    def extract_feat(self, inputs: torch.Tensor):
        return self.backbone(inputs, mask=None)

    def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
             **kwargs) -> Dict[str, torch.Tensor]:
        """The forward function in training.

        Args:
            inputs (torch.Tensor): The input images.
            data_samples (List[DataSample]): All elements required
                during the forward function.

        Returns:
            Dict[str, torch.Tensor]: A dictionary of loss components.
        """

        if self.backbone.reconstruction_type == 'pixel':
            latent, mask, ids_restore = self.backbone(inputs)
            pred = self.neck(latent, ids_restore)

            loss = self.head.loss(pred, inputs, mask)
        else:
            mask = torch.stack(
                [data_sample.mask for data_sample in data_samples])

            img_latent = self.backbone(inputs[0], mask)

            # inputs[1] is the target image
            with torch.no_grad():
                target = self.target_generator(inputs[1])[0]
                target = target.detach()

            # iTPN contains a neck module
            feats = self.neck(img_latent)
            loss = self.head.loss(feats, target[:, 1:, :], mask)

        losses = dict(loss=loss)
        return losses
