from operator import matmul
import numpy as np
import math
from scipy.optimize import fsolve
import numpy.matlib
import scipy.stats
import scipy.linalg as scilin
from scipy.stats import multivariate_normal
from scipy.special import logsumexp
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import time
import random
from matplotlib.patches import Ellipse
from sklearn import mixture
import matplotlib.transforms as transforms
from functools import partial
import tensorflow as tf
from tensorflow.python.ops.numpy_ops import np_config
from scipy import optimize
np_config.enable_numpy_behavior()
import util
# import os
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
###################################################################################
###########################Outline#################################################
# 1) Generate a Gaussian Mixture Model Parameters
# 2) Numerically Integrate MI
#       a) Break into seperate Entropy terms
# 3) Create General Variational approx. computations
#       a) Marginal and Conditional
# 4) Compute Moment Matching
# 5) Compute Gradient Ascent
# 6) Compute Barber and Agakov
# 8) Plot Samples and approximation
###################################################################################

def GaussianMixtureParams(M,Dx,Dy):
    ###############################################################################
    # Outline: Randomly Generates Parameters for GMM
    #
    # Inputs:
    #       M - Number of components
    #       Dx - Number of dimensions for Latent Variable, X
    #       Dy - Number of dimensions for Observation Variable, Y
    #
    # Outputs:
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    ###############################################################################
    D = Dx+Dy
    w = np.random.dirichlet(np.ones(M))
    mu = []
    sigma = []
    for d in range(M):
        mu.append(np.random.uniform(-5,5,(D,1)))
        A = np.random.rand(D, D)
        B = np.dot(A, A.transpose())
        sigma.append(B)
        # mean = np.zeros((D,1))
        # cov = 1*np.eye(D)+40*np.ones((D,D))
        # mu.append(np.random.multivariate_normal(mean.flatten(),cov).reshape(D,1))
        # B = 1*np.eye(D)+np.random.uniform(1,30)*np.ones((D,D))
        # sigma.append(B)
    return w,mu,sigma

def SampleGMM(N,w,mu,sigma):
    ###############################################################################
    # Outline: Samples Points from a GMM
    #
    # Inputs:
    #       N - Number of points to sample
    #       w - weights of GMM components
    #       mu - means of GMM components
    #       Sigma - Variance of GMM components
    #
    # Outputs:
    #       samples - coordniates of sampled points
    ###############################################################################
    samples = np.zeros((N,len(mu[0])))
    for j in range(N):
        acc_pis = [np.sum(w[:i]) for i in range(1, len(w)+1)]
        r = np.random.uniform(0, 1)
        k = 0
        for i, threshold in enumerate(acc_pis):
            if r < threshold:
                k = i
                break
        x = np.random.multivariate_normal(mu[k].T.tolist()[0],sigma[k].tolist())
        samples[j,:] = x
    return samples

def MargEntGMM(N,L,Dx,w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    x = np.linspace(-L,L,N)
    if Dx == 1:
        X=x 
    else:
        X1, X2 = np.meshgrid(x,x)
        X = np.vstack((X1.flatten(),X2.flatten()))

    MargEntPart = np.zeros((M,len(X.T)))
    for d in range(M):
        MargEntPart[d,:] = multivariate_normal.logpdf(X.T,mu[d][0:Dx].T.tolist()[0],Sigma[d][0:Dx,0:Dx])+np.log(w[d])
    if Dx == 1:
        MargEnt = -1*sum(np.sum(np.exp(MargEntPart),axis=0)*logsumexp(MargEntPart,axis=0))*2*L/N
    else:
        MargEnt = -1*sum(np.sum(np.exp(MargEntPart),axis=0)*logsumexp(MargEntPart,axis=0))*(2*L/N)**2
    return MargEnt

def MargEntGMMTaylor(N,Dx,w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    TaylorEnt = 0
    Scale = np.zeros((M,1))
    Sigma_inv = []
    for i in range(M):
        Scale[i] = w[i]*np.linalg.det(2*np.pi*Sigma[i])**(-1/2)
        Sigma_inv.append(np.linalg.inv(Sigma[i]))
    MaxConst = np.sum(Scale)/4
    for i in range(M):
        outer = 0 
        for n in range(1,N+1):
            middle = 0
            for k in range(n+1):
                Nmatrix, NCoef = multinomial_expand(n-k,M)#
                # numtermscheck = math.comb(n-k+M-1,M-1)
                inner = 0
                for t in range(len(NCoef)):
                    SumSigmaInv = np.matmul(Sigma_inv[i],np.eye(len(Sigma_inv[i])))
                    SumMu = np.matmul(Sigma_inv[i],mu[i])
                    SumInd = np.matmul(mu[i].T,np.matmul(Sigma_inv[i],mu[i]))
                    for j in range(M):
                        SumSigmaInv += Nmatrix[t,j]*Sigma_inv[j]
                        SumMu += Nmatrix[t,j]*np.matmul(Sigma_inv[j],mu[j])
                        SumInd += Nmatrix[t,j]*np.matmul(mu[j].T,np.matmul(Sigma_inv[j],mu[j]))
                    SumSigma = np.linalg.inv(SumSigmaInv)
                    SumMu = np.matmul(SumSigma,SumMu)
                    expTerm = np.exp(-.5*(-1*np.matmul(SumMu.T,np.matmul(SumSigmaInv,SumMu))+SumInd))
                    inner += NCoef[t]*np.prod((Scale.T)**Nmatrix[t])*expTerm*(np.linalg.det(2*np.pi*SumSigma)**(1/2))*(Scale[i]/w[i])
                combcoef = math.comb(n,k)
                middle += combcoef*inner*(-MaxConst)**k
            outer += (-1)**(n-1)/(n*MaxConst**n)*middle
        TaylorEnt += -w[i]*(np.log(MaxConst)+outer)
    return TaylorEnt

def MargEntGMMLimit(N,Dx,w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    TaylorEnt = np.zeros((1,N+1))
    Scale = np.zeros((M,1))
    Sigma_inv = []
    
    # MaxConst = 0
    # x = np.linspace(-3,3,1000)
    # X,Y = np.meshgrid(x,x) 
    # XX = np.hstack((X.reshape((len(X.flatten()),1)),Y.reshape((len(Y.flatten()),1))))
    for i in range(M):
        # MaxConst += w[i]*multivariate_normal.pdf(XX,mu[i].flatten(),Sigma[i])
        Scale[i] = w[i]*np.linalg.det(2*np.pi*Sigma[i])**(-1/2)
        Sigma_inv.append(np.linalg.inv(Sigma[i]))
    MaxConst = np.sum(Scale)
    TaylorEnt[0,0] += -w[i]*np.log(MaxConst)
    # MaxConst = np.max(MaxConst)/2
    for i in range(M):
        outer = 0 
        for n in range(1,N+1):
            middle = 0
            for k in range(n+1):
                Nmatrix, NCoef = multinomial_expand(n-k,M)#
                # numtermscheck = math.comb(n-k+M-1,M-1)
                inner = 0
                for t in range(len(NCoef)):
                    SumSigmaInv = np.matmul(Sigma_inv[i],np.eye(len(Sigma_inv[i])))
                    SumMu = np.matmul(Sigma_inv[i],mu[i])
                    SumInd = np.matmul(mu[i].T,np.matmul(Sigma_inv[i],mu[i]))
                    for j in range(M):
                        SumSigmaInv += Nmatrix[t,j]*Sigma_inv[j]
                        SumMu += Nmatrix[t,j]*np.matmul(Sigma_inv[j],mu[j])
                        SumInd += Nmatrix[t,j]*np.matmul(mu[j].T,np.matmul(Sigma_inv[j],mu[j]))
                    SumSigma = np.linalg.inv(SumSigmaInv)
                    SumMu = np.matmul(SumSigma,SumMu)
                    expTerm = np.exp(-.5*(-1*np.matmul(SumMu.T,np.matmul(SumSigmaInv,SumMu))+SumInd))
                    inner += NCoef[t]*np.prod((Scale.T)**Nmatrix[t])*expTerm*(np.linalg.det(2*np.pi*SumSigma)**(1/2))*(Scale[i]/w[i])
                combcoef = math.comb(n,k)
                middle += combcoef*inner*(-MaxConst)**k
            outer += (-1)**(n-1)/(n*MaxConst**n)*middle
            TaylorEnt[0,n] += -w[i]*(np.log(MaxConst)+outer)
        # UpperBound = -w[i]*np.log(MaxConst)
    TaylorLimit = TaylorEnt[0,-3]-(TaylorEnt[0,-2]-TaylorEnt[0,-3])**2/(TaylorEnt[0,-1]-2*TaylorEnt[0,-2]+TaylorEnt[0,-3])
    # TaylorLimit = TaylorEnt[0,0]-(TaylorEnt[0,3]-TaylorEnt[0,0])**2/(TaylorEnt[0,8]-2*TaylorEnt[0,3]+TaylorEnt[0,0])
    # d=.43
    # def equations(p):
    #     c = p#, d
    #     return ((((N)**d-(N-3)**d)/((N-2)**d-(N-3)**d))*np.log((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))-np.log((TaylorEnt[0,-1]-c)/(TaylorEnt[0,-4]-c))- ((((N-1)**d-(N-3)**d)/((N-2)**d-(N-3)**d))*np.log((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))-np.log((TaylorEnt[0,-2]-c)/(TaylorEnt[0,-4]-c))))
    # def equations(p):
    #     c = p#
    #     return (((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))**(((N-1)**d-(N-3)**d)/((N-2)**d-(N-3)**d))-((TaylorEnt[0,-2]-c)/(TaylorEnt[0,-4]-c))-(((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))**(((N)**d-(N-3)**d)/((N-2)**d-(N-3)**d))-((TaylorEnt[0,-1]-c)/(TaylorEnt[0,-4]-c))))
    # TaylorLimit=  fsolve(equations, (TaylorEnt[0,-1]))
    
    # def equations(p):
    #     c, d = p#
    #     return ((((N)**d-(N-3)**d)/((N-2)**d-(N-3)**d))*np.log((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))-np.log((TaylorEnt[0,-1]-c)/(TaylorEnt[0,-4]-c)), ((((N-1)**d-(N-3)**d)/((N-2)**d-(N-3)**d))*np.log((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))-np.log((TaylorEnt[0,-2]-c)/(TaylorEnt[0,-4]-c))))
    # def equations(p):
    #     c, d = p#
    #     return (((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))**(((N-1)**d-(N-3)**d)/((N-2)**d-(N-3)**d))-((TaylorEnt[0,-2]-c)/(TaylorEnt[0,-4]-c)),((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))**(((N)**d-(N-3)**d)/((N-2)**d-(N-3)**d))-((TaylorEnt[0,-1]-c)/(TaylorEnt[0,-4]-c)))
    # TaylorLimit, d =  fsolve(equations, (TaylorEnt[0,-1]+.01, .5))#
    # TaylorLimit = TaylorEnt[0,n-1]
    
    # testa = 2*(TaylorEnt[0,2]-TaylorEnt[0,1])/(TaylorEnt[0,1]-TaylorEnt[0,0])
    # TaylorLimit = TaylorEnt[0,0] -np.log((testa-1)/testa)
    
    UpperBound = TaylorEnt[0,n-1]
    # UpperBound = UpperBoundTest(N,MaxConst,w,mu,Sigma)
    # UpperBound = np.log(2)
    # for i in range(1,N+1):
    #     UpperBound += (-1)**(i)/i
    # UpperBound = TaylorEnt[0,-1]+UpperBound
    return TaylorEnt, TaylorLimit, UpperBound

def UpperBoundTest(N,a,w,mu,Sigma):
    Dx=2
    M = len(w)
    prob=np.zeros((M,1))
    for d in range(M):
        prob += multivariate_normal.pdf(X.T,mu[d][0:Dx].T.tolist()[0],Sigma[d][0:Dx,0:Dx])
    prob = prob[prob>.0000000001]
    greater = np.where(prob>=a)
    less = np.where(prob<a)
    first = prob[greater]*((prob[greater]-a)/a)**(N)
    second = prob[less]*((prob[less]-a)/prob[less])**(N)
    Bound = 1/N*(np.sum(first)+np.sum(second))
    return Bound

def multinomial_expand(pow,dim):
    ############ https://www.mathworks.com/matlabcentral/fileexchange/48215-multinomial-expansion
    NMatrix = multinomial_powers_recursive(pow,dim)
    powvec = np.matlib.repmat(pow,np.shape(NMatrix)[0],1)
    NCoef = np.floor(np.exp(scipy.special.gammaln(powvec+1).flatten() - np.sum(scipy.special.gammaln(NMatrix+1),1))+0.5)
    return NMatrix, NCoef

def multinomial_powers_recursive(pow,dim):
    if dim == 1:
        Nmatrix = np.array([[pow]])
    else:
        Nmatrix = []
        for pow_on_x1 in range(pow+1):
            newsubterms = multinomial_powers_recursive(pow-pow_on_x1,dim-1)
            new = np.hstack((pow_on_x1*np.ones((np.shape(newsubterms)[0],1)),newsubterms))
            if len(Nmatrix)==0:#Nmatrix == []:
                Nmatrix = new
            else:
                Nmatrix =np.vstack((Nmatrix, new))
            # Nmatrix = [Nmatrix; [pow_on_x1*ones(np.shape(newsubterms,1),1) , newsubterms] ]
    return Nmatrix

def MargEntGMMClosed(Dx,w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    ClosedTaylorEnt = 0
    Normals = np.zeros((M,1))
    maxNormals = np.zeros((M,1))
    Scale = np.zeros((M,1))
    Sigma_sum = 0
    mu_sum = 0
    for i in range(M):
        maxNormals[i] = multivariate_normal.pdf(np.zeros((Dx,1)).flatten(),mu[i].flatten(),Sigma[i])
        Normals[i] = multivariate_normal.logpdf(np.zeros((Dx,1)).flatten(),mu[i].flatten(),Sigma[i])+np.log(w[i])
        Sigma_sum += np.linalg.inv(Sigma[i])
        mu_sum += np.matmul(np.linalg.inv(Sigma[i]),mu[i])
        Scale[i] = w[i]*np.linalg.det(2*np.pi*Sigma[i])**(-1/2)
    maxNorm = np.max(maxNormals)
    MaxConst = np.sum(Scale)/2
    Sigma_sum = (1/M)*Sigma_sum
    mu_sum = (1/M)*mu_sum
    for i in range(M):
        Ni = multivariate_normal.pdf(np.zeros((Dx,1)).flatten(),mu[i].flatten(),Sigma[i])/maxNorm
        first = (1-Ni)*np.log(MaxConst)
        second = Ni*logsumexp(Normals)
        Sigma_hat = np.linalg.inv(np.linalg.inv(Sigma[i])+Sigma_sum)
        mu_hat = np.matmul(Sigma_hat,(np.matmul(np.linalg.inv(Sigma[i]),mu[i])+mu_sum))
        third = Ni*np.log(maxNorm)#multivariate_normal.logpdf(np.zeros((Dx,1)).flatten(),mu_hat.flatten(),Sigma_hat)
        ClosedTaylorEnt += -w[i]*(first+second-third)
    return ClosedTaylorEnt

def HuberTaylor0(w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    TaylorEnt = 0
    logNorms = np.zeros((M,1))
    for i in range(M):
        logNorms = np.zeros((M,1))
        for j in range(M):
            logNorms[j] = multivariate_normal.logpdf(mu[i].flatten(),mu[j].flatten(),Sigma[j])+np.log(w[j])
        TaylorEnt += -1*w[i]*logsumexp(logNorms)
    return TaylorEnt

def HuberTaylor2(w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    TaylorEnt = HuberTaylor0(w,mu,Sigma)
    logNorms = np.zeros((M,1))
    for i in range(M):
        logNorms = np.zeros((M,1))
        F=0
        # f = w[i]*multivariate_normal.pdf(mu[i].flatten(),mu[i].flatten(),Sigma[i])
        # df = np.matmul(np.linalg.inv(Sigma[i]),(mu[i]-mu[i]))*w[i]*multivariate_normal.pdf(mu[i].flatten(),mu[i].flatten(),Sigma[i])
        f=0
        df = 0
        for j in range(M):
            f += w[j]*multivariate_normal.pdf(mu[i].flatten(),mu[j].flatten(),Sigma[j])
            df += -np.matmul(np.linalg.inv(Sigma[j]),(mu[i]-mu[j]))*w[j]*multivariate_normal.pdf(mu[i].flatten(),mu[j].flatten(),Sigma[j])
        for k in range(M):
            F += w[k]*np.matmul(np.linalg.inv(Sigma[k]),((1/f)*np.matmul((mu[i]-mu[k]),df.T)+(mu[i]-mu[k])*np.matmul(np.linalg.inv(Sigma[k]),(mu[i]-mu[k])).T-np.eye(len(Sigma[i]))))*multivariate_normal.pdf(mu[i].flatten(),mu[k].flatten(),Sigma[k])
        F = (1/f)*F
        TaylorEnt += -1*(w[i]/2)*np.sum((F*Sigma[i]))
    return TaylorEnt

def HuberTaylorN(N,w,mu,Sigma):
    M = len(w)
    # HuberTaylorEnt = HuberTaylor1(w,mu,Sigma)
    # HuberTaylorEnt = 0
    HuberTaylorEnt = np.zeros((1,N+1))
    # derivatives = np.zeros((10,1))
    for i in range(M):
        derivatives = logGMMautoGrad10(mu[i],w,mu,Sigma)
        for n in range(N+1):
            if n==0:
                HuberTaylorEnt[0,n] += -w[i]*derivatives[n]
            if n%2==0 and not n==0:
                # HuberTaylorEnt += -w[i]*(doublefactorial(n-1)/np.math.factorial(n))*np.sum(np.matmul(derivatives[n].T,Sigma[i]**(n/2)))
                HuberTaylorEnt[0,n] += -w[i]*(doublefactorial(n-1)/np.math.factorial(n))*np.sum(derivatives[n]*Sigma[i]**(n/2))
    return HuberTaylorEnt

def doublefactorial(n):
     if n <= 0:
         return 1
     else:
         return n * doublefactorial(n-2)

def logGMMautoGrad10(center,w,mu,Sigma):
    M =len(w)
    x = tf.Variable(center, dtype='float32')
    w = tf.convert_to_tensor(w, dtype=tf.float32)
    mu = tf.convert_to_tensor(mu, dtype=tf.float32)
    Sigma = tf.convert_to_tensor(Sigma, dtype=tf.float32)
    pi = tf.constant(np.pi)
    # with tf.GradientTape(persistent=True) as t16:
    #     with tf.GradientTape(persistent=True) as t15:
    #         with tf.GradientTape(persistent=True) as t14:
    #             with tf.GradientTape(persistent=True) as t13:
    # with tf.GradientTape(persistent=True) as t12:
    #     with tf.GradientTape(persistent=True) as t11:
    #         with tf.GradientTape(persistent=True) as t10:
    #             with tf.GradientTape(persistent=True) as t9:
    with tf.GradientTape(persistent=True) as t8:
        with tf.GradientTape(persistent=True) as t7:
            with tf.GradientTape(persistent=True) as t6:
                with tf.GradientTape(persistent=True) as t5:
                    with tf.GradientTape(persistent=True) as t4:
                        with tf.GradientTape(persistent=True) as t3:
                            with tf.GradientTape(persistent=True) as t2:
                                with tf.GradientTape(persistent=True) as t1:
                                    with tf.GradientTape(persistent=True) as t0:
                                        y=0
                                        for j in range(M):
                                            # y += -.5*tf.math.log(tf.linalg.det(2*pi*Sigma[j]))-.5*tf.linalg.matmul((x-mu[j]),tf.linalg.matmul(Sigma[j],(x-mu[j])),transpose_a=True)+tf.math.log(w[j])
                                            y += w[j]*tf.linalg.det(2*pi*Sigma[j])**(-.5)*tf.math.exp(-.5*tf.linalg.matmul((x-mu[j]),tf.linalg.matmul(tf.linalg.inv(Sigma[j]),(x-mu[j])),transpose_a=True))
                                        y = tf.math.log(y)
                        #             dy = t0.jacobian(y,x)
                        #         dy2 = t1.jacobian(dy,x)
                        #     dy3 = t2.jacobian(dy2,x)
                        # dy4 = t3.jacobian(dy3,x)
                                    dy = t0.gradient(y,x)
                                dy2 = t1.gradient(dy,x)
                            dy3 = t2.gradient(dy2,x)
                        dy4 = t3.gradient(dy3,x)
                    dy5 = t4.gradient(dy4,x)
                dy6 = t5.gradient(dy5,x)
            dy7 = t6.gradient(dy6,x)
        dy8 = t7.gradient(dy7,x)
    dy9 = t8.gradient(dy8,x)
    #             dy10 = t9.gradient(dy9,x)
    #         dy11 = t10.gradient(dy9,x)
    #     dy12 = t11.gradient(dy9,x)
    # dy13 = t12.gradient(dy9,x)
    #             dy14 = t13.gradient(dy9,x)
    #         dy15 = t14.gradient(dy9,x)
    #     dy16 = t15.gradient(dy9,x)
    # dy17 = t16.gradient(dy9,x)
    return y.numpy(), dy.numpy(), dy2.numpy(), dy3.numpy(), dy4.numpy(), dy5.numpy(), dy6.numpy(), dy7.numpy(), dy8.numpy(), dy9.numpy()#, dy10.numpy(), dy11.numpy(), dy12.numpy(), dy13.numpy()#, dy14.numpy(), dy15.numpy(), dy16.numpy(), dy17.numpy()

def HuberExample(K,Ns):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       Dy - Dimension of Obsevation Variable, Y
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       CondEnt - Conditional Entropy
    ###############################################################################
    Dx = 2
    M = 5
    TrueEnt = np.zeros((K,1))
    TaylorEnt = np.zeros((K,len(Ns)))
    TaylorLimit = np.zeros((K,1))
    UpperBound = np.zeros((K,1))
    HuberEnt = np.zeros((K,len(Ns)))
    i=0
    for c in np.linspace(-3,3,K):    
            
        #Huber Example
        ws = np.array([0.2, 0.2, 0.2, 0.2 , 0.2])
        mus = [np.array([[0],[0]]), np.array([[3],[2]]), np.array([[1],[-.5]]), np.array([[2.5],[1.5]]),np.array([[c],[c]])]
        sigmas = [np.diag((.16,1)),np.diag((1,.16)),np.diag((.5,.5)),np.diag((.5,.5)),np.diag((.5,.5))]
        
        # c1 = 1/(6*np.pi)
        # ws = np.array([0.05, 0.05, 0.05, 0.05 , 0.8])
        # mus = [np.array([[0],[0]]), np.array([[3],[2]]), n p.array([[1],[-.5]]), np.array([[2.5],[1.5]]),np.array([[c],[c]])]
        # sigmas = [np.diag((.16,1)),np.diag((1,.16)),np.diag((.5,.5)),np.diag((.5,.5)),np.diag((c1,c1))]
        
        HuberEnt[i,0] = HuberTaylor0(ws,mus,sigmas)
        HuberEnt[i,1] = HuberEnt[i,0]
        HuberEnt[i,2] = HuberTaylor2(ws,mus,sigmas)
        HuberEnt[i,3] = HuberEnt[i,2]
        TaylorEnt[i,:], TaylorLimit[i], UpperBound[i] = MargEntGMMLimit(Ns[-1],Dx,ws,mus,sigmas)
        TrueEnt[i] = MargEntGMM(1000,50,Dx,ws,mus,sigmas)
        i+=1
    return TrueEnt, TaylorEnt, TaylorLimit, HuberEnt #, EntTaylor

def HuberCounterExample(K,Ns):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       Dy - Dimension of Obsevation Variable, Y
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       CondEnt - Conditional Entropy
    ###############################################################################
    Dx = 1
    M = 2
    TrueEnt = np.zeros((K,1))
    TaylorEnt = np.zeros((K,len(Ns)))
    TaylorLimit = np.zeros((K,1))
    UpperBound = np.zeros((K,1))
    HuberEnt = np.zeros((K,len(Ns)))
    i=0
    for c in np.linspace(-3,3,K):
        c1 = ((c+3)/6)+.01
        ws = np.array([0.35, 0.65])
        mus = [np.array([[-2]]), np.array([[-1]])]#
        sigmas = [np.array([[2]]),np.array([[c1]])]
        # fig1,fig2 = plotGMM(ws,mus,sigmas)

        HuberEnt[i,:] = HuberTaylorN(Ns[-1],ws,mus,sigmas)
        TaylorEnt[i,:], TaylorLimit[i], UpperBound[i] = MargEntGMMLimit(Ns[-1],Dx,ws,mus,sigmas)
        TrueEnt[i] = MargEntGMM(1000,50,Dx,ws,mus,sigmas)
        i+=1
    HuberEnt = np.cumsum(HuberEnt,axis=1)
    return TrueEnt, TaylorEnt, TaylorLimit, HuberEnt


def plotGMMTaylor(N,w,mu,Sigma):
    M= len(w)
    
    ######################## True log(p(x)) ##############################
    x = np.linspace(-7,3,1400)
    GMM = 0
    for i in range(M):
        GMM += w[i]*multivariate_normal.pdf(x,mu[i].flatten(),Sigma[i])
    logGMM =GMM# np.log(GMM)#

    ####################### Huber Approx. ################################
    TaylorGMM = np.zeros((N,len(x)))
    TaylorGMMSeperate = np.zeros((N,len(x),M))
    for i in range(M):
        # derivative = np.zeros((1,N))
        derivative = logGMMautoGrad10(mu[i],w,mu,Sigma)
        for n in range(N):
            TaylorGMMSeperate[n,:,i] += ((derivative[n]/np.math.factorial(n))*(x-mu[i])**n).flatten()
    logTaylorGMMSeperate = np.exp(np.cumsum(TaylorGMMSeperate,axis=0))#np.exp

    #################### Our Approx. ######################################
    # N = 2
    TaylorEnt = np.zeros((N,len(x)))
    Scale = np.zeros((M,1))
    for i in range(M):
        Scale[i] = w[i]*np.linalg.det(2*np.pi*Sigma[i])**(-1/2)
    MaxConst = np.sum(Scale)
    # MaxConst = w[1]*np.linalg.det(2*np.pi*Sigma[1])**(-1/2)
    TaylorEnt[0,:] += np.log(MaxConst)
    for n in range(1,N):
        TaylorEnt[n,:] += ((-1)**(n-1)/(np.math.factorial(n)*(MaxConst)**n)*(GMM-MaxConst)**n).flatten()
    TaylorEnt = np.exp(np.cumsum(TaylorEnt,axis=0))#np.exp
    
    #################### Our Approx. ######################################
    # N = 2
    LegendreEnt = np.zeros((N,len(x)))
    # MaxConst = w[1]*np.linalg.det(2*np.pi*Sigma[1])**(-1/2)
    for n in range(0,N):
        LegendreEnt[n,:] = util.log_Taylor_series(n, MaxConst, util.generate_vandermonde(GMM.T, n)).flatten()
    LegendreEnt = np.exp(LegendreEnt)#np.exp
    
    fig2 = go.Figure([
        go.Scatter(
            x=x,
            y=logGMM.flatten(),
            line=dict(color='rgb(0,0,0)', width=3),
            mode='lines',
            name='GMM PDF'
        )], layout_yaxis_range=[0,.6])#[-12,4])#
    # fig2.add_trace(
    #         go.Scatter(
    #             x=x,
    #             y=logTaylorGMMSeperate[-1,:,0].flatten(),
    #             line=dict(color='rgb(0,0,255)', width=3),
    #             mode='lines',
    #             name='Huber et al.'))
    # fig2.add_trace(
    #         go.Scatter(
    #             x=x,
    #             y=TaylorEnt[-1,:].flatten(),#logTaylorGMMSeperate[-1,:,1]
    #             line=dict(color='rgb(0,255,0)', width=3),
    #             mode='lines',
    #             name='Our method'))
    for i in range(N):
        C = 'rgb(%d,0,0)'%(i*180/(N-1)+75)
        D = 'rgb(0,%d,%d)'%(i*180/(N-1)+75,i*180/(N-1)+75)
        E = 'rgb(0,%d,0)'%(i*180/(N-1)+75)
        if i == N-1:
            fig2.add_trace(
                    go.Scatter(
                        x=x,
                        y=logTaylorGMMSeperate[i,:,0].flatten(),
                        line=dict(color=C, width=3),
                        mode='lines',
                        name='Huber et al.'))
            fig2.add_trace(
                    go.Scatter(
                        x=x,
                        y=TaylorEnt[i,:].flatten(),#logTaylorGMMSeperate[-1,:,1]
                        line=dict(color=D, width=3),
                        mode='lines',
                        name='Our Taylor'))
            fig2.add_trace(
                    go.Scatter(
                        x=x,
                        y=LegendreEnt[i,:].flatten(),#logTaylorGMMSeperate[-1,:,1]
                        line=dict(color=E, width=3),
                        mode='lines',
                        name='Our Legendre'))
        # if i==2:
        #     fig2.add_trace(
        #         go.Scatter(
        #             x=x,
        #             y=logTaylorGMMSeperate[i,:,0].flatten(),
        #             line=dict(color=C, width=3),
        #             mode='lines',
        #             showlegend=False))
        #     fig2.add_trace(
        #         go.Scatter(
        #             x=x,
        #             y=TaylorEnt[i,:].flatten(),#logTaylorGMMSeperate[-1,:,1]
        #             line=dict(color=D, width=3),
        #             mode='lines',
        #             showlegend=False))
        #     fig2.add_trace(
        #         go.Scatter(
        #             x=x,
        #             y=LegendreEnt[i,:].flatten(),#logTaylorGMMSeperate[-1,:,1]
        #             line=dict(color=E, width=3),
        #             mode='lines',
        #             showlegend=False))
    fig2.update_xaxes(title_text="x")#"5th Gaussian Component Mean", type="log", dtick = "D2"
    # fig2.update_yaxes(title_text="logGMM")#, type="log", dtick = 1
    fig2.update_yaxes(title_text="GMM")
    #fig1.update_layout(paper_bgcolor='rgba(0,0,0,0)',plot_bgcolor='rgba(0,0,0,0)')
    fig2.update_layout(font=dict(size=25),legend=dict(yanchor="top", y=0.95, xanchor="left", x=0.01))#,showlegend=False
    fig2.update_layout(plot_bgcolor='white')
    fig2.update_xaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey'
    )
    fig2.update_yaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey'
    )
    # fig2.write_image("LogGMMDiverge.pdf")
    fig2.write_image("GMMDiverge.pdf")
    fig2.show()
    return fig2

# c=0
# c1 = ((c+3)/6)+.01
c1 = .3
ws = np.array([0.35, 0.65])
mus = [np.array([[-3]]), np.array([[0]])]#[np.array([[-2]]), np.array([[-1]])]
Sigmas = [np.array([[2]]),np.array([[c1]])]

N=8
fig = plotGMMTaylor(N,ws,mus,Sigmas)