# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/facebookresearch/detr/blob/main/models/segmentation.py
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/focal_loss.py
# ------------------------------------------------------------------------------------------------

import torch
import torch.nn as nn
import torch.nn.functional as F

from .utils import weight_reduce_loss


def sigmoid_focal_loss(
    preds,
    targets,
    weight=None,
    alpha: float = 0.25,
    gamma: float = 2,
    reduction: str = "mean",
    avg_factor: int = None,
):
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.

    Args:
        preds (torch.Tensor): A float tensor of arbitrary shape.
            The predictions for each example.
        targets (torch.Tensor): A float tensor with the same shape as inputs. Stores the binary
            classification label for each element in inputs
            (0 for the negative class and 1 for the positive class).
        alpha (float, optional): Weighting factor in range (0, 1) to balance
            positive vs negative examples. Default: 0.25.
        gamma (float): Exponent of the modulating factor (1 - p_t) to
            balance easy vs hard examples. Default: 2.
        reduction: 'none' | 'mean' | 'sum'
                 'none': No reduction will be applied to the output.
                 'mean': The output will be averaged.
                 'sum': The output will be summed.
        avg_factor (int): Average factor that is used to average
            the loss. Default: None.

    Returns:
        torch.Tensor: The computed sigmoid focal loss with the reduction option applied.
    """
    preds = preds.float()
    targets = targets.float()
    p = torch.sigmoid(preds)
    ce_loss = F.binary_cross_entropy_with_logits(preds, targets, reduction="none")
    p_t = p * targets + (1 - p) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    if weight is not None:
        assert weight.ndim == loss.ndim

    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss


def focal_loss_with_prob(
    preds,
    targets,
    weight=None,
    alpha=0.25,
    gamma=2.0,
    reduction="mean",
    avg_factor=None,
):
    """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
    Different from `sigmoid_focal_loss`, this function accepts probability
    as input.

    Args:
        preds (torch.Tensor): The prediction probability with shape (N, C),
            C is the number of classes.
        targets (torch.Tensor): The learning label of the prediction.
        weight (torch.Tensor, optional): Sample-wise loss weight.
        gamma (float, optional): The gamma for calculating the modulating
            factor. Defaults to 2.0.
        alpha (float, optional): A balanced form for Focal Loss.
            Defaults to 0.25.
        reduction (str, optional): The method used to reduce the loss into
            a scalar. Defaults to 'mean'.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
    """
    num_classes = preds.size(1)
    targets = F.one_hot(targets, num_classes=num_classes + 1)
    targets = targets[:, :num_classes]

    targets = targets.type_as(preds)

    p_t = preds * targets + (1 - preds) * (1 - targets)
    ce_loss = F.binary_cross_entropy(preds, targets, reduction="none")
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    if weight is not None:
        assert weight.ndim == loss.ndim

    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss


class FocalLoss(nn.Module):
    """`Focal Loss <https://arxiv.org/abs/1708.02002>`_

    Args:
        gamma (float, optional): The gamma for calculating the modulating
            factor. Defaults to 2.0.
        alpha (float, optional): A balanced form for Focal Loss.
            Defaults to 0.25.
        reduction (str, optional): The method used to reduce the loss into
            a scalar. Defaults to 'mean'. Options are "none", "mean" and
            "sum".
        loss_weight (float, optional): Weight of loss. Defaults to 1.0.
    """

    def __init__(
        self,
        alpha=0.25,
        gamma=2.0,
        reduction="mean",
        loss_weight=1.0,
        activated=False,
    ):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.loss_weight = loss_weight
        self.activated = activated

    def forward(
        self,
        preds,
        targets,
        weight=None,
        avg_factor=None,
    ):
        """Forward function for FocalLoss

        Args:
            preds (torch.Tensor): The prediction probability with shape ``(N, C)``.
                C is the number of classes.
            targets (torch.Tensor): The learning label of the prediction.
            weight (torch.Tensor, optional): The weight of loss for each
                prediction. Defaults to None.
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.

        Returns:
            torch.Tensor: The calculated loss
        """
        if self.activated:
            loss_func = focal_loss_with_prob
        else:
            num_classes = preds.size(1)
            targets = F.one_hot(targets, num_classes=num_classes + 1)
            targets = targets[:, :num_classes]
            loss_func = sigmoid_focal_loss

        loss_class = self.loss_weight * loss_func(
            preds,
            targets,
            weight,
            alpha=self.alpha,
            gamma=self.gamma,
            reduction=self.reduction,
            avg_factor=avg_factor,
        )
        return loss_class
