import torch

from src.gpu_utils import roc_auc_gpu_safe


def test_multi_class():
    x = torch.tensor([[0.0, 0.2, 0.8], [0.9, 0.1, 0], [0, 1, 0]])
    y = torch.tensor([2, 1, 0])
    value = roc_auc_gpu_safe(y, x, multi_class="ovo")
    assert value is not None
