import torch
import torch.nn.functional as F
import numpy as np

from math import pi


class GaussianMixture(torch.nn.Module):
    """
    Fits a mixture of k=1,..,K Gaussians to the input data. Input tensors are expected to be flat with dimensions (n: number of samples, d: number of features).
    The model then extends them to (n, k: number of components, d).
    The model parametrization (mu, sigma) is stored as (1, k, d), and probabilities are shaped (n, k, 1) if they relate to an individual sample, or (1, k, 1) if they assign membership probabilities to one of the mixture components.
    """
    def __init__(self, n_components, mu_init=None, var_init=None, eps=1.e-6):
        """
        Initializes the model and brings all tensors into their required shape. The class expects data to be fed as a flat tensor in (n, d). The class owns:
            x:              torch.Tensor (n, k, d)
            mu:             torch.Tensor (1, k, d)
            var:            torch.Tensor (1, k, d)
            pi:             torch.Tensor (1, k, 1)
            eps:            float
            n_components:   int
            n_features:     int
            score:          float
        args:
            n_components:   int
            n_features:     int
            mu_init:        torch.Tensor (1, k, d)
            var_init:       torch.Tensor (1, k, d)
            eps:            float
        """

        super(GaussianMixture, self).__init__()
        n_features = 1
        self.eps = eps
        self.n_components = n_components
        self.n_features = n_features
        self.score = -np.inf

        if mu_init is not None:
            assert mu_init.size() == (1, n_components, n_features), "Input mu_init does not have required tensor dimensions (1, %i, %i)" % (n_components, n_features)
            # (1, k, d)
            self.var = torch.nn.Parameter(mu_init, requires_grad=False)
        else:
            self.mu = torch.nn.Parameter(torch.randn(1, n_components, n_features), requires_grad=False)

        if var_init is not None:
            assert var_init.size() == (1, n_components, n_features), "Input var_init does not have required tensor dimensions (1, %i, %i)" % (n_components, n_features)
            # (1, k, d)
            self.var = torch.nn.Parameter(var_init, requires_grad=False)
        else:
            self.var = torch.nn.Parameter(torch.ones(1, n_components, n_features), requires_grad=False)
        # (1, k, 1)
        self.pi = torch.nn.Parameter(torch.Tensor(1, n_components, 1), requires_grad=False).fill_(1./n_components)

    def forward(self, like):
        pi = F.softmax(self.pi.squeeze(-1), dim=-1).expand(like.shape[0], self.n_components)
        mu = self.mu.squeeze(-1).expand(like.shape[0], self.n_components)
        sigma = torch.sqrt(self.var).squeeze(-1).expand(like.shape[0], self.n_components)
        return pi, mu, sigma

    def fit(self, x, n_iter=1000, delta=1e-8):
        """
        Public method that fits data to the model.
        args:
            n_iter:     int
            delta:      float
        """

        if len(x.size()) == 2:
            # (n, d) --> (n, k, d)
            x = x.unsqueeze(1).expand(x.size(0), self.n_components, x.size(1))

        i = 0
        j = np.inf

        while (i <= n_iter) and (j >= delta):

            old_score = self.score
            old_mu = self.mu
            old_var = self.var

            self.__em(x)
            self.score = self.__score(self.pi, self.__p_k(x, self.mu, self.var))

            if (self.score.abs() == float("Inf")) or (self.score == float("nan")):
                # when the log-likelihood assumes inane values, reinitialize model
                self.__init__(self.n_components, self.n_features)

            i += 1
            j = self.score - old_score

            if j <= delta:
                # when the score decreases, revert to old parameters
                self.__update_mu(old_mu)
                self.__update_var(old_var)


    def predict(self, x, probs=False):
        """
        Assigns input data to one of the mixture components by evaluating the likelihood under each. If probs=True returns normalized probabilities of class membership instead.
        args:
            x:          torch.Tensor (n, d) or (n, k, d)
            probs:      bool
        returns:
            y:          torch.LongTensor (n)
        """

        if len(x.size()) == 2:
            # (n, d) --> (n, k, d)
            x = x.unsqueeze(1).expand(x.size(0), self.n_components, x.size(1))

        p_k = self.__p_k(x, self.mu, self.var)
        if probs:
            return p_k / (p_k.sum(1, keepdim=True) + self.eps)
        else:
            _, predictions = torch.max(p_k, 1)
            return torch.squeeze(predictions).type(torch.LongTensor)


    def __p_k(self, x, mu, var):
        """
        Returns a tensor with dimensions (n, k, 1) indicating the likelihood of data belonging to the k-th Gaussian.
        args:
            x:      torch.Tensor (n, k, d)
            mu:     torch.Tensor (1, k, d)
            var:    torch.Tensor (1, k, d)
        returns:
            p_k:    torch.Tensor (n, k, 1)
        """

        # (1, k, d) --> (n, k, d)
        mu = mu.expand(x.size(0), self.n_components, self.n_features)
        var = var.expand(x.size(0), self.n_components, self.n_features)

        # (n, k, d) --> (n, k, 1)
        exponent = torch.exp(-.5 * torch.sum((x - mu) * (x - mu) / var, 2, keepdim=True))
        # (n, k, d) --> (n, k, 1)
        prefactor = torch.rsqrt(((2. * pi) ** self.n_features) * torch.prod(var, dim=2, keepdim=True) + self.eps)

        return prefactor * exponent


    def __e_step(self, pi, p_k):
        """
        Computes weights that indicate the probabilistic belief that a data point was generated by one of the k mixture components. This is the so-called expectation step of the EM-algorithm.
        args:
            pi:         torch.Tensor (1, k, 1)
            p_k:        torch.Tensor (n, k, 1)
        returns:
            weights:    torch.Tensor (n, k, 1)
        """

        weights = pi * p_k
        return torch.div(weights, torch.sum(weights, 1, keepdim=True) + self.eps)


    def __m_step(self, x, weights):
        """
        Updates the model's parameters. This is the maximization step of the EM-algorithm.
        args:
            x:          torch.Tensor (n, k, d)
            weights:    torch.Tensor (n, k, 1)
        returns:
            pi_new:     torch.Tensor (1, k, 1)
            mu_new:     torch.Tensor (1, k, d)
            var_new:    torch.Tensor (1, k, d)
        """

        # (n, k, 1) --> (1, k, 1)
        n_k = torch.sum(weights, 0, keepdim=True)
        pi_new = torch.div(n_k, torch.sum(n_k, 1, keepdim=True) + self.eps)
        # (n, k, d) --> (1, k, d)
        mu_new = torch.div(torch.sum(weights * x, 0, keepdim=True), n_k + self.eps)
        # (n, k, d) --> (1, k, d)
        var_new = torch.div(torch.sum(weights * (x - mu_new) * (x - mu_new), 0, keepdim=True), n_k + self.eps)

        return pi_new, mu_new, var_new


    def __em(self, x):
        """
        Performs one iteration of the expectation-maximization algorithm by calling the respective subroutines.
        args:
            x:          torch.Tensor (n, k, d)
        """

        weights = self.__e_step(self.pi, self.__p_k(x, self.mu, self.var))
        pi_new, mu_new, var_new = self.__m_step(x, weights)

        self.__update_pi(pi_new)
        self.__update_mu(mu_new)
        self.__update_var(var_new)


    def __score(self, pi, p_k):
        """
        Computes the log-likelihood of the data under the model.
        args:
            pi:         torch.Tensor (1, k, 1)
            p_k:        torch.Tensor (n, k, 1)
        """

        weights = pi * p_k
        return torch.sum(torch.log(torch.sum(weights, 1) + self.eps))


    def __update_mu(self, mu):
        """
        Updates mean to the provided value.
        args:
            mu:         torch.FloatTensor
        """

        assert mu.size() in [(self.n_components, self.n_features), (1, self.n_components, self.n_features)], "Input mu does not have required tensor dimensions (%i, %i) or (1, %i, %i)" % (self.n_components, self.n_features, self.n_components, self.n_features)

        if mu.size() == (self.n_components, self.n_features):
            self.mu = mu.unsqueeze(0)
        elif mu.size() == (1, self.n_components, self.n_features):
            self.mu.data = mu


    def __update_var(self, var):
        """
        Updates variance to the provided value.
        args:
            var:        torch.FloatTensor
        """

        assert var.size() in [(self.n_components, self.n_features), (1, self.n_components, self.n_features)], "Input var does not have required tensor dimensions (%i, %i) or (1, %i, %i)" % (self.n_components, self.n_features, self.n_components, self.n_features)

        if var.size() == (self.n_components, self.n_features):
            self.var = var.unsqueeze(0)
        elif var.size() == (1, self.n_components, self.n_features):
            self.var.data = var


    def __update_pi(self, pi):
        """
        Updates pi to the provided value.
        args:
            pi:         torch.FloatTensor
        """

        assert pi.size() in [(1, self.n_components, 1)], "Input pi does not have required tensor dimensions (%i, %i, %i)" % (1, self.n_components, 1)

        self.pi.data = pi