# Copyright (c) OpenMMLab. All rights reserved.

from typing import Optional

import torch
from mmengine.model import BaseModule
from torch import nn

from mmpretrain.registry import MODELS


@MODELS.register_module()
class CosineSimilarityLoss(BaseModule):
    """Cosine similarity loss function.

    Compute the similarity between two features and optimize that similarity as
    loss.

    Args:
        shift_factor (float): The shift factor of cosine similarity.
            Default: 0.0.
        scale_factor (float): The scale factor of cosine similarity.
            Default: 1.0.
    """

    def __init__(self,
                 shift_factor: float = 0.0,
                 scale_factor: float = 1.0) -> None:
        super().__init__()
        self.shift_factor = shift_factor
        self.scale_factor = scale_factor

    def forward(self,
                pred: torch.Tensor,
                target: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Forward function of cosine similarity loss.

        Args:
            pred (torch.Tensor): The predicted features.
            target (torch.Tensor): The target features.

        Returns:
            torch.Tensor: The cosine similarity loss.
        """
        pred_norm = nn.functional.normalize(pred, dim=-1)
        target_norm = nn.functional.normalize(target, dim=-1)
        loss = self.shift_factor - self.scale_factor * (
            pred_norm * target_norm).sum(dim=-1)

        if mask is None:
            loss = loss.mean()
        else:
            loss = (loss * mask).sum() / mask.sum()
        return loss
