import numpy as np
import torch
from loguru import logger

from pycle.sketching import MatrixFeatureMap, FeatureMap, computeSketch




def fourierSketchOfGaussian(mu, Sigma, Omega, xi=None, scst=None):
    res = np.exp(1j * (mu @ Omega) - np.einsum('ij,ij->i', np.dot(Omega.T, Sigma), Omega.T) / 2.)
    if xi is not None:
        res = res * np.exp(1j * xi)
    if scst is not None:  # Sketch constant, eg 1/sqrt(m)
        res = scst * res
    return res


def fourier_sketch_of_gaussianS(muS, SigmaS, Omega, xi=None, scst=None, use_torch=True):

    if use_torch:
        backend = torch
        dim_name = "dim"
    else:
        backend = np
        dim_name = "axis"
    print(muS)
    #muS_ = torch.Tensor([1.,0.5])
    #muS = muS_.Float()
    right_hand = SigmaS[..., np.newaxis] * Omega
    right_hand = Omega * right_hand
    right_hand = - 0.5 * backend.sum(right_hand, **{dim_name:-2})
    # this line does the multiplication between the frequencies and the means (left hand part of Eq 15)
    left_hand = +1j * (muS @ Omega)
    print('ttt')
    print(left_hand.size())
    print(right_hand.size())
    
    d = right_hand.size()[0]
    M = right_hand.size()[1]
    right_hand_sum = torch.zeros([M])
    for i in list(range(d)):
        right_hand_sum = right_hand_sum + right_hand[i,:]
    # print(pre_result.size())
    #print(pre_result)
    pre_result = left_hand + right_hand_sum
    result = backend.exp(pre_result)  # adding the contents of an exp is like multiplying the exps
    #print(result)
    
    
    # result = backend.exp(left_hand + right_hand)  # adding the contents of an exp is like multiplying the exps
    
    
    # muS_ = muS.reshape(1,2)
    # #muS_ = torch.ones(1,2)
    # muS_ = muS_.double()
    

    # right_hand = SigmaS[..., np.newaxis] * Omega
    # right_hand = Omega * right_hand
    # right_hand = - 0.5 * backend.sum(right_hand, **{dim_name:-2})
    # # this line does the multiplication between the frequencies and the means (left hand part of Eq 15)
    # left_hand = 1j * (muS_ @ Omega)
    # pre_result = left_hand + right_hand
    # d = pre_result.size()[0]
    # M = pre_result.size()[1]
    # proto_result = torch.zeros([M])
    # for i in list(range(d)):
    #     proto_result[:] = proto_result[:] + pre_result[i,:]
    # result = backend.exp(left_hand[0,:])  # adding the contents of an exp is like multiplying the exps
    
    # if xi is not None:
    #     result = result * backend.exp(1j * xi)
    # if scst is not None:  # Sketch constant, eg 1/sqrt(m)
    #     result = scst * result
    return result


def fourierSketchOfGMM(GMM, featureMap):
    """Returns the complex exponential sketch of a Gaussian Mixture Model

    Parameters
    ----------
    GMM: (weigths,means,covariances) tuple, the Gaussian Mixture Model, with
        - weigths:     (K,)-numpy array containing the weigthing factors of the Gaussians
        - means:       (K,d)-numpy array containing the means of the Gaussians
        - covariances: (K,d,d)-numpy array containing the covariance matrices of the Gaussians
    featureMap: the sketch the sketch featureMap (Phi), provided as either:
        - a SimpleFeatureMap object (i.e., complex exponential or universal quantization periodic map)
        - (Omega,xi): tuple with the (d,m) Fourier projection matrix and the (m,) dither (see above)

    Returns
    -------
    z: (m,)-numpy array containing the sketch of the provided GMM
    """
    # Parse GMM input
    (w, mus, Sigmas) = GMM
    #print(w)
    #print(mus)
    #print(Sigmas)
    K = w.size()[0]
    #print('iamhereeeeee')
    #print(mus.size())
    #print(w)
    # Parse featureMap input
    if isinstance(featureMap, MatrixFeatureMap):
        Omega = featureMap.Omega
        #print(Omega)
        xi = featureMap.xi
        d = featureMap.d
        m = featureMap._m
        scst = featureMap.c_norm  # Sketch normalization constant, e.g. 1/sqrt(m)
    elif isinstance(featureMap, tuple):
        (Omega, xi) = featureMap
        (d, m) = Omega.shape
        scst = 1.  # This type of argument passing does't support different normalizations
    else:
        raise ValueError('The featureMap argument does not match one of the supported formats.')

    z = 1j * torch.zeros(m)
    for k in range(K):
        w[k] = 0.5
        z = z +w[k]*fourier_sketch_of_gaussianS(mus[k], Sigmas[k], Omega, xi, scst,use_torch=True)
        # w[k]*
    print(w)
    #print(z[100])
    
    #print(type(z[0]))
    return z
