import torch
import torch.nn as nn
import torch.nn.functional as F

class BinaryFocalLoss(nn.Module):
    """
    Focal Loss for binary classification.
    Used for our router loss calculation.
    """
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(BinaryFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)  # pt = p if y=1 else 1-p
        F_loss = self.alpha * (1 - pt)**self.gamma * BCE_loss

        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss

class MultiClassFocalLoss(nn.Module):
    """
    Focal Loss for multi-class classification.
    Used for our final classification loss (loss_cls).
    """
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(MultiClassFocalLoss, self).__init__()
        self.alpha = alpha # alpha is a list, e.g., [0.1, 0.2, 0.7]
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        log_softmax = F.log_softmax(inputs, dim=1)
        # get the log_softmax value of the correct class
        log_pt = log_softmax.gather(1, targets.view(-1, 1)).squeeze(1)
        pt = torch.exp(log_pt)

        # calculate focal loss
        loss = -1 * (1 - pt)**self.gamma * log_pt

        if self.alpha is not None:
            # assign the corresponding alpha weight to the loss of each sample based on the target class
            alpha_t = self.alpha.gather(0, targets.view(-1))
            loss = alpha_t * loss

        if self.reduction == 'mean':
            return torch.mean(loss)
        elif self.reduction == 'sum':
            return torch.sum(loss)
        else:
            return loss