# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import Union

import torch
import torch.nn.functional as F
from mmengine.structures import InstanceData
from torch import Tensor

from mmseg.registry import TASK_UTILS


class BaseMatchCost:
    """Base match cost class.

    Args:
        weight (Union[float, int]): Cost weight. Defaults to 1.
    """

    def __init__(self, weight: Union[float, int] = 1.) -> None:
        self.weight = weight

    @abstractmethod
    def __call__(self, pred_instances: InstanceData,
                 gt_instances: InstanceData, **kwargs) -> Tensor:
        """Compute match cost.

        Args:
            pred_instances (InstanceData): Instances of model predictions.
            It often includes "labels" and "scores".
            gt_instances (InstanceData): Ground truth of instance
            annotations. It usually includes "labels".

        Returns:
            Tensor: Match Cost matrix of shape (num_preds, num_gts).
        """
        pass


@TASK_UTILS.register_module()
class ClassificationCost(BaseMatchCost):
    """ClsSoftmaxCost.

    Args:
        weight (Union[float, int]): Cost weight. Defaults to 1.

    Examples:
        >>> from mmseg.models.assigners import ClassificationCost
        >>> import torch
        >>> self = ClassificationCost()
        >>> cls_pred = torch.rand(4, 3)
        >>> gt_labels = torch.tensor([0, 1, 2])
        >>> factor = torch.tensor([10, 8, 10, 8])
        >>> self(cls_pred, gt_labels)
        tensor([[-0.3430, -0.3525, -0.3045],
            [-0.3077, -0.2931, -0.3992],
            [-0.3664, -0.3455, -0.2881],
            [-0.3343, -0.2701, -0.3956]])
    """

    def __init__(self, weight: Union[float, int] = 1) -> None:
        super().__init__(weight=weight)

    def __call__(self, pred_instances: InstanceData,
                 gt_instances: InstanceData, **kwargs) -> Tensor:
        """Compute match cost.

        Args:
            pred_instances (InstanceData): "scores" inside is
                predicted classification logits, of shape
                (num_queries, num_class).
            gt_instances (InstanceData): "labels" inside should have
                shape (num_gt, ).

        Returns:
            Tensor: Match Cost matrix of shape (num_preds, num_gts).
        """
        assert hasattr(pred_instances, 'scores'), \
            "pred_instances must contain 'scores'"
        assert hasattr(gt_instances, 'labels'), \
            "gt_instances must contain 'labels'"
        pred_scores = pred_instances.scores
        gt_labels = gt_instances.labels

        pred_scores = pred_scores.softmax(-1)
        cls_cost = -pred_scores[:, gt_labels]

        return cls_cost * self.weight


@TASK_UTILS.register_module()
class DiceCost(BaseMatchCost):
    """Cost of mask assignments based on dice losses.

    Args:
        pred_act (bool): Whether to apply sigmoid to mask_pred.
            Defaults to False.
        eps (float): Defaults to 1e-3.
        naive_dice (bool): If True, use the naive dice loss
            in which the power of the number in the denominator is
            the first power. If False, use the second power that
            is adopted by K-Net and SOLO. Defaults to True.
        weight (Union[float, int]): Cost weight. Defaults to 1.
    """

    def __init__(self,
                 pred_act: bool = False,
                 eps: float = 1e-3,
                 naive_dice: bool = True,
                 weight: Union[float, int] = 1.) -> None:
        super().__init__(weight=weight)
        self.pred_act = pred_act
        self.eps = eps
        self.naive_dice = naive_dice

    def _binary_mask_dice_loss(self, mask_preds: Tensor,
                               gt_masks: Tensor) -> Tensor:
        """
        Args:
            mask_preds (Tensor): Mask prediction in shape (num_queries, *).
            gt_masks (Tensor): Ground truth in shape (num_gt, *)
                store 0 or 1, 0 for negative class and 1 for
                positive class.

        Returns:
            Tensor: Dice cost matrix in shape (num_queries, num_gt).
        """
        mask_preds = mask_preds.flatten(1)
        gt_masks = gt_masks.flatten(1).float()
        numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
        if self.naive_dice:
            denominator = mask_preds.sum(-1)[:, None] + \
                gt_masks.sum(-1)[None, :]
        else:
            denominator = mask_preds.pow(2).sum(1)[:, None] + \
                gt_masks.pow(2).sum(1)[None, :]
        loss = 1 - (numerator + self.eps) / (denominator + self.eps)
        return loss

    def __call__(self, pred_instances: InstanceData,
                 gt_instances: InstanceData, **kwargs) -> Tensor:
        """Compute match cost.

        Args:
            pred_instances (InstanceData): Predicted instances which
                must contain "masks".
            gt_instances (InstanceData): Ground truth which must contain
                "mask".

        Returns:
            Tensor: Match Cost matrix of shape (num_preds, num_gts).
        """
        assert hasattr(pred_instances, 'masks'), \
            "pred_instances must contain 'masks'"
        assert hasattr(gt_instances, 'masks'), \
            "gt_instances must contain 'masks'"
        pred_masks = pred_instances.masks
        gt_masks = gt_instances.masks

        if self.pred_act:
            pred_masks = pred_masks.sigmoid()
        dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks)
        return dice_cost * self.weight


@TASK_UTILS.register_module()
class CrossEntropyLossCost(BaseMatchCost):
    """CrossEntropyLossCost.

    Args:
        use_sigmoid (bool): Whether the prediction uses sigmoid
                of softmax. Defaults to True.
        weight (Union[float, int]): Cost weight. Defaults to 1.
    """

    def __init__(self,
                 use_sigmoid: bool = True,
                 weight: Union[float, int] = 1.) -> None:
        super().__init__(weight=weight)
        self.use_sigmoid = use_sigmoid

    def _binary_cross_entropy(self, cls_pred: Tensor,
                              gt_labels: Tensor) -> Tensor:
        """
        Args:
            cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or
                (num_queries, *).
            gt_labels (Tensor): The learning label of prediction with
                shape (num_gt, *).

        Returns:
            Tensor: Cross entropy cost matrix in shape (num_queries, num_gt).
        """
        cls_pred = cls_pred.flatten(1).float()
        gt_labels = gt_labels.flatten(1).float()
        n = cls_pred.shape[1]
        pos = F.binary_cross_entropy_with_logits(
            cls_pred, torch.ones_like(cls_pred), reduction='none')
        neg = F.binary_cross_entropy_with_logits(
            cls_pred, torch.zeros_like(cls_pred), reduction='none')
        cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
            torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
        cls_cost = cls_cost / n

        return cls_cost

    def __call__(self, pred_instances: InstanceData,
                 gt_instances: InstanceData, **kwargs) -> Tensor:
        """Compute match cost.

        Args:
            pred_instances (:obj:`InstanceData`): Predicted instances which
                must contain ``masks``.
            gt_instances (:obj:`InstanceData`): Ground truth which must contain
                ``masks``.

        Returns:
            Tensor: Match Cost matrix of shape (num_preds, num_gts).
        """
        assert hasattr(pred_instances, 'masks'), \
            "pred_instances must contain 'masks'"
        assert hasattr(gt_instances, 'masks'), \
            "gt_instances must contain 'masks'"
        pred_masks = pred_instances.masks
        gt_masks = gt_instances.masks
        if self.use_sigmoid:
            cls_cost = self._binary_cross_entropy(pred_masks, gt_masks)
        else:
            raise NotImplementedError

        return cls_cost * self.weight
