# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmaction.models import MobileNetV2TSM
from mmaction.testing import generate_backbone_demo_inputs


def test_mobilenetv2_tsm_backbone():
    """Test mobilenetv2_tsm backbone."""
    from mmcv.cnn import ConvModule

    from mmaction.models.backbones.mobilenet_v2 import InvertedResidual
    from mmaction.models.backbones.resnet_tsm import TemporalShift

    input_shape = (8, 3, 64, 64)
    imgs = generate_backbone_demo_inputs(input_shape)

    # mobilenetv2_tsm with width_mult = 1.0
    mobilenetv2_tsm = MobileNetV2TSM()
    mobilenetv2_tsm.init_weights()
    for cur_module in mobilenetv2_tsm.modules():
        if isinstance(cur_module, InvertedResidual) and \
            len(cur_module.conv) == 3 and \
                cur_module.use_res_connect:
            assert isinstance(cur_module.conv[0], TemporalShift)
            assert cur_module.conv[0].num_segments == \
                mobilenetv2_tsm.num_segments
            assert cur_module.conv[0].shift_div == mobilenetv2_tsm.shift_div
            assert isinstance(cur_module.conv[0].net, ConvModule)

    # TSM-MobileNetV2 with widen_factor = 1.0 forword
    feat = mobilenetv2_tsm(imgs)
    assert feat.shape == torch.Size([8, 1280, 2, 2])

    # mobilenetv2 with widen_factor = 0.5 forword
    mobilenetv2_tsm_05 = MobileNetV2TSM(widen_factor=0.5)
    mobilenetv2_tsm_05.init_weights()
    feat = mobilenetv2_tsm_05(imgs)
    assert feat.shape == torch.Size([8, 1280, 2, 2])

    # mobilenetv2 with widen_factor = 1.5 forword
    mobilenetv2_tsm_15 = MobileNetV2TSM(widen_factor=1.5)
    mobilenetv2_tsm_15.init_weights()
    feat = mobilenetv2_tsm_15(imgs)
    assert feat.shape == torch.Size([8, 1920, 2, 2])
