# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm

from mmaction.models import ResNet
from mmaction.testing import check_norm_state, generate_backbone_demo_inputs


def test_resnet_backbone():
    """Test resnet backbone."""
    with pytest.raises(KeyError):
        # ResNet depth should be in [18, 34, 50, 101, 152]
        ResNet(20)

    with pytest.raises(AssertionError):
        # In ResNet: 1 <= num_stages <= 4
        ResNet(50, num_stages=0)

    with pytest.raises(AssertionError):
        # In ResNet: 1 <= num_stages <= 4
        ResNet(50, num_stages=5)

    with pytest.raises(AssertionError):
        # len(strides) == len(dilations) == num_stages
        ResNet(50, strides=(1, ), dilations=(1, 1), num_stages=3)

    with pytest.raises(TypeError):
        # pretrain must be a str
        resnet50 = ResNet(50, pretrained=0)
        resnet50.init_weights()

    with pytest.raises(AssertionError):
        # style must be in ['pytorch', 'caffe']
        ResNet(18, style='tensorflow')

    with pytest.raises(AssertionError):
        # assert not with_cp
        ResNet(18, with_cp=True)

    # resnet with depth 18, norm_eval False, initial weights
    resnet18 = ResNet(18)
    resnet18.init_weights()

    # resnet with depth 50, norm_eval True
    resnet50 = ResNet(50, norm_eval=True)
    resnet50.init_weights()
    resnet50.train()
    assert check_norm_state(resnet50.modules(), False)

    # resnet with depth 50, norm_eval True, pretrained
    resnet50_pretrain = ResNet(
        pretrained='torchvision://resnet50', depth=50, norm_eval=True)
    resnet50_pretrain.init_weights()
    resnet50_pretrain.train()
    assert check_norm_state(resnet50_pretrain.modules(), False)

    # resnet with depth 50, norm_eval True, frozen_stages 1
    frozen_stages = 1
    resnet50_frozen = ResNet(50, frozen_stages=frozen_stages)
    resnet50_frozen.init_weights()
    resnet50_frozen.train()
    assert resnet50_frozen.conv1.bn.training is False
    for layer in resnet50_frozen.conv1.modules():
        for param in layer.parameters():
            assert param.requires_grad is False
    for i in range(1, frozen_stages + 1):
        layer = getattr(resnet50_frozen, f'layer{i}')
        for mod in layer.modules():
            if isinstance(mod, _BatchNorm):
                assert mod.training is False
        for param in layer.parameters():
            assert param.requires_grad is False

    # resnet with depth 50, partial batchnorm
    resnet_pbn = ResNet(50, partial_bn=True)
    resnet_pbn.train()
    count_bn = 0
    for m in resnet_pbn.modules():
        if isinstance(m, nn.BatchNorm2d):
            count_bn += 1
            if count_bn >= 2:
                assert m.weight.requires_grad is False
                assert m.bias.requires_grad is False
                assert m.training is False
            else:
                assert m.weight.requires_grad is True
                assert m.bias.requires_grad is True
                assert m.training is True

    input_shape = (1, 3, 64, 64)
    imgs = generate_backbone_demo_inputs(input_shape)

    # resnet with depth 18 inference
    resnet18 = ResNet(18, norm_eval=False)
    resnet18.init_weights()
    resnet18.train()
    feat = resnet18(imgs)
    assert feat.shape == torch.Size([1, 512, 2, 2])

    # resnet with depth 50 inference
    resnet50 = ResNet(50, norm_eval=False)
    resnet50.init_weights()
    resnet50.train()
    feat = resnet50(imgs)
    assert feat.shape == torch.Size([1, 2048, 2, 2])

    # resnet with depth 50 in caffe style inference
    resnet50_caffe = ResNet(50, style='caffe', norm_eval=False)
    resnet50_caffe.init_weights()
    resnet50_caffe.train()
    feat = resnet50_caffe(imgs)
    assert feat.shape == torch.Size([1, 2048, 2, 2])

    resnet50_flow = ResNet(
        depth=50, pretrained='torchvision://resnet50', in_channels=10)
    input_shape = (1, 10, 64, 64)
    imgs = generate_backbone_demo_inputs(input_shape)
    feat = resnet50_flow(imgs)
    assert feat.shape == torch.Size([1, 2048, 2, 2])

    resnet50 = ResNet(
        depth=50, pretrained='torchvision://resnet50', in_channels=3)
    input_shape = (1, 3, 64, 64)
    imgs = generate_backbone_demo_inputs(input_shape)
    feat = resnet50(imgs)
    assert feat.shape == torch.Size([1, 2048, 2, 2])
