# Copyright (c) OpenMMLab. All rights reserved.
import platform

import pytest
import torch
import torch.nn as nn
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm

from mmaction.models import ResNet3d, ResNet3dLayer
from mmaction.testing import check_norm_state, generate_backbone_demo_inputs


@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_resnet3d_backbone():
    """Test resnet3d backbone."""
    with pytest.raises(AssertionError):
        # In ResNet3d: 1 <= num_stages <= 4
        ResNet3d(34, None, num_stages=0)

    with pytest.raises(AssertionError):
        # In ResNet3d: 1 <= num_stages <= 4
        ResNet3d(34, None, num_stages=5)

    with pytest.raises(AssertionError):
        # In ResNet3d: 1 <= num_stages <= 4
        ResNet3d(50, None, num_stages=0)

    with pytest.raises(AssertionError):
        # In ResNet3d: 1 <= num_stages <= 4
        ResNet3d(50, None, num_stages=5)

    with pytest.raises(AssertionError):
        # len(spatial_strides) == len(temporal_strides)
        # == len(dilations) == num_stages
        ResNet3d(
            50,
            None,
            spatial_strides=(1, ),
            temporal_strides=(1, 1),
            dilations=(1, 1, 1),
            num_stages=4)

    with pytest.raises(AssertionError):
        # len(spatial_strides) == len(temporal_strides)
        # == len(dilations) == num_stages
        ResNet3d(
            34,
            None,
            spatial_strides=(1, ),
            temporal_strides=(1, 1),
            dilations=(1, 1, 1),
            num_stages=4)

    with pytest.raises(TypeError):
        # pretrain must be str or None.
        resnet3d_34 = ResNet3d(34, ['resnet', 'bninception'])
        resnet3d_34.init_weights()

    with pytest.raises(TypeError):
        # pretrain must be str or None.
        resnet3d_50 = ResNet3d(50, ['resnet', 'bninception'])
        resnet3d_50.init_weights()

    # resnet3d with depth 34, no pretrained, norm_eval True
    resnet3d_34 = ResNet3d(34, None, pretrained2d=False, norm_eval=True)
    resnet3d_34.init_weights()
    resnet3d_34.train()
    assert check_norm_state(resnet3d_34.modules(), False)

    # resnet3d with depth 50, no pretrained, norm_eval True
    resnet3d_50 = ResNet3d(50, None, pretrained2d=False, norm_eval=True)
    resnet3d_50.init_weights()
    resnet3d_50.train()
    assert check_norm_state(resnet3d_50.modules(), False)

    # resnet3d with depth 50, pretrained2d, norm_eval True
    resnet3d_50_pretrain = ResNet3d(
        50, 'torchvision://resnet50', norm_eval=True)
    resnet3d_50_pretrain.init_weights()
    resnet3d_50_pretrain.train()
    assert check_norm_state(resnet3d_50_pretrain.modules(), False)
    from mmengine.runner.checkpoint import _load_checkpoint
    chkp_2d = _load_checkpoint('torchvision://resnet50')
    for name, module in resnet3d_50_pretrain.named_modules():
        if len(name.split('.')) == 4:
            # layer.block.module.submodule
            prefix = name.split('.')[:2]
            module_type = name.split('.')[2]
            submodule_type = name.split('.')[3]

            if module_type == 'downsample':
                name2d = name.replace('conv', '0').replace('bn', '1')
            else:
                layer_id = name.split('.')[2][-1]
                name2d = prefix[0] + '.' + prefix[1] + '.' + \
                    submodule_type + layer_id

            if isinstance(module, nn.Conv3d):
                conv2d_weight = chkp_2d[name2d + '.weight']
                conv3d_weight = getattr(module, 'weight').data
                assert torch.equal(
                    conv3d_weight,
                    conv2d_weight.data.unsqueeze(2).expand_as(conv3d_weight) /
                    conv3d_weight.shape[2])
                if getattr(module, 'bias') is not None:
                    conv2d_bias = chkp_2d[name2d + '.bias']
                    conv3d_bias = getattr(module, 'bias').data
                    assert torch.equal(conv2d_bias, conv3d_bias)

            elif isinstance(module, nn.BatchNorm3d):
                for pname in ['weight', 'bias', 'running_mean', 'running_var']:
                    param_2d = chkp_2d[name2d + '.' + pname]
                    param_3d = getattr(module, pname).data
                assert torch.equal(param_2d, param_3d)

    conv3d = resnet3d_50_pretrain.conv1.conv
    assert torch.equal(
        conv3d.weight,
        chkp_2d['conv1.weight'].unsqueeze(2).expand_as(conv3d.weight) /
        conv3d.weight.shape[2])
    conv3d = resnet3d_50_pretrain.layer3[2].conv2.conv
    assert torch.equal(
        conv3d.weight, chkp_2d['layer3.2.conv2.weight'].unsqueeze(2).expand_as(
            conv3d.weight) / conv3d.weight.shape[2])

    # resnet3d with depth 34, no pretrained, norm_eval False
    resnet3d_34_no_bn_eval = ResNet3d(
        34, None, pretrained2d=False, norm_eval=False)
    resnet3d_34_no_bn_eval.init_weights()
    resnet3d_34_no_bn_eval.train()
    assert check_norm_state(resnet3d_34_no_bn_eval.modules(), True)

    # resnet3d with depth 50, no pretrained, norm_eval False
    resnet3d_50_no_bn_eval = ResNet3d(
        50, None, pretrained2d=False, norm_eval=False)
    resnet3d_50_no_bn_eval.init_weights()
    resnet3d_50_no_bn_eval.train()
    assert check_norm_state(resnet3d_50_no_bn_eval.modules(), True)

    # resnet3d with depth 34, no pretrained, frozen_stages, norm_eval False
    frozen_stages = 1
    resnet3d_34_frozen = ResNet3d(
        34, None, pretrained2d=False, frozen_stages=frozen_stages)
    resnet3d_34_frozen.init_weights()
    resnet3d_34_frozen.train()
    assert resnet3d_34_frozen.conv1.bn.training is False
    for param in resnet3d_34_frozen.conv1.parameters():
        assert param.requires_grad is False
    for i in range(1, frozen_stages + 1):
        layer = getattr(resnet3d_34_frozen, f'layer{i}')
        for mod in layer.modules():
            if isinstance(mod, _BatchNorm):
                assert mod.training is False
        for param in layer.parameters():
            assert param.requires_grad is False
    # test zero_init_residual
    for m in resnet3d_34_frozen.modules():
        if hasattr(m, 'conv2'):
            assert torch.equal(m.conv2.bn.weight,
                               torch.zeros_like(m.conv2.bn.weight))
            assert torch.equal(m.conv2.bn.bias,
                               torch.zeros_like(m.conv2.bn.bias))

    # resnet3d with depth 50, no pretrained, frozen_stages, norm_eval False
    frozen_stages = 1
    resnet3d_50_frozen = ResNet3d(
        50, None, pretrained2d=False, frozen_stages=frozen_stages)
    resnet3d_50_frozen.init_weights()
    resnet3d_50_frozen.train()
    assert resnet3d_50_frozen.conv1.bn.training is False
    for param in resnet3d_50_frozen.conv1.parameters():
        assert param.requires_grad is False
    for i in range(1, frozen_stages + 1):
        layer = getattr(resnet3d_50_frozen, f'layer{i}')
        for mod in layer.modules():
            if isinstance(mod, _BatchNorm):
                assert mod.training is False
        for param in layer.parameters():
            assert param.requires_grad is False
    # test zero_init_residual
    for m in resnet3d_50_frozen.modules():
        if hasattr(m, 'conv3'):
            assert torch.equal(m.conv3.bn.weight,
                               torch.zeros_like(m.conv3.bn.weight))
            assert torch.equal(m.conv3.bn.bias,
                               torch.zeros_like(m.conv3.bn.bias))

    # resnet3d frozen with depth 34 inference
    input_shape = (1, 3, 6, 64, 64)
    imgs = generate_backbone_demo_inputs(input_shape)
    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            resnet3d_34_frozen = resnet3d_34_frozen.cuda()
            imgs_gpu = imgs.cuda()
            feat = resnet3d_34_frozen(imgs_gpu)
            assert feat.shape == torch.Size([1, 512, 3, 2, 2])
    else:
        feat = resnet3d_34_frozen(imgs)
        assert feat.shape == torch.Size([1, 512, 3, 2, 2])

    # resnet3d with depth 50 inference
    input_shape = (1, 3, 6, 64, 64)
    imgs = generate_backbone_demo_inputs(input_shape)
    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            resnet3d_50_frozen = resnet3d_50_frozen.cuda()
            imgs_gpu = imgs.cuda()
            feat = resnet3d_50_frozen(imgs_gpu)
            assert feat.shape == torch.Size([1, 2048, 3, 2, 2])
    else:
        feat = resnet3d_50_frozen(imgs)
        assert feat.shape == torch.Size([1, 2048, 3, 2, 2])

    # resnet3d with depth 50 in caffe style inference
    resnet3d_50_caffe = ResNet3d(50, None, pretrained2d=False, style='caffe')
    resnet3d_50_caffe.init_weights()
    resnet3d_50_caffe.train()

    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            resnet3d_50_caffe = resnet3d_50_caffe.cuda()
            imgs_gpu = imgs.cuda()
            feat = resnet3d_50_caffe(imgs_gpu)
            assert feat.shape == torch.Size([1, 2048, 3, 2, 2])
    else:
        feat = resnet3d_50_caffe(imgs)
        assert feat.shape == torch.Size([1, 2048, 3, 2, 2])

    # resnet3d with depth 34 in caffe style inference
    resnet3d_34_caffe = ResNet3d(34, None, pretrained2d=False, style='caffe')
    resnet3d_34_caffe.init_weights()
    resnet3d_34_caffe.train()
    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            resnet3d_34_caffe = resnet3d_34_caffe.cuda()
            imgs_gpu = imgs.cuda()
            feat = resnet3d_34_caffe(imgs_gpu)
            assert feat.shape == torch.Size([1, 512, 3, 2, 2])
    else:
        feat = resnet3d_34_caffe(imgs)
        assert feat.shape == torch.Size([1, 512, 3, 2, 2])

    # resnet3d with depth with 3x3x3 inflate_style inference
    resnet3d_50_1x1x1 = ResNet3d(
        50, None, pretrained2d=False, inflate_style='3x3x3')
    resnet3d_50_1x1x1.init_weights()
    resnet3d_50_1x1x1.train()
    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            resnet3d_50_1x1x1 = resnet3d_50_1x1x1.cuda()
            imgs_gpu = imgs.cuda()
            feat = resnet3d_50_1x1x1(imgs_gpu)
            assert feat.shape == torch.Size([1, 2048, 3, 2, 2])
    else:
        feat = resnet3d_50_1x1x1(imgs)
        assert feat.shape == torch.Size([1, 2048, 3, 2, 2])

    resnet3d_34_1x1x1 = ResNet3d(
        34, None, pretrained2d=False, inflate_style='3x3x3')
    resnet3d_34_1x1x1.init_weights()
    resnet3d_34_1x1x1.train()

    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            resnet3d_34_1x1x1 = resnet3d_34_1x1x1.cuda()
            imgs_gpu = imgs.cuda()
            feat = resnet3d_34_1x1x1(imgs_gpu)
            assert feat.shape == torch.Size([1, 512, 3, 2, 2])
    else:
        feat = resnet3d_34_1x1x1(imgs)
        assert feat.shape == torch.Size([1, 512, 3, 2, 2])

    # resnet3d with non-local module
    non_local_cfg = dict(
        sub_sample=True,
        use_scale=False,
        norm_cfg=dict(type='BN3d', requires_grad=True),
        mode='embedded_gaussian')
    non_local = ((0, 0, 0), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0), (0, 0, 0))
    resnet3d_nonlocal = ResNet3d(
        50,
        None,
        pretrained2d=False,
        non_local=non_local,
        non_local_cfg=non_local_cfg)
    resnet3d_nonlocal.init_weights()
    for layer_name in ['layer2', 'layer3']:
        layer = getattr(resnet3d_nonlocal, layer_name)
        for i, _ in enumerate(layer):
            if i % 2 == 0:
                assert hasattr(layer[i], 'non_local_block')

    feat = resnet3d_nonlocal(imgs)
    assert feat.shape == torch.Size([1, 2048, 3, 2, 2])


def test_resnet3d_layer():
    with pytest.raises(AssertionError):
        ResNet3dLayer(22, None)

    with pytest.raises(AssertionError):
        ResNet3dLayer(50, None, stage=4)

    res_layer = ResNet3dLayer(50, None, stage=3, norm_eval=True)
    res_layer.init_weights()
    res_layer.train()
    input_shape = (1, 1024, 1, 4, 4)
    imgs = generate_backbone_demo_inputs(input_shape)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            res_layer = res_layer.cuda()
            imgs_gpu = imgs.cuda()
            feat = res_layer(imgs_gpu)
            assert feat.shape == torch.Size([1, 2048, 1, 2, 2])
    else:
        feat = res_layer(imgs)
        assert feat.shape == torch.Size([1, 2048, 1, 2, 2])

    res_layer = ResNet3dLayer(
        50, 'torchvision://resnet50', stage=3, all_frozen=True)
    res_layer.init_weights()
    res_layer.train()
    imgs = generate_backbone_demo_inputs(input_shape)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            res_layer = res_layer.cuda()
            imgs_gpu = imgs.cuda()
            feat = res_layer(imgs_gpu)
            assert feat.shape == torch.Size([1, 2048, 1, 2, 2])
    else:
        feat = res_layer(imgs)
        assert feat.shape == torch.Size([1, 2048, 1, 2, 2])
