import pytest
import torch
import torch.distributions
import XXX.uib.information_quantities as iq
from XXX.uib import gaussian_continuous_iq_loss
from XXX.uib.utils import torch_kraskov_entropy
from XXX.uib.utils import nonparametric_mutual_info
import numpy as np

@pytest.mark.repeat(10)
def test_coverage():
    num_samples = 1000
    features = torch.randn(num_samples, 5, dtype=torch.double, requires_grad=True)

    result = nonparametric_mutual_info.entropy(features.detach().numpy())
    test_result = torch_kraskov_entropy.entropy(features)

    assert np.isclose(test_result.detach().numpy(), result)


def test_entropies():
    num_samples = 1000
    features = torch.randn(num_samples, 5, dtype=torch.double, requires_grad=True)
    labels = torch.randint(0, 10, (num_samples,), requires_grad=False)

    information_quantity = iq.H_Z

    result = gaussian_continuous_iq_loss.onehot_continuous_information(features, labels, information_quantity)

    test_result = torch_kraskov_entropy.entropy(features)

    assert torch.isclose(test_result, result, atol=1)


@pytest.mark.repeat(5)
def test_split_knn():
    num_samples = 1000
    features = torch.randn(num_samples, 5, dtype=torch.double, requires_grad=True)

    ref_dists = torch_kraskov_entropy.naive_knn_dist(features, 10)
    test_dists = torch_kraskov_entropy.split_knn_dist(features, 10, split_size=100)

    assert torch.allclose(ref_dists, test_dists, atol=0.1)


@pytest.mark.repeat(5)
def test_split_knn_duplicates_ignored():
    num_samples = 50
    features = torch.randn(num_samples, 5, dtype=torch.double, requires_grad=True)

    duplicated_features = torch.cat([features, features], dim=0)

    ref_dists = torch_kraskov_entropy.naive_knn_dist(features, 10)
    test_dists = torch_kraskov_entropy.split_knn_dist(duplicated_features, 10, split_size=50)

    assert torch.allclose(ref_dists, test_dists[:num_samples], atol=0.1)


@pytest.mark.repeat(5)
def test_split_knn_duplicates_ignored_2():
    num_samples = 50
    features = torch.randn(num_samples, 5, dtype=torch.double, requires_grad=True)

    duplicated_features = torch.cat([features, features], dim=0)

    ref_dists = torch_kraskov_entropy.naive_knn_dist(features, 10)
    test_dists = torch_kraskov_entropy.split_knn_dist(duplicated_features, 10, split_size=75)

    assert torch.allclose(ref_dists, test_dists[:num_samples], atol=0.1)


@pytest.mark.repeat(5)
def test_split_squared_knn_duplicates_ignored_2():
    num_samples = 50
    features = torch.randn(num_samples, 5, dtype=torch.double, requires_grad=True)

    duplicated_features = torch.cat([features, features], dim=0)

    ref_dists = torch_kraskov_entropy.naive_knn_dist(features, 10)
    test_dists = torch_kraskov_entropy.split_knn_squared_dist(duplicated_features, 10, split_size=75)

    assert torch.allclose(ref_dists**2, test_dists[:num_samples], atol=0.1)