import pytest
import torch
import torch.distributions
import XXX.uib.information_quantities as iq
from XXX.uib import gaussian_continuous_iq_loss


@pytest.mark.repeat(1)
def test_onehot_continuous_information():
    features = torch.randn(200, 5, dtype=torch.double, requires_grad=True)
    labels = torch.randint(0, 10, (200,), requires_grad=False)
    information_quantity = torch.randn_like(iq.H_Z__X) * 5

    torch.autograd.gradcheck(gaussian_continuous_iq_loss.onehot_continuous_information, (features, labels, information_quantity), eps=1e-6,
                             atol=1e-4,
                             raise_exception=True)
