# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmpretrain.models.backbones import SEResNeXt
from mmpretrain.models.backbones.seresnext import SEBottleneck as SEBottleneckX


def test_bottleneck():
    with pytest.raises(AssertionError):
        # Style must be in ['pytorch', 'caffe']
        SEBottleneckX(64, 64, groups=32, width_per_group=4, style='tensorflow')

    # Test SEResNeXt Bottleneck structure
    block = SEBottleneckX(
        64, 256, groups=32, width_per_group=4, stride=2, style='pytorch')
    assert block.width_per_group == 4
    assert block.conv2.stride == (2, 2)
    assert block.conv2.groups == 32
    assert block.conv2.out_channels == 128
    assert block.conv2.out_channels == block.mid_channels

    # Test SEResNeXt Bottleneck structure (groups=1)
    block = SEBottleneckX(
        64, 256, groups=1, width_per_group=4, stride=2, style='pytorch')
    assert block.conv2.stride == (2, 2)
    assert block.conv2.groups == 1
    assert block.conv2.out_channels == 64
    assert block.mid_channels == 64
    assert block.conv2.out_channels == block.mid_channels

    # Test SEResNeXt Bottleneck forward
    block = SEBottleneckX(
        64, 64, base_channels=16, groups=32, width_per_group=4)
    x = torch.randn(1, 64, 56, 56)
    x_out = block(x)
    assert x_out.shape == torch.Size([1, 64, 56, 56])


def test_seresnext():
    with pytest.raises(KeyError):
        # SEResNeXt depth should be in [50, 101, 152]
        SEResNeXt(depth=18)

    # Test SEResNeXt with group 32, width_per_group 4
    model = SEResNeXt(
        depth=50, groups=32, width_per_group=4, out_indices=(0, 1, 2, 3))
    for m in model.modules():
        if isinstance(m, SEBottleneckX):
            assert m.conv2.groups == 32
    model.init_weights()
    model.train()

    imgs = torch.randn(1, 3, 224, 224)
    feat = model(imgs)
    assert len(feat) == 4
    assert feat[0].shape == torch.Size([1, 256, 56, 56])
    assert feat[1].shape == torch.Size([1, 512, 28, 28])
    assert feat[2].shape == torch.Size([1, 1024, 14, 14])
    assert feat[3].shape == torch.Size([1, 2048, 7, 7])

    # Test SEResNeXt with group 32, width_per_group 4 and layers 3 out forward
    model = SEResNeXt(
        depth=50, groups=32, width_per_group=4, out_indices=(3, ))
    for m in model.modules():
        if isinstance(m, SEBottleneckX):
            assert m.conv2.groups == 32
    model.init_weights()
    model.train()

    imgs = torch.randn(1, 3, 224, 224)
    feat = model(imgs)
    assert len(feat) == 1
    assert feat[0].shape == torch.Size([1, 2048, 7, 7])
