from typing import List,Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb

@torch.jit.script
def F5_IoU_BCELoss(pred_mask, five_gt_masks):
    """
    binary cross entropy loss (iou loss) of the total five frames for multiple sound source segmentation

    Args:
    pred_mask: predicted masks for a batch of data, shape:[bs*5, 1, 224, 224]
    five_gt_masks: ground truth mask of the total five frames, shape: [bs*5, 1, 224, 224]
    """
    assert len(pred_mask.shape) == 4
    pred_mask, five_gt_masks = pred_mask.flatten(1), five_gt_masks.flatten(1)
    loss = F.binary_cross_entropy_with_logits(pred_mask, five_gt_masks)

    return loss

@torch.jit.script
def F5_Dice_Loss(pred_mask, five_gt_masks):
    pred_masks = torch.sigmoid(pred_mask)

    pred_mask = pred_masks.flatten(1)
    gt_mask = five_gt_masks.flatten(1)
    a = (pred_mask * gt_mask).sum(-1)
    b = (pred_mask * pred_mask).sum(-1) + 0.001
    c = (gt_mask * gt_mask).sum(-1) + 0.001
    d = (2 * a) / (b + c)
    loss = 1 - d
    loss = loss.mean()
    return loss

class SegmentationLoss(nn.Module):
    def __init__(self, weight: Dict):
        super().__init__()
        self.weight = weight

    def forward(self, pred_masks, gt_masks, aux: Dict):
        """
        :param pred_masks:[B*T,1,224,224]
        :param gt_masks: [B,1,1,224,224]
        :param aux: dict with list of B*T,C,224,224,including 'aux_pred':[...]
        :return:
        """

        iou_loss = F5_IoU_BCELoss(pred_masks, gt_masks) * self.weight['iou_loss']
        dice_loss = F5_Dice_Loss(pred_masks, gt_masks) * self.weight['dice_loss']
        total_loss = iou_loss + dice_loss
        loss_dict = {'iou_loss': iou_loss.item(), 'dice_loss': dice_loss.item()}

        return total_loss, loss_dict


if __name__ == "__main__":
    pdb.set_trace()
