# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch

from mmdet.models.roi_heads.mask_heads import (DynamicMaskHead, FCNMaskHead,
                                               MaskIoUHead)
from .utils import _dummy_bbox_sampling


def test_mask_head_loss():
    """Test mask head loss when mask target is empty."""
    self = FCNMaskHead(
        num_convs=1,
        roi_feat_size=6,
        in_channels=8,
        conv_out_channels=8,
        num_classes=8)

    # Dummy proposals
    proposal_list = [
        torch.Tensor([[23.6667, 23.8757, 228.6326, 153.8874]]),
    ]

    gt_bboxes = [
        torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
    ]
    gt_labels = [torch.LongTensor([2])]
    sampling_results = _dummy_bbox_sampling(proposal_list, gt_bboxes,
                                            gt_labels)

    # create dummy mask
    import numpy as np

    from mmdet.core import BitmapMasks
    dummy_mask = np.random.randint(0, 2, (1, 160, 240), dtype=np.uint8)
    gt_masks = [BitmapMasks(dummy_mask, 160, 240)]

    # create dummy train_cfg
    train_cfg = mmcv.Config(dict(mask_size=12, mask_thr_binary=0.5))

    # Create dummy features "extracted" for each sampled bbox
    num_sampled = sum(len(res.bboxes) for res in sampling_results)
    dummy_feats = torch.rand(num_sampled, 8, 6, 6)

    mask_pred = self.forward(dummy_feats)
    mask_targets = self.get_targets(sampling_results, gt_masks, train_cfg)
    pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
    loss_mask = self.loss(mask_pred, mask_targets, pos_labels)

    onegt_mask_loss = sum(loss_mask['loss_mask'])
    assert onegt_mask_loss.item() > 0, 'mask loss should be non-zero'

    # test mask_iou_head
    mask_iou_head = MaskIoUHead(
        num_convs=1,
        num_fcs=1,
        roi_feat_size=6,
        in_channels=8,
        conv_out_channels=8,
        fc_out_channels=8,
        num_classes=8)

    pos_mask_pred = mask_pred[range(mask_pred.size(0)), pos_labels]
    mask_iou_pred = mask_iou_head(dummy_feats, pos_mask_pred)
    pos_mask_iou_pred = mask_iou_pred[range(mask_iou_pred.size(0)), pos_labels]

    mask_iou_targets = mask_iou_head.get_targets(sampling_results, gt_masks,
                                                 pos_mask_pred, mask_targets,
                                                 train_cfg)
    loss_mask_iou = mask_iou_head.loss(pos_mask_iou_pred, mask_iou_targets)
    onegt_mask_iou_loss = loss_mask_iou['loss_mask_iou'].sum()
    assert onegt_mask_iou_loss.item() >= 0

    # test dynamic_mask_head
    dummy_proposal_feats = torch.rand(num_sampled, 8)
    dynamic_mask_head = DynamicMaskHead(
        dynamic_conv_cfg=dict(
            type='DynamicConv',
            in_channels=8,
            feat_channels=8,
            out_channels=8,
            input_feat_shape=6,
            with_proj=False,
            act_cfg=dict(type='ReLU', inplace=True),
            norm_cfg=dict(type='LN')),
        num_convs=1,
        num_classes=8,
        in_channels=8,
        roi_feat_size=6)

    mask_pred = dynamic_mask_head(dummy_feats, dummy_proposal_feats)

    mask_target = dynamic_mask_head.get_targets(sampling_results, gt_masks,
                                                train_cfg)
    loss_mask = dynamic_mask_head.loss(mask_pred, mask_target, pos_labels)
    loss_mask = loss_mask['loss_mask'].sum()
    assert loss_mask.item() >= 0
