import pytest
import torch
import torch.distributions
import XXX.uib.information_quantities as iq
from XXX.uib import multilabel_categorical_iq_loss


@pytest.mark.repeat(10)
def test_iq_loss():
    logits_x_yhat = torch.randn(200, 5, dtype=torch.double, requires_grad=True)
    prob_x_y = torch.nn.functional.softmax(torch.randn(200, 10, dtype=torch.double, requires_grad=False), dim=1)
    input = (logits_x_yhat, prob_x_y)

    torch.autograd.gradcheck(multilabel_categorical_iq_loss.iq_loss(torch.randn_like(iq.H_Z__X) * 5), input, eps=1e-6,
                             atol=1e-4,
                             raise_exception=True)


def test_decoder_uncertainty():
    logits_x_yhat = torch.randn(200, 5, dtype=torch.double, requires_grad=True)
    prob_x_y = torch.nn.functional.softmax(torch.randn(200, 10, dtype=torch.double, requires_grad=False), dim=1)
    input = (logits_x_yhat, prob_x_y)

    torch.autograd.gradcheck(multilabel_categorical_iq_loss.iq_loss(iq.decoder_uncertainty), input,
                             eps=1e-6,
                             atol=1e-4,
                             raise_exception=True)


def test_entropy_distance():
    logits_x_yhat = torch.randn(200, 5, dtype=torch.double, requires_grad=True)
    prob_x_y = torch.nn.functional.softmax(torch.randn(200, 10, dtype=torch.double, requires_grad=False), dim=1)
    input = (logits_x_yhat, prob_x_y)

    torch.autograd.gradcheck(multilabel_categorical_iq_loss.iq_loss(iq.entropy_distance), input, eps=1e-6,
                             atol=1e-4,
                             raise_exception=True)


def test_reverse_decoder_uncertainty():
    logits_x_yhat = torch.randn(200, 5, dtype=torch.double, requires_grad=True)
    prob_x_y = torch.nn.functional.softmax(torch.randn(200, 10, dtype=torch.double, requires_grad=False), dim=1)
    input = (logits_x_yhat, prob_x_y)

    torch.autograd.gradcheck(multilabel_categorical_iq_loss.iq_loss(iq.reverse_decoder_uncertainty), input,
                             eps=1e-6,
                             atol=1e-4,
                             raise_exception=True)

def test_semantics():
    b = torch.distributions.Bernoulli(0.4)
    logits = torch.log(torch.tensor([[0.4, 0.6], [0.4, 0.6]]))
    prob_x_y = torch.nn.functional.softmax(logits, dim=1)

    assert torch.isclose(multilabel_categorical_iq_loss.iq_loss(iq.H_Z)(logits, prob_x_y).float(), b.entropy())

