import torch
from torch import nn

_EPS = 1e-7


class FocalLoss(nn.Module):
    def __init__(self, gamma: float = 2.0, alpha: float = 0.25):
        super().__init__()
        self.gamma, self.alpha = gamma, alpha

    def forward(self, pred, mask):
        p = torch.sigmoid(pred)
        num_pos = mask.sum()
        num_neg = mask.numel() - num_pos
        w_pos = (1 - p) ** self.gamma
        w_neg = p**self.gamma
        loss = -self.alpha * mask * w_pos * torch.log(p + _EPS) - (1 - self.alpha) * (1 - mask) * w_neg * torch.log(1 - p + _EPS)
        return loss.sum() / (num_pos + num_neg + _EPS)


class DiceLoss(nn.Module):
    def __init__(self, smooth: float = 1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self, pred, mask):
        p = torch.sigmoid(pred)
        inter = (p * mask).sum()
        union = p.sum() + mask.sum()
        return 1 - (2 * inter + self.smooth) / (union + self.smooth)


class MaskIoULoss(nn.Module):
    def forward(self, pred_mask, gt_mask, pred_iou):
        p = torch.sigmoid(pred_mask)
        inter = (p * gt_mask).sum()
        union = p.sum() + gt_mask.sum() - inter
        iou = (inter + _EPS) / (union + _EPS)
        return ((iou - pred_iou) ** 2).mean()


class FocalDiceloss_IoULoss(nn.Module):
    def __init__(self, weight: float = 20.0, iou_scale: float = 1.0):
        super().__init__()
        self.weight = weight
        self.iou_scale = iou_scale
        self.focal = FocalLoss()
        self.dice = DiceLoss()
        self.maskiou = MaskIoULoss()

    def forward(self, pred, mask, pred_iou):
        loss1 = self.weight * self.focal(pred, mask) + self.dice(pred, mask)
        loss2 = self.maskiou(pred, mask, pred_iou)
        return loss1 + self.iou_scale * loss2
