from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb

@torch.jit.script
def F10_IoU_BCELoss(pred_mask, ten_gt_masks, gt_temporal_mask_flag):
    """
    binary cross entropy loss (iou loss) of the total ten frames for multiple sound source segmentation

    Args:
    pred_mask: predicted masks for a batch of data, shape:[bs*10, N_CLASSES, 224, 224]
    ten_gt_masks: ground truth mask of the total five frames, shape: [bs*10, 224, 224]
    """
    assert len(pred_mask.shape) == 4
    if ten_gt_masks.shape[1] == 1:
        ten_gt_masks = ten_gt_masks.squeeze(1)  # [bs*10, 224, 224]
    # ! notice:
    loss = F.cross_entropy(pred_mask, ten_gt_masks.long(), reduction='none')  # [bs*10, 224, 224]
    loss = loss.mean((-2, -1))  # [bs*10]
    loss = loss * gt_temporal_mask_flag  # [bs*10]
    loss = torch.sum(loss) / torch.sum(gt_temporal_mask_flag)

    return loss
@torch.jit.script
def F10_Dice_Loss(mask_feature, ten_gt_masks, gt_temporal_mask_flag):
    """
    dice loss for total ten frames for multiple sound source segmentation
    :param pred_mask_featuremask: [bs*10, C, 224, 224]
    :param ten_gt_masks: [bs*10,1, 224, 224]
    :param gt_temporal_mask_flag: [bs*10]
    :return:
    """
    mask_feature = torch.mean(mask_feature, dim=1, keepdim=True)
    mask_feature = F.interpolate(
        mask_feature, ten_gt_masks.shape[-2:], mode='bilinear', align_corners=False)
    one_mask = torch.ones_like(ten_gt_masks)
    norm_gt_mask = torch.where(ten_gt_masks > 0, one_mask, ten_gt_masks)
    mask_feature = torch.sigmoid(mask_feature)
    mask_feature = mask_feature.flatten(1)
    gt_mask = norm_gt_mask.flatten(1)
    a = (mask_feature * gt_mask).sum(-1)
    b = (mask_feature * mask_feature).sum(-1) + 0.001
    c = (gt_mask * gt_mask).sum(-1) + 0.001
    d = (2 * a) / (b + c)
    loss = 1 - d
    loss = loss * gt_temporal_mask_flag
    loss = torch.sum(loss) / torch.sum(gt_temporal_mask_flag)
    return loss


class SegmentationLoss(nn.Module):
    def __init__(self, weight: Dict):
        super().__init__()
        self.weight = weight

    def forward(self, pred_masks, gt_masks, gt_temporal_mask_flag,aux: Dict):
        """
        :param pred_masks:[B*T,1,224,224]
        :param gt_masks: [B,1,1,224,224]
        :param aux: dict with list of B*T,C,224,224,including 'aux_pred':[...]
        :return:
        """
        assert len(pred_masks.shape) == 4
        # indices = torch.tensor(list(range(0, len(pred_masks), 5)), device=pred_masks.device)
        #
        # pred_masks = torch.index_select(pred_masks, dim=0, index=indices)  # [bs, 1, 224, 224]
        assert pred_masks.requires_grad, "Error when indexing predicted masks"
        if len(gt_masks.shape) == 5:
            gt_masks = gt_masks.squeeze(1)  # [bs, 1, 224, 224]
        iou_loss = F10_IoU_BCELoss(pred_masks, gt_masks,gt_temporal_mask_flag) * self.weight['iou_loss']
        total_loss = iou_loss
        loss_dict = {'iou_loss': iou_loss.item()}

        aux_loss = F10_Dice_Loss(aux['mask_feature'],gt_masks,gt_temporal_mask_flag)* self.weight['aux_loss']
        aux_loss_dict = {'aux_loss': aux_loss.item()}
        total_loss += aux_loss
        loss_dict.update(aux_loss_dict)

        return total_loss, loss_dict


if __name__ == "__main__":
    pdb.set_trace()
