# Copyright (c) OpenMMLab. All rights reserved.
import copy

import pytest
import torch
import torch.nn as nn

from mmpretrain.models.backbones.twins import (PCPVT, SVT,
                                               GlobalSubsampledAttention,
                                               LocallyGroupedSelfAttention)


def test_LSA_module():
    lsa = LocallyGroupedSelfAttention(embed_dims=32, window_size=3)
    outs = lsa(torch.randn(1, 3136, 32), (56, 56))
    assert outs.shape == torch.Size([1, 3136, 32])


def test_GSA_module():
    gsa = GlobalSubsampledAttention(embed_dims=32, num_heads=8)
    outs = gsa(torch.randn(1, 3136, 32), (56, 56))
    assert outs.shape == torch.Size([1, 3136, 32])


def test_pcpvt():
    # test init
    path = 'PATH_THAT_DO_NOT_EXIST'

    # init_cfg loads pretrain from an non-existent file
    model = PCPVT('s', init_cfg=dict(type='Pretrained', checkpoint=path))
    assert model.init_cfg == dict(type='Pretrained', checkpoint=path)

    # Test loading a checkpoint from an non-existent file
    with pytest.raises(OSError):
        model.init_weights()

    # init_cfg=123, whose type is unsupported
    model = PCPVT('s', init_cfg=123)
    with pytest.raises(TypeError):
        model.init_weights()

    H, W = (64, 64)
    temp = torch.randn((1, 3, H, W))

    # test output last feat
    model = PCPVT('small')
    model.init_weights()
    outs = model(temp)
    assert len(outs) == 1
    assert outs[-1].shape == (1, 512, H // 32, W // 32)

    # test with multi outputs
    model = PCPVT('small', out_indices=(0, 1, 2, 3))
    model.init_weights()
    outs = model(temp)
    assert len(outs) == 4
    assert outs[0].shape == (1, 64, H // 4, W // 4)
    assert outs[1].shape == (1, 128, H // 8, W // 8)
    assert outs[2].shape == (1, 320, H // 16, W // 16)
    assert outs[3].shape == (1, 512, H // 32, W // 32)

    # test with arch of dict
    arch = {
        'embed_dims': [64, 128, 320, 512],
        'depths': [3, 4, 18, 3],
        'num_heads': [1, 2, 5, 8],
        'patch_sizes': [4, 2, 2, 2],
        'strides': [4, 2, 2, 2],
        'mlp_ratios': [8, 8, 4, 4],
        'sr_ratios': [8, 4, 2, 1]
    }

    pcpvt_arch = copy.deepcopy(arch)
    model = PCPVT(pcpvt_arch, out_indices=(0, 1, 2, 3))
    model.init_weights()
    outs = model(temp)
    assert len(outs) == 4
    assert outs[0].shape == (1, 64, H // 4, W // 4)
    assert outs[1].shape == (1, 128, H // 8, W // 8)
    assert outs[2].shape == (1, 320, H // 16, W // 16)
    assert outs[3].shape == (1, 512, H // 32, W // 32)

    # assert length of arch value not equal
    pcpvt_arch = copy.deepcopy(arch)
    pcpvt_arch['sr_ratios'] = [8, 4, 2]
    with pytest.raises(AssertionError):
        model = PCPVT(pcpvt_arch, out_indices=(0, 1, 2, 3))

    # assert lack arch essential_keys
    pcpvt_arch = copy.deepcopy(arch)
    del pcpvt_arch['sr_ratios']
    with pytest.raises(AssertionError):
        model = PCPVT(pcpvt_arch, out_indices=(0, 1, 2, 3))

    # assert arch value not list
    pcpvt_arch = copy.deepcopy(arch)
    pcpvt_arch['sr_ratios'] = 1
    with pytest.raises(AssertionError):
        model = PCPVT(pcpvt_arch, out_indices=(0, 1, 2, 3))

    pcpvt_arch = copy.deepcopy(arch)
    pcpvt_arch['sr_ratios'] = '1, 2, 3, 4'
    with pytest.raises(AssertionError):
        model = PCPVT(pcpvt_arch, out_indices=(0, 1, 2, 3))

    # test norm_after_stage is bool True
    model = PCPVT('small', norm_after_stage=True, norm_cfg=dict(type='LN'))
    for i in range(model.num_stage):
        assert hasattr(model, f'norm_after_stage{i}')
        assert isinstance(getattr(model, f'norm_after_stage{i}'), nn.LayerNorm)

    # test norm_after_stage is bool Flase
    model = PCPVT('small', norm_after_stage=False)
    for i in range(model.num_stage):
        assert hasattr(model, f'norm_after_stage{i}')
        assert isinstance(getattr(model, f'norm_after_stage{i}'), nn.Identity)

    # test norm_after_stage is bool list
    norm_after_stage = [False, True, False, True]
    model = PCPVT('small', norm_after_stage=norm_after_stage)
    assert len(norm_after_stage) == model.num_stage
    for i in range(model.num_stage):
        assert hasattr(model, f'norm_after_stage{i}')
        norm_layer = getattr(model, f'norm_after_stage{i}')
        if norm_after_stage[i]:
            assert isinstance(norm_layer, nn.LayerNorm)
        else:
            assert isinstance(norm_layer, nn.Identity)

    # test norm_after_stage is not bool list
    norm_after_stage = [False, 'True', False, True]
    with pytest.raises(AssertionError):
        model = PCPVT('small', norm_after_stage=norm_after_stage)


def test_svt():
    # test init
    path = 'PATH_THAT_DO_NOT_EXIST'

    # init_cfg loads pretrain from an non-existent file
    model = SVT('s', init_cfg=dict(type='Pretrained', checkpoint=path))
    assert model.init_cfg == dict(type='Pretrained', checkpoint=path)

    # Test loading a checkpoint from an non-existent file
    with pytest.raises(OSError):
        model.init_weights()

    # init_cfg=123, whose type is unsupported
    model = SVT('s', init_cfg=123)
    with pytest.raises(TypeError):
        model.init_weights()

    # Test feature map output
    H, W = (64, 64)
    temp = torch.randn((1, 3, H, W))

    model = SVT('s')
    model.init_weights()
    outs = model(temp)
    assert len(outs) == 1
    assert outs[-1].shape == (1, 512, H // 32, W // 32)

    # test with multi outputs
    model = SVT('small', out_indices=(0, 1, 2, 3))
    model.init_weights()
    outs = model(temp)
    assert len(outs) == 4
    assert outs[0].shape == (1, 64, H // 4, W // 4)
    assert outs[1].shape == (1, 128, H // 8, W // 8)
    assert outs[2].shape == (1, 256, H // 16, W // 16)
    assert outs[3].shape == (1, 512, H // 32, W // 32)

    # test with arch of dict
    arch = {
        'embed_dims': [96, 192, 384, 768],
        'depths': [2, 2, 18, 2],
        'num_heads': [3, 6, 12, 24],
        'patch_sizes': [4, 2, 2, 2],
        'strides': [4, 2, 2, 2],
        'mlp_ratios': [4, 4, 4, 4],
        'sr_ratios': [8, 4, 2, 1],
        'window_sizes': [7, 7, 7, 7]
    }
    model = SVT(arch, out_indices=(0, 1, 2, 3))
    model.init_weights()
    outs = model(temp)
    assert len(outs) == 4
    assert outs[0].shape == (1, 96, H // 4, W // 4)
    assert outs[1].shape == (1, 192, H // 8, W // 8)
    assert outs[2].shape == (1, 384, H // 16, W // 16)
    assert outs[3].shape == (1, 768, H // 32, W // 32)

    # assert length of arch value not equal
    svt_arch = copy.deepcopy(arch)
    svt_arch['sr_ratios'] = [8, 4, 2]
    with pytest.raises(AssertionError):
        model = SVT(svt_arch, out_indices=(0, 1, 2, 3))

    # assert lack arch essential_keys
    svt_arch = copy.deepcopy(arch)
    del svt_arch['window_sizes']
    with pytest.raises(AssertionError):
        model = SVT(svt_arch, out_indices=(0, 1, 2, 3))

    # assert arch value not list
    svt_arch = copy.deepcopy(arch)
    svt_arch['sr_ratios'] = 1
    with pytest.raises(AssertionError):
        model = SVT(svt_arch, out_indices=(0, 1, 2, 3))

    svt_arch = copy.deepcopy(arch)
    svt_arch['sr_ratios'] = '1, 2, 3, 4'
    with pytest.raises(AssertionError):
        model = SVT(svt_arch, out_indices=(0, 1, 2, 3))

    # test norm_after_stage is bool True
    model = SVT('small', norm_after_stage=True, norm_cfg=dict(type='LN'))
    for i in range(model.num_stage):
        assert hasattr(model, f'norm_after_stage{i}')
        assert isinstance(getattr(model, f'norm_after_stage{i}'), nn.LayerNorm)

    # test norm_after_stage is bool Flase
    model = SVT('small', norm_after_stage=False)
    for i in range(model.num_stage):
        assert hasattr(model, f'norm_after_stage{i}')
        assert isinstance(getattr(model, f'norm_after_stage{i}'), nn.Identity)

    # test norm_after_stage is bool list
    norm_after_stage = [False, True, False, True]
    model = SVT('small', norm_after_stage=norm_after_stage)
    assert len(norm_after_stage) == model.num_stage
    for i in range(model.num_stage):
        assert hasattr(model, f'norm_after_stage{i}')
        norm_layer = getattr(model, f'norm_after_stage{i}')
        if norm_after_stage[i]:
            assert isinstance(norm_layer, nn.LayerNorm)
        else:
            assert isinstance(norm_layer, nn.Identity)

    # test norm_after_stage is not bool list
    norm_after_stage = [False, 'True', False, True]
    with pytest.raises(AssertionError):
        model = SVT('small', norm_after_stage=norm_after_stage)
