import pytest
import torch
from XXX.uib import information_quantities as iq
from XXX.uib.losses.very_approx_regularizers import covariance_trace_given_Y, covariance_trace_given_X, \
    estimate_entropy_Z__Y


@pytest.mark.repeat(1)
def test_covariance_given_Y_k1():
    features = torch.randn(200, 1, 5, dtype=torch.double, requires_grad=True)
    labels = torch.randint(0, 10, (200,), requires_grad=False)

    torch.autograd.gradcheck(covariance_trace_given_Y(stochastic=True), (features, labels), eps=1e-6,
                             atol=1e-4,
                             raise_exception=True)


@pytest.mark.repeat(1)
def test_cestimate_entropy_Z__Y_k1():
    features = torch.randn(200, 1, 5, dtype=torch.double, requires_grad=True)
    labels = torch.randint(0, 10, (200,), requires_grad=False)

    torch.autograd.gradcheck(estimate_entropy_Z__Y(stochastic=True), (features, labels), eps=1e-6,
                             atol=1e-4,
                             raise_exception=True)


@pytest.mark.repeat(1)
def test_covariance_given_X_k2():
    features = torch.randn(200, 2, 5, dtype=torch.double, requires_grad=True)
    labels = torch.randint(0, 10, (200,), requires_grad=False)

    torch.autograd.gradcheck(covariance_trace_given_X(stochastic=True), (features, labels), eps=1e-6,
                             atol=1e-4,
                             raise_exception=True)
