"""
The KL divergence between a normal distribution and the prior distribution.
"""

import numpy as np
import torch
import torch.nn.functional as F

from pdisvae import utils


class KLNormal:
    """KL divergence between a normal distribution and the prior distribution.

    Parameters
    ----------
    prior : str
        Name of the prior distribution. ["normal", "logcosh"]
    alpha : int
        Weight of the Index-code mutual information term. Default: 1.
    beta : int
        Weight of the partial correlation term. Default: 1.
    gamma : int
        Weight of the dimension-wise KL divergence term. Default: 1.
    n_groups : int
        Number of groups. Default: 1.
    group_rank : int
        Rank of the group. Default: 1.
    n_total_samples : int | None
        Number of total samples. Default: None.

    Raises
    ------
    ValueError
        If the prior is not in ["normal", "logcosh"].
    """

    def __init__(
        self,
        prior: str = "normal",
        n_groups: int = 1,
        group_rank: int = 1,
        n_total_samples: int | None = None,
    ) -> None:

        self.prior = prior
        if prior == "normal":
            self.prior_log_prob = utils.normal_log_prob
        elif self.prior == "logcosh":
            self.prior_log_prob = utils.logcosh_log_prob
        else:
            raise ValueError(f"Unknown prior: {prior}")
        self.group_rank = group_rank
        self.n_groups = n_groups
        self.n_total_samples = n_total_samples

    def analytical(
        self, z_pred_mean: torch.Tensor, z_pred_log_std: torch.Tensor
    ) -> torch.Tensor:
        """Analytical KL divergence between a normal distribution and the normal prior distribution.

        KL(q||p) = 0.5 * (tr(Σ_p^-1 Σ_q) + (μ_p - μ_q)^T Σ_p^-1 (μ_p - μ_q) - k + log(|Σ_p|/|Σ_q|))

        When p is a standard normal distribution, the KL divergence simplifies to:
        KL(q||p) = 0.5 * (μ^T μ + tr(Σ) - k - log|Σ|)

        Here we also assume Σ is diagonal, so the KL divergence simplifies to:
        KL(q||p) = 0.5 * (μ^T μ + σ^T σ - k - 2 * sum(log(σ)))

        https://mr-easy.github.io/2020-04-16-kl-divergence-between-2-gaussian-distributions/

        Parameters
        ----------
        z_pred_mean : torch.Tensor of shape (*, n_components)
            The predicted mean of the latent variable.
        z_pred_log_std : torch.Tensor of shape (*, n_components)
            The predicted log standard deviation of the latent variable.

        Returns
        -------
        torch.Tensor of shape (*,)
            The analytical KL divergence.

        Raises
        ------
        ValueError
            If the prior is not Normal.
        """
        if self.prior != "normal":
            raise ValueError("Prior is not Normal, so no analytical KL divergence.")
        return (
            0.5 * (z_pred_mean**2 + z_pred_log_std.exp() ** 2 - 1 - 2 * z_pred_log_std)
        ).sum(dim=-1)

    def aggregated_posterior(
        self, z_pred_mean: torch.Tensor, z: torch.Tensor, z_pred_log_std: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Aggregated posterior jointly and dimension-wise.

        Parameters
        ----------
        z_pred_mean : torch.Tensor of shape (batch_size, n_components)
            The predicted mean of the latent variable.
        z : torch.Tensor of shape (batch_size, n_components)
            The sampled latent variable.
        z_pred_log_std : torch.Tensor of shape (batch_size, n_components)
            The predicted log standard deviation of the latent variable.

        Returns
        -------
        ln_q_z : torch.Tensor of shape (batch_size,)
            The joint aggregated posterior.
        ln_prod_q_zi : torch.Tensor of shape (batch_size,)
            The dimension-wise aggregated posterior.
        """
        batch_size, n_components = z.shape
        if self.n_total_samples is None:
            n_total_samples = batch_size
        else:
            n_total_samples = self.n_total_samples

        mat_ln_q_z = -F.gaussian_nll_loss(
            z_pred_mean.view((1, batch_size, self.n_groups, self.group_rank)),
            z.view((batch_size, 1, self.n_groups, self.group_rank)),
            (z_pred_log_std.exp() ** 2).view(
                (1, batch_size, self.n_groups, self.group_rank)
            ),
            full=True,
            reduction="none",
        )  # (n_monte_carlo = batch_size, batch_size, n_groups, group_rank)

        reweights = (
            torch.ones(batch_size, batch_size, device=z.device)
            / (batch_size - 1)
            * (n_total_samples - 1)
        )
        reweights[torch.arange(batch_size), torch.arange(batch_size)] = 1
        reweights = reweights.log()

        ln_q_z = torch.logsumexp(
            mat_ln_q_z.sum(dim=(2, 3)) + reweights, dim=1
        ) - np.log(n_total_samples)
        ln_prod_q_zi = (
            torch.logsumexp(mat_ln_q_z.sum(dim=3) + reweights[:, :, None], dim=1)
            - np.log(n_total_samples)
        ).sum(dim=1)
        return ln_q_z, ln_prod_q_zi

    def partial_correlation(
        self, z_pred_mean: torch.Tensor, z: torch.Tensor, z_pred_log_std: torch.Tensor
    ) -> torch.Tensor:
        ln_q_z, ln_prod_q_zi = self.aggregated_posterior(z_pred_mean, z, z_pred_log_std)
        return (ln_q_z - ln_prod_q_zi).mean()

    def decomposed(
        self,
        z_pred_mean: torch.Tensor,
        z: torch.Tensor,
        z_pred_log_std: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Decomposed KL divergence between a normal distribution and the normal prior distribution.

        https://github.com/YannDubs/disentangling-vae

        KL(q||p) =
        α * I_q(z; x)
        + β * KL(q(z)||prod_i q(z_i))
        + γ * sum_i KL(q(z_i)||p(z_i))
        = α * KL(q(z, n) || q(z)p(n))
        + β * KL(q(z) || prod_i q(z_i))
        + γ * sum_i KL(q(z_i) || p(z_i))


        Parameters
        ----------
        z_pred_mean : torch.Tensor of shape (batch_size, n_components)
            The predicted mean of the latent variable.
        z : torch.Tensor of shape (batch_size, n_components)
            The sampled latent variable.
        z_pred_log_std : torch.Tensor of shape (batch_size, n_components)
            The predicted log standard deviation of the latent variable.

        Returns
        -------
        index_code_mutual_information : torch.Tensor
            The index-code mutual information.
        partial_correlation : torch.Tensor
            The partial correlation.
        dimension_wise_kl : torch.Tensor
            The dimension-wise KL divergence.
        """

        ln_q_zgx = -F.gaussian_nll_loss(
            z_pred_mean, z, (z_pred_log_std.exp() ** 2), full=True, reduction="none"
        ).sum(
            dim=-1
        )  # (n_monte_carlo = batch_size,)
        ln_p_z = self.prior_log_prob(z).sum(dim=-1)  # (n_monte_carlo = batch_size,)

        ln_q_z, ln_prod_q_zi = self.aggregated_posterior(z_pred_mean, z, z_pred_log_std)

        index_code_mutual_information = (ln_q_zgx - ln_q_z).mean()
        partial_correlation = (ln_q_z - ln_prod_q_zi).mean()
        dimension_wise_kl = (ln_prod_q_zi - ln_p_z).mean()

        return index_code_mutual_information, partial_correlation, dimension_wise_kl
