# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
from mmengine.model import BaseModule

from mmpretrain.registry import MODELS


@MODELS.register_module()
class PixelReconstructionLoss(BaseModule):
    """Loss for the reconstruction of pixel in Masked Image Modeling.

    This module measures the distance between the target image and the
    reconstructed image and compute the loss to optimize the model. Currently,
    This module only provides L1 and L2 loss to penalize the reconstructed
    error. In addition, a mask can be passed in the ``forward`` function to
    only apply loss on visible region, like that in MAE.

    Args:
        criterion (str): The loss the penalize the reconstructed error.
            Currently, only supports L1 and L2 loss
        channel (int, optional): The number of channels to average the
            reconstruction loss. If not None, the reconstruction loss
            will be divided by the channel. Defaults to None.
    """

    def __init__(self, criterion: str, channel: Optional[int] = None) -> None:
        super().__init__()

        if criterion == 'L1':
            self.penalty = torch.nn.L1Loss(reduction='none')
        elif criterion == 'L2':
            self.penalty = torch.nn.MSELoss(reduction='none')
        else:
            raise NotImplementedError(f'Currently, PixelReconstructionLoss \
            only supports L1 and L2 loss, but get {criterion}')

        self.channel = channel if channel is not None else 1

    def forward(self,
                pred: torch.Tensor,
                target: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Forward function to compute the reconstrction loss.

        Args:
            pred (torch.Tensor): The reconstructed image.
            target (torch.Tensor): The target image.
            mask (torch.Tensor): The mask of the target image.

        Returns:
            torch.Tensor: The reconstruction loss.
        """
        loss = self.penalty(pred, target)

        # if the dim of the loss is 3, take the average of the loss
        # along the last dim
        if len(loss.shape) == 3:
            loss = loss.mean(dim=-1)

        if mask is None:
            loss = loss.mean()
        else:
            loss = (loss * mask).sum() / mask.sum() / self.channel

        return loss
