"""
Contains the CLOMP_GMM class to solve the compressive GMM problem.
"""
from typing import NoReturn

from pycle.compressive_learning.CLOMP_G import CLOMP_G
from pycle.sketching import FeatureMap
from pycle.utils.projectors import Projector, ProjectorClip, ProjectorNoProjection
import torch
import numpy as np



def fourierSketchOfGaussian(mu,Sigma,Omega,xi=None,scst=None):
    a = torch.matmul(mu,Omega)
    b = torch.matmul(torch.matmul(Omega.T,Sigma),Omega)
    #print(a.shape)
    #print(b.shape)
    #res = torch.exp(1j*torch.matmul(mu,Omega)-torch.matmul(torch.matmul(Omega.T,Sigma),Omega)/2)
    res = torch.exp(1j*torch.matmul(mu,Omega) -torch.einsum('ij,ij->i', torch.matmul(Omega.T, Sigma), Omega.T)/2.)
    #print('ccc')
    #print(res.shape)
    #res = np.exp(1j*(mu@Omega) -np.einsum('ij,ij->i', np.dot(Omega.T, Sigma), Omega.T)/2.)
    if xi is not None:
        res = res*torch.exp(1j*xi)
    if scst is not None: # Sketch constant, eg 1/sqrt(m)
        res = scst*res
    return res


class CLOMP_dGMM(CLOMP_G):
    """
    Instanciate a CLOMP solver for the specific case of compressive kmeans. That is compressive learning of centroids.

    Cluster center and centroid are interchangeable words for the same concept. They are the location of each dirac
    representing one mixture component.

    The definitions are specialized, more precise, less abstract, now:

    - The sketched component k of the mixture model is denoted phi_theta_k. It is the feature map applied to a centroid.
    - The parameters of the component k of the mixture model is denoted theta_k. It is a centroid.
    - The dimension of each parameter vector is D. It is the dimension of the underlying data.
    - The size of the mixture is K, the number of components. It is also the number of centroids/clusters.
    - The alphas are the weights of the mixture. It correspond to the weights, the importance of each centroid.
    - The solution is the pair (alphas, thetas), that is, all the parameters of the mixture
    """
    def __init__(self, phi: FeatureMap, centroid_projector: Projector = ProjectorNoProjection(), *args, **kwargs):
        """
        Parameters
        ----------
        phi
            The feature map used in the sketching operator.
        centroid_projector:
            A callback projector object to call on the centroids at the end of each iteration
            in order to enforce some constrains. See :mod:`pycle.utils.projectors`
        args
        kwargs
        """

        # Lower and upper bounds are for random initialization
        self.lower_bounds = None
        self.upper_bounds = None

        super().__init__(phi, thetas_dimension_D=2*phi.d, *args, **kwargs)

        assert isinstance(centroid_projector, Projector)
        self.centroid_projector = centroid_projector
        if isinstance(self.centroid_projector, ProjectorClip):
            # note that bounds for projection and bounds for initialization are different entities
            self.centroid_projector.lower_bound = self.centroid_projector.lower_bound.to(self.real_dtype).to(self.device)
            self.centroid_projector.upper_bound = self.centroid_projector.upper_bound.to(self.real_dtype).to(self.device)

    def sketch_of_mixture_components(self, thetas: torch.Tensor):
        """
        Computes and returns phi(theta_k) for each centroid in theta.

        D is the dimension of centroid, M is the dimension of a sketch.

        Parameters
        ----------
        thetas
            (D,) or (current_size_mixture,D)-shaped tensor containing centroids.

        Returns
        -------
            (M,) or (current_size_mixture, M)-shaped tensor constaining the M-dimensional feature maps of the centroids,
            e.g. the atoms.
        """

        
        assert thetas.size()[-1] == self.thetas_dimension_D
        s = thetas.size()
        #print(s)
        if len(s)==1:
            thetas = thetas.reshape((1,s[0]))
            
        _n_atoms = thetas.size()[0]
        
        A_M = 1j*torch.zeros((_n_atoms,self.phi._m)).double()
        for n in list(range(_n_atoms)):
            #print(_n_atoms)
            #print(thetas.shape)
            (mu,sig) = (thetas[n,:self.phi.d],thetas[n,-self.phi.d:])
            
            A_M[n,:] = fourierSketchOfGaussian(mu,torch.diag(sig),self.phi._Omega,self.phi.xi,self.phi.c_norm)

        if len(s) ==1:
            return A_M.squeeze(0)
        else:
            return A_M
    #self.phi(thetas)

    def set_bounds_thetas(self, bounds: torch.Tensor) -> NoReturn:
        """
        Set the bounds where the centroids can be found.

        These bounds can be used for initizalizing new centroids
        and for setting bounds to the optimization procedure.

        Parameters
        ----------
        bounds
            (2, D)- shaped tensor containing the lower bounds in position 0 and upper bounds in position 1.
        """
        d = self.thetas_dimension_D
        d_2 = int(self.thetas_dimension_D/2)
        
        assert len(bounds) == 2
        self.lower_bounds = bounds[0].to(self.real_dtype).to(self.device)
        self.upper_bounds = bounds[1].to(self.real_dtype).to(self.device)
        self.bounds = bounds  # data bounds
        # self.bounds_atom =
        # [[lowerbound_1, upperbound_1],
        #  ...,
        #  [lowerbound_d_atom, upperbound_d_atom]]
        self.bounds_atom = bounds.T.tolist()
        
        
        
    # def set_bounds_atom___(self, bounds: torch.Tensor) -> NoReturn:
    #     """
    #     Should set self.bounds_atom to a list of length d_atom of lower and upper bounds, i.e.,
    #         self.bounds_atom = [[lowerbound_1,upperbound_1], ..., [lowerbound_d_atom,upperbound_d_atom]]
    #     """
    #     assert bounds.shape == (2,self.Phi.d)
    #     self.bounds = bounds # data bounds
    #     self.bounds_atom = bounds.T.tolist()
    #     for i in range(self.Phi.d): # bounds for the variance in each dimension
    #         max_variance_this_dimension = (bounds[1][i]-bounds[0][i])**2
    #         self.bounds_atom.append([self.variance_relative_lowerbound*max_variance_this_dimension,
    #                                  self.variance_relative_upperbound*max_variance_this_dimension]) 

    # def randomly_initialize_several_mixture_components__(self, nb_mixture_components: int):
    #     """
    #     Uniform initialization of several centroids between the lower and upper bounds.

    #     Parameters
    #     ----------
    #     nb_mixture_components
    #         The number of centroids to initialize.

    #     Returns
    #     -------
    #         (nb_atoms, D) shaped tensor containing the centroids.
    #     """
    #     all_new_theta = (self.upper_bounds -
    #                      self.lower_bounds) * torch.rand(nb_mixture_components, self.thetas_dimension_D).to(self.device) + self.lower_bounds
    #     return all_new_theta
    
    
    def randomly_initialize_several_mixture_components(self, nb_mixture_components: int):
        d_2 = int(self.thetas_dimension_D/2)
        all_new_theta = torch.zeros((nb_mixture_components,self.thetas_dimension_D))
        all_new_theta = all_new_theta.double()
        #print('herehere')
        #print(nb_mixture_components)
        for n in list(range(nb_mixture_components)):
            mu0 = np.random.uniform(self.bounds[0][0:d_2],self.bounds[1][0:d_2])
            all_new_theta[n,0:d_2] = torch.from_numpy(mu0)
            sig0 = np.random.uniform(0.001,1,d_2) # initial covariances((10*(self.bounds[1]-self.bounds[0]))**2)*
            #print(sig0.shape)
            all_new_theta[n,d_2:self.thetas_dimension_D] = torch.from_numpy(sig0)
        #print('herehererere')
        #print(all_new_theta.shape)
        return all_new_theta
    
    
    

    def projection_step(self, thetas: torch.Tensor) -> NoReturn:
        """
        Project a centroid theta (or a set of thetas) on the constraint specifed
        by self.centroid_project of class `Projector`.

        The modification is made in place.

        Parameters
        ----------
        thetas
            The centroids to project
        """
        if self.centroid_projector is not None:
            self.centroid_projector.project(thetas)

    def get_centroids(self, return_numpy=False) -> [torch.Tensor, np.ndarray]:
        """
        Return the centroids.

        Parameters
        ----------
        return_numpy
            If True, return as numpy array (default: False).

        Returns
        -------
            (current_size_mixture, D) shaped tensor or ndarray containing the centroids.
        """
        if return_numpy:
            return self.thetas[:,0:self.thetas_dimension_D].cpu().numpy()
        return self.thetas[:,0:self.thetas_dimension_D]
