# Copyright (c) OpenMMLab. All rights reserved.
# migrate from mmdetection with modifications
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..registry import LOSSES
from .utils import weight_reduce_loss, convert_to_one_hot


def seesaw_ce_loss(cls_score,
                   labels,
                   weight=None,
                   cum_samples=256,
                   p=0.8,
                   q=2.0,
                   eps=1e-2,
                   reduction='mean',
                   avg_factor=None):
    """Calculate the Seesaw CrossEntropy loss.

    Args:
        cls_score (torch.Tensor): The prediction with shape (N, C),
             C is the number of classes.
        labels (torch.Tensor): The learning label of the prediction.
        weight (torch.Tensor): Sample-wise loss weight.
        cum_samples (torch.Tensor): Cumulative samples for each category.
        p (float): The ``p`` in the mitigation factor.
        q (float): The ``q`` in the compenstation factor.
        eps (float): The minimal value of divisor to smooth
             the computation of compensation factor
        reduction (str, optional): The method used to reduce the loss.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.

    Returns:
        torch.Tensor: The calculated loss
    """
    assert len(cum_samples) == cls_score.size(-1)

    seesaw_weights = cls_score.new_ones(labels.size())

    # mitigation factor
    if p > 0:
        sample_ratio_matrix = cum_samples[None, :].clamp(
            min=1) / cum_samples[:, None].clamp(min=1)
        index = (sample_ratio_matrix < 1.0).float()
        sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index
                                                               )  # M_{ij}
        mitigation_factor = sample_weights[labels.long(), :]
        seesaw_weights = seesaw_weights * mitigation_factor

    # compensation factor
    if q > 0:
        scores = F.softmax(cls_score.detach(), dim=1)
        self_scores = scores[
            torch.arange(0, len(scores)).to(scores.device).long(),
            labels.long()]
        score_matrix = scores / self_scores[:, None].clamp(min=eps)
        index = (score_matrix > 1.0).float()
        compensation_factor = score_matrix.pow(q) * index + (1 - index)
        seesaw_weights = seesaw_weights * compensation_factor

    cls_score = cls_score + (seesaw_weights.log() * (1 - labels))

    loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none')

    if weight is not None:
        weight = weight.float()
    loss = weight_reduce_loss(
        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
    return loss


@LOSSES.register_module()
class SeesawLoss(nn.Module):
    """Implementation of seesaw loss.

    Seesaw Loss for Long-Tailed Instance Segmentation. In CVPR, 2021.
    <https://arxiv.org/abs/2008.10032>

    Args:
        use_sigmoid (bool): Whether the prediction uses sigmoid of softmax.
             Only False is supported. Defaults to False.
        p (float): The ``p`` in the mitigation factor.
             Defaults to 0.8.
        q (float): The ``q`` in the compenstation factor.
             Defaults to 2.0.
        num_classes (int): The number of classes.
             Default to 1000 for the ImageNet dataset.
        eps (float): The minimal value of divisor to smooth
             the computation of compensation factor, default to 1e-2.
        reduction (str): The method used to reduce the loss into
            a scalar.
        loss_weight (float): Weight of loss. Defaults to 1.0.
    """

    def __init__(self,
                 use_sigmoid=False,
                 p=0.8,
                 q=2.0,
                 num_classes=1000,
                 eps=1e-2,
                 reduction='mean',
                 loss_weight=1.0,
                 **kwargs):
        super(SeesawLoss, self).__init__()
        assert not use_sigmoid, '`use_sigmoid` is not supported'
        self.use_sigmoid = False
        self.p = p
        self.q = q
        self.num_classes = num_classes
        self.eps = eps
        self.reduction = reduction
        self.loss_weight = loss_weight

        self.criterion = seesaw_ce_loss

        self.post_process = "softmax" if not use_sigmoid \
            else "sigmoid"  # multi-label classification

        # cumulative samples for each category
        self.register_buffer('cum_samples',
                             torch.zeros(self.num_classes, dtype=torch.float))

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        """Forward function.

        Args:
            pred (torch.Tensor): The prediction with shape (N, C).
            labels (torch.Tensor): The learning label of the prediction.
            weight (torch.Tensor, optional): Sample-wise loss weight.
            avg_factor (int, optional): Average factor that is used to average
                 the loss. Defaults to None.
            reduction (str, optional): The method used to reduce the loss.
                 Options are "none", "mean" and "sum".
        Returns:
            torch.Tensor: The calculated loss
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)

        # accumulate the samples for each category
        unique_labels = target.unique()
        for u_l in unique_labels:
            inds_ = target == u_l.item()
            self.cum_samples[u_l] += inds_.sum()

        # convert target to onehot formats
        assert self.num_classes == pred.size(-1)
        target = convert_to_one_hot(target, self.num_classes)

        if weight is not None:
            weight = weight.float()
        else:
            weight = target.new_ones(target.size(), dtype=torch.float)

        loss_cls = self.loss_weight * self.criterion(
            pred,
            target,
            weight,
            self.cum_samples,
            self.p,
            self.q,
            self.eps,
            reduction=reduction,
            avg_factor=avg_factor)
        return loss_cls
