from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F


@torch.jit.script
def F1_Dice_loss(pred_masks, first_gt_mask):
    """dice loss for aux loss
    Args:
        pred_mask (Tensor): (bs, 1, h, w)
        five_gt_masks (Tensor): (bs, 1, h, w)
    """
    pred_masks = torch.sigmoid(pred_masks)

    pred_mask = pred_masks.flatten(1)
    gt_mask = first_gt_mask.flatten(1)
    a = (pred_mask * gt_mask).sum(-1)
    b = (pred_mask * pred_mask).sum(-1) + 0.001
    c = (gt_mask * gt_mask).sum(-1) + 0.001
    d = (2 * a) / (b + c)
    loss = 1 - d
    loss = loss.mean()
    return loss


@torch.jit.script
def F1_IoU_BCELoss(pred_masks, first_gt_mask):
    """
    binary cross entropy loss (iou loss) of the first frame for single sound source segmentation

    Args:
    pred_masks: predicted masks for a batch of data, shape:[bs, 1, 224, 224]
    first_gt_mask: ground truth mask of the first frame, shape: [bs, 1, 224, 224]
    """
    pred_masks, first_gt_mask = pred_masks.flatten(1), first_gt_mask.flatten(1)
    first_bce_loss = F.binary_cross_entropy_with_logits(pred_masks, first_gt_mask)

    return first_bce_loss


class SegmentationLoss(nn.Module):
    def __init__(self, weight: Dict):
        super().__init__()
        self.weight = weight

    def forward(self, pred_masks, gt_masks, aux: Dict):
        """
        :param pred_masks:[B*T,1,224,224]
        :param gt_masks: [B,1,1,224,224]
        :param aux: dict with list of B*T,C,224,224,including 'aux_pred':[...]
        :return:
        """
        assert len(pred_masks.shape) == 4
        assert pred_masks.requires_grad, "Error when indexing predicted masks"
        if len(gt_masks.shape) == 5:
            gt_masks = gt_masks.squeeze(1)  # [bs, 1, 224, 224]

        iou_loss = F1_IoU_BCELoss(pred_masks, gt_masks) * self.weight['iou_loss']
        dice_loss = F1_Dice_loss(pred_masks, gt_masks) * self.weight['dice_loss']
        total_loss = iou_loss + dice_loss
        loss_dict = {'iou_loss': iou_loss.item(), 'dice_loss': dice_loss.item()}

        return total_loss, loss_dict


if __name__ == "__main__":
    breakpoint()
