from operator import matmul
import numpy as np
import math
from scipy.optimize import fsolve
import numpy.matlib
import scipy.stats
import scipy.linalg as scilin
from scipy.stats import multivariate_normal
from scipy.special import logsumexp
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import time
import random
from matplotlib.patches import Ellipse
from sklearn import mixture
import matplotlib.transforms as transforms
from functools import partial
import tensorflow as tf
from tensorflow.python.ops.numpy_ops import np_config
from scipy import optimize
np_config.enable_numpy_behavior()
import util
# import os
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
###################################################################################
###########################Outline#################################################
# 1) Generate a Gaussian Mixture Model Parameters
# 2) Numerically Integrate MI
#       a) Break into seperate Entropy terms
# 3) Create General Variational approx. computations
#       a) Marginal and Conditional
# 4) Compute Moment Matching
# 5) Compute Gradient Ascent
# 6) Compute Barber and Agakov
# 8) Plot Samples and approximation
###################################################################################

def GaussianMixtureParams(M,Dx,Dy):
    ###############################################################################
    # Outline: Randomly Generates Parameters for GMM
    #
    # Inputs:
    #       M - Number of components
    #       Dx - Number of dimensions for Latent Variable, X
    #       Dy - Number of dimensions for Observation Variable, Y
    #
    # Outputs:
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    ###############################################################################
    D = Dx+Dy
    w = np.random.dirichlet(np.ones(M))
    mu = []
    sigma = []
    for d in range(M):
        mu.append(np.random.uniform(-5,5,(D,1)))
        A = np.random.rand(D, D)
        B = np.dot(A, A.transpose())
        sigma.append(B)
        # mean = np.zeros((D,1))
        # cov = 1*np.eye(D)+40*np.ones((D,D))
        # mu.append(np.random.multivariate_normal(mean.flatten(),cov).reshape(D,1))
        # B = 1*np.eye(D)+np.random.uniform(1,30)*np.ones((D,D))
        # sigma.append(B)
    return w,mu,sigma

def SampleGMM(N,w,mu,sigma):
    ###############################################################################
    # Outline: Samples Points from a GMM
    #
    # Inputs:
    #       N - Number of points to sample
    #       w - weights of GMM components
    #       mu - means of GMM components
    #       Sigma - Variance of GMM components
    #
    # Outputs:
    #       samples - coordniates of sampled points
    ###############################################################################
    samples = np.zeros((N,len(mu[0])))
    for j in range(N):
        acc_pis = [np.sum(w[:i]) for i in range(1, len(w)+1)]
        r = np.random.uniform(0, 1)
        k = 0
        for i, threshold in enumerate(acc_pis):
            if r < threshold:
                k = i
                break
        x = np.random.multivariate_normal(mu[k].T.tolist()[0],sigma[k].tolist())
        samples[j,:] = x
    return samples

# def MargEntGMM(N,L,Dx,w,mu,Sigma):
#     ###############################################################################
#     # Outline: Numerically Calculates Marginal Entropy
#     #
#     # Inputs:
#     #       samples - List of full sample set
#     #       Dx - Dimension of Latenat Variable, X
#     #       w - weights of components
#     #       mu - means of components
#     #       Sigma - Variance of components
#     #
#     # Outputs:
#     #       MargEnt - Marginal Entropy
#     ###############################################################################
#     M = len(w)
#     x = np.linspace(-L,L,N)
#     if Dx == 1:
#         X=x 
#     else:
#         X1, X2 = np.meshgrid(x,x)
#         X = np.vstack((X1.flatten(),X2.flatten()))

#     MargEntPart = np.zeros((M,len(X.T)))
#     for d in range(M):
#         MargEntPart[d,:] = multivariate_normal.logpdf(X.T,mu[d][0:Dx].T.tolist()[0],Sigma[d][0:Dx,0:Dx])+np.log(w[d])
#     if Dx == 1:
#         MargEnt = -1*sum(np.sum(np.exp(MargEntPart),axis=0)*logsumexp(MargEntPart,axis=0))*2*L/N
#     else:
#         MargEnt = -1*sum(np.sum(np.exp(MargEntPart),axis=0)*logsumexp(MargEntPart,axis=0))*(2*L/N)**2
#     return MargEnt

def MargEntGMM(N, L, Dx, w, mu, Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latent Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    x = np.linspace(-L, L, N)
    X_mesh = np.meshgrid(*([x] * Dx))
    X = np.array(X_mesh).reshape(Dx, N**Dx)  # create a meshgrid for all dimensions of X

    MargEntPart = np.zeros((M, X[0].size))
    for d in range(M):
        MargEntPart[d, :] = multivariate_normal.logpdf(
            X.T, mu[d].flatten(), Sigma[d]
        ) + np.log(w[d])
        
    MargEnt = -np.sum(np.sum(np.exp(MargEntPart), axis=0) * logsumexp(MargEntPart, axis=0)) * (
        2 * L / N
    ) ** Dx  # generalize to arbitrary dimension
    
    return MargEnt

def MargEntGMMTaylor(N,Dx,w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    TaylorEnt = 0
    Scale = np.zeros((M,1))
    Sigma_inv = []
    for i in range(M):
        Scale[i] = w[i]*np.linalg.det(2*np.pi*Sigma[i])**(-1/2)
        Sigma_inv.append(np.linalg.inv(Sigma[i]))
    MaxConst = np.sum(Scale)/4
    for i in range(M):
        outer = 0 
        for n in range(1,N+1):
            middle = 0
            for k in range(n+1):
                Nmatrix, NCoef = multinomial_expand(n-k,M)#
                # numtermscheck = math.comb(n-k+M-1,M-1)
                inner = 0
                for t in range(len(NCoef)):
                    SumSigmaInv = np.matmul(Sigma_inv[i],np.eye(len(Sigma_inv[i])))
                    SumMu = np.matmul(Sigma_inv[i],mu[i])
                    SumInd = np.matmul(mu[i].T,np.matmul(Sigma_inv[i],mu[i]))
                    for j in range(M):
                        SumSigmaInv += Nmatrix[t,j]*Sigma_inv[j]
                        SumMu += Nmatrix[t,j]*np.matmul(Sigma_inv[j],mu[j])
                        SumInd += Nmatrix[t,j]*np.matmul(mu[j].T,np.matmul(Sigma_inv[j],mu[j]))
                    SumSigma = np.linalg.inv(SumSigmaInv)
                    SumMu = np.matmul(SumSigma,SumMu)
                    expTerm = np.exp(-.5*(-1*np.matmul(SumMu.T,np.matmul(SumSigmaInv,SumMu))+SumInd))
                    inner += NCoef[t]*np.prod((Scale.T)**Nmatrix[t])*expTerm*(np.linalg.det(2*np.pi*SumSigma)**(1/2))*(Scale[i]/w[i])
                combcoef = math.comb(n,k)
                middle += combcoef*inner*(-MaxConst)**k
            outer += (-1)**(n-1)/(n*MaxConst**n)*middle
        TaylorEnt += -w[i]*(np.log(MaxConst)+outer)
    return TaylorEnt

def MargEntGMMLimit(N,Dx,w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    TaylorEnt = np.zeros((1,N+1))
    Scale = np.zeros((M,1))
    Sigma_inv = []
    
    # MaxConst = 0
    # x = np.linspace(-3,3,1000)
    # X,Y = np.meshgrid(x,x) 
    # XX = np.hstack((X.reshape((len(X.flatten()),1)),Y.reshape((len(Y.flatten()),1))))
    for i in range(M):
        # MaxConst += w[i]*multivariate_normal.pdf(XX,mu[i].flatten(),Sigma[i])
        Scale[i] = w[i]*np.linalg.det(2*np.pi*Sigma[i])**(-1/2)
        Sigma_inv.append(np.linalg.inv(Sigma[i]))
    MaxConst = np.sum(Scale)
    TaylorEnt[0,0] += -w[i]*np.log(MaxConst)#
    # MaxConst = np.max(MaxConst)/2
    for i in range(M):
        outer = 0 
        for n in range(1,N+1):
            middle = 0
            for k in range(n+1):
                Nmatrix, NCoef = multinomial_expand(n-k,M)#
                # numtermscheck = math.comb(n-k+M-1,M-1)
                inner = 0
                for t in range(len(NCoef)):
                    SumSigmaInv = np.matmul(Sigma_inv[i],np.eye(len(Sigma_inv[i])))
                    SumMu = np.matmul(Sigma_inv[i],mu[i])
                    SumInd = np.matmul(mu[i].T,np.matmul(Sigma_inv[i],mu[i]))
                    for j in range(M):
                        SumSigmaInv += Nmatrix[t,j]*Sigma_inv[j]
                        SumMu += Nmatrix[t,j]*np.matmul(Sigma_inv[j],mu[j])
                        SumInd += Nmatrix[t,j]*np.matmul(mu[j].T,np.matmul(Sigma_inv[j],mu[j]))
                    SumSigma = np.linalg.inv(SumSigmaInv)
                    SumMu = np.matmul(SumSigma,SumMu)
                    expTerm = np.exp(-.5*(-1*np.matmul(SumMu.T,np.matmul(SumSigmaInv,SumMu))+SumInd))
                    inner += NCoef[t]*np.prod((Scale.T)**Nmatrix[t])*expTerm*(np.linalg.det(2*np.pi*SumSigma)**(1/2))*(Scale[i]/w[i])
                combcoef = math.comb(n,k)
                middle += combcoef*inner*(-MaxConst)**k
            outer += (-1)**(n-1)/(n*MaxConst**n)*middle
            TaylorEnt[0,n] += -w[i]*(np.log(MaxConst)+outer)
        # UpperBound = -w[i]*np.log(MaxConst)
    TaylorLimit = TaylorEnt[0,-3]-(TaylorEnt[0,-2]-TaylorEnt[0,-3])**2/(TaylorEnt[0,-1]-2*TaylorEnt[0,-2]+TaylorEnt[0,-3])
    # TaylorLimit = TaylorEnt[0,0]-(TaylorEnt[0,3]-TaylorEnt[0,0])**2/(TaylorEnt[0,8]-2*TaylorEnt[0,3]+TaylorEnt[0,0])
    # d=.43
    # def equations(p):
    #     c = p#, d
    #     return ((((N)**d-(N-3)**d)/((N-2)**d-(N-3)**d))*np.log((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))-np.log((TaylorEnt[0,-1]-c)/(TaylorEnt[0,-4]-c))- ((((N-1)**d-(N-3)**d)/((N-2)**d-(N-3)**d))*np.log((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))-np.log((TaylorEnt[0,-2]-c)/(TaylorEnt[0,-4]-c))))
    # def equations(p):
    #     c = p#
    #     return (((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))**(((N-1)**d-(N-3)**d)/((N-2)**d-(N-3)**d))-((TaylorEnt[0,-2]-c)/(TaylorEnt[0,-4]-c))-(((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))**(((N)**d-(N-3)**d)/((N-2)**d-(N-3)**d))-((TaylorEnt[0,-1]-c)/(TaylorEnt[0,-4]-c))))
    # TaylorLimit=  fsolve(equations, (TaylorEnt[0,-1]))
    
    # def equations(p):
    #     c, d = p#
    #     return ((((N)**d-(N-3)**d)/((N-2)**d-(N-3)**d))*np.log((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))-np.log((TaylorEnt[0,-1]-c)/(TaylorEnt[0,-4]-c)), ((((N-1)**d-(N-3)**d)/((N-2)**d-(N-3)**d))*np.log((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))-np.log((TaylorEnt[0,-2]-c)/(TaylorEnt[0,-4]-c))))
    # def equations(p):
    #     c, d = p#
    #     return (((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))**(((N-1)**d-(N-3)**d)/((N-2)**d-(N-3)**d))-((TaylorEnt[0,-2]-c)/(TaylorEnt[0,-4]-c)),((TaylorEnt[0,-3]-c)/(TaylorEnt[0,-4]-c))**(((N)**d-(N-3)**d)/((N-2)**d-(N-3)**d))-((TaylorEnt[0,-1]-c)/(TaylorEnt[0,-4]-c)))
    # TaylorLimit, d =  fsolve(equations, (TaylorEnt[0,-1]+.01, .5))#
    # TaylorLimit = TaylorEnt[0,n-1]
    
    # testa = 2*(TaylorEnt[0,2]-TaylorEnt[0,1])/(TaylorEnt[0,1]-TaylorEnt[0,0])
    # TaylorLimit = TaylorEnt[0,0] -np.log((testa-1)/testa)
    
    UpperBound = TaylorEnt[0,n-1]
    # UpperBound = UpperBoundTest(N,MaxConst,w,mu,Sigma)
    # UpperBound = np.log(2)
    # for i in range(1,N+1):
    #     UpperBound += (-1)**(i)/i
    # UpperBound = TaylorEnt[0,-1]+UpperBound
    return TaylorEnt, TaylorLimit, UpperBound

def UpperBoundTest(N,a,w,mu,Sigma):
    Dx=2
    M = len(w)
    prob=np.zeros((M,1))
    for d in range(M):
        prob += multivariate_normal.pdf(X.T,mu[d][0:Dx].T.tolist()[0],Sigma[d][0:Dx,0:Dx])
    prob = prob[prob>.0000000001]
    greater = np.where(prob>=a)
    less = np.where(prob<a)
    first = prob[greater]*((prob[greater]-a)/a)**(N)
    second = prob[less]*((prob[less]-a)/prob[less])**(N)
    Bound = 1/N*(np.sum(first)+np.sum(second))
    return Bound

def multinomial_expand(pow,dim):
    ############ https://www.mathworks.com/matlabcentral/fileexchange/48215-multinomial-expansion
    NMatrix = multinomial_powers_recursive(pow,dim)
    powvec = np.matlib.repmat(pow,np.shape(NMatrix)[0],1)
    NCoef = np.floor(np.exp(scipy.special.gammaln(powvec+1).flatten() - np.sum(scipy.special.gammaln(NMatrix+1),1))+0.5)
    return NMatrix, NCoef

def multinomial_powers_recursive(pow,dim):
    if dim == 1:
        Nmatrix = np.array([[pow]])
    else:
        Nmatrix = []
        for pow_on_x1 in range(pow+1):
            newsubterms = multinomial_powers_recursive(pow-pow_on_x1,dim-1)
            new = np.hstack((pow_on_x1*np.ones((np.shape(newsubterms)[0],1)),newsubterms))
            if len(Nmatrix)==0:#Nmatrix == []:
                Nmatrix = new
            else:
                Nmatrix =np.vstack((Nmatrix, new))
            # Nmatrix = [Nmatrix; [pow_on_x1*ones(np.shape(newsubterms,1),1) , newsubterms] ]
    return Nmatrix

def MargEntGMMClosed(Dx,w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    ClosedTaylorEnt = 0
    Normals = np.zeros((M,1))
    maxNormals = np.zeros((M,1))
    Scale = np.zeros((M,1))
    Sigma_sum = 0
    mu_sum = 0
    for i in range(M):
        maxNormals[i] = multivariate_normal.pdf(np.zeros((Dx,1)).flatten(),mu[i].flatten(),Sigma[i])
        Normals[i] = multivariate_normal.logpdf(np.zeros((Dx,1)).flatten(),mu[i].flatten(),Sigma[i])+np.log(w[i])
        Sigma_sum += np.linalg.inv(Sigma[i])
        mu_sum += np.matmul(np.linalg.inv(Sigma[i]),mu[i])
        Scale[i] = w[i]*np.linalg.det(2*np.pi*Sigma[i])**(-1/2)
    maxNorm = np.max(maxNormals)
    MaxConst = np.sum(Scale)/2
    Sigma_sum = (1/M)*Sigma_sum
    mu_sum = (1/M)*mu_sum
    for i in range(M):
        Ni = multivariate_normal.pdf(np.zeros((Dx,1)).flatten(),mu[i].flatten(),Sigma[i])/maxNorm
        first = (1-Ni)*np.log(MaxConst)
        second = Ni*logsumexp(Normals)
        Sigma_hat = np.linalg.inv(np.linalg.inv(Sigma[i])+Sigma_sum)
        mu_hat = np.matmul(Sigma_hat,(np.matmul(np.linalg.inv(Sigma[i]),mu[i])+mu_sum))
        third = Ni*np.log(maxNorm)#multivariate_normal.logpdf(np.zeros((Dx,1)).flatten(),mu_hat.flatten(),Sigma_hat)
        ClosedTaylorEnt += -w[i]*(first+second-third)
    return ClosedTaylorEnt

def HuberTaylor0(w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    TaylorEnt = 0
    logNorms = np.zeros((M,1))
    for i in range(M):
        logNorms = np.zeros((M,1))
        for j in range(M):
            logNorms[j] = multivariate_normal.logpdf(mu[i].flatten(),mu[j].flatten(),Sigma[j])+np.log(w[j])
        TaylorEnt += -1*w[i]*logsumexp(logNorms)
    return TaylorEnt

def HuberTaylor2(w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    TaylorEnt = HuberTaylor0(w,mu,Sigma)
    logNorms = np.zeros((M,1))
    for i in range(M):
        logNorms = np.zeros((M,1))
        F=0
        # f = w[i]*multivariate_normal.pdf(mu[i].flatten(),mu[i].flatten(),Sigma[i])
        # df = np.matmul(np.linalg.inv(Sigma[i]),(mu[i]-mu[i]))*w[i]*multivariate_normal.pdf(mu[i].flatten(),mu[i].flatten(),Sigma[i])
        f=0
        df = 0
        for j in range(M):
            f += w[j]*multivariate_normal.pdf(mu[i].flatten(),mu[j].flatten(),Sigma[j])
            df += -np.matmul(np.linalg.inv(Sigma[j]),(mu[i]-mu[j]))*w[j]*multivariate_normal.pdf(mu[i].flatten(),mu[j].flatten(),Sigma[j])
        for k in range(M):
            F += w[k]*np.matmul(np.linalg.inv(Sigma[k]),((1/f)*np.matmul((mu[i]-mu[k]),df.T)+(mu[i]-mu[k])*np.matmul(np.linalg.inv(Sigma[k]),(mu[i]-mu[k])).T-np.eye(len(Sigma[i]))))*multivariate_normal.pdf(mu[i].flatten(),mu[k].flatten(),Sigma[k])
        F = (1/f)*F
        TaylorEnt += -1*(w[i]/2)*np.sum((F*Sigma[i]))
    return TaylorEnt

def HuberTaylor0Splitting(w,mu,Sigma,w1,mu1,Sigma1):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    M1 = len(w1)
    TaylorEnt = 0
    logNorms = np.zeros((M,1))
    for i in range(M1):
        logNorms = np.zeros((M,1))
        for j in range(M):
            logNorms[j] = multivariate_normal.logpdf(mu1[i].flatten(),mu[j].flatten(),Sigma[j])+np.log(w[j])
        TaylorEnt += -1*w1[i]*logsumexp(logNorms)
    return TaylorEnt

def HuberTaylor2Splitting(w,mu,Sigma,w1,mu1,Sigma1):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    M1 = len(w1)
    TaylorEnt = HuberTaylor0Splitting(w,mu,Sigma,w1,mu1,Sigma1)
    logNorms = np.zeros((M,1))
    for i in range(M1):
        logNorms = np.zeros((M,1))
        F=0
        # f = w[i]*multivariate_normal.pdf(mu[i].flatten(),mu[i].flatten(),Sigma[i])
        # df = np.matmul(np.linalg.inv(Sigma[i]),(mu[i]-mu[i]))*w[i]*multivariate_normal.pdf(mu[i].flatten(),mu[i].flatten(),Sigma[i])
        f=0
        df = 0
        for j in range(M):
            f += w[j]*multivariate_normal.pdf(mu1[i].flatten(),mu[j].flatten(),Sigma[j])
            df += -np.matmul(np.linalg.inv(Sigma[j]),(mu1[i]-mu[j]))*w[j]*multivariate_normal.pdf(mu1[i].flatten(),mu[j].flatten(),Sigma[j])
        for k in range(M):
            F += w[k]*np.matmul(np.linalg.inv(Sigma[k]),((1/f)*np.matmul((mu1[i]-mu[k]),df.T)+(mu1[i]-mu[k])*np.matmul(np.linalg.inv(Sigma[k]),(mu1[i]-mu[k])).T-np.eye(len(Sigma1[i]))))*multivariate_normal.pdf(mu1[i].flatten(),mu[k].flatten(),Sigma[k])
        F = (1/f)*F
        TaylorEnt += -1*(w1[i]/2)*np.sum((F*Sigma1[i]))
    return TaylorEnt

def HuberTaylorN(N,w,mu,Sigma):
    M = len(w)
    # HuberTaylorEnt = HuberTaylor1(w,mu,Sigma)
    # HuberTaylorEnt = 0
    HuberTaylorEnt = np.zeros((1,N+1))
    # derivatives = np.zeros((10,1))
    for i in range(M):
        derivatives = logGMMautoGrad10(mu[i],w,mu,Sigma)
        for n in range(N+1):
            if n==0:
                HuberTaylorEnt[0,n] += -w[i]*derivatives[n]
            if n%2==0 and not n==0:
                # HuberTaylorEnt += -w[i]*(doublefactorial(n-1)/np.math.factorial(n))*np.sum(np.matmul(derivatives[n].T,Sigma[i]**(n/2)))
                HuberTaylorEnt[0,n] += -w[i]*(doublefactorial(n-1)/np.math.factorial(n))*np.sum(derivatives[n]*Sigma[i]**(n/2))
    return HuberTaylorEnt

def HuberTaylorNSplit(N,w,mu,Sigma,w1,mu1,Sigma1):
    M = len(w)
    M1 = len(w1)
    # HuberTaylorEnt = HuberTaylor1(w,mu,Sigma)
    # HuberTaylorEnt = 0
    HuberTaylorEnt = np.zeros((1,N+1))
    # derivatives = np.zeros((10,1))
    for i in range(M1):
        derivatives = logGMMautoGrad10(mu1[i],w,mu,Sigma)
        for n in range(N+1):
            if n==0:
                HuberTaylorEnt[0,n] += -w1[i]*derivatives[n]
            if n%2==0 and not n==0:
                # HuberTaylorEnt += -w[i]*(doublefactorial(n-1)/np.math.factorial(n))*np.sum(np.matmul(derivatives[n].T,Sigma[i]**(n/2)))
                HuberTaylorEnt[0,n] += -w1[i]*(doublefactorial(n-1)/np.math.factorial(n))*np.sum(derivatives[n]*Sigma1[i]**(n/2))
    return HuberTaylorEnt

def doublefactorial(n):
     if n <= 0:
         return 1
     else:
         return n * doublefactorial(n-2)

def logGMMautoGrad10(center,w,mu,Sigma):
    M =len(w)
    x = tf.Variable(center, dtype='float32')
    w = tf.convert_to_tensor(w, dtype=tf.float32)
    mu = tf.convert_to_tensor(mu, dtype=tf.float32)
    Sigma = tf.convert_to_tensor(Sigma, dtype=tf.float32)
    pi = tf.constant(np.pi)
    # with tf.GradientTape(persistent=True) as t16:
    #     with tf.GradientTape(persistent=True) as t15:
    #         with tf.GradientTape(persistent=True) as t14:
    #             with tf.GradientTape(persistent=True) as t13:
    #                 with tf.GradientTape(persistent=True) as t12:
    #                     with tf.GradientTape(persistent=True) as t11:
    #                         with tf.GradientTape(persistent=True) as t10:
    #                             with tf.GradientTape(persistent=True) as t9:
    #                                 with tf.GradientTape(persistent=True) as t8:
    # with tf.GradientTape(persistent=True) as t7:
    #     with tf.GradientTape(persistent=True) as t6:
    with tf.GradientTape(persistent=True) as t5:
        with tf.GradientTape(persistent=True) as t4:
            with tf.GradientTape(persistent=True) as t3:
                with tf.GradientTape(persistent=True) as t2:
                    with tf.GradientTape(persistent=True) as t1:
                        with tf.GradientTape(persistent=True) as t0:
                            y=0
                            for j in range(M):
                                # y += -.5*tf.math.log(tf.linalg.det(2*pi*Sigma[j]))-.5*tf.linalg.matmul((x-mu[j]),tf.linalg.matmul(Sigma[j],(x-mu[j])),transpose_a=True)+tf.math.log(w[j])
                                y += w[j]*tf.linalg.det(2*pi*Sigma[j])**(-.5)*tf.math.exp(-.5*tf.linalg.matmul((x-mu[j]),tf.linalg.matmul(tf.linalg.inv(Sigma[j]),(x-mu[j])),transpose_a=True))
                            y = tf.math.log(y)
                        dy = t0.gradient(y,x)
                    dy2 = t1.gradient(dy,x)
                dy3 = t2.gradient(dy2,x)
            dy4 = t3.gradient(dy3,x)
        dy5 = t4.gradient(dy4,x)
    dy6 = t5.gradient(dy5,x)
    #     dy7 = t6.gradient(dy6,x)
    # dy8 = t7.gradient(dy7,x)
    #                                 dy9 = t8.gradient(dy8,x)
    #                             dy10 = t9.gradient(dy9,x)
    #                         dy11 = t10.gradient(dy9,x)
    #                     dy12 = t11.gradient(dy9,x)
    #                 dy13 = t12.gradient(dy9,x)
    #             dy14 = t13.gradient(dy9,x)
    #         dy15 = t14.gradient(dy9,x)
    #     dy16 = t15.gradient(dy9,x)
    # dy17 = t16.gradient(dy9,x)
    return y.numpy(), dy.numpy(), dy2.numpy(), dy3.numpy(), dy4.numpy(), dy5.numpy(), dy6.numpy()#, dy7.numpy(), dy8.numpy()#, dy9.numpy()#, dy10.numpy(), dy11.numpy(), dy12.numpy(), dy13.numpy(), dy14.numpy(), dy15.numpy(), dy16.numpy(), dy17.numpy()

def HuberExample(K,Ns):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       Dy - Dimension of Obsevation Variable, Y
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       CondEnt - Conditional Entropy
    ###############################################################################
    Dx = 2
    M = 5
    TrueEnt = np.zeros((K,1))
    TaylorEnt = np.zeros((K,len(Ns)))
    TaylorLimit = np.zeros((K,1))
    UpperBound = np.zeros((K,1))
    HuberEnt = np.zeros((K,len(Ns)))
    SplitEnt = np.zeros((K,len(Ns)))
    i=0
    for c in np.linspace(-3,3,K):    
            
        #Huber Example
        ws = np.array([0.2, 0.2, 0.2, 0.2 , 0.2])
        mus = [np.array([[0],[0]]), np.array([[3],[2]]), np.array([[1],[-.5]]), np.array([[2.5],[1.5]]),np.array([[c],[c]])]
        sigmas = [np.diag((.16,1)),np.diag((1,.16)),np.diag((.5,.5)),np.diag((.5,.5)),np.diag((.5,.5))]
        
        # ws1, mus1, sigmas1 = SplittingMethod(20,ws,mus,sigmas)
        
        # ws = np.array([0.2, 0.2, 0.2, 0.2 , 0.2])
        # mus = [np.array([[0],[0]]), np.array([[3],[2]]), np.array([[1],[-.5]]), np.array([[2.5],[1.5]]),np.array([[c],[c]])]
        # sigmas = [np.diag((.16,1)),np.diag((1,.16)),np.diag((.5,.5)),np.diag((.5,.5)),np.diag((.5,.5))]
        
        # SplitEnt[i,0] = HuberTaylor0Splitting(ws,mus,sigmas,ws1,mus1,sigmas1)
        # SplitEnt[i,1] = SplitEnt[i,0]
        # SplitEnt[i,2] = HuberTaylor2Splitting(ws,mus,sigmas,ws1,mus1,sigmas1)
        # SplitEnt[i,3] = SplitEnt[i,2]
        
        
        # fig5 = plotpdfs(ws1,mus1,sigmas1,ws,mus,sigmas)
        # c1 = 1/(6*np.pi)
        # ws = np.array([0.05, 0.05, 0.05, 0.05 , 0.8])
        # mus = [np.array([[0],[0]]), np.array([[3],[2]]), n p.array([[1],[-.5]]), np.array([[2.5],[1.5]]),np.array([[c],[c]])]
        # sigmas = [np.diag((.16,1)),np.diag((1,.16)),np.diag((.5,.5)),np.diag((.5,.5)),np.diag((c1,c1))]
        M = len(ws)
        MaxBound = 0
        for m in range(M):
            MaxBound += ws[m]*np.linalg.det(2*np.pi*sigmas[m])**(-1/2)
        Epp = util.gmm_power_expected_value_parallel(Ns[-1], ws, mus, sigmas, ws, mus, sigmas)

        for j in range(Ns[-1]+1):
            TaylorEnt[i,j] = -1*util.log_Taylor_series(j, MaxBound, Epp)
            SplitEnt[i,j] = -1*util.log_Legendre_series(j, MaxBound, Epp)
            
        TaylorLimit[i] = -1*util.log_Taylor_limit(Ns[-1], MaxBound, Epp)
        
        HuberEnt[i,0] = HuberTaylor0(ws,mus,sigmas)
        HuberEnt[i,1] = HuberEnt[i,0]
        HuberEnt[i,2] = HuberTaylor2(ws,mus,sigmas)
        HuberEnt[i,3] = HuberEnt[i,2]
        
        # TaylorEnt[i,:], TaylorLimit[i], UpperBound[i] = MargEntGMMLimit(Ns[-1],Dx,ws,mus,sigmas)

        TrueEnt[i] = MargEntGMM(1000,50,Dx,ws,mus,sigmas)
        i+=1
    return TrueEnt, TaylorEnt, TaylorLimit, HuberEnt, SplitEnt #, EntTaylor

def HuberCounterExample(K,Ns):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       Dy - Dimension of Obsevation Variable, Y
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       CondEnt - Conditional Entropy
    ###############################################################################
    Dx = 1
    M = 2
    TrueEnt = np.zeros((K,1))
    TaylorEnt = np.zeros((K,len(Ns)))
    TaylorLimit = np.zeros((K,1))
    UpperBound = np.zeros((K,1))
    HuberEnt = np.zeros((K,len(Ns)))
    SplitEnt = np.zeros((K,len(Ns)))
    i=0
    for c in np.linspace(-3,3,K):
        c1 = ((c+3)/6)+.01
        ws = np.array([0.35, 0.65])
        mus = [np.array([[-2]]), np.array([[-1]])]#
        sigmas = [np.array([[2]]),np.array([[c1]])]
        # fig1,fig2 = plotGMM(ws,mus,sigmas)
        
        # ws1, mus1, sigmas1 = SplittingMethod(10,ws,mus,sigmas)
        
        # ws = np.array([0.35, 0.65])
        # mus = [np.array([[-2]]), np.array([[-1]])]#
        # sigmas = [np.array([[2]]),np.array([[c1]])]
        
        # SplitEnt[i,:] = HuberTaylorNSplit(Ns[-1],ws,mus,sigmas,ws1,mus1,sigmas1)
        
        M = len(ws)
        MaxBound = 0
        for m in range(M):
            MaxBound += ws[m]*np.linalg.det(2*np.pi*sigmas[m])**(-1/2)
        Epp = util.gmm_power_expected_value_parallel(Ns[-1], ws, mus, sigmas, ws, mus, sigmas)
        
        for j in range(Ns[-1]+1):
            TaylorEnt[i,j] = -1*util.log_Taylor_series(j, MaxBound, Epp)
            SplitEnt[i,j] = -1*util.log_Legendre_series(j, MaxBound, Epp)
            
        TaylorLimit[i] = -1*util.log_Taylor_limit(Ns[-1], MaxBound, Epp)
        
        HuberEnt[i,:] = HuberTaylorN(Ns[-1],ws,mus,sigmas)
        # TaylorEnt[i,:], TaylorLimit[i], UpperBound[i] = MargEntGMMLimit(Ns[-1],Dx,ws,mus,sigmas)  
        TrueEnt[i] = MargEntGMM(1000,50,Dx,ws,mus,sigmas)
        i+=1
    HuberEnt = np.cumsum(HuberEnt,axis=1)
    # SplitEnt = np.cumsum(SplitEnt,axis=1)
    return TrueEnt, TaylorEnt, TaylorLimit, HuberEnt, SplitEnt


def HuberDiverge(K,Ns):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       Dy - Dimension of Obsevation Variable, Y
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       CondEnt - Conditional Entropy
    ###############################################################################
    Dx = 2
    M = 5
    TrueEnt = np.zeros((K,1))
    TaylorEnt = np.zeros((K,len(Ns)))
    TaylorLimit = np.zeros((K,1))
    UpperBound = np.zeros((K,1))
    HuberEnt = np.zeros((K,len(Ns)))
    SplitEnt = np.zeros((K,len(Ns)))
    i=0
    for c in np.linspace(-3,3,K):    
            
        
        ws = np.array([0.2, 0.2, 0.2, 0.2 , 0.2])
        mus = [0*np.zeros((Dx,1)), np.array([[3] * int(np.ceil(Dx/2)) + [2] * int(np.floor(Dx/2))]).T, np.array([[1] * int(np.ceil(Dx/2)) + [-.5] * int(np.floor(Dx/2))]).T, np.array([[2.5] * int(np.ceil(Dx/2)) + [1.5] * int(np.floor(Dx/2))]).T,c*np.ones((Dx,1))]
        sigmas = [.25*np.eye(Dx),3*np.eye(Dx),2*np.eye(Dx),2*np.eye(Dx),2*np.eye(Dx)]
        
        #Huber Example
        # ws = np.array([0.2, 0.2, 0.2, 0.2 , 0.2])
        # mus = [np.array([[0],[0]]), np.array([[3],[2]]), np.array([[1],[-.5]]), np.array([[2.5],[1.5]]),np.array([[c],[c]])]
        # sigmas = [np.diag((.16,1)),np.diag((1,.16)),np.diag((.5,.5)),np.diag((.5,.5)),np.diag((.5,.5))]
        

        M = len(ws)
        MaxBound = 0
        for m in range(M):
            MaxBound += ws[m]*np.linalg.det(2*np.pi*sigmas[m])**(-1/2)
        Epp = util.gmm_power_expected_value_parallel(Ns[-1], ws, mus, sigmas, ws, mus, sigmas)

        for j in range(Ns[-1]+1):
            TaylorEnt[i,j] = -1*util.log_Taylor_series(j, MaxBound, Epp)
            SplitEnt[i,j] = -1*util.log_Legendre_series(j, MaxBound, Epp)
        
        TaylorLimit[i] = -1*util.log_Taylor_limit(Ns[-1], MaxBound, Epp)
        
        HuberEnt[i,0] = HuberTaylor0(ws,mus,sigmas)
        HuberEnt[i,1] = HuberEnt[i,0]
        HuberEnt[i,2] = HuberTaylor2(ws,mus,sigmas)
        HuberEnt[i,3] = HuberEnt[i,2]
        
        # TaylorEnt[i,:], TaylorLimit[i], UpperBound[i] = MargEntGMMLimit(Ns[-1],Dx,ws,mus,sigmas)

        TrueEnt[i] = MargEntGMM(1000,50,Dx,ws,mus,sigmas)
        i+=1
    return TrueEnt, TaylorEnt, TaylorLimit, HuberEnt, SplitEnt #, EntTaylor



def plotGMM(w,mu,Sigma):
    M= len(w)
    x = np.linspace(-7,3,1400)
    GMM = 0
    for i in range(M):
        GMM += w[i]*multivariate_normal.pdf(x,mu[i].flatten(),Sigma[i])
    logGMM = np.log(GMM)
    
    N=18
    TaylorGMM = np.zeros((N,len(x)))
    TaylorGMMSeperate = np.zeros((N,len(x),M))
    for i in range(M):
        # derivative = np.zeros((1,N))
        derivative = logGMMautoGrad10(mu[i],w,mu,Sigma)
        for n in range(N):
            TaylorGMM[n,:] += ((derivative[n]/np.math.factorial(n))*(x-mu[i])**n).flatten()
            TaylorGMMSeperate[n,:,i] += ((derivative[n]/np.math.factorial(n))*(x-mu[i])**n).flatten()
    logTaylorGMM = np.cumsum(TaylorGMM,axis=0)/M
    logTaylorGMMSeperate = np.cumsum(TaylorGMMSeperate,axis=0)
    
    fig1 = go.Figure([
        go.Scatter(
            x=x,
            y=GMM.flatten(),
            line=dict(color='rgb(255,0,0)', width=3),
            mode='lines',
            name='True Ent'
        )], layout_yaxis_range=[-0.1,2])
    for i in range(N):
        C = 'rgb(%d,%d,0)'%(i*255/(N-1),i*255/(N-1))
        D = 'rgb(0,%d,0)'%(i*255/(N-1))
        E = 'rgb(0,0,%d)'%(i*255/(N-1))
        fig1.add_trace(
                go.Scatter(
                    x=x,
                    y=logTaylorGMM[i,:].flatten(),
                    line=dict(color=D, width=3),
                    mode='lines',
                    name='%d Order'%(i)))
        fig1.add_trace(
                go.Scatter(
                    x=x,
                    y=logTaylorGMMSeperate[i,:,0].flatten(),
                    line=dict(color=C, width=3),
                    mode='lines',
                    name='%d Order 1'%(i)))
        fig1.add_trace(
                go.Scatter(
                    x=x,
                    y=logTaylorGMMSeperate[i,:,1].flatten(),
                    line=dict(color=E, width=3),
                    mode='lines',
                    name='%d Order 2'%(i)))
    fig1.update_xaxes(title_text="x")#"5th Gaussian Component Mean", type="log", dtick = "D2"
    fig1.update_yaxes(title_text="GMM")#, type="log", dtick = 1
    #fig1.update_layout(paper_bgcolor='rgba(0,0,0,0)',plot_bgcolor='rgba(0,0,0,0)')
    fig1.update_layout(font=dict(size=25))#,showlegend=False,legend=dict(yanchor="top", y=0.95, xanchor="left", x=0.01)
    fig1.update_layout(plot_bgcolor='white')
    fig1.update_xaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey'
    )
    fig1.update_yaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey'
    )
    # fig1.write_image("LargeGMMTimeNew.pdf")
    fig1.show()
    
    fig2 = go.Figure([
        go.Scatter(
            x=x,
            y=logGMM.flatten(),
            line=dict(color='rgb(255,0,0)', width=3),
            mode='lines',
            name='True Ent'
        )], layout_yaxis_range=[-12,4])
    # for i in range(N):
    #     C = 'rgb(%d,%d,0)'%(i*255/(N-1),i*255/(N-1))
    #     D = 'rgb(0,%d,0)'%(i*255/(N-1))
    #     E = 'rgb(0,0,%d)'%(i*255/(N-1))
    #     fig2.add_trace(
    #             go.Scatter(
    #                 x=x,
    #                 y=logTaylorGMM[i,:].flatten(),
    #                 line=dict(color=D, width=3),
    #                 mode='lines',
    #                 name='%d Order'%(i)))
    #     fig2.add_trace(
    #             go.Scatter(
    #                 x=x,
    #                 y=logTaylorGMMSeperate[i,:,0].flatten(),
    #                 line=dict(color=C, width=3),
    #                 mode='lines',
    #                 name='%d Order 1'%(i)))
    #     fig2.add_trace(
    #             go.Scatter(
    #                 x=x,
    #                 y=logTaylorGMMSeperate[i,:,1].flatten(),
    #                 line=dict(color=E, width=3),
    #                 mode='lines',
    #                 name='%d Order 2'%(i)))
    fig2.update_xaxes(title_text="x")#"5th Gaussian Component Mean", type="log", dtick = "D2"
    fig2.update_yaxes(title_text="logGMM")#, type="log", dtick = 1
    #fig1.update_layout(paper_bgcolor='rgba(0,0,0,0)',plot_bgcolor='rgba(0,0,0,0)')
    fig2.update_layout(font=dict(size=25))#,showlegend=False,legend=dict(yanchor="top", y=0.95, xanchor="left", x=0.01)
    fig2.update_layout(plot_bgcolor='white')
    fig2.update_xaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey'
    )
    fig2.update_yaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey'
    )
    # fig1.write_image("LargeGMMTimeNew.pdf")
    fig2.show()
    return fig1,fig2


def SplittingMethod(S,ws,mus,sigmas):
    K = len(ws)
    D = len(mus[0])
    wuhat = np.array([0.12738084098,0.37261915901,0.37261915901,0.12738084098])
    muhat = [-1.4131205233,-0.44973059608,0.44973059608,1.4131205233]
    sigmahat = 0.51751260421
    for s in range(S):
        maxEig = 0
        maxEigpos = 0
        maxEigcomp = 0
        for k in range(K):
            eigs = np.linalg.eigvals(ws[k]*sigmas[k])#### maybe not with ws[k]
            maxeig = np.max(eigs)
            if maxeig>maxEig:
                maxEig = maxeig
                maxEigpos = np.argmax(eigs)
                maxEigcomp = k
        eigValue, eigVector = np.linalg.eig(sigmas[maxEigcomp])
        wnew = wuhat*ws[maxEigcomp]
        ws = np.concatenate((ws,wnew))
        onehot = np.zeros(np.shape(mus[0]))
        onehot[maxEigpos] = 1
        for i in range(4):
            mus.append(mus[maxEigcomp]+np.sqrt(eigValue[maxEigpos])*muhat[i]*onehot)
            Lambdai = 1*eigValue
            Lambdai[maxEigpos] = Lambdai[maxEigpos]*sigmahat**2
            Sigmai = np.matmul(eigVector,np.matmul(np.diag(Lambdai),eigVector.T))#np.linalg.inv(eigVector)
            sigmas.append(Sigmai)
        sigmas.pop(maxEigcomp)#sigmas[maxEigcomp]
        mus.pop(maxEigcomp)#mus[maxEigcomp]
        ws = list(ws)
        ws.pop(maxEigcomp)
        ws = np.array(ws)
        K = len(ws)
        
    return ws, mus, sigmas

# ################################## Huber Example #####################################
# K = 50
# Ns = [0,1,2,3]#,7,8,9
# TrueEnt, TaylorEnt, TaylorLimit, HuberEnt, SplitEnt = HuberExample(K,Ns)
# x1 = np.linspace(-3,3,K)
# xs = x1
# fig = go.Figure([
#         go.Scatter(
#             x=xs,
#             y=TrueEnt.flatten(),
#             line=dict(color='rgb(0,0,0)', width=3),
#             mode='lines',
#             name='True Entropy'
#         # ),
#         # go.Scatter(
#         #     x=xs,
#         #     y=TaylorLimit.flatten(),
#         #     line=dict(color='rgb(255,0,255)', width=3),
#         #     mode='lines',
#         #     name='Our Approx. Limit'
#         )])

# for i in range(len(Ns)):
#     C = 'rgb(0,0,%d)'%(i*155/(len(Ns)-1)+100)
#     D = 'rgb(0,%d,0)'%(i*155/(len(Ns)-1)+100)
#     E = 'rgb(0,%d,%d)'%(i*155/(len(Ns)-1)+100,i*155/(len(Ns)-1)+100)
#     if i == Ns[-1]:
#         fig.add_trace(
#                 go.Scatter(
#                     x=xs,
#                     y=HuberEnt[:,i].flatten(),
#                     line=dict(color='rgb(255,0,0)', width=3),#C
#                     mode='lines',
#                     name='Huber et al.'))
#         fig.add_trace(
#                 go.Scatter(
#                     x=xs,
#                     y=TaylorEnt[:,i].flatten(),
#                     line=dict(color='rgb(0,255,255)', width=3),#D
#                     mode='lines',
#                     name='Taylor'))
#         fig.add_trace(
#                 go.Scatter(
#                     x=xs,
#                     y=TaylorLimit.flatten(),
#                     line=dict(color='rgb(0,0,255)', width=3),
#                     mode='lines',
#                     name='Taylor Limit'))
#         fig.add_trace(
#                 go.Scatter(
#                     x=xs,
#                     y=SplitEnt[:,i].flatten(),
#                     line=dict(color='rgb(0,255,0)', width=3),#E
#                     mode='lines',
#                     name='Legendre'))


#     # else:
#     #     fig.add_trace(
#     #             go.Scatter(
#     #                 x=xs,
#     #                 y=TaylorEnt[:,i].flatten(),
#     #                 line=dict(color=D, width=3),
#     #                 mode='lines',
#     #                 showlegend=False))
#     #     fig.add_trace(
#     #             go.Scatter(
#     #                 x=xs,
#     #                 y=HuberEnt[:,i].flatten(),
#     #                 line=dict(color=C, width=3),
#     #                 mode='lines',
#     #                 showlegend=False))
#     #     fig.add_trace(
#     #             go.Scatter(
#     #                 x=xs,
#     #                 y=SplitEnt[:,i].flatten(),
#     #                 line=dict(color=E, width=3),
#     #                 mode='lines',
#     #                 showlegend=False))
# fig.update_xaxes(title_text="5th Gaussian Component Mean")#, type="log", dtick = "D2"
# fig.update_yaxes(title_text="H(x)")#, type="log", dtick = 1
# #fig1.update_layout(paper_bgcolor='rgba(0,0,0,0)',plot_bgcolor='rgba(0,0,0,0)')
# fig.update_layout(font=dict(size=25),showlegend=False)#,legend=dict(orientation="h",xanchor="center",x=0.5,yanchor="bottom", y=0.1),showlegend=False
# fig.update_layout(plot_bgcolor='white')
# fig.update_xaxes(
#     mirror=True,
#     ticks='outside',
#     showline=True,
#     linecolor='black',
#     gridcolor='lightgrey'
# )
# fig.update_yaxes(
#     #range = [0,3.5],
#     mirror=True,
#     ticks='outside',
#     showline=True,
#     linecolor='black',
#     gridcolor='lightgrey'
# )
# # fig.write_image("HuberExample.pdf")
# fig.show()


# ################################## Huber Counter Example ###########################
# K = 50
# Ns = [0,1,2,3,4,5,6]#,7,8,9
# TrueEnt, TaylorEnt, TaylorLimit, HuberEnt, SplitEnt = HuberCounterExample(K,Ns) #EntTaylor,
# x1 = np.linspace(-3,3,K)
# x2 =np.linspace(0,1,K)+.01
# xs = x2
# fig1 = go.Figure([
#         go.Scatter(
#             x=xs,
#             y=TrueEnt.flatten(),
#             line=dict(color='rgb(0,0,0)', width=3),
#             mode='lines',
#             name='True Entropy'
#         # ),
#         # go.Scatter(
#         #     x=xs,
#         #     y=TaylorLimit.flatten(),
#         #     line=dict(color='rgb(255,0,255)', width=3),
#         #     mode='lines',
#         #     name='Our Approx. Limit'
#         )])

# for i in range(len(Ns)):#[::2]
#     C = 'rgb(%d,0,0)'%(i*180/(len(Ns)-1)+75)
#     D = 'rgb(0,%d,%d)'%(i*180/(len(Ns)-1)+75,i*180/(len(Ns)-1)+75)
#     E = 'rgb(0,%d,0)'%(i*180/(len(Ns)-1)+75)
#     if i == Ns[-3]:
#         fig1.add_trace(
#                 go.Scatter(
#                     x=xs,
#                     y=HuberEnt[:,i].flatten(),
#                     line=dict(color='rgb(255,0,0)', width=3),
#                     mode='lines',
#                     name='Huber et al.'))
#         fig1.add_trace(
#                 go.Scatter(
#                     x=xs,
#                     y=TaylorEnt[:,i].flatten(),
#                     line=dict(color='rgb(0,255,255)', width=3),
#                     mode='lines',
#                     name='Our Taylor'))
#         fig1.add_trace(
#                 go.Scatter(
#                     x=xs,
#                     y=TaylorLimit.flatten(),
#                     line=dict(color='rgb(0,0,255)', width=3),
#                     mode='lines',
#                     name='Taylor Limit'))
#         fig1.add_trace(
#                 go.Scatter(
#                     x=xs,
#                     y=SplitEnt[:,i].flatten(),
#                     line=dict(color='rgb(0,255,0)', width=3),
#                     mode='lines',
#                     name='Our Legendre'))

#     # else:
#     #     fig1.add_trace(
#     #             go.Scatter(
#     #                 x=xs,
#     #                 y=TaylorEnt[:,i].flatten(),
#     #                 line=dict(color=D, width=3),
#     #                 mode='lines',
#     #                 showlegend=False))
#     #     fig1.add_trace(
#     #             go.Scatter(
#     #                 x=xs,
#     #                 y=HuberEnt[:,i].flatten(),
#     #                 line=dict(color=C, width=3),
#     #                 mode='lines',
#     #                 showlegend=False))
#     #     fig1.add_trace(
#     #             go.Scatter(
#     #                 x=xs,
#     #                 y=SplitEnt[:,i].flatten(),
#     #                 line=dict(color=E, width=3),
#     #                 mode='lines',
#     #                 showlegend=False))
        
#         # fig1.add_trace(
#         #         go.Scatter(
#         #             x=xs,
#         #             y=SplitEnt[:,i].flatten(),
#         #             line=dict(color=E, width=3),
#         #             mode='lines',
#         #             name='OurSplit not final'))
#         # fig1.add_trace(
#         #         go.Scatter(
#         #             x=xs,
#         #             y=HuberEnt[:,i].flatten(),
#         #             line=dict(color=C, width=3),
#         #             mode='lines',
#         #             name='Huber not final'))
#         # fig1.add_trace(
#         #         go.Scatter(
#         #             x=xs,
#         #             y=TaylorEnt[:,i].flatten(),
#         #             line=dict(color=D, width=3),
#         #             mode='lines',
#         #             name='HuberSplit not final'))
# fig1.update_xaxes(title_text="2nd Gaussian Variance")#"5th Gaussian Component Mean", type="log", dtick = "D2"
# fig1.update_yaxes(title_text="H(x)")#, type="log", dtick = 1
# #fig1.update_layout(paper_bgcolor='rgba(0,0,0,0)',plot_bgcolor='rgba(0,0,0,0)')
# fig1.update_layout(font=dict(size=25),legend=dict(yanchor="bottom", y=0.01, xanchor="right", x=0.95))#,showlegend=False
# fig1.update_layout(plot_bgcolor='white')
# fig1.update_xaxes(
#     mirror=True,
#     ticks='outside',
#     showline=True,
#     linecolor='black',
#     gridcolor='lightgrey'
# )
# fig1.update_yaxes(
#     range = [-.5,1.75],
#     mirror=True,
#     ticks='outside',
#     showline=True,
#     linecolor='black',
#     gridcolor='lightgrey'
# )
# # fig1.write_image("HuberCounterExample.pdf")
# fig1.show()


################################## Huber Diverge #####################################
K = 100
Ns = [0,1,2,3]#,7,8,9
TrueEnt, TaylorEnt, TaylorLimit, HuberEnt, SplitEnt = HuberDiverge(K,Ns)
x1 = np.linspace(-3,3,K)
xs = x1
fig2 = go.Figure([
        go.Scatter(
            x=xs,
            y=TrueEnt.flatten(),
            line=dict(color='rgb(0,0,0)', width=3),
            mode='lines',
            name='True Entropy'
        # ),
        # go.Scatter(
        #     x=xs,
        #     y=TaylorLimit.flatten(),
        #     line=dict(color='rgb(255,0,255)', width=3),
        #     mode='lines',
        #     name='Our Approx. Limit'
        )])

for i in range(len(Ns)):
    C = 'rgb(0,0,%d)'%(i*155/(len(Ns)-1)+100)
    D = 'rgb(0,%d,0)'%(i*155/(len(Ns)-1)+100)
    E = 'rgb(0,%d,%d)'%(i*155/(len(Ns)-1)+100,i*155/(len(Ns)-1)+100)
    if i == Ns[-1]:
        fig2.add_trace(
                go.Scatter(
                    x=xs,
                    y=HuberEnt[:,i].flatten(),
                    line=dict(color='rgb(255,0,0)', width=3),#C
                    mode='lines',
                    name='Huber et al.'))
        fig2.add_trace(
                go.Scatter(
                    x=xs,
                    y=TaylorEnt[:,i].flatten(),
                    line=dict(color='rgb(0,255,255)', width=3),#D
                    mode='lines',
                    name='Taylor'))
        fig2.add_trace(
                go.Scatter(
                    x=xs,
                    y=TaylorLimit.flatten(),
                    line=dict(color='rgb(0,0,255)', width=3),
                    mode='lines',
                    name='Taylor Limit'))
        fig2.add_trace(
                go.Scatter(
                    x=xs,
                    y=SplitEnt[:,i].flatten(),
                    line=dict(color='rgb(0,255,0)', width=3),#E
                    mode='lines',
                    name='Legendre'))


    # else:
    #     fig.add_trace(
    #             go.Scatter(
    #                 x=xs,
    #                 y=TaylorEnt[:,i].flatten(),
    #                 line=dict(color=D, width=3),
    #                 mode='lines',
    #                 showlegend=False))
    #     fig.add_trace(
    #             go.Scatter(
    #                 x=xs,
    #                 y=HuberEnt[:,i].flatten(),
    #                 line=dict(color=C, width=3),
    #                 mode='lines',
    #                 showlegend=False))
    #     fig.add_trace(
    #             go.Scatter(
    #                 x=xs,
    #                 y=SplitEnt[:,i].flatten(),
    #                 line=dict(color=E, width=3),
    #                 mode='lines',
    #                 showlegend=False))
fig2.update_xaxes(title_text="5th Gaussian Component Mean")#, type="log", dtick = "D2"
fig2.update_yaxes(title_text="H(x)")#, type="log", dtick = 1
#fig1.update_layout(paper_bgcolor='rgba(0,0,0,0)',plot_bgcolor='rgba(0,0,0,0)')
fig2.update_layout(font=dict(size=25),showlegend=False)#,legend=dict(orientation="h",xanchor="center",x=0.5,yanchor="bottom", y=0.1),showlegend=False
fig2.update_layout(plot_bgcolor='white')
fig2.update_xaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='lightgrey'
)
fig2.update_yaxes(
    #range = [0,3.5],
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='lightgrey'
)
fig2.write_image("HuberDiverge.pdf")
fig2.show()