import torch
from typing import Tuple, List, Optional

from torch.distributions import (
    MixtureSameFamily,
    Categorical,
    Independent,
    Normal, MultivariateNormal,
)
from torch.nn.functional import softplus, normalize
from torch.nn.parameter import Parameter

from SED import SED, SED_full
from M_c import M_c
from mixture import sum_of_mixtures, mixture_times_constant


STD_MIN = 1e-0


class ClassSeparator(torch.nn.Module):
    def __init__(
        self,
        num_classes: int,
        num_mixtures: int,
        num_features: int,
        use_full_covariance: bool,
    ):
        """
        Instantiates a model that tries to maximize the (|C|*(|C|+1))/2 - |C|
        interactions between pairs of distinct classes c,c' un this way
            \sum_{c,c'} SED(p(h|c), p(h|c')) - SED(p(x|c), p(x|c'))

        :param num_classes: the parameter |C| in the paper
        :param num_mixtures: the parameter |M| in the paper
        :param num_features: the parameter D in the paper
        :param use_full_covariance: whether to use the full covariance matrix
        """
        super().__init__()
        self.C = num_classes
        self.M = num_mixtures
        self.D = num_features

        self.use_full_covariance = use_full_covariance

        self.k = Parameter(
            torch.ones(1, dtype=torch.double), requires_grad=True
        )

        self.epsilon = Parameter(
            torch.ones(1, dtype=torch.double), requires_grad=True
        )  # 1 as epsilon

        self.sigma = Parameter(
            torch.ones(1, dtype=torch.double), requires_grad=True
        )  # 1 as epsilon

        self.prior = Parameter(
            torch.rand(self.C, dtype=torch.double), requires_grad=True
        )
        # normalize prior weights
        self.prior.data = self.prior.data / self.prior.data.sum(
            0, keepdims=True
        )

        # # We give equal importance to the classes as it is not important
        # # right now
        # self.prior = Parameter(
        #     torch.tensor(
        #         [0.5 for _ in range(num_classes)], dtype=torch.double
        #     ),
        #     requires_grad=False,
        # )

        self.weights = Parameter(
            torch.rand(self.C, self.M, dtype=torch.double)
        )
        # normalize mixture weights
        self.weights.data = self.weights.data / self.weights.data.sum(
            1, keepdims=True
        )

        self.mean = Parameter(
            torch.rand(self.C, self.M, self.D, dtype=torch.double) * 100.0
        )

        if not self.use_full_covariance:
            self.std = Parameter(
                torch.rand(self.C, self.M, self.D, dtype=torch.double) * 100.0
            )
        else:
            self.std = Parameter(
                torch.rand(self.C, self.M, self.D, self.D, dtype=torch.double)
                * 100.0
            )

    def get_parameters(self):
        epsilon = torch.abs(self.epsilon)
        prior = torch.abs(self.prior)
        prior = normalize(prior, p=1, dim=-1)
        weights = torch.abs(self.weights)
        weights = normalize(weights, p=1, dim=-1)
        mean = self.mean
        std = torch.abs(self.std) + STD_MIN
        return prior, weights, mean, std, epsilon

    def compute_SED_X(self, c, c_prime) -> torch.Tensor:
        """
        Compute the SED(H_c,H_c')

        :param c: the first class
        :param c_prime: the second class
        :return: a tensor with a single value representing SED(H_c,H_c')
        """
        _, weights, mean, std, _ = self.get_parameters()

        # SED for X
        if not self.use_full_covariance:
            sed_x = SED(
                mixture_weights_x=weights[c],
                gaussian_mean_x=mean[c],
                gaussian_std_x=std[c],
                mixture_weights_y=weights[c_prime],
                gaussian_mean_y=mean[c_prime],
                gaussian_std_y=std[c_prime],
            )
        else:
            sigma = torch.matmul(std, std.transpose(2, 3))

            sed_x = SED_full(
                mixture_weights_x=weights[c],
                gaussian_mean_x=mean[c],
                gaussian_sigma_x=sigma[c],
                mixture_weights_y=weights[c_prime],
                gaussian_mean_y=mean[c_prime],
                gaussian_sigma_y=sigma[c_prime],
            )
        return sed_x

    def compute_SED_H(self, c, c_prime) -> torch.Tensor:
        """
        Compute the SED(H_c,H_c')

        :param c: the first class
        :param c_prime: the second class
        :return: a tensor with a single value representing SED(H_c,H_c')
        """
        _, weights, mean, std, _ = self.get_parameters()

        lambda_epsilon = (
            torch.ones(
                self.C,
                self.M,
                self.D,
                dtype=torch.double,
                device=self.sigma.device,
            )
            * torch.pow(self.sigma, 2)
            / (torch.relu(self.k) + 1)
        )

        if not self.use_full_covariance:
            h_std = std + torch.sqrt(lambda_epsilon)

            # SED for H
            sed_h = SED(
                mixture_weights_x=weights[c],
                gaussian_mean_x=mean[c],
                gaussian_std_x=h_std[c],
                mixture_weights_y=weights[c_prime],
                gaussian_mean_y=mean[c_prime],
                gaussian_std_y=h_std[c_prime],
            )

        else:
            sigma = torch.matmul(std, std.transpose(2, 3))
            h_sigma = sigma + torch.diag_embed(lambda_epsilon)

            # SED for H
            sed_h = SED_full(
                mixture_weights_x=weights[c],
                gaussian_mean_x=mean[c],
                gaussian_sigma_x=h_sigma[c],
                mixture_weights_y=weights[c_prime],
                gaussian_mean_y=mean[c_prime],
                gaussian_sigma_y=h_sigma[c_prime],
            )

        return sed_h

    def compute_CCNS_LB(self,
                        remove_std_min: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute the lower bound of the CCNS for two classes c,c'

        :param remove_std_min: whether or not to remove the STD_MIN
            contribution from std

        :return: a CxC matrix with the lower bound of CCNS between pairs of
        classes and a CxC matrix with M_c'(c)
        """
        prior, weights, mean, std, epsilon = self.get_parameters()

        if remove_std_min:
            std.data -= STD_MIN

        # compute normalized M_c'(c) for every c'
        m_cc = M_c(
            epsilon=epsilon,
            prior=prior,
            p_m_given_c=weights,
            gaussian_mean=mean,
            gaussian_std=std,
            normalize=True,
        )
        lb = torch.cdist(m_cc, m_cc, p=2)  # |C|x|C|
        return lb, m_cc

    def forward(
        self,
    ) -> Tuple[
        List[Tuple[torch.Tensor, int, int]], torch.Tensor, torch.Tensor
    ]:
        """
        Computes the two terms of the inequality in Th. 3.7, so that we can
        optimize wrt their difference

        :return: a list of tuples of the form (SED(hc,hc') - SED(xc,xc'), c, c')
            for every c != c'
        """
        sed_tuples = []

        if not self.use_full_covariance:
            lb_ccns, m_cc = self.compute_CCNS_LB()
        else:
            lb_ccns, m_cc = None, None

        # compute pairwise SED for x and h
        for c_prime in range(self.C):
            for c in range(c_prime + 1, self.C):
                # SED for X
                sed_x = self.compute_SED_X(c, c_prime)
                sed_h = self.compute_SED_H(c, c_prime)

                sed_tuples.append(
                    (
                        sed_h - sed_x,
                        c_prime,
                        c,
                        sed_h.detach().cpu(),
                        sed_x.detach().cpu(),
                    )
                )

        return sed_tuples, lb_ccns, m_cc

    def get_log_likelihood_original_distribution(
        self, input_tensor: torch.Tensor, class_id: int = None
    ) -> torch.Tensor:
        """
        Given a set of points, it returns the likelihood of the HNB. If a class
        id is specified, it returns the likelihood of the GMM associated with
        a specific class only (thus ignoring the learned prior distribution)
        :param input_tensor: the data wrt which compute the likelihood
        :param class_id: Optional parameter to specify one of the classes of
            the HNB
        :return: a tensor containing the log-likelihood for each input point
        """
        prior, weights, mean, std, _ = self.get_parameters()

        raise NotImplementedError("Full Covariance case to be implemented")
        # TODO IF FULL COVARIANCE WE NEED TO TAKE IT INTO ACCOUNT

        gmm_class = []
        for c in range(self.C):
            gmm_class.append(
                MixtureSameFamily(
                    Categorical(weights[c]),
                    Independent(Normal(loc=mean[c], scale=std[c]), 1),
                )
            )
            # print(gmm_class[-1].batch_shape, gmm_class[-1].event_shape)

        if class_id is not None:
            return gmm_class[class_id].log_prob(input_tensor)
        else:
            log_likelihood = torch.stack(
                [
                    prior[c].log() + gmm_class[c].log_prob(input_tensor)
                    for c in range(self.C)
                ],
                dim=1,
            )
            log_likelihood = torch.logsumexp(log_likelihood, dim=1)
            return log_likelihood

    def get_log_likelihood_embedding_distribution(
        self, input_tensor: torch.Tensor, class_id: Optional[int] = None
    ) -> torch.Tensor:
        """
        Given a set of points, it returns the likelihood of the HNB
        corresponding to the embedding space of a DGN. If a class
        id is specified, it returns the likelihood of the GMM associated with
        a specific class only (thus ignoring the learned prior distribution)
        :param input_tensor: the data wrt which compute the likelihood
        :param class_id: optional parameter to specify to which class the
            neighboring embedding distribution is supposed to belong to
        :return: a tensor containing the log-likelihood for each input point
        """

        prior, weights, mean, std, _ = self.get_parameters()

        raise NotImplementedError("Full Covariance case to be implemented")
        # TODO IF FULL COVARIANCE WE NEED TO TAKE IT INTO ACCOUNT

        lambda_epsilon = (
            torch.ones(
                self.C,
                self.M,
                self.D,
                dtype=torch.double,
                device=self.sigma.device,
            )
            * torch.pow(self.sigma, 2)
            / (softplus(self.k) + 1)
        )
        h_std = std + torch.sqrt(lambda_epsilon)

        gmm_emb_class = []
        for c in range(self.C):
            gmm_emb_class.append(
                MixtureSameFamily(
                    Categorical(weights[c]),
                    Independent(Normal(loc=mean[c], scale=h_std[c]), 1),
                )
            )
            # print(gmm_class[-1].batch_shape, gmm_class[-1].event_shape)

        if class_id is not None:
            return gmm_emb_class[class_id].log_prob(input_tensor)
        else:
            log_likelihood = torch.stack(
                [
                    prior[c].log() + gmm_emb_class[c].log_prob(input_tensor)
                    for c in range(self.C)
                ],
                dim=1,
            )
            log_likelihood = torch.logsumexp(log_likelihood, dim=1)
            return log_likelihood

    def sample_data(
            self,
            num_samples: int = 10000,
            remove_std_min: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Samples data points from the learned distribution.
        :param num_samples: the number of total points to sample
        :param remove_std_min: whether or not to remove the STD_MIN
            contribution from std
        :return: tuple with X and y tensors
        """
        prior, weights, mean, std, _ = self.get_parameters()

        if remove_std_min:
            std.data -= STD_MIN

        class_choice = Categorical(prior)
        y = class_choice.sample(sample_shape=(num_samples,))

        if self.use_full_covariance and self.D > 1:
            sigma = torch.matmul(std, std.transpose(2, 3))

            # print(mean.shape, sigma.shape)
            # print(mean[y].shape, sigma[y].shape, weights[y].shape)

            gmm = MixtureSameFamily(
                Categorical(weights[y]),
                Independent(MultivariateNormal(loc=mean[y],
                                               covariance_matrix=sigma[y]), 0),
            )
            # print(gmm.batch_shape, gmm.event_shape)
        else:
            gmm = MixtureSameFamily(
                Categorical(weights[y]),
                Independent(Normal(loc=mean[y], scale=std[y]), 1),
            )

        X = gmm.sample()

        return X, y
