import torch
from torch import cdist
from torch.distributions import Categorical, MixtureSameFamily, \
    MultivariateNormal, Independent, Normal
from torch.nn.functional import one_hot
from torch_geometric.utils import degree
from torch_scatter import scatter

from M_x import M_x


def ccns(
    edge_index: torch.Tensor,
    class_labels: torch.Tensor,
    num_nodes: torch.Tensor,
) -> torch.Tensor:
    """
    Computes the CCNS of https://arxiv.org/abs/2106.06134.
    It assumes source nodes are stored in `edge_index[0]` and destination
    nodes in `edge_index[1]`. It uses L2 distance to compute discrepancy

    :param edge_index: the connectivity of the graph in PyG format
    :param class_labels: a 1D vector of class labels for the nodes
    :param num_nodes: the number of nodes in the graph
    :return: a tensor of size CxC containing the empirical CCNS
    """
    src_index = edge_index[0]
    dst_index = edge_index[1]

    deg = degree(index=dst_index, num_nodes=num_nodes).unsqueeze(1)  # Nx1


    classes_one_hot = one_hot(class_labels)  # NxC
    num_classes = classes_one_hot.shape[1]

    # compute distribution of neighboring classes for each node
    dist = scatter(src=classes_one_hot[src_index], index=dst_index, dim=0) / deg

    ccns = torch.zeros(num_classes, num_classes)

    for c in range(num_classes):
        for c_prime in range(num_classes):
            v_c_mask = class_labels == c
            v_c_prime_mask = class_labels == c_prime

            # normalize count of classes to get a probability
            dist_c = dist[v_c_mask]  # NVC x num_classes
            dist_c_prime = dist[v_c_prime_mask]  # NVC' x num_classes

            assert torch.allclose(dist_c.sum(1), torch.tensor([1.]))
            assert torch.allclose(dist_c_prime.sum(1), torch.tensor([1.]))

            ccns[c, c_prime] += (
                cdist(
                    dist_c,
                    dist_c_prime,
                    p=2.0,
                    compute_mode="donot_use_mm_for_euclid_dist",
                ).sum() / (v_c_mask.sum() * v_c_prime_mask.sum())
            )

    return ccns



def monte_carlo_ccns(num_classes: int,
                     num_samples: int,
                     epsilon: torch.double,
                     p_c: torch.Tensor,
                     p_m_given_c: torch.Tensor,
                     gaussian_mean: torch.Tensor,
                     gaussian_std: torch.Tensor) -> torch.Tensor:
    """
    It estimates the expectation of Equation 1 using a q_c and q_c' computed
    using the normalized version of Proposition 3.3.

    :param num_classes: the value C of the number of classes
    :param num_samples: the number of samples **for each class** to use to
        approximate CCNS(c,c').
    :param epsilon: the length of the hypercube centered at x
    :param p_c: the prior class distribution p(c) estimated from the dataset.
        It is a vector of shape (C,), where C is the number of classes
    :param p_m_given_c: the weight vector of the class-conditional mixture
        p(m|c). It is a tensor of shape (C,M), where M is the number of
        mixtures.
    :param gaussian_mean: a tensor of shape (C,M,D) containing the means of
        the gaussian distributions associated with the different classes,
        mixtures and features
    :param gaussian_std: a tensor of shape (C,M,D) containing the standard
        deviations of the gaussian distributions associated with the different
        classes, mixtures and features
    :param normalize: whether to compute a normalization over the classes
        for each data point
    :param q_x: a tensor of shape NxC holding the distributions computed by
        Proposition 3.3 for the dataset of points.

    :return: an estimate of the ccns as a tensor of size CxC
    """
    num_features = gaussian_mean.shape[-1]
    assert len(gaussian_std.shape) < 4, 'the theory of Section 3.1 assumes a' \
                                        'diagonal covariance matrix'

    def sample(class_id):

        if num_features > 1:
            sigma = torch.diag_embed(gaussian_std)

            gmm = MixtureSameFamily(
                Categorical(p_m_given_c[class_id]),
                Independent(MultivariateNormal(loc=gaussian_mean[class_id],
                                               covariance_matrix=sigma[class_id]), 0),
            )

        else:
            gmm = MixtureSameFamily(
                Categorical(p_m_given_c[class_id]),
                Independent(Normal(loc=gaussian_mean[class_id],
                                   scale=gaussian_std[class_id]), 1),
            )

        samples = gmm.sample((num_samples,))
        return samples

    ccns_monte_carlo = torch.zeros(num_classes, num_classes)

    for c in range(num_classes):
        for c_prime in range(num_classes):
            X_c = sample(c)
            q_c = M_x(X=X_c,
                      epsilon=epsilon,
                      p_c=p_c,
                      p_m_given_c=p_m_given_c,
                      gaussian_mean=gaussian_mean,
                      gaussian_std=gaussian_std,
                      normalize=True)

            X_c_prime = sample(c_prime)
            q_c_prime = M_x(X=X_c_prime,
                            epsilon=epsilon,
                            p_c=p_c,
                            p_m_given_c=p_m_given_c,
                            gaussian_mean=gaussian_mean,
                            gaussian_std=gaussian_std,
                            normalize=True)

            ccns_monte_carlo[c, c_prime] +=  \
                torch.norm(q_c - q_c_prime, p=2, dim=1).sum() / num_samples

    return ccns_monte_carlo