import pytest
import torch
import torch.distributions
import XXX.uib.information_quantities as iq
from XXX.uib import gaussian_continuous_iq_loss
from XXX.uib.gaussian_continuous_iq_loss import diff_entropy
from XXX.uib.modules.continuous_encoding_summarizer import OldContinuousLatentLabelEntropiesSummarizer, \
    ContinuousLatentLabelEntropiesSummarizer


@pytest.mark.repeat(10)
def test_coverage():
    num_samples = 1000
    features = torch.randn(num_samples, 1, 5, dtype=torch.double, requires_grad=False)
    labels = torch.randint(0, 10, (num_samples,), requires_grad=False)

    information_quantity = torch.randn_like(iq.decoder_uncertainty)

    print(diff_entropy(torch.diag(torch.tensor([1, 1, 1, 1, 1]).float())))

    summarizer = ContinuousLatentLabelEntropiesSummarizer()

    expanded_labels = labels[:, None].expand(-1, 1)
    result = gaussian_continuous_iq_loss.onehot_continuous_information(features.flatten(0, 1), expanded_labels.flatten(0, 1),
                                                                       information_quantity)

    summarizer.fit(features, labels)

    test_result = summarizer.get_iq(information_quantity, 5)

    assert torch.isclose(test_result, result, atol=1)


def test_entropies():
    num_samples = 1000
    features = torch.randn(num_samples, 5, 5, dtype=torch.double, requires_grad=False)
    labels = torch.randint(0, 10, (num_samples,), requires_grad=False)

    information_quantity = iq.H_Z

    summarizer = ContinuousLatentLabelEntropiesSummarizer()

    result = gaussian_continuous_iq_loss.onehot_continuous_information(features.flatten(0, 1), labels,
                                                                       information_quantity)

    summarizer.fit(features, labels)

    h_yhat = summarizer.get_iq(iq.H_Z, 1)
    h_yhat__y = summarizer.get_iq(iq.H_Z__X, 1)

    print(h_yhat, h_yhat__y)

    assert torch.isclose(h_yhat, h_yhat__y, atol=1)
    assert torch.isclose(h_yhat, result, atol=1)
