import pytest
import torch
import torch.distributions
import XXX.uib.information_quantities as iq
from XXX.uib import gaussian_continuous_iq_loss
from XXX.uib import kraskov_continuous_iq_loss
from XXX.uib.modules.continuous_encoding_summarizer import OldContinuousLatentLabelEntropiesSummarizer


@pytest.mark.repeat(10)
def test_coverage():
    num_samples = 1000
    features = torch.randn(num_samples, 5, dtype=torch.double, requires_grad=True)
    labels = torch.randint(0, 10, (num_samples,), requires_grad=False)

    information_quantity = torch.randn_like(iq.H_Z__X) * 5

    result = gaussian_continuous_iq_loss.onehot_continuous_information(features, labels, information_quantity)

    test_result = kraskov_continuous_iq_loss.iq_loss(information_quantity)(features, labels)

    assert torch.isclose(test_result, result, atol=1)


def test_entropies():
    num_samples = 10000
    features = torch.randn(num_samples, 5, dtype=torch.double, requires_grad=True)
    labels = torch.randint(0, 10, (num_samples,), requires_grad=False)

    information_quantity = iq.H_Z

    summarizer = OldContinuousLatentLabelEntropiesSummarizer()

    result = gaussian_continuous_iq_loss.onehot_continuous_information(features, labels, information_quantity)

    summarizer.fit(features, labels)

    h_yhat = summarizer.get_iq(iq.H_Z, 1)
    h_yhat__y = summarizer.get_iq(iq.H_Z__Y, 1)

    print(h_yhat, h_yhat__y)

    assert h_yhat >= h_yhat__y
    assert torch.isclose(h_yhat, result, atol=1)
