import numpy as np
import torch
import torch.distributions as D
import torch.nn.functional as F
from hypothesis import given, settings
from hypothesis import strategies as st

from pdisvae import inference


class TestKLNormal:
    def test_analytical(self):
        """Test analytical KL divergence.

        Check the KL divergence is correct using torch.distributions.kl.kl_divergence.
        """
        z_pred_mean = torch.randn(10, 5)
        z_pred_log_std = torch.randn(10, 5)

        kl_normal = inference.KLNormal(prior="normal")

        kl = kl_normal.analytical(z_pred_mean, z_pred_log_std)
        kl_ref = D.kl_divergence(
            D.Normal(z_pred_mean, z_pred_log_std.exp()),
            D.Normal(torch.zeros_like(z_pred_mean), torch.ones_like(z_pred_log_std)),
        ).sum(dim=-1)

        torch.testing.assert_close(kl, kl_ref)

    @given(st.data())
    @settings(deadline=None)
    def test_decomposed(self, data: st.DataObject):
        """Test decomposed.

        Check the decomposed weighted sum is equal to the KL divergence when alpha = beta = gamma = 1.
        """

        prior = data.draw(st.sampled_from(["normal", "logcosh"]))
        n_groups = data.draw(st.integers(1, 3))
        group_rank = data.draw(st.integers(1, 3))

        batch_size = 10
        n_components = n_groups * group_rank
        n_total_samples = 100

        z_pred_mean = torch.randn(batch_size, n_components)
        z_pred_log_std = torch.randn(batch_size, n_components)
        z = torch.randn(batch_size, n_components)

        n_total_samples = data.draw(st.sampled_from([None, n_total_samples]))

        kl_normal = inference.KLNormal(
            prior=prior,
            n_groups=n_groups,
            group_rank=group_rank,
            n_total_samples=n_total_samples,
        )

        index_code_mutual_information, partial_correlation, dimension_wise_kl = (
            kl_normal.decomposed(z_pred_mean, z, z_pred_log_std)
        )
        kl = index_code_mutual_information + partial_correlation + dimension_wise_kl

        ln_q_zgx = -F.gaussian_nll_loss(
            z_pred_mean, z, (z_pred_log_std.exp() ** 2), full=True, reduction="none"
        ).sum(
            dim=-1
        )  # (n_monte_carlo = batch_size,)
        ln_p_z = kl_normal.prior_log_prob(z)  # (n_monte_carlo = batch_size,)
        kl_ref = (ln_q_zgx - ln_p_z).mean()

        torch.testing.assert_close(kl, kl_ref)
