import pytest
import torch
import XXX.uib.categorical_iq_loss as categorical_entropies
import XXX.uib.information_quantities as iq
from XXX.uib.modules.categorical_entropies_summarizer import CategoricalEntropiesSummarizer


@pytest.mark.repeat(10)
def test_moving_categorical_entropies():
    logits_x_k_z, labels_x = (
        torch.randn(500, 2, 5, dtype=torch.double, requires_grad=True), torch.randint(0, 10, (1000,), requires_grad=False))

    #p_x_yhat = torch.nn.functional.softmax(logits_x_yhat, dim=-1, dtype=torch.double)

    quantity_type = torch.randn_like(iq.H_Z__X) * 5
    value = categorical_entropies.iq_loss(quantity_type)(logits_x_k_z, labels_x)

    test_class = CategoricalEntropiesSummarizer(5, 10)
    split_logits = torch.split(logits_x_k_z,10)
    split_labels = torch.split(labels_x, 10)

    for probs, labels in zip(split_logits, split_labels):
        test_class.fit(probs, labels)

    test_value = test_class.get_iq(quantity_type)
    assert torch.isclose(value, test_value)

    test_class.reset()

    for probs, labels in zip(split_logits, split_labels):
        test_class.fit(probs, labels)

    test_value = test_class.get_iq(quantity_type)
    assert torch.isclose(value, test_value)


