import pytest
import torch

from mmdet.models.utils import (LearnedPositionalEncoding,
                                SinePositionalEncoding)


def test_sine_positional_encoding(num_feats=16, batch_size=2):
    # test invalid type of scale
    with pytest.raises(AssertionError):
        module = SinePositionalEncoding(
            num_feats, scale=(3., ), normalize=True)

    module = SinePositionalEncoding(num_feats)
    h, w = 10, 6
    mask = torch.rand(batch_size, h, w) > 0.5
    assert not module.normalize
    out = module(mask)
    assert out.shape == (batch_size, num_feats * 2, h, w)

    # set normalize
    module = SinePositionalEncoding(num_feats, normalize=True)
    assert module.normalize
    out = module(mask)
    assert out.shape == (batch_size, num_feats * 2, h, w)


def test_learned_positional_encoding(num_feats=16,
                                     row_num_embed=10,
                                     col_num_embed=10,
                                     batch_size=2):
    module = LearnedPositionalEncoding(num_feats, row_num_embed, col_num_embed)
    assert module.row_embed.weight.shape == (row_num_embed, num_feats)
    assert module.col_embed.weight.shape == (col_num_embed, num_feats)
    h, w = 10, 6
    mask = torch.rand(batch_size, h, w) > 0.5
    out = module(mask)
    assert out.shape == (batch_size, num_feats * 2, h, w)
