# running the score neural operator on 2D-GMMs datasets
import os
import jax
import jax.numpy as np
from jax import random, grad, vmap, pmap, jit, lax, devices
from jax.example_libraries import stax, optimizers
from jax.example_libraries.stax import Dense, Gelu
from jax.flatten_util import ravel_pytree
import numpy as onp
import itertools
from functools import partial
from torch.utils import data
from tqdm.auto import trange
import matplotlib.pyplot as plt
from scipy import integrate
import matplotlib.patches as patches



os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']="False"
def marginal_prob_std(t, sigma):
  """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.

  Args:    
    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.  
  
  Returns:
    The standard deviation.
  """      
  return np.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
  """Compute the diffusion coefficient of our SDE.

  Args:
    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.
  
  Returns:
    The vector of diffusion coefficients.
  """
  return sigma**t



# Data generator
class DataGenerator(data.Dataset):
    def __init__(self, u, x, sigma_SDE,
                 batch_size1=64, batch_size2=256, rng_key=random.PRNGKey(1234)):
        'Initialization'
        self.u = u # input sample
        self.x = x # location
        self.N_sample = x.shape[1]
        self.sigma_SDE = sigma_SDE
        self.N = u.shape[0]
        self.batch_size1 = batch_size1
        self.batch_size2 = batch_size2
        self.key = rng_key


    def __getitem__(self, index):
        'Generate one batch of data'
        self.key, subkey = random.split(self.key)
        inputs = self.__data_generation(subkey)
        return inputs

    def __data_generation(self, key):
        'Generates data containing batch_size samples'
        _, subkey = random.split(key)
        idx = random.choice(subkey, self.N, (self.batch_size1,), replace=False)
        idx_2 = random.choice(key, self.N_sample, (self.batch_size2,), replace=False)

        x1 = self.x[idx,:,:]
        x = x1[:,idx_2,:]
        u = self.u[idx,:]
        sigma_SDE = self.sigma_SDE[idx]
        # data are stored as numpy array and convert them into jax during training
        inputs = (u, x, sigma_SDE)
        return inputs


    
def init_NN(Q, activation=Gelu):
    layers = []
    num_layers = len(Q)
    if num_layers < 2:
        net_init, net_apply = stax.serial()
    else:
        for i in range(0, num_layers-2):
            layers.append(Dense(Q[i+1]))
            layers.append(activation)
        layers.append(Dense(Q[-1]))
        net_init, net_apply = stax.serial(*layers)
    return net_init, net_apply


# Define the neural net
def Encoder(layers):
  ''' Vanilla MLP'''
  def init(rng_key):
      def init_layer(key, d_in, d_out):
          k1, k2 = random.split(key)
          glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.)
          W = glorot_stddev * random.normal(k1, (d_in, d_out))
          b = np.zeros(d_out)
          return W, b
      key, *keys = random.split(rng_key, len(layers))
      params = list(map(init_layer, keys, layers[:-1], layers[1:]))
      return params
  def apply(params, inputs):
      for W, b in params[:-1]:
          outputs = np.dot(inputs, W) + b
          inputs = jax.nn.relu(outputs)
      W, b = params[-1]
      outputs = np.dot(inputs, W) + b
      return outputs
  return init, apply


def Decoder(layers):
    def init(rng_key):
        def init_layer(key, d_in, d_out):
            k1, k2 = random.split(key)
            glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.)
            W = glorot_stddev * random.normal(k1, (d_in, d_out))
            b = np.zeros(d_out)
            return W, b
        key, *keys = random.split(rng_key, len(layers))
        params = list(map(init_layer, keys, layers[:-2], layers[1:-1]))
        # Last layer
        k1, k2 = random.split(key)
        params.append(init_layer(k1, layers[-2], layers[-1]))
        params.append(init_layer(k2, layers[-2], layers[-1]))
        return params
    def apply(params, inputs):
        for W, b in params[:-2]:
            outputs = np.dot(inputs, W) + b
            inputs = np.tanh(outputs)
        W, b = params[-1]
        mu = np.dot(inputs, W) + b
        W, b = params[-2]
        Sigma = np.dot(inputs, W) + b
        return mu, Sigma
    return init, apply


# Define the model
class NOMAD:
    def __init__(self, branch_layers, trunk_layers, out_layers, embed_dim = 256, scale=10):    
        # Network initialization and evaluation functions
        self.branch_init, self.branch_apply = init_NN(branch_layers, activation=Gelu)
        self.trunk_init, self.trunk_apply = init_NN(trunk_layers, activation=Gelu)     
        self.out_init, self.out_apply = init_NN(out_layers, activation=Gelu) 
        # Initialize
        in_shape = (-1, branch_layers[0])
        out_shape, branch_params = self.branch_init(random.PRNGKey(1234), in_shape)
        
        in_shape = (-1, trunk_layers[0])
        out_shape, trunk_params = self.trunk_init(random.PRNGKey(4321), in_shape)

        in_shape = (-1, out_layers[0])
        out_shape, out_params = self.out_init(random.PRNGKey(4325), in_shape)

        #W = np.ones((100,dim))
        params = (branch_params, trunk_params, out_params)

        # Use optimizers to set optimizer initialization and update functions
        self.opt_init, \
        self.opt_update, \
        self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 
                                                                      decay_steps=1000, 
                                                                      decay_rate=0.99))
        self.opt_state = self.opt_init(params)
        self.W = scale * random.normal(random.PRNGKey(1225), (3, embed_dim//2))
        # Used to restore the trained model parameters
        _, self.unravel_params = ravel_pytree(params)

        self.itercount = itertools.count()

        # Loggers
        self.loss_log = []


    def operator_net(self, params, u, x, t):
        branch_params, trunk_params, out_params = params
        t = t.reshape(-1)
        y = np.concatenate([x, t])
        z = np.concatenate([np.cos(np.dot(y,self.W)), np.sin(np.dot(y,self.W))])
        B = self.branch_apply(branch_params, u).reshape(-1)
        
        C = self.trunk_apply(trunk_params, z).reshape(-1)
        D = B * C
        D = np.tanh(D)
        outputs = self.out_apply(out_params, D)
        return  outputs
    

    # compute the loss fun using the method in Song et al.    
    def loss_operator(self, params, batch, rng):
       # x:(batch_size1, batch_size2, dim), u:(batch_size1,Nx,Nx,2), t:(batch_size1,batch_size2), sigma_SDE:(batch_size1,1), std:(batch_size1, batch_size2)
        u, x, sigma_SDE = batch
        sigma_SDE = sigma_SDE.reshape(-1)
        rng, step_rng = random.split(rng)
        t = random.uniform(step_rng, (x.shape[0],x.shape[1]), minval=1e-5, maxval=1.)
        rng, step_rng = random.split(rng)
        z = random.normal(step_rng, x.shape)
        std = marginal_prob_std(t, sigma_SDE[:,None])
        perturbed_x = x + z * std[:,:,None]
        score =  vmap(vmap(self.operator_net, (None, None, 0, 0)),(None,0,0,0))(params, u, perturbed_x, t) 
        loss = np.mean(np.sum((score * std[:,:,None] + z)**2, axis=(1,2)))
        return loss
    
    @partial(jit, static_argnums=(0,))
    def step(self, i, opt_state, batch, rng):
        params = self.get_params(opt_state)
        g = grad(self.loss_operator)(params, batch, rng)
        return self.opt_update(i, g, opt_state)

    # Optimize parameters in a loop
    def train(self, dataset, nIter = 10000):
        # Define data iterators
        data_iterator = iter(dataset)
        rng = random.PRNGKey(2)
        pbar = trange(nIter)
        # Main training loop
        for it in pbar:
            rng, subrng = random.split(rng, 2)
            # Fetch data
            batch = next(data_iterator)
            self.opt_state = self.step(next(self.itercount), self.opt_state, batch, subrng)
            
            if it % 10 == 0:
                params = self.get_params(self.opt_state)
       
                # Compute loss
                loss_value = self.loss_operator(params, batch, subrng)

                # Store loss
                self.loss_log.append(loss_value)
                # Print loss
                pbar.set_postfix({'Loss': loss_value})





def k_fn(x, y, sigma=1.2):
    diffs = (x - y) / sigma
    r2 = np.sum(abs(diffs))
    return np.exp(-0.5 * r2)

kernel = vmap(vmap(k_fn, in_axes=(None,0)), in_axes=(0,None))

@jit
def kme(x, y):
    return np.mean(kernel(x,y))

def euclid_distance(x,y):
    xx=np.dot(x,x)
    yy=np.dot(y,y)
    xy=np.dot(x,y)
    return np.sqrt(abs(xx+yy-2.0*xy))

distance = vmap(vmap(euclid_distance, in_axes=(None,0)), in_axes=(0,None))





def generate_mu_train_1(key, num_mixtures):
    # left row right col
    num_mixtures = num_mixtures//2
    keys = random.split(key, 4)
    y = random.choice(keys[0], np.arange(-6,7,2.4), (1,), replace=False)
    y = np.tile(y, (num_mixtures, 1))
    x = random.choice(keys[1], np.arange(-6,0,2.4), (2, 1), replace=False)
    mu_1 = np.concatenate([x,y],axis=1)
    x = random.choice(keys[2], np.arange(1.2,7,2.4), (1,), replace=False)
    x = np.tile(x, (num_mixtures, 1))
    y = random.choice(keys[3], np.arange(-6,7,2.4), (2, 1), replace=False)
    mu_2 = np.concatenate([x,y],axis=1)
    mu = np.concatenate([mu_1,mu_2],axis=0)
    return mu

def generate_mu_train_2(key, num_mixtures):
    # left col right row
    num_mixtures = num_mixtures//2
    keys = random.split(key, 4)
    y = random.choice(keys[0], np.arange(-6,7,2.4), (1,), replace=False)
    y = np.tile(y, (num_mixtures, 1))
    x = random.choice(keys[1], np.arange(1.2,7,2.4), (2, 1), replace=False)
    mu_2 = np.concatenate([x,y],axis=1)
    x = random.choice(keys[2], np.arange(-6,0,2.4), (1,), replace=False)
    x = np.tile(x, (num_mixtures, 1))
    y = random.choice(keys[3], np.arange(-6,7,2.4), (2, 1), replace=False)
    mu_1 = np.concatenate([x,y],axis=1)
    mu = np.concatenate([mu_1,mu_2],axis=0)
    return mu

def generate_mu_test_1(key, num_mixtures):
    # left row right row
    num_mixtures = num_mixtures//2
    keys = random.split(key, 4)
    y = random.choice(keys[0], np.arange(-6,7,2.4), (1,), replace=False)
    y = np.tile(y, (num_mixtures, 1))
    x = random.choice(keys[1], np.arange(1.2,7,2.4), (2, 1), replace=False)
    mu_2 = np.concatenate([x,y],axis=1)
    y = random.choice(keys[2], np.arange(-6,7,2.4), (1,), replace=False)
    y = np.tile(y, (num_mixtures, 1))
    x = random.choice(keys[3], np.arange(-6,0,2.4), (2, 1), replace=False)
    mu_1 = np.concatenate([x,y],axis=1)
    mu = np.concatenate([mu_1,mu_2],axis=0)
    return mu

def generate_mu_test_2(key, num_mixtures):
    # left col right col
    num_mixtures = num_mixtures//2
    keys = random.split(key, 4)
    x = random.choice(keys[0], np.arange(-6,0,2.4), (1,), replace=False)
    x = np.tile(x, (num_mixtures, 1))
    y = random.choice(keys[1], np.arange(-6,7,2.4), (2, 1), replace=False)
    mu_2 = np.concatenate([x,y],axis=1)
    x = random.choice(keys[2], np.arange(1.2,7,2.4), (1,), replace=False)
    x = np.tile(x, (num_mixtures, 1))
    y = random.choice(keys[3], np.arange(-6,7,2.4), (2, 1), replace=False)
    mu_1 = np.concatenate([x,y],axis=1)
    mu = np.concatenate([mu_1,mu_2],axis=0)
    return mu

def generate_one_Gaussian_mixture_train_1(key, num_samples, num_mixtures, dim):
    key, subkey = random.split(key)
    mu_list = generate_mu_train_1(subkey, num_mixtures).reshape(num_mixtures, dim)

    samples_array  = []
    for mu in mu_list:
        key, subkey = random.split(key)
        samples = random.uniform(subkey, (num_samples//num_mixtures,2), minval=mu-1.2, maxval=mu+1.2)
        samples_array.append(samples)
    samples_array = np.array(samples_array).reshape(-1,2)
 
    return samples_array

def generate_one_Gaussian_mixture_train_2(key, num_samples, num_mixtures, dim):
    key, subkey = random.split(key)
    mu_list = generate_mu_train_2(subkey, num_mixtures).reshape(num_mixtures, dim)

    samples_array  = []
    for mu in mu_list:
        key, subkey = random.split(key)
        samples = random.uniform(subkey, (num_samples//num_mixtures,2), minval=mu-1.2, maxval=mu+1.2)
        samples_array.append(samples)
    samples_array = np.array(samples_array).reshape(-1,2)

    return samples_array

def generate_one_Gaussian_mixture_test_1(key, num_samples, num_mixtures, dim):
    key, subkey = random.split(key)
    mu_list = generate_mu_test_1(subkey, num_mixtures).reshape(num_mixtures, dim)
    samples_array  = []
    for mu in mu_list:
        key, subkey = random.split(key)
        samples = random.uniform(subkey, (num_samples//num_mixtures,2), minval=mu-1.2, maxval=mu+1.2)
        samples_array.append(samples)
    samples_array = np.array(samples_array).reshape(-1,2)
    return samples_array

def generate_one_Gaussian_mixture_test_2(key, num_samples, num_mixtures, dim):
    key, subkey = random.split(key)
    mu_list = generate_mu_test_2(subkey, num_mixtures).reshape(num_mixtures, dim)
    samples_array  = []
    for mu in mu_list:
        key, subkey = random.split(key)
        samples = random.uniform(subkey, (num_samples//num_mixtures,2), minval=mu-1.2, maxval=mu+1.2)
        samples_array.append(samples)
    samples_array = np.array(samples_array).reshape(-1,2)
    return samples_array

def generate_training_data(key, num_examples, num_samples, num_mixtures, dim):
    keys = random.split(key, num_examples)
    x_train_1 = vmap(generate_one_Gaussian_mixture_train_1,(0,None,None,None))(keys[:num_examples//2], num_samples, num_mixtures, dim).reshape(num_examples//2, num_samples, dim)
    x_train_2 = vmap(generate_one_Gaussian_mixture_train_2,(0,None,None,None))(keys[num_examples//2:], num_samples, num_mixtures, dim).reshape(num_examples//2, num_samples, dim)
    x_train = np.concatenate([x_train_1,x_train_2], axis=0)
    return x_train

def generate_testing_data(key, num_examples, num_samples, num_mixtures, dim):
    keys = random.split(key, num_examples)
    x_test_1 = vmap(generate_one_Gaussian_mixture_test_1,(0,None,None,None))(keys[:num_examples//2], num_samples, num_mixtures, dim).reshape(num_examples//2, num_samples, dim)
    x_test_2 = vmap(generate_one_Gaussian_mixture_test_2,(0,None,None,None))(keys[num_examples//2:], num_samples, num_mixtures, dim).reshape(num_examples//2, num_samples, dim)
    x_test = np.concatenate([x_test_1,x_test_2], axis=0)
    return x_test


num_devices = len(devices())
num_examples = 2000 
num_samples = 2000
num_mixtures = 4         
num_sensors = 10   # probability embedding u dimension
dim = 2         # dim of training examples (dataset)


key = random.PRNGKey(1512)
# Create data set
batch_size1 = 64  
batch_size2 = 256  

keys = random.split(key, num_devices)
num_examples_per_device = num_examples // num_devices
gen_fn = lambda key: generate_training_data(key, num_examples_per_device, num_samples, num_mixtures, dim)
x_train = pmap(gen_fn, axis_name='i')(keys)

pretrained = False
if pretrained:
  x_train = np.load("data/2DGMM_x_train.npy")
else:
  x_train = x_train.reshape(num_examples, num_samples, dim)
  np.save("data/2DGMM_x_train.npy", x_train)

print(x_train.shape)

# Normalize
x_train_mean, x_train_std = np.mean(x_train, axis=1), np.std(x_train, axis=1)
x_train = (x_train - x_train_mean[:,None,:])/x_train_std[:,None,:]

if not pretrained:
    rows, cols = np.triu_indices(num_examples)
    idx = np.stack([rows, cols]).T
    num_entries = idx.shape[0]
    entries_per_device = num_entries // num_devices
    batch_size = 5
    num_iter = entries_per_device / batch_size
    if num_iter.is_integer():
        print('Number of devices: %d' % (num_devices))
        print('Total number of entries: %d' % (num_entries))
        print('Entries per device: %d' % (entries_per_device))
        print('Entries per device per batch: %d' % (batch_size))
        print('Number of iterations per device: %d' % (num_iter))
    else: 
        raise ValueError('Entries per device not divisble by batch-size. Try choosing a different batch-size.')

    # Compute entries
    idx = idx.reshape(num_devices, int(num_iter), batch_size, 2)
    kme_fn = lambda idx: kme(x_train[idx[0],:,:], x_train[idx[1],:,:])
    def body_fn(carry, idx):
        out = vmap(kme_fn)(idx)
        return out, out

    scan_fn = lambda idx: lax.scan(body_fn, np.zeros(batch_size), idx)[1]
    entries = pmap(scan_fn, axis_name='i')(idx)

    # Construct covariance matrix
    cov = np.zeros((num_examples, num_examples))
    cov = cov.at[rows,cols].set(entries.reshape(num_entries,))
    cov = cov + cov.T - np.diag(np.diag(cov))
    np.save("data/2DGMM_cov.npy", cov)
else:
    cov = np.load("data/2DGMM_cov.npy")
    print(cov.shape)



sigma_SDE = 2*np.ones(num_examples,)
# Centered covariance matrix
matrix_mean = np.mean(cov)
row_mean = -np.mean(cov, axis=0)
row_mean_tile = np.tile(row_mean,(num_examples,1))
cov = cov + matrix_mean + row_mean_tile + row_mean_tile.T

# Compute eigendecomposition
evals, evecs = np.linalg.eigh(cov)
idx = np.abs(evals).argsort()[::-1]
evals = evals[idx]
evecs = evecs[:,idx]

# take first num_sensors largest evals
evals_t = evals[:num_sensors]
evecs_t = evecs[:,:num_sensors]
u_train = evecs_t * evals_t[None,:]
np.save('data/2DGMM_u_train.npy',u_train)
ratio = np.sum(evals_t)/np.sum(evals)
print("Variance Explained by first {} components: {}".format(num_sensors,ratio))


dataset = DataGenerator(u_train, x_train, sigma_SDE, batch_size1, batch_size2)

branch_layers = [num_sensors, 500, 500,  500, 500,  500, 256]
trunk_layers =  [256, 500, 500,  500, 500,  500, 256]
out_layers =  [256, 500, 500,  500, 500,  500, 2]
model = NOMAD(branch_layers, trunk_layers, out_layers)

if not pretrained:
    model.train(dataset, nIter=120000)
    params = model.get_params(model.opt_state)
    flat_params, _ = ravel_pytree(params)
    np.save('weights/2DGMM_params.npy',flat_params)
else:
    flat_params = np.load('weights/2DGMM_params.npy')
    params = model.unravel_params(flat_params)
   
"""Generating samples using NOMAD"""
from scipy import integrate
def ode_sampler_exact(rng, params,
               u,
               dim,
               sigma,
                batch_size, 
                atol=1e-5, 
                rtol=1e-5,                 
                z=None,
                eps=1e-5):

  sample_shape = (batch_size,dim)
  time_shape = (batch_size,)
  # Create the latent code
  if z is None:
    rng, step_rng = jax.random.split(rng)
    z = jax.random.normal(step_rng, sample_shape)
    init_x = z * marginal_prob_std(1.,sigma)
  else:
    init_x = z

  def score_eval_wrapper(sample, time_steps):
    """A wrapper for evaluating the score-based model for the black-box ODE solver."""
 
    sample = np.asarray(sample, dtype=np.float32).reshape(sample_shape)
    time_steps = np.asarray(time_steps, dtype=np.float32).reshape(time_shape) 
    score =  vmap(model.operator_net,(None,None,0,0))(params, u, sample, time_steps)
    return onp.asarray(score).reshape((-1,)).astype(onp.float64) 
        
 
  def ode_func(t, x):
    """The ODE function for use by the ODE solver."""
    time_steps = onp.ones(batch_size) * t    
    g = diffusion_coeff(t,sigma)
    return  -0.5 * (g**2) * score_eval_wrapper(x, time_steps)


  # Run the black-box ODE solver.
  res = integrate.solve_ivp(ode_func, (1., eps), onp.asarray(init_x).reshape(-1).astype(onp.float64),
                            rtol=rtol, atol=atol, method='RK45')  
  print(f"Number of function evaluations: {res.nfev}")
  x = np.asarray(res.y[:, -1]).reshape(sample_shape)

  return x


def MMD(x, y):
  m, n = x.shape[0], y.shape[0]
  return (np.sum(kernel(x, x)) - np.sum(vmap(k_fn, in_axes=(0,0))(x, x)))/m/(m-1) + (np.sum(kernel(y, y)) - np.sum(vmap(k_fn, in_axes=(0,0))(y, y)))/n/(n-1) - 2*kme(x,y)

## apply trained SNO on training sets

MMD_list_train = []


gbar =  np.arange(0, num_examples)
for kk in gbar:
    samples = x_train[kk].reshape(num_samples,dim)
    # Normalize
    samples_mean, samples_std = x_train_mean[kk], x_train_std[kk]
    sigmaSDE = 2

    u = u_train[kk]
    key, subkey1, subkey2, subkey3, subkey4 = random.split(key,5)
    # generate new samples using NOMAD
    samples_new = ode_sampler_exact(subkey1, 
                  params,
                  u,
                  dim,
                  sigmaSDE,
                    num_samples, 
                    atol=1e-5, 
                    rtol=1e-5,                 
                    z=None,
                    eps=1e-5)

    samples_new = samples_new * samples_std + samples_mean
    samples = samples * samples_std + samples_mean

    mmd = MMD(samples_new, samples)
    print("X_train{}, Maximum Mean Discrepancy {}".format(kk,mmd))

    MMD_list_train.append(mmd)

    plt.figure()
    hb = plt.hexbin(samples[:, 0], samples[:, 1], gridsize=50, cmap='Greens', mincnt=1, alpha=0.7)
    plt.gca().set_facecolor('#2B0042')
    plt.xlim([-7.5, 7.5]);
    plt.ylim([-7.5, 7.5]);
    plt.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02)
    plt.axis('off')
    plt.gca().set_aspect('equal')
    plt.savefig('./samples/'+'GMM_train'+str(kk)+'.png', bbox_inches='tight', pad_inches=0, facecolor='#2B0042')
    plt.close()

    plt.figure()
    hb = plt.hexbin(samples_new[:, 0], samples_new[:, 1], gridsize=50, cmap='Greens', mincnt=1, alpha=0.7)
    plt.gca().set_facecolor('#2B0042')
    plt.xlim([-7.5, 7.5]);
    plt.ylim([-7.5, 7.5]);
    plt.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02)
    plt.axis('off')
    plt.gca().set_aspect('equal')
    plt.savefig('./samples/'+'GMM_train_new_'+str(kk)+'.png', bbox_inches='tight', pad_inches=0, facecolor='#2B0042')
    plt.close()

print('Train: averaged MMD over {} generated distributions is {}'.format(num_examples, np.mean(np.array(MMD_list_train))))
    

## apply SNO on testing sets

key = random.PRNGKey(242)

num_testing_examples  = 1000
MMD_list_test = []
gbar = np.arange(0, num_testing_examples)

keys = random.split(key, num_devices)
num_examples_per_device = num_testing_examples // num_devices
gen_fn = lambda key: generate_testing_data(key, num_examples_per_device, num_samples, num_mixtures, dim)
x_test = pmap(gen_fn, axis_name='i')(keys).reshape(num_testing_examples, num_samples, dim)
num_splits = 10

for kk in gbar:
    samples = x_test[kk].reshape(num_samples,dim)
    # Normalize
    samples_mean, samples_std = samples.mean(0), samples.std(0)
    samples = (samples - samples_mean)/samples_std
      

    sigmaSDE = 2

    samples_inner_product = np.concatenate([ vmap(vmap(kme,(None,0)),(0,None))(x_train[k*num_examples//num_splits:(k+1)*num_examples//num_splits].reshape(num_examples//num_splits,num_samples,dim),samples.reshape(-1,num_samples,dim))   for k in range(num_splits)    ],axis=0)

    test_total = -np.mean(samples_inner_product.flatten())*np.ones((num_examples,1)) + row_mean.reshape(num_examples,1) + matrix_mean * np.ones((num_examples,1)) +  samples_inner_product
    u = evecs_t.T @ test_total
    u = u.T
    key, subkey1, subkey2, subkey3, subkey4 = random.split(key,5)
    # generate new samples using NOMAD
    samples_new = ode_sampler_exact(subkey1, 
                  params,
                  u,
                  dim,
                  sigmaSDE,
                    num_samples, 
                    atol=1e-5, 
                    rtol=1e-5,                 
                    z=None,
                    eps=1e-5)

    samples_new = samples_new * samples_std + samples_mean
    samples = samples * samples_std + samples_mean

    mmd = MMD(samples_new, samples)
    print("X_test{}, Maximum Mean Discrepancy {}".format(kk,mmd))
    MMD_list_test.append(mmd)

    plt.figure()
    hb = plt.hexbin(samples[:, 0], samples[:, 1], gridsize=50, cmap='Greens', mincnt=1, alpha=0.7)
    plt.gca().set_facecolor('#2B0042')
    plt.xlim([-7.5, 7.5]);
    plt.ylim([-7.5, 7.5]);
    plt.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02)
    plt.axis('off')
    plt.gca().set_aspect('equal')
    plt.savefig('./samples/'+'GMM_test'+str(kk)+'.png', bbox_inches='tight', pad_inches=0, facecolor='#2B0042')
    plt.close()

    plt.figure()
    hb = plt.hexbin(samples_new[:, 0], samples_new[:, 1], gridsize=50, cmap='Greens', mincnt=1, alpha=0.7)
    plt.gca().set_facecolor('#2B0042')
    plt.xlim([-7.5, 7.5]);
    plt.ylim([-7.5, 7.5]);
    plt.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02)
    plt.axis('off')
    plt.gca().set_aspect('equal')
    plt.savefig('./samples/'+'GMM_test_new_'+str(kk)+'.png', bbox_inches='tight', pad_inches=0, facecolor='#2B0042')  
    plt.close()
    
print('Test: averaged MMD over {} generated distributions is {}'.format(num_testing_examples, np.mean(np.array(MMD_list_test))))




