import pytest
import torch
import torch.distributions
import XXX.uib.information_quantities as iq
import XXX.uib.categorical_iq_loss as categorical_entropies

@pytest.mark.repeat(10)
def test_iq_loss():
    input = (
        torch.randn(200, 5, dtype=torch.double, requires_grad=True), torch.randint(0, 10, (200,), requires_grad=False))

    torch.autograd.gradcheck(categorical_entropies.iq_loss(torch.randn_like(iq.H_Z__X) * 5), input, eps=1e-6,
                             atol=1e-4,
                             raise_exception=True)


def test_decoder_uncertainty():
    input = (
        torch.randn(200, 5, dtype=torch.double, requires_grad=True), torch.randint(0, 10, (200,), requires_grad=False))

    torch.autograd.gradcheck(categorical_entropies.iq_loss(iq.decoder_uncertainty), input,
                             eps=1e-6,
                             atol=1e-4,
                             raise_exception=True)


def test_entropy_distance():
    input = (
        torch.randn(200, 5, dtype=torch.double, requires_grad=True), torch.randint(0, 10, (200,), requires_grad=False))

    torch.autograd.gradcheck(categorical_entropies.iq_loss(iq.entropy_distance), input, eps=1e-6,
                             atol=1e-4,
                             raise_exception=True)


def test_reverse_decoder_uncertainty():
    input = (
        torch.randn(200, 5, dtype=torch.double, requires_grad=True), torch.randint(0, 10, (200,), requires_grad=False))

    torch.autograd.gradcheck(categorical_entropies.iq_loss(iq.reverse_decoder_uncertainty), input,
                             eps=1e-6,
                             atol=1e-4,
                             raise_exception=True)

def test_semantics():
    b = torch.distributions.Bernoulli(0.4)
    logits = torch.log(torch.tensor([[0.4, 0.6], [0.4, 0.6]]))
    labels = torch.tensor([1, 1])

    assert torch.isclose(categorical_entropies.iq_loss(iq.H_Z)(logits, labels).float(), b.entropy())

