import numpy as np
import numpy.matlib
import math
import plotly.graph_objs as go
from scipy.special import gammaln, logsumexp
from scipy.stats import multivariate_normal, multivariate_t
import numba
from functools import lru_cache
from joblib import Parallel, delayed
import multiprocessing
from numba import jit, njit, prange
import time
import threading
from numpy.polynomial.hermite import hermgauss

############################### log Functions ##################################
def log_Taylor_series(N, a, xn):
    result = np.log(a)
    for n in range(1, N+1):
        for k in range(n+1):
            result += (((-1)**(n-1))/(n*(a**(n))))*math.comb(n,k)*((-a)**(n-k)*xn[k,:])
    return result

def log_Taylor_limit(N, a, xn):
    if N<2:
        # print("Need at least order 2 approximation. Using just Taylor")
        result = log_Taylor_series(N, a, xn)
    else:
        TaylorEnt = np.zeros(N+1)
        TaylorEnt[0] = np.log(a)
        for n in range(1, N+1):
            for k in range(n+1):
                TaylorEnt[n] += (((-1)**(n-1))/(n*(a**(n))))*math.comb(n,k)*((-a)**(n-k)*xn[k,:])
        TaylorEnt = np.cumsum(TaylorEnt)
        result = TaylorEnt[-3]-(TaylorEnt[-2]-TaylorEnt[-3])**2/(TaylorEnt[-1]-2*TaylorEnt[-2]+TaylorEnt[-3])
    return result

def log_Legendre_series(N, a, xn):
    result = 0
    for n in range(0, N+1):
        cn=0
        for k in range(0,n+1):
            cn += (((-1)**(n+k)*np.math.factorial(n+k))/(np.math.factorial(n-k)*(np.math.factorial(k+1))**2))*((k+1)*np.log(a)-1)
        cn = (2*n+1)*cn
        Pan = 0
        for l in range(0,n+1):
            Pan += (((-1)**(n+l)*np.math.factorial(n+l))/(np.math.factorial(n-l)*(a**l)*(np.math.factorial(l))**2))*xn[l,:]
        result += cn*Pan
    return result

def log_Chebyshev_series(N, a, xn):
    x_Cheb = (a/2)*(np.cos(np.pi*(2*(np.arange(N+1)+1)-1)/(2*(N+1)))+1)
    y_Cheb = np.log(x_Cheb)
    Vander = np.vander(x_Cheb, increasing=True)
    c = np.matmul(np.linalg.inv(Vander),y_Cheb)
    Bound = 0#((a-x_Cheb[-1])/2)**(N+1)*1/((2**N)*(N+1)*x_Cheb[-1]**(N+1))
    result = Bound
    for n in range(0, N+1):
        result += c[n]*xn[n,:]
    return result
################################################################################

############################### Entropy Functions ##############################
def multivariate_gauss_hermite_quad_gmm(N,m, ws, mus, sigmas, f):
    """
    Perform multivariate Gauss-Hermite quadrature over a Gaussian mixture model.

    Parameters
    ----------
    m : int
        The number of points in each dimension.
    ws : numpy.ndarray, shape=(K,)
        The mixture weights of the Gaussian mixture model.
    mus : list of numpy.ndarray, each of shape (D, 1)
        The means of the Gaussian components.
    sigmas : list of numpy.ndarray, each of shape (D, D)
        The covariance matrices of the Gaussian components.
    f : callable
        The function to integrate.

    Returns
    -------
    float
        An approximation of the integral of f against the Gaussian mixture model using Gauss-Hermite quadrature.
    """
    
    integral = np.zeros((N+1,1))
    p, w = hermgauss(m)
    x = np.array(np.meshgrid(*[p]*mus[0].shape[0])).T.reshape(-1, mus[0].shape[0])
    w = np.prod(np.array(np.meshgrid(*[w]*mus[0].shape[0])).T.reshape(-1, mus[0].shape[0]), axis=1)
    w = w / np.sqrt(np.pi)**mus[0].shape[0]
    for i in range(len(ws)):
        xi = np.sqrt(2) * np.linalg.cholesky(sigmas[i]) @ x.T + mus[i]
        for n in range(N+1):
            if n == 0 :
                integral[n] += ws[i] * np.sum(w)
            # Compute the quadrature approximation using the function values and weights
            else:
                integral[n] += ws[i] * np.sum((f(xi.T)**n)* w)
    
    return integral


def evaluate_gmm(points, ws, mus, sigmas):
    """
    Evaluates a Gaussian mixture model at given points.

    Arguments:
    points -- A numpy array of shape (N, D) containing N D-dimensional points at which to evaluate the GMM.
    ws -- A numpy array of shape (K,) containing the mixture weights of the GMM.
    mus -- A list of K numpy arrays, each of shape (D,), containing the means of the K Gaussian components of the GMM.
    sigmas -- A list of K numpy arrays, each of shape (D, D), containing the covariance matrices of the K Gaussian components of the GMM.

    Returns:
    A numpy array of shape (N,) containing the GMM evaluated at each of the N points.
    """
    result = np.zeros(points.shape[0])
    for k in range(len(ws)):
        mvn = multivariate_normal(mean=mus[k].ravel(), cov=sigmas[k])
        result += ws[k] * mvn.pdf(points)
    return result

def evaluate_mixture_t(points, means, scales, df, weights):
    """
    Evaluates a mixture of bivariate t-distributions at a given point x without using tensorflow.
    
    Parameters:
    - x (ndarray): array of shape (2,) containing the point at which to evaluate the mixture.
    - means (ndarray): array of shape (K, 2) containing the means of each t-distribution in the mixture.
    - scales (ndarray): array of shape (K, 2, 2) containing the covariance matrices of each t-distribution in the mixture.
    - df (ndarray): array of shape (K,) containing the degrees of freedom of each t-distribution in the mixture.
    - weights (ndarray): array of shape (K,) containing the weights of each t-distribution in the mixture.
    
    Returns:
    - value (float): the value of the mixture at the given point x.
    """
    K = means.shape[0]
    result = np.zeros(points.shape[0])
    for i in range(K):
        loc = means[i]
        cov = scales[i]
        rv = multivariate_t(loc, cov, df[i])
        result += weights[i] * rv.pdf(points)
    return result

def evaluate_mixture_t_n(points,n, means, scales, df, weights):
    """
    Evaluates a mixture of bivariate t-distributions at a given point x without using tensorflow.
    
    Parameters:
    - x (ndarray): array of shape (2,) containing the point at which to evaluate the mixture.
    - means (ndarray): array of shape (K, 2) containing the means of each t-distribution in the mixture.
    - scales (ndarray): array of shape (K, 2, 2) containing the covariance matrices of each t-distribution in the mixture.
    - df (ndarray): array of shape (K,) containing the degrees of freedom of each t-distribution in the mixture.
    - weights (ndarray): array of shape (K,) containing the weights of each t-distribution in the mixture.
    
    Returns:
    - value (float): the value of the mixture at the given point x.
    """
    K = means.shape[0]
    result = np.zeros(points.shape[0])
    for i in range(K):
        loc = means[i]
        cov = scales[i]
        rv = multivariate_t(loc, cov, df[i])
        result += weights[i] * rv.pdf(points)
    return result**n

def evaluate_log_mixture_t(points, means, scales, df, weights):
    """
    Evaluates a mixture of bivariate t-distributions at a given point x without using tensorflow.
    
    Parameters:
    - x (ndarray): array of shape (2,) containing the point at which to evaluate the mixture.
    - means (ndarray): array of shape (K, 2) containing the means of each t-distribution in the mixture.
    - scales (ndarray): array of shape (K, 2, 2) containing the covariance matrices of each t-distribution in the mixture.
    - df (ndarray): array of shape (K,) containing the degrees of freedom of each t-distribution in the mixture.
    - weights (ndarray): array of shape (K,) containing the weights of each t-distribution in the mixture.
    
    Returns:
    - value (float): the value of the mixture at the given point x.
    """
    K = means.shape[0]
    result = np.zeros((points.shape[0],K))
    for i in range(K):
        loc = means[i]
        cov = scales[i]
        rv = multivariate_t(loc, cov, df[i])
        result[:,i] += np.log(weights[i]) + rv.logpdf(points)
    return logsumexp(result,axis=1)


def gmm_power_expected_value_gpu(N, wp, mup, sigmap, wq, muq, sigmaq):
    Mp = len(wp)
    Mq = len(wq)
    Epq = np.zeros((N+1,1))
    D = np.shape(muq[0])[0]
    
    # SInv = []
    # SInvMu = []
    SInv = np.zeros(np.shape(sigmaq))
    SInvMu = np.zeros(np.shape(muq))
    logcomp = np.zeros(Mq)
    for mq in range(Mq):
        # SInv.append(np.linalg.inv(sigmaq[mq]))
        # SInvMu.append(SInv[mq] @ muq[mq])
        SInv[mq] = np.linalg.inv(sigmaq[mq])
        SInvMu[mq] = SInv[mq] @ muq[mq]
        logcomp[mq] = np.log(wq[mq])+multivariate_normal.logpdf(0, mean=muq[mq].ravel(), cov=sigmaq[mq])
    
    lognormmp = np.zeros(Mp)
    for mp in range(Mp):
        lognormmp[mp] = np.log(wp[mp])+multivariate_normal.logpdf(0, mean=mup[mp].ravel(), cov=sigmap[mp])
        
    for n in range(N+1):       
        Nmatrix, NCoef = multinomial_expand_memoized(n,Mq) #       
        for mp in range(Mp):
            hold = np.zeros((len(NCoef),1))
            results = for_loop_gmm_epp_gpu(hold,len(NCoef),mup[mp],sigmap[mp],SInv,SInvMu,Nmatrix,NCoef,logcomp,lognormmp[mp],D)
            # results = calculate_comb_gmm(len(NCoef),len(NCoef),mup[mp],sigmap[mp],SInv,SInvMu,Nmatrix,NCoef,logcomp,lognormmp[mp],D)
            Epq[n] +=  np.sum(results)
    return Epq

# @njit(parallel=True)
def for_loop_gmm_epp_gpu(results,loop_length,mup,sigmap,SInv,SInvMu,Nmatrix,NCoef,logcomp,lognormmp,D):
    for j in prange(loop_length):
        results[j] = calculate_comb_gmm_gpu(j,mup,sigmap,SInv,SInvMu,Nmatrix[j,:],NCoef[j],logcomp,lognormmp,D)
    return results

# @njit#(parallel=True)
def calculate_comb_gmm_gpu(j,mup,sigmap,SInv,SInvMu,Nmatrix,NCoef,logcomp,lognormmp,D):
    SumSigmaInv = np.linalg.inv(sigmap)+sum((Nmatrix*SInv.T).T)#[j,:]
    SumMu = np.linalg.inv(sigmap)@mup+np.sum((Nmatrix*SInvMu.T)[0],axis=1).reshape(np.shape(mup))
    lognormProd = sum(Nmatrix*logcomp)          
    SumSigma = np.linalg.inv(SumSigmaInv)
    SumMu = SumSigma@SumMu
    # D = len(SumMu)
    lognormSum = -0.5*np.log(np.linalg.det(2*np.pi*SumSigma)) - 0.5*SumMu.T@SumSigmaInv@SumMu
    # lognormSum = multivariate_normal.logpdf(0, mean=SumMu.ravel(), cov=SumSigma*np.eye(D))
    Epqj = NCoef*np.exp(lognormmp+lognormProd-lognormSum)
    return Epqj


# from concurrent.futures import ThreadPoolExecutor
# ########################## Working threading for parallel comb ########################################
# def gmm_power_expected_value_threading(N, wp, mup, sigmap, wq, muq, sigmaq):
#     Mp = len(wp)
#     Mq = len(wq)
#     Epq = np.zeros((N+1,1))
#     D = np.shape(muq[0])[0]
    
#     SInv = np.zeros(np.shape(sigmaq))
#     SInvMu = np.zeros(np.shape(muq))
#     logcomp = np.zeros(Mq)
#     for mq in range(Mq):
#         SInv[mq] = np.linalg.inv(sigmaq[mq])
#         SInvMu[mq] = SInv[mq] @ muq[mq]
#         logcomp[mq] = np.log(wq[mq])+multivariate_normal.logpdf(0, mean=muq[mq].ravel(), cov=sigmaq[mq])
    
#     lognormmp = np.zeros(Mp)
#     for mp in range(Mp):
#         lognormmp[mp] = np.log(wp[mp])+multivariate_normal.logpdf(0, mean=mup[mp].ravel(), cov=sigmap[mp])
        
#     for n in range(N+1):
#         Nmatrix, NCoef = multinomial_expand_memoized(n,Mq)        
#         for mp in range(Mp):
#             num_threads = 1000  # example number of threads

#             results = []
#             with ThreadPoolExecutor(max_workers=num_threads) as executor:
#                 results = list(executor.map(calculate_comb, range(len(NCoef)), [mup[mp]]*(len(NCoef)), [sigmap[mp]]*(len(NCoef)), [SInv]*(len(NCoef)), [SInvMu]*(len(NCoef)), Nmatrix, NCoef, [logcomp]*(len(NCoef)),[lognormmp[mp]]*(len(NCoef)), [D]*(len(NCoef))))
#             Epq[n] +=  np.sum(np.array(results))
#     return Epq

# def calculate_comb(j,mup,sigmap,SInv,SInvMu,Nmatrix,NCoef,logcomp,lognormmp,D):
#     SumSigmaInv = np.linalg.inv(sigmap)+sum((Nmatrix*SInv.T).T)#[j,:]
#     SumMu1 = np.linalg.inv(sigmap)@mup+np.sum((Nmatrix*SInvMu.T)[0],axis=1).reshape(np.shape(mup))
#     lognormProd = sum(Nmatrix*logcomp)          
#     SumSigma = np.linalg.inv(SumSigmaInv)
#     SumMu = SumSigma@SumMu1
#     lognormSum = -0.5*np.log(np.linalg.det(2*np.pi*SumSigma)) - 0.5*SumMu.T@SumMu1#SumSigmaInv@SumMu
#     Epqj = NCoef*np.exp(lognormmp+lognormProd-lognormSum)
#     return Epqj

def gmm_power_expected_value_parallel(N, wp, mup, sigmap, wq, muq, sigmaq):
    Mp = len(wp)
    Mq = len(wq)
    Epq = np.zeros((N+1,1))
    D = np.shape(muq[0])[0]
    
    SInv = np.zeros(np.shape(sigmaq))
    SInvMu = np.zeros(np.shape(muq))
    logcomp = np.zeros(Mq)
    for mq in range(Mq):
        SInv[mq] = np.linalg.inv(sigmaq[mq])
        SInvMu[mq] = SInv[mq] @ muq[mq]
        logcomp[mq] = np.log(wq[mq])+multivariate_normal.logpdf(0, mean=muq[mq].ravel(), cov=sigmaq[mq])
    
    lognormmp = np.zeros(Mp)
    for mp in range(Mp):
        lognormmp[mp] = np.log(wp[mp])+multivariate_normal.logpdf(0, mean=mup[mp].ravel(), cov=sigmap[mp])
        
    for n in range(N+1):
        Nmatrix, NCoef = multinomial_expand_memoized(n,Mq)        
        for mp in range(Mp):
            num_cores = 1#threading.active_count()#multiprocessing.cpu_count()
            results = Parallel(n_jobs=num_cores, prefer="threads")(delayed(calculate_comb_parallel)(mup[mp],sigmap[mp],SInv,SInvMu,Nmatrix[j],NCoef[j],logcomp,lognormmp[mp]) for j in range(len(NCoef)))
            Epq[n] += np.sum(results)
    return Epq

def calculate_comb_parallel(mup,sigmap,SInv,SInvMu,Nmatrix,NCoef,logcomp,lognormmp):
    SumSigmaInv = np.linalg.inv(sigmap)+sum((Nmatrix*SInv.T).T)#[j,:]
    SumMu1 = np.linalg.inv(sigmap)@mup+np.sum((Nmatrix*SInvMu.T)[0],axis=1).reshape(np.shape(mup))
    lognormProd = sum(Nmatrix*logcomp)          
    SumSigma = np.linalg.inv(SumSigmaInv)
    SumMu = SumSigma@SumMu1
    lognormSum = -0.5*np.log(np.linalg.det(2*np.pi*SumSigma)) - 0.5*SumMu.T@SumMu1#SumSigmaInv@SumMu
    Epqj = NCoef*np.exp(lognormmp+lognormProd-lognormSum)
    return Epqj

################## Working Parallel for Comb #########################################################
# def isotropic_gmm_power_expected_value_test(N, wp, mup, sigmap, wq, muq, sigmaq):
#     Mp = len(wp)
#     Mq = len(wq)
#     Epq = np.zeros((N+1,1))
#     D = np.shape(muq[0])[0]
#     def calculate_comb1(j):
#         Epqj = 0
#         SumSigmaInv = 1/sigmap[mp]+sum(Nmatrix[j,:]*SInv)
#         SumMu = (1/sigmap[mp])*mup[mp]+np.sum((Nmatrix[j,:]*SInvMu.T)[0],axis=1).reshape(np.shape(mup[0]))
#         normProd = np.exp(sum(Nmatrix[j,:]*logcomp))             
#         SumSigma = 1/SumSigmaInv
#         SumMu = SumSigma*SumMu
#         normSum = multivariate_normal.pdf(0, mean=SumMu.ravel(), cov=SumSigma*np.eye(D))
#         Epqj = NCoef[j]*normmp*normProd/normSum  
#         return Epqj
    
#     SInv = np.zeros(Mq)
#     SInvMu = np.zeros(np.shape(muq))
#     logcomp = np.zeros(Mq)
#     for mq in range(Mq):
#         SInv[mq] = 1/sigmaq[mq]
#         SInvMu[mq] = (1/sigmaq[mq])*muq[mq]
#         logcomp[mq] = np.log(wq[mq])+multivariate_normal.logpdf(0, mean=muq[mq].ravel(), cov=sigmaq[mq]*np.eye(D))
        
#     for n in range(N+1):
#         Nmatrix, NCoef = multinomial_expand_memoized(n,Mq)        
#         for mp in range(Mp):
#             normmp = wp[mp]*multivariate_normal.pdf(0, mean=mup[mp].ravel(), cov=sigmap[mp]*np.eye(D))
#             num_cores = 1#multiprocessing.cpu_count()
#             results = Parallel(n_jobs=num_cores, prefer="threads")(delayed(calculate_comb1)(j) for j in range(len(NCoef)))
#             Epq[n] +=  np.sum(np.array(results))
#     return Epq

#################### Isotropic GPU EPP #######################################
def isotropic_gmm_power_expected_value_test(N, wp, mup, sigmap, wq, muq, sigmaq):
    Mp = len(wp)
    Mq = len(wq)
    Epq = np.zeros((N+1,1))
    D = np.shape(muq[0])[0]
    
    SInv = np.zeros(Mq)
    SInvMu = np.zeros(np.shape(muq))
    logcomp = np.zeros(Mq)
    for mq in range(Mq):
        SInv[mq] = 1/sigmaq[mq]
        SInvMu[mq] = SInv[mq]*muq[mq]
        logcomp[mq] = np.log(wq[mq])+multivariate_normal.logpdf(0, mean=muq[mq].ravel(), cov=sigmaq[mq]*np.eye(D))
        
    for n in range(N+1):       
        Nmatrix, NCoef = multinomial_expand_memoized(n,Mq) #       
        for mp in range(Mp): #logcomp[mp]#
            lognormmp = np.log(wp[mp])+multivariate_normal.logpdf(0, mean=mup[mp].ravel(), cov=sigmap[mp]*np.eye(D))
            hold = np.zeros((len(NCoef),1))
            results = for_loop_test(hold,len(NCoef),mup[mp],sigmap[mp],SInv,SInvMu,Nmatrix,NCoef,logcomp,lognormmp,D)
            Epq[n] +=  np.sum(results)
    return Epq

@njit(parallel=True)
def for_loop_test(results,loop_length,mup,sigmap,SInv,SInvMu,Nmatrix,NCoef,logcomp,lognormmp,D):
    for j in prange(loop_length):
        results[j] = calculate_comb_gpu(j,mup,sigmap,SInv,SInvMu,Nmatrix[j,:],NCoef[j],logcomp,lognormmp,D)
    return results

@njit#(parallel=True)
def calculate_comb_gpu(j,mup,sigmap,SInv,SInvMu,Nmatrix,NCoef,logcomp,lognormmp,D):
    Epqj = 0
    SumSigmaInv = 1/sigmap+sum(Nmatrix*SInv)#[j,:]
    SumMu = (1/sigmap)*mup+np.sum((Nmatrix*SInvMu.T)[0],axis=1).reshape(np.shape(mup))
    lognormProd = sum(Nmatrix*logcomp)          
    SumSigma = 1/SumSigmaInv
    SumMu = SumSigma*SumMu
    D = len(SumMu)
    lognormSum = -0.5*D*np.log(2*np.pi*SumSigma) - 0.5*SumSigmaInv*SumMu.T.dot(SumMu)
    # lognormSum = multivariate_normal.logpdf(0, mean=SumMu.ravel(), cov=SumSigma*np.eye(D))
    Epqj = NCoef*np.exp(lognormmp+lognormProd-lognormSum)
    return Epqj
####################################################################################

@lru_cache(maxsize=None)
def multinomial_powers_recursive_memoized(pow, dim):
    if dim == 1:
        return np.array([[pow]])
    else:
        Nmatrix = []
        for pow_on_x1 in range(pow+1):
            newsubterms = multinomial_powers_recursive_memoized(pow-pow_on_x1,dim-1)
            new = np.hstack((pow_on_x1*np.ones((np.shape(newsubterms)[0],1)),newsubterms))
            if len(Nmatrix)==0:#Nmatrix == []:
                Nmatrix = new
            else:
                Nmatrix =np.vstack((Nmatrix, new))
            # Nmatrix = [Nmatrix; [pow_on_x1*ones(np.shape(newsubterms,1),1) , newsubterms] ]
        return Nmatrix
    
def multinomial_expand_memoized(pow,dim):
    NMatrix = multinomial_powers_recursive_memoized(pow,dim)
    powvec = np.matlib.repmat(pow,np.shape(NMatrix)[0],1)
    NCoef = np.floor(np.exp(gammaln(powvec+1).flatten() - np.sum(gammaln(NMatrix+1),1))+0.5)
    return NMatrix, NCoef

def generate_vandermonde(x, m):
    return np.vander(x, m+1, increasing=True).T
################################################################################

################## Likely Not Needed #######################################
# def multinomial_expand(pow,dim):
#     ############ https://www.mathworks.com/matlabcentral/fileexchange/48215-multinomial-expansion
#     NMatrix = multinomial_powers_recursive(pow,dim)
#     powvec = np.matlib.repmat(pow,np.shape(NMatrix)[0],1)
#     NCoef = np.floor(np.exp(gammaln(powvec+1).flatten() - np.sum(gammaln(NMatrix+1),1))+0.5)
#     return NMatrix, NCoef

# def multinomial_powers_recursive(pow,dim):
#     if dim == 1:
#         Nmatrix = np.array([[pow]])
#     else:
#         Nmatrix = []
#         for pow_on_x1 in range(pow+1):
#             newsubterms = multinomial_powers_recursive(pow-pow_on_x1,dim-1)
#             new = np.hstack((pow_on_x1*np.ones((np.shape(newsubterms)[0],1)),newsubterms))
#             if len(Nmatrix)==0:#Nmatrix == []:
#                 Nmatrix = new
#             else:
#                 Nmatrix =np.vstack((Nmatrix, new))
#             # Nmatrix = [Nmatrix; [pow_on_x1*ones(np.shape(newsubterms,1),1) , newsubterms] ]
#     return Nmatrix

############### Working threading for parallel N #########################
# def isotropic_gmm_power_expected_value_test(N, wp, mup, sigmap, wq, muq, sigmaq):
#     Mp = len(wp)
#     Mq = len(wq)
#     Epq = np.zeros((N+1,1))
#     D = np.shape(muq[0])[0]
#     num_threads = 100  # example number of threads

#     results = []
#     with ThreadPoolExecutor(max_workers=num_threads) as executor:
#         results = list(executor.map(calculate_epq, range(N+1), [wp]*(N+1), [mup]*(N+1), [sigmap]*(N+1), [wq]*(N+1), [muq]*(N+1), [sigmaq]*(N+1), [Mp]*(N+1), [Mq]*(N+1), [D]*(N+1)))
#         Epq = np.array(results).reshape((N+1,1))
#     return Epq

# def gmm_power_expected_value(N, wp, mup, sigmap, wq, muq, sigmaq):
#     Mp = len(wp)
#     Mq = len(wq)
#     Epq = np.zeros((N+1,1))
#     for n in range(N+1):
#         for mp in range(Mp):
#             normmp = wp[mp]*multivariate_normal.pdf(0, mean=mup[mp].ravel(), cov=sigmap[mp])
#             Nmatrix, NCoef = multinomial_expand(n,Mq)
#             for j in range(len(NCoef)):
#                 SumSigmaInv = np.linalg.inv(sigmap[mp])
#                 SumMu = np.matmul(np.linalg.inv(sigmap[mp]),mup[mp])
#                 normProd = 1             
#                 for mq in range(Mq):
#                     SumSigmaInv += Nmatrix[j,mq]*np.linalg.inv(sigmaq[mq])
#                     SumMu += Nmatrix[j,mq]*np.matmul(np.linalg.inv(sigmaq[mq]),muq[mq])
#                     normProd *= (wq[mq]*multivariate_normal.pdf(0, mean=muq[mq].ravel(), cov=sigmaq[mq]))**Nmatrix[j,mq]
#                 SumSigma = np.linalg.inv(SumSigmaInv)
#                 SumMu = np.matmul(SumSigma,SumMu)
#                 normSum = multivariate_normal.pdf(0, mean=SumMu.ravel(), cov=SumSigma)
#                 Epq[n]+=NCoef[j]*normmp*normProd/normSum       
#     return Epq

# def multivariate_gauss_hermite_quad_gmm_old(N,m, ws, mus, sigmas, f):
#     """
#     Perform multivariate Gauss-Hermite quadrature over a Gaussian mixture model.

#     Parameters
#     ----------
#     m : int
#         The number of points in each dimension.
#     ws : numpy.ndarray, shape=(K,)
#         The mixture weights of the Gaussian mixture model.
#     mus : list of numpy.ndarray, each of shape (D, 1)
#         The means of the Gaussian components.
#     sigmas : list of numpy.ndarray, each of shape (D, D)
#         The covariance matrices of the Gaussian components.
#     f : callable
#         The function to integrate.

#     Returns
#     -------
#     float
#         An approximation of the integral of f against the Gaussian mixture model using Gauss-Hermite quadrature.
#     """
    
#     integral = np.zeros((N+1,1))
#     for n in range(N+1):
#         for i in range(len(ws)):
#             p, w = hermgauss(m)
#             x = np.array(np.meshgrid(*[p]*mus[i].shape[0])).T.reshape(-1, mus[i].shape[0])
#             w = np.prod(np.array(np.meshgrid(*[w]*mus[i].shape[0])).T.reshape(-1, mus[i].shape[0]), axis=1)

#             xi = np.sqrt(2) * np.linalg.cholesky(sigmas[i]) @ x.T + mus[i]
#             w = w / np.sqrt(np.pi)**mus[i].shape[0]
#             # Compute the quadrature approximation using the function values and weights
#             integral[n] += ws[i] * np.sum((f(xi.T)**n)* w)
    
#     return integral

# def gmm_power_expected_value1(N, wp, mup, sigmap, wq, muq, sigmaq):
#     Mp = len(wp)
#     Mq = len(wq)
#     Epq = np.zeros((N+1,1))
#     for n in range(N+1):
#         for mp in range(Mp):
#             normmp = wp[mp]*multivariate_normal.pdf(0, mean=mup[mp].ravel(), cov=sigmap[mp])
#             Nmatrix, NCoef = multinomial_expand_memoized(n,Mq)
#             for j in range(len(NCoef)):
#                 SumSigmaInv = np.linalg.inv(sigmap[mp])
#                 SumMu = np.matmul(np.linalg.inv(sigmap[mp]),mup[mp])
#                 normProd = 1             
#                 for mq in range(Mq):
#                     SumSigmaInv += Nmatrix[j,mq]*np.linalg.inv(sigmaq[mq])
#                     SumMu += Nmatrix[j,mq]*np.matmul(np.linalg.inv(sigmaq[mq]),muq[mq])
#                     normProd *= (wq[mq]*multivariate_normal.pdf(0, mean=muq[mq].ravel(), cov=sigmaq[mq]))**Nmatrix[j,mq]
#                 SumSigma = np.linalg.inv(SumSigmaInv)
#                 SumMu = np.matmul(SumSigma,SumMu)
#                 normSum = multivariate_normal.pdf(0, mean=SumMu.ravel(), cov=SumSigma)
#                 Epq[n]+=NCoef[j]*normmp*normProd/normSum       
#     return Epq

################### BEST SO FAR ###############################################
# def isotropic_gmm_power_expected_value(N, wp, mup, sigmap, wq, muq, sigmaq):
#     Mp = len(wp)
#     Mq = len(wq)
#     Epq = np.zeros((N+1,1))
#     D = np.shape(muq[0])[0]
#     num_cores = 1#threading.active_count()#multiprocessing.cpu_count()
#     results = Parallel(n_jobs=num_cores, prefer="threads")(delayed(calculate_epq)(n,wp, mup, sigmap, wq, muq, sigmaq,Mp,Mq,D) for n in range(N+1))
#     Epq = np.array(results).reshape((N+1,1))
#     return Epq

# def calculate_epq(n,wp, mup, sigmap, wq, muq, sigmaq,Mp,Mq,D):
#     epq = 0
#     SInv = np.zeros(Mq)
#     SInvMu = np.zeros(np.shape(muq))
#     logcomp = np.zeros(Mq)
#     for mq in range(Mq):
#         SInv[mq] = 1/sigmaq[mq]
#         SInvMu[mq] = (1/sigmaq[mq])*muq[mq]
#         logcomp[mq] = np.log(wq[mq])+multivariate_normal.logpdf(0, mean=muq[mq].ravel(), cov=sigmaq[mq]*np.eye(D))
#     Nmatrix, NCoef = multinomial_expand_memoized(n,Mq)            
#     for mp in range(Mp):
#         lognormmp = np.log(wp[mp])+multivariate_normal.logpdf(0, mean=mup[mp].ravel(), cov=sigmap[mp]*np.eye(D))
#         for j in range(len(NCoef)):
#             SumSigmaInv = 1/sigmap[mp]+sum(Nmatrix[j,:]*SInv)
#             SumMu = (1/sigmap[mp])*mup[mp]+np.sum((Nmatrix[j,:]*SInvMu.T)[0],axis=1).reshape(np.shape(mup[0]))
#             lognormProd = sum(Nmatrix[j,:]*logcomp)             
#             SumSigma = 1/SumSigmaInv
#             SumMu = SumSigma*SumMu
#             lognormSum = multivariate_normal.logpdf(0, mean=SumMu.ravel(), cov=SumSigma*np.eye(D))
#             epq += NCoef[j]*np.exp(lognormmp+lognormProd-lognormSum)  #### Need to change to log     
#     return epq