## running the score-based generative model on 2D-GMMs as comparison

import os
import jax
import jax.numpy as np
from jax import random, grad, vmap, pmap, jit, lax, devices
from jax.example_libraries import stax, optimizers
from jax.example_libraries.stax import Dense, Gelu
from jax.flatten_util import ravel_pytree
import numpy as onp
import itertools
from functools import partial
from torch.utils import data
from tqdm.auto import trange
import matplotlib.pyplot as plt
from scipy import integrate
import matplotlib.patches as patches

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']="False"
def marginal_prob_std(t, sigma):
  """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.

  Args:    
    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.  
  
  Returns:
    The standard deviation.
  """      
  return np.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
  """Compute the diffusion coefficient of our SDE.

  Args:
    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.
  
  Returns:
    The vector of diffusion coefficients.
  """
  return sigma**t

# Data generator
class DataGenerator(data.Dataset):
    def __init__(self, x, batch_size=256, rng_key=random.PRNGKey(1234)):
        'Initialization'
        self.x = x # location
        self.N = x.shape[0]
        self.batch_size = batch_size
        self.key = rng_key


    def __getitem__(self, index):
        'Generate one batch of data'
        self.key, subkey = random.split(self.key)
        inputs = self.__data_generation(subkey)
        return inputs

    def __data_generation(self, key):
        'Generates data containing batch_size samples'
        idx = random.choice(key, self.N, (self.batch_size,), replace=False)
        x = self.x[idx,:]
        return x


    
def init_NN(Q, activation=Gelu):
    layers = []
    num_layers = len(Q)
    if num_layers < 2:
        net_init, net_apply = stax.serial()
    else:
        for i in range(0, num_layers-2):
            layers.append(Dense(Q[i+1]))
            layers.append(activation)
        layers.append(Dense(Q[-1]))
        net_init, net_apply = stax.serial(*layers)
    return net_init, net_apply


def MLP(layers):
    def init(rng_key):
        def init_layer(key, d_in, d_out):
            k1, k2 = random.split(key)
            glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.)
            W = glorot_stddev * random.normal(k1, (d_in, d_out))
            b = np.zeros(d_out)
            return W, b
        key, *keys = random.split(rng_key, len(layers))
        params = list(map(init_layer, keys, layers[:-2], layers[1:-1]))
        # Last layer
        k1, k2 = random.split(key)
        params.append(init_layer(k1, layers[-2], layers[-1]))
        params.append(init_layer(k2, layers[-2], layers[-1]))
        return params
    def apply(params, inputs):
        for W, b in params[:-2]:
            outputs = np.dot(inputs, W) + b
            inputs = np.tanh(outputs)
        W, b = params[-1]
        mu = np.dot(inputs, W) + b
        W, b = params[-2]
        Sigma = np.dot(inputs, W) + b
        return mu, Sigma
    return init, apply

# Define the model
class Scalenet:
    def __init__(self, layers):    
        # Network initialization and evaluation functions
        self.init, self.apply = MLP(layers)
    
        
        # Initialize
        params = self.init(random.PRNGKey(1234))
    
        # Use optimizers to set optimizer initialization and update functions
        self.opt_init, \
        self.opt_update, \
        self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 
                                                                      decay_steps=1000, 
                                                                      decay_rate=0.99))
        self.opt_state = self.opt_init(params)

        # Used to restore the trained model parameters
        _, self.unravel_params = ravel_pytree(params)

        self.itercount = itertools.count()

        # Loggers
        self.loss_log = []

    

    # compute the loss fun using the method in Song et al.    
    def loss_operator(self, params, dataset):
        u, mean, std = dataset
        mean_pred, std_pred =  vmap(self.apply, (None,0))(params, u) 
        loss = np.mean((mean_pred.flatten()-mean.flatten()) ** 2 + (std_pred.flatten()-std.flatten()) ** 2)
        return loss
    
    @partial(jit, static_argnums=(0,))
    def step(self, i, opt_state, batch):
        params = self.get_params(opt_state)
        g = grad(self.loss_operator)(params, batch)
        return self.opt_update(i, g, opt_state)

    # Optimize parameters in a loop
    def train(self, dataset, nIter = 100000):
        # Define data iterators
        pbar = trange(nIter)
        # Main training loop
        for it in pbar:
            # Fetch data
            self.opt_state = self.step(next(self.itercount), self.opt_state, dataset)
            
            if it % 10 == 0:
                params = self.get_params(self.opt_state)
       
                # Compute loss
                loss_value = self.loss_operator(params, dataset)

                # Store loss
                self.loss_log.append(loss_value)
                # Print loss
                pbar.set_postfix({'Loss': loss_value})

# Define the model
class ScoreNet:
    def __init__(self, layers):    
        # Network initialization and evaluation functions
        self.init, self.apply = init_NN(layers, activation=Gelu)
    
        
        # Initialize
        in_shape = (-1, layers[0])
        out_shape, params = self.init(random.PRNGKey(1234), in_shape)
    
     
        # Use optimizers to set optimizer initialization and update functions
        self.opt_init, \
        self.opt_update, \
        self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 
                                                                      decay_steps=1000, 
                                                                      decay_rate=0.99))
        self.opt_state = self.opt_init(params)

        # Used to restore the trained model parameters
        _, self.unravel_params = ravel_pytree(params)

        self.itercount = itertools.count()

        # Loggers
        self.loss_log = []


    def operator_net(self, params, x, t):
        t = t.reshape(-1)
        y = np.concatenate([x, t])
        outputs = self.apply(params, y)
        return  outputs
    

    # compute the loss fun using the method in Song et al.    
    def loss_operator(self, params, x, rng):
       # x:(batch_size, dim), t:(batch_size1,)
        rng, step_rng = random.split(rng)
        t = random.uniform(step_rng, (x.shape[0],), minval=1e-5, maxval=1.)
        rng, step_rng = random.split(rng)
        z = random.normal(step_rng, x.shape)
        std = marginal_prob_std_fn(t)
        perturbed_x = x + z * std[:,None]
        score =  vmap(self.operator_net, (None,0,0))(params, perturbed_x, t) 
        loss = np.mean(np.sum((score * std[:,None] + z)**2, axis=1))
        return loss
    
    @partial(jit, static_argnums=(0,))
    def step(self, i, opt_state, batch, rng):
        params = self.get_params(opt_state)
        g = grad(self.loss_operator)(params, batch, rng)
        return self.opt_update(i, g, opt_state)

    # Optimize parameters in a loop
    def train(self, dataset, nIter = 10000):
        # Define data iterators
        data_iterator = iter(dataset)
        rng = random.PRNGKey(2)
        pbar = trange(nIter)
        # Main training loop
        for it in pbar:
            rng, subrng = random.split(rng, 2)
            # Fetch data
            batch = next(data_iterator)
            self.opt_state = self.step(next(self.itercount), self.opt_state, batch, subrng)
            
            if it % 10 == 0:
                params = self.get_params(self.opt_state)
       
                # Compute loss
                loss_value = self.loss_operator(params, batch, subrng)

                # Store loss
                self.loss_log.append(loss_value)
                # Print loss
                pbar.set_postfix({'Loss': loss_value})


     
dim = 2       

key = random.PRNGKey(1512)
# Create data set
batch_size = 256
x_train = np.load("data/2DGMM_x_test_1.npy")
num_samples = x_train.shape[0]
print(x_train.shape)

# Normalize
x_train_mean, x_train_std = np.mean(x_train, axis=0), np.std(x_train, axis=0)
x_train = (x_train - x_train_mean[None,:])/x_train_std[None,:]


sigma_SDE = 2
marginal_prob_std_fn = partial(marginal_prob_std, sigma = sigma_SDE)
diffusion_coeff_fn = partial(diffusion_coeff, sigma = sigma_SDE)
dataset = DataGenerator(x_train, batch_size)
layers =  [dim+1, 500, 500,  500, 500,  500, dim]
model = ScoreNet(layers)
model.train(dataset, nIter=120000)
params = model.get_params(model.opt_state)

"""Generating samples using NOMAD"""
from scipy import integrate
def ode_sampler_exact(rng, params,
               dim,
                batch_size, 
                atol=1e-5, 
                rtol=1e-5,                 
                z=None,
                eps=1e-5):

  sample_shape = (batch_size,dim)
  time_shape = (batch_size,)
  # Create the latent code
  if z is None:
    rng, step_rng = jax.random.split(rng)
    z = jax.random.normal(step_rng, sample_shape)
    init_x = z * marginal_prob_std_fn(1.)
  else:

    init_x = z

  def score_eval_wrapper(sample, time_steps):
    """A wrapper for evaluating the score-based model for the black-box ODE solver."""
 
    sample = np.asarray(sample, dtype=np.float32).reshape(sample_shape)
    time_steps = np.asarray(time_steps, dtype=np.float32).reshape(time_shape) 
    score =  vmap(model.operator_net,(None,0,0))(params,  sample, time_steps)
    return onp.asarray(score).reshape((-1,)).astype(onp.float64) 
        
 
  def ode_func(t, x):
    """The ODE function for use by the ODE solver."""
    time_steps = onp.ones(batch_size) * t    
    g = diffusion_coeff_fn(t)
    return  -0.5 * (g**2) * score_eval_wrapper(x, time_steps)


  # Run the black-box ODE solver.
  res = integrate.solve_ivp(ode_func, (1., eps), onp.asarray(init_x).reshape(-1).astype(onp.float64),
                            rtol=rtol, atol=atol, method='RK45')  
  print(f"Number of function evaluations: {res.nfev}")
  x = np.asarray(res.y[:, -1]).reshape(sample_shape)

  return x



key = random.PRNGKey(242)


# generate new samples using NOMAD
samples_new = ode_sampler_exact(key, 
                params,
                dim,
                num_samples, 
                atol=1e-5, 
                rtol=1e-5,                 
                z=None,
                eps=1e-5)


samples_new = samples_new * x_train_std[None,:] +  x_train_mean[None,:]
print(samples_new.shape)
np.save('data/GMM_score_test_new_'+str(2)+'.npy',samples_new)

plt.figure()
hb = plt.hexbin(samples_new[:, 0], samples_new[:, 1], gridsize=50, cmap='Greens', mincnt=1, alpha=0.7)
plt.gca().set_facecolor('#009688')
plt.xlim([-7.5, 7.5]);
plt.ylim([-7.5, 7.5]);
plt.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02)
plt.axis('off')
plt.gca().set_aspect('equal')
plt.savefig('./samples/'+'GMM_score_test_new_'+str(2)+'.png', bbox_inches='tight', pad_inches=0, facecolor='#009688') 
