# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import BaseModule

from mmpretrain.registry import MODELS


@MODELS.register_module()
class SparKPretrainHead(BaseModule):
    """Pre-training head for SparK.

    Args:
        loss (dict): Config of loss.
        norm_pix (bool): Whether or not normalize target. Defaults to True.
        patch_size (int): Patch size, equal to downsample ratio of backbone.
            Defaults to 32.
    """

    def __init__(self,
                 loss: dict,
                 norm_pix: bool = True,
                 patch_size: int = 32) -> None:
        super().__init__()
        self.norm_pix = norm_pix
        self.patch_size = patch_size
        self.loss = MODELS.build(loss)

    def patchify(self, imgs):
        """Split images into non-overlapped patches.

        Args:
            imgs (torch.Tensor): A batch of images, of shape B x C x H x W.
        Returns:
            torch.Tensor: Patchified images. The shape is B x L x D.
        """
        p = self.patch_size
        assert len(imgs.shape
                   ) == 4 and imgs.shape[2] % p == 0 and imgs.shape[3] % p == 0

        B, C, ori_h, ori_w = imgs.shape
        h = ori_h // p
        w = ori_w // p
        x = imgs.reshape(shape=(B, C, h, p, w, p))
        x = torch.einsum('bchpwq->bhwpqc', x)

        # (B, f*f, downsample_raito*downsample_raito*3)
        x = x.reshape(shape=(B, h * w, p**2 * C))
        return x

    def construct_target(self, target: torch.Tensor) -> torch.Tensor:
        """Construct the reconstruction target.

        In addition to splitting images into tokens, this module will also
        normalize the image according to ``norm_pix``.
        Args:
            target (torch.Tensor): Image with the shape of B x 3 x H x W
        Returns:
            torch.Tensor: Tokenized images with the shape of B x L x C
        """
        target = self.patchify(target)
        if self.norm_pix:
            # normalize the target image
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        return target

    def forward(self, pred: torch.Tensor, target: torch.Tensor,
                active_mask: torch.Tensor) -> torch.Tensor:
        """Forward function of MAE head.

        Args:
            pred (torch.Tensor): The reconstructed image.
            target (torch.Tensor): The target image.
            active_mask (torch.Tensor): The mask of the target image.
        Returns:
            torch.Tensor: The reconstruction loss.
        """
        # (B, C, H, W) -> (B, L, C) and perform normalization
        target = self.construct_target(target)

        # (B, C, H, W) -> (B, L, C)
        pred = self.patchify(pred)

        # (B, 1, f, f) -> (B, L)
        non_active_mask = active_mask.logical_not().int().view(
            active_mask.shape[0], -1)

        # MSE loss on masked patches
        loss = self.loss(pred, target, non_active_mask)
        return loss
