import pytest
import torch

from transfer.models.mlp import MLPModel


def create_sequence(batch_size=64, context_size=20, hidden_size=256):
    sequence = torch.randn((context_size, hidden_size)).unsqueeze(0).repeat(batch_size, 1, 1)
    return sequence


@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("context_size", [10, 20])
@pytest.mark.parametrize("use_returns", [True, False])
@pytest.mark.parametrize("use_actions", [True, False])
def test_mask_sequence(hidden_size, context_size, use_returns, use_actions):
    mlp = MLPModel(10, 10, hidden_size, 4, max_length=context_size, use_actions=use_actions, use_returns=use_returns)
    sequence = create_sequence(context_size=context_size * mlp.n, hidden_size=hidden_size)
    masked_sequence = mlp.mask_sequence(sequence)
    assert masked_sequence.shape == (sequence.shape[0], context_size, context_size * mlp.n, hidden_size)

    for i in range(context_size):
        for j in range(sequence.shape[0]):
            split_idx = -(mlp.n * (i + 1) - int(use_actions))
            assert torch.count_nonzero(masked_sequence[j][i][:split_idx]) == 0
            assert torch.count_nonzero(masked_sequence[j][i][split_idx:]) == torch.prod(
                torch.tensor(masked_sequence[j][i][split_idx:].shape)
            )
