import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)  # Exp of the negative BCE loss
        F_loss = (1 - pt) ** self.gamma * BCE_loss

        if self.alpha is not None:
            if isinstance(self.alpha, (float, int)):
                alpha_t = torch.tensor([self.alpha, 1 - self.alpha]).to(inputs.device)
            alpha_t = alpha_t[targets.long()]
            F_loss = alpha_t * F_loss

        return F_loss.mean()



class KDLoss(nn.Module):
    def __init__(self, dweight=0.5, sweight=0.5):
        super(KDLoss, self).__init__()
        self.dweight = dweight
        self.sweight = sweight
        self.mse = nn.MSELoss()

    def forward(self, student_logits, teacher_logits, student_loss, is_teacher):
        if is_teacher:
            return student_loss
        else:
            distillation_loss = self.mse(student_logits, teacher_logits)
            total_loss = self.dweight * distillation_loss + self.sweight * student_loss
            return total_loss
