import torch

from XXX.uib.modules.latent_label_chain import LatentLabelChain
from XXX.uib.modules.summarizer import Summarizer

import pingouin as pg


class IsGaussianSummarizer(Summarizer):
    significance_level = 0.01

    """Keeps track of relevant quantities, so we can compute any kind of information quantity after the fact."""

    def __init__(self, encodings_label_chain):
        self.latent_label_chain: LatentLabelChain = encodings_label_chain

    def reset(self):
        pass

    def fit(self, encodings_x_k_Z: torch.Tensor, labels_x: torch.Tensor):
        pass

    def is_gaussian(self, samples_x_z):
        normal, p = pg.multivariate_normality(samples_x_z, self.significance_level)
        return normal

    def is_gaussian_Z__Y(self):
        labels_x = self.latent_label_chain.labels_x.get()
        latent_x_k_z = self.latent_label_chain.latent_x_k_Z.get()

        labels, counts = torch.unique(labels_x, return_counts=True)

        p_Y = counts.float() / counts.sum()

        weighted_is_gaussian = 0.0
        for label, p_Y in zip(labels, p_Y):
            latent__Y_xk_z = latent_x_k_z[labels_x == label].flatten(0, 1)

            if self.is_gaussian(latent__Y_xk_z):
                weighted_is_gaussian += p_Y

        return weighted_is_gaussian

    def is_gaussian_Z(self):
        latent_xk_z = self.latent_label_chain.latent_x_k_Z.get().flatten(0, 1)

        is_gaussian = self.is_gaussian(latent_xk_z)
        return is_gaussian
