import torch

from memoria.utils import super_unique


def test_super_unique():
    x = torch.tensor(
        [
            [[0, 1, 0, 2, 1], [1, 4, 1, 2, 1], [3, 4, 1, 4, 3], [2, 0, 0, 3, 2]],
            [[0, 1, 4, 2, 4], [1, 4, 3, 2, 0], [2, 3, 4, 2, 0], [4, 4, 4, 3, 0]],
            [[2, 3, 2, 3, 3], [2, 0, 1, 3, 0], [1, 3, 2, 0, 0], [1, 2, 3, 0, 0]],
        ],
        dtype=torch.int32,
    )
    assert (
        super_unique(x, dim=1)
        == torch.tensor(
            [
                [[1, 1, 1, 2, 1], [3, 4, 0, 4, 3], [2, 0, -1, 3, 2], [0, -1, -1, -1, -1]],
                [[2, 1, 4, 2, 4], [4, 4, 3, 3, 0], [0, 3, -1, -1, -1], [1, -1, -1, -1, -1]],
                [[1, 2, 1, 0, 0], [2, 3, 3, 3, 3], [-1, 0, 2, -1, -1], [-1, -1, -1, -1, -1]],
            ]
        )
    ).all()

    x = torch.tensor([[2, 3, 4, 3, 0], [3, 1, 3, 1, 0], [4, 3, 2, 2, 4], [2, 2, 2, 0, 3]], dtype=torch.int32)
    assert (
        super_unique(x, dim=0) == torch.tensor([[2, 1, 2, 1, 0], [4, 3, 4, 3, 3], [3, 2, 3, 2, 4], [-1, -1, -1, 0, -1]])
    ).all()
