# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
# pytype: skip-file
"""Various sampling methods."""
import functools

import jax
import jax.numpy as jnp
import jax.random as random
import abc
import flax
import numpy as np

from models.utils import get_score_fn
from scipy import integrate
import sde_lib
from utils import batch_mul, from_flattened_numpy, to_flattened_numpy
from utils_qm9 import assert_correctly_masked, sample_center_gravity_zero_gaussian_with_mask, sample_gaussian_with_mask, assert_mean_zero_with_mask, remove_mean_with_mask
import logging

from models import utils as mutils
from sde_lib import VPSDE, RFSDE

_PREDICTORS = {}


def register_predictor(cls=None, *, name=None):
  """A decorator for registering predictor classes."""

  def _register(cls):
    if name is None:
      local_name = cls.__name__
    else:
      local_name = name
    if local_name in _PREDICTORS:
      raise ValueError(f'Already registered model with name: {local_name}')
    _PREDICTORS[local_name] = cls
    return cls

  if cls is None:
    return _register
  else:
    return _register(cls)


def get_predictor(name):
  return _PREDICTORS[name]


def get_sampling_qm9_fn(config, sde, eps, deq):
  sampler_name = config.sampling.method
  assert sampler_name.lower() in ['pc', 'ode']
  if sampler_name.lower() == 'ode':
    sampling_fn = get_ode_sampler_qm9(sde=sde,
                                      rtol=config.sampling.tol,
                                      atol=config.sampling.tol,
                                      eps=eps,
                                      aug_dim=config.model.aug_dim)
  elif sampler_name.lower() == 'pc':
    predictor = get_predictor(config.sampling.predictor.lower())
    sampling_fn = get_pc_sampler_qm9(sde=sde,
                                     predictor=predictor,
                                     eps=eps,
                                     save_trajectory=config.eval.save_trajectory,
                                     aug_dim=config.model.aug_dim)
  else:
    raise ValueError(f"Sampler name {sampler_name} unknown.")

  return sampling_fn


def get_ode_sampler_qm9(sde, rtol, atol, eps, aug_dim):

  @jax.pmap
  def drift_fn(state, x, t):
    """Get the drift function of the reverse-time SDE."""
    score_fn = get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=True)
    rsde = sde.reverse(score_fn, probability_flow=True)
    return rsde.sde(x, t)[0]


  def ode_sampler(prng, pstate, z=None):
    """The probability flow ODE sampler with black-box ODE solver.

    Args:
      prng: An array of random state. The leading dimension equals the number of devices.
      pstate: Replicated training state for running on multiple devices.
      z: If present, generate samples from latent code `z`.
    Returns:
      Samples, and the number of function evaluations.
    """
    # Initial sample
    rng = flax.jax_utils.unreplicate(prng)
    rng, step_rng = random.split(rng)
    if z is None:
      x = sde.prior_sampling(step_rng, (jax.local_device_count(),) + shape)
    else:
      x = z

    def ode_func(t, x):
      x = from_flattened_numpy(x, (jax.local_device_count(),) + shape)
      vec_t = jnp.ones((x.shape[0], x.shape[1])) * t
      drift = drift_fn(pstate, x, vec_t)
      return to_flattened_numpy(drift)

    # Black-box ODE solver for the probability flow ODE
    solution = integrate.solve_ivp(ode_func, (sde.T, eps), to_flattened_numpy(x),
                                   rtol=rtol, atol=atol, method='RK45')
    nfe = solution.nfev
    x = jnp.asarray(solution.y[:, -1]).reshape((jax.local_device_count(),) + shape)

    return x, nfe



  @jax.pmap
  def drift_fn(state, batch):
    pass

    def ode_sampler():
      def ode_fn():

        pass

  return ode_sampler

def get_pc_sampler_qm9(sde, predictor, eps, save_trajectory, aug_dim):
  def sampler(rng, state, cond_image=None, **kwargs):
    """
      rng: RNG
      state: state
      cond_image: input for conditional_generation

    Return
      x: output
      stat: output statistics, including trajectory
    """
    pass

  return jax.pmap(sampler, axis_name='batch')



def get_sampling_fn(config, sde, shape, inverse_scaler, eps, **kwargs):
  """Create a sampling function.

  Args:
    config: A `ml_collections.ConfigDict` object that contains all configuration information.
    sde: A `sde_lib.SDE` object that represents the forward SDE.
    model: A `flax.linen.Module` object that represents the architecture of a time-dependent score-based model.
    shape: A sequence of integers representing the expected shape of a single sample.
    inverse_scaler: The inverse data normalizer function.
    eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability.

  Returns:
    A function that takes random states and a replicated training state and outputs samples with the
      trailing dimensions matching `shape`.
  """
  gen_reflow = kwargs['gen_reflow'] if 'gen_reflow' in kwargs else False

  sampler_name = config.sampling.method
  # Probability flow ODE sampling with black-box ODE solvers
  if sampler_name.lower() == 'ode':
    sampling_fn = get_ode_sampler(sde=sde,
                                  shape=shape,
                                  inverse_scaler=inverse_scaler,
                                  rtol=config.sampling.tol,
                                  atol=config.sampling.tol,
                                  eps=eps,
                                  gen_reflow=gen_reflow,
                                  reflow_t=config.training.reflow_t if 'reflow_t' in config.training else 1,
                                  adaptive_interval=config.training.adaptive_interval,
                                  vp_to_rf=config.training.vp_to_rf)
  # Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases.
  elif sampler_name.lower() == 'pc':
    predictor = get_predictor(config.sampling.predictor.lower())
    if config.sampling.predictor == 'rf_solver':
      nfe_multiplier = 1
    elif config.sampling.predictor == 'rf_solver_heun':
      nfe_multiplier = 2
    else:
      raise ValueError()
    sampling_fn = get_pc_sampler(sde=sde,
                                 shape=shape,
                                 predictor=predictor,
                                 inverse_scaler=inverse_scaler,
                                 probability_flow=config.sampling.probability_flow,
                                 eps=eps,
                                 save_trajectory=config.eval.save_trajectory,
                                 aug_dim=config.model.aug_dim)
  else:
    raise ValueError(f"Sampler name {sampler_name} unknown.")

  return sampling_fn


class Predictor(abc.ABC):
  """The abstract class for a predictor algorithm."""

  def __init__(self, sde, score_fn, probability_flow=False):
    super().__init__()
    self.sde = sde
    # Compute the reverse SDE/ODE
    if isinstance(sde, sde_lib.VPSDE):
      self.rsde = sde.reverse(score_fn, probability_flow)
    self.score_fn = score_fn
    self.probability_flow = probability_flow

  @abc.abstractmethod
  def update_fn(self, rng, x, t):
    """One update of the predictor.

    Args:
      rng: A JAX random state.
      x: A JAX array representing the current state
      t: A JAX array representing the current time step.

    Returns:
      x: A JAX array of the next state.
      x_mean: A JAX array. The next state without random noise. Useful for denoising.
    """
    pass


@register_predictor(name='rf_solver')
class RFPredictor(Predictor):
  def __init__(self, sde, score_fn, probability_flow=True):
    super().__init__(sde, score_fn, probability_flow)
  
  def update_fn(self, rng, x, t):
    current_t, next_t = t

    score = self.score_fn(x, current_t)
    x = x + batch_mul(score, next_t - current_t)
    return x, x


@register_predictor(name='rf_solver_heun')
class RFHeunPredictor(Predictor):
  def __init__(self, sde, score_fn, probability_flow=True):
    super().__init__(sde, score_fn, probability_flow)

  def update_fn(self, rng, x, t):
    current_t, next_t = t

    # algorithm
    intvl = next_t - current_t
    score = self.score_fn(x, current_t)
    x_mid = x + batch_mul(score, intvl / 2)
    score_mid = self.score_fn(x_mid, next_t)
    x = x + (batch_mul(score_mid, intvl / 2) + batch_mul(score, intvl / 2))
    return x, x
#############################################################################################################
def shared_predictor_update_fn(rng, state, x, t, sde, predictor, probability_flow, eps=None):
  """A wrapper that configures and returns the update function of predictors."""
  score_fn = mutils.get_score_fn(sde, state, state.opt_state_ema.ema, train=False, eps=eps)
  predictor_obj = predictor(sde, score_fn, probability_flow)
  return predictor_obj.update_fn(rng, x, t)


def get_pc_sampler(sde, shape, predictor, inverse_scaler, probability_flow=False, denoise=True, eps=1e-3,
                   save_trajectory=False, aug_dim=0):
  """Create a Predictor-Corrector (PC) sampler.

  Args:
    sde: An `sde_lib.SDE` object representing the forward SDE.
    model: A `flax.linen.Module` object that represents the architecture of a time-dependent score-based model.
    shape: A sequence of integers. The expected shape of a single sample.
    predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.
    inverse_scaler: The inverse data normalizer.
    probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
    denoise: If `True`, add one-step denoising to the final samples.
    eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
    save_trajectory: For analysis.

  Returns:
    A sampling function that takes random states, and a replcated training state and returns samples as well as
    the number of function evaluations during sampling.
  """
  # Create predictor & corrector update functions
  predictor_update_fn = functools.partial(shared_predictor_update_fn,
                                          sde=sde,
                                          predictor=predictor,
                                          probability_flow=probability_flow,
                                          eps=eps)

  def rf_sampler(rng, state, cond_image=None):
    rng, step_rng = random.split(rng)
    stats = dict()

    if save_trajectory:
      def loop_body(i, val):
        rng, x, x_mean, x_all = val
        vec = jnp.full((x.shape[0],), i / sde.N), jnp.full((x.shape[0],), (i + 1) / sde.N)
        rng, step_rng = random.split(rng)
        new_x, x_mean = predictor_update_fn(step_rng, state, x, vec)
        x_all.append(jnp.expand_dims(new_x, axis=0))
        x_all.pop(0)
        return rng, new_x, x_mean, x_all

    else:
      def loop_body(i, val):
        rng, x, x_mean = val
        vec = jnp.full((x.shape[0],), i / sde.N), jnp.full((x.shape[0],), (i + 1) / sde.N)
        rng, step_rng = random.split(rng)
        new_x, x_mean = predictor_update_fn(step_rng, state, x, vec)
        return rng, new_x, x_mean

    # get initial image
    if cond_image is None: # generative modeling case
      rng, step_rng = random.split(rng)
      initial_image = sde.prior_sampling(step_rng, shape)
    else: # distribution matching
      initial_image = cond_image

    # Augment some dimensions
    rng, step_rng = random.split(rng)
    initial_y0 = random.normal(step_rng, shape[:-1] + (aug_dim,))
    initial_image = jnp.concatenate([initial_image, initial_y0], axis=-1) # (B, H, W, C + Cy)

    if save_trajectory:
      # get x and trajectory (x_all)
      x_all = [jnp.zeros([1, *initial_image.shape])] * (sde.N + 1) # Dummy definition of all x
      x_all.append(jnp.expand_dims(initial_image, axis=0))
      x_all.pop(0)
      _, x, _, x_all = jax.lax.fori_loop(0, sde.N, loop_body, (rng, initial_image, initial_image, x_all))

      # analyze trajectory
      x_all = jnp.concatenate(x_all, axis=0) # array of x_t, (sde.N + 1, B, H, W, C + Cy)

      # Calculate straightness
      diff = x - initial_image # (B, H, W, C + Cy)
      curv = (x_all[1:] - x_all[:-1]) * sde.N # (sde.N, B, H, W, C + Cy)
      Cx = shape[-1]

      sum_str = jnp.sum(jnp.square(diff - curv), axis=(2, 3))
      # straightness (all dim)
      straightness_by_t = jnp.mean(jnp.sum(sum_str, axis=-1), axis=1) # || (x1 - x0) - d/dt xt ||_2^2 (sde.N,)
      straightness = jnp.mean(straightness_by_t) # (1,)
      # straightness (x)
      x_straightness_by_t = jnp.mean(jnp.sum(sum_str[..., :Cx], axis=-1), axis=1) # || (x1 - x0) - d/dt xt ||_2^2 (sde.N,)
      x_straightness = jnp.mean(x_straightness_by_t) # (1,)
      if aug_dim > 0:
        # straightness (y)
        y_straightness_by_t = jnp.mean(jnp.sum(sum_str[..., Cx:], axis=-1), axis=1) # || (x1 - x0) - d/dt xt ||_2^2 (sde.N,)
        y_straightness = jnp.mean(y_straightness_by_t) # (1,)

      stats['straightness'] = straightness           # (1,)
      stats['x_straightness'] = x_straightness       # (1,)
      if aug_dim > 0:
        stats['y_straightness'] = y_straightness       # (1,)
      stats['straightness_by_t'] = straightness_by_t # (sde.N,)
      stats['trajectory'] = x_all                    # (sde.N + 1, B, H, W, C + Cy)

    else:
      _, x, _ = jax.lax.fori_loop(0, sde.N, loop_body, (rng, initial_image, initial_image))

    stats['nfe'] = sde.N

    x = x[..., 0:shape[-1]]
    initial_image = initial_image[..., :shape[-1]]
    return (inverse_scaler(x), inverse_scaler(initial_image)), stats

  return jax.pmap(rf_sampler, axis_name='batch')

# TODO: simplify get_ode_sampler
# def get_ode_sampler(sde, shape, inverse_scaler, rtol=1e-5, atol=1e-5, method='RK45', eps=1e-3,
#                     gen_reflow=False, reflow_t=1, adaptive_interval=False, vp_to_rf=False):
#   """Probability flow ODE sampler with the black-box ODE solver.

#   Args:
#     sde: An `sde_lib.SDE` object that represents the forward SDE.
#     model: A `flax.linen.Module` object that represents the architecture of the score-based model.
#     shape: A sequence of integers. The expected shape of a single sample.
#     inverse_scaler: The inverse data normalizer.
#     denoise: If `True`, add one-step denoising to final samples.
#     rtol: A `float` number. The relative tolerance level of the ODE solver.
#     atol: A `float` number. The absolute tolerance level of the ODE solver.
#     method: A `str`. The algorithm used for the black-box ODE solver.
#       See the documentation of `scipy.integrate.solve_ivp`.
#     eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.

#   Returns:
#     A sampling function that takes random states, and a replicated training state and returns samples
#     as well as the number of function evaluations during sampling.

#   Note: the ode sampler is not available of generating randomized interval in SeqRF.
#   """
#   if adaptive_interval:
#     optimal_interval = np.load('assets/c10_adaptive_index_500.npy', allow_pickle=True)[()][reflow_t]
#     optimal_interval = 1 - np.array(optimal_interval) / optimal_interval[-1] # 1: noise, 0: data.
#     assert optimal_interval.shape[0] == reflow_t + 1
#     timestep_dict = get_timestep(reflow_t, sde.N, optimal_interval)
#   else:
#     if isinstance(sde, sde_lib.RFSDE):
#       timestep_dict = get_timestep(reflow_t, sde.N, vp_to_rf=vp_to_rf)
#     elif isinstance(sde, sde_lib.VPSDE):
#       timestep_dict = get_timestep(reflow_t, sde.N, sde='vp')

#   if isinstance(sde, sde_lib.RFSDE):
#     @jax.pmap
#     def drift_fn(state, x, t, index):
#       """Get the drift function of the reverse-time SDE."""
#       score_fn = get_score_fn(sde, state, state.opt_state_ema.ema, train=False, eps=eps, vp_to_rf=vp_to_rf)
#       # return score_fn(x, t, index) # (sde.T - eps) multiplied for bug in the original RF code.
#       return - score_fn(x, t, index) # (sde.T - eps) multiplied for bug in the original RF code. (vp_to_rf case)
#   elif isinstance(sde, sde_lib.VPSDE):
#     @jax.pmap
#     def drift_fn(state, x, t, index):
#       """Get the drift function of the reverse-time SDE."""
#       score_fn = get_score_fn(sde, state, state.opt_state_ema.ema, train=False, eps=eps)
#       rsde = sde.reverse(score_fn, probability_flow=True)
#       return rsde.sde(x, t, index)[0]

#   def ode_sampler(prng, pstate, cond_image=None, **kwargs):
#     """The probability flow ODE sampler with black-box ODE solver.

#     Args:
#       prng: An array of random state. The leading dimension equals the number of devices.
#       pstate: Replicated training state for running on multiple devices.
#       cond_image: conditional image for generating initial `noisy image`. shape (n_tpu, B // n_tpu, H, W, C)
#     Returns:
#       Samples, and the number of function evaluations.
#     """
#     def ode_func_(t, x, index):
#       """
#       Input
#         t:     time, float.
#         x:     flattened numpy array, shape (B * H * W * C,)
#         index: head index, int.

#       Return
#         Flattened drift function output, shape (B * H * W * C,)
#       """
#       x = from_flattened_numpy(x, (jax.local_device_count(),) + shape) # (n_tpu, B // n_tpu, H, W, C)
#       if isinstance(sde, sde_lib.RFSDE) and (not vp_to_rf):
#         t = rescale_time(t, to='diffusion') # (n_tpu, B), (eps -> 1) scale to (1 -> 0) scale.
#       else: # VP-SDE: Do not rescale time coefficients
#         pass
#       vec_ = jnp.ones((x.shape[0], x.shape[1])) * t # (n_tpu, B // n_tpu)
#       drift = drift_fn(pstate, x, vec_, jnp.ones_like(vec_, jnp.int32) * index)
#       return to_flattened_numpy(drift)

#     ode_func = []
#     for idx in range(reflow_t):
#       ode_func.append(functools.partial(ode_func_, index=idx))

#     # Initial sample
#     rng = flax.jax_utils.unreplicate(prng)
#     rng, step_rng = random.split(rng)
#     if cond_image is None:
#       # generate from noise.
#       initial_image = sde.prior_sampling(step_rng, (jax.local_device_count(),) + shape) # shape (n_tpu, B // n_tpu, H, W, C)
#     else:
#       # initial image is linear interpolation between noise and image. (e.g. noisy image) shape: (n_tpu, B // n_tpu, H, W, C)
#       initial_image = jnp.reshape(cond_image, (jax.local_device_count(),) + shape)

#     if gen_reflow:
#       assert kwargs['batch_idx'] is not None
#       assert isinstance(kwargs['batch_idx'], int)
#       assert kwargs['batch_idx'] in jnp.arange(reflow_t)
#       start_t_batch = timestep_dict['interval'][kwargs['batch_idx']]
#       end_t_batch = timestep_dict['interval'][kwargs['batch_idx'] + 1]
#       rng, step_rng = random.split(rng)
#       noise = sde.prior_sampling(step_rng, (jax.local_device_count(),) + shape) # shape (n_tpu, B // n_tpu, H, W, C)
#       if isinstance(sde, sde_lib.RFSDE):
#         initial_image = (1. - start_t_batch) * cond_image + start_t_batch * noise
#       else: # VP-SDE
#         mean_vp, std_vp = jax.pmap(sde.marginal_prob)(cond_image, jnp.full((cond_image.shape[0], cond_image.shape[1]), start_t_batch))
#         initial_image = jax.pmap(lambda mean, std, noise: mean + batch_mul(std, noise))(mean_vp, std_vp, noise)
#     else:
#       rng, step_rng = random.split(rng)
#       initial_image = sde.prior_sampling(step_rng, (jax.local_device_count(),) + shape) # shape (n_tpu, B // n_tpu, H, W, C)

#     # Black-box ODE solver for the probability flow ODE
#     solution = []
#     mid_images = [initial_image]
#     n_points_per_seq = []
#     if isinstance(sde, sde_lib.RFSDE) and (not vp_to_rf):
#       interval_resized = rescale_time(timestep_dict['interval'], to='rf') # (1 --> 0) to (0.001 --> 1) scale.
#     else: # VP-SDE
#       interval_resized = timestep_dict['interval']
#     if gen_reflow: # generate reflow dataset
#       solution.append(integrate.solve_ivp(ode_func[kwargs['batch_idx']], (interval_resized[kwargs['batch_idx']], interval_resized[kwargs['batch_idx'] + 1]),
#                                           to_flattened_numpy(initial_image), rtol=rtol, atol=atol, method=method))
#       n_points_per_seq.append(solution[-1].t.shape[0]) # (1,)
#     else: # sample
#       current_image = to_flattened_numpy(initial_image)
#       for rf_div in range(reflow_t):
#         solution.append(integrate.solve_ivp(ode_func[rf_div], (interval_resized[rf_div], interval_resized[rf_div + 1]),
#                                             current_image, rtol=rtol, atol=atol, method=method))

#         current_image = solution[-1].y[:, -1]
#         mid_images.append(from_flattened_numpy(current_image, (jax.local_device_count(),) + shape)) # (n_tpu, B // n_tpu, H, W, C)
#         n_points_per_seq.append(solution[-1].t.shape[0])

#     if not gen_reflow:
#       nfe = jnp.sum(jnp.asarray([s.nfev for s in solution]))
#       n_points_per_seq = jnp.asarray(n_points_per_seq) # (1,) or (reflow_t,), resp.
      
#       mid_images = [jnp.expand_dims(m, axis=-1) for m in mid_images]
#       mid_images = jnp.concatenate(mid_images, axis=-1) # (n_tpu, B // n_tpu, H, W, C, reflow_t + 1)
#       t = jnp.asarray(jnp.concatenate([s.t[:-1] for s in solution])) # (n_points,)
#       t = jnp.concatenate([t, jnp.expand_dims(solution[-1].t[-1], -1)])
#       x_all = jnp.asarray(jnp.concatenate([s.y[:, :-1] for s in solution], axis=-1)) # (B * H * W * C, n_points)
#       x_all = jnp.concatenate([x_all, solution[-1].y[:, -1:]], axis=-1)
#       x_all = x_all.reshape((jax.local_device_count(),) + shape + (x_all.shape[1],)) # x_all: jnp.array of the trajectory, shape (n_tpu, B // n_tpu, H, W, C, n_points)
#       x = x_all[:, :, :, :, :, -1] # shape (n_tpu, B, H, W, C)

#       # Calculate statistics using curvature statistics, if required.
#       dt = t[1:] - t[:-1]                                                                          # tnext - t, (rf scale) (n_points - 1,)
#       x_all = jnp.transpose(x_all, (5, 0, 1, 2, 3, 4))
#       x_all = jnp.reshape(x_all,
#         (x_all.shape[0], x_all.shape[1] * x_all.shape[2]) + x_all.shape[3:])                       # (n_points, B, H, W, C)

#     else:
#       x = solution[-1].y[:, -1:]
#       t = solution[-1].t
#       del solution
#       x = jnp.reshape(x, initial_image.shape)

#     stats = dict()
#     # In the following cases, B <- n_tpu * B.
#     if not gen_reflow:
#       curv_diff = x_all[1:] - x_all[:-1]                                                         # x{tnext} - xt,  (n_points - 1, B, H, W, C)
#       curv_derivative = batch_mul(curv_diff, 1. / dt)                                            # (x{tnext} - xt) / (tnext - t), (n_points - 1, B, H, W, C)
      
#       # Calculate (marginal) straightness.
#       marginal_diff = x - initial_image                                                          # x0 - x1, (n_tpu, B, H, W, C)
#       marginal_diff = jnp.reshape(marginal_diff, 
#         (marginal_diff.shape[0] * marginal_diff.shape[1],) + marginal_diff.shape[2:])            # x0 - x1, (B, H, W, C)
#       straightness_gap = jnp.sum(jnp.square(marginal_diff - curv_derivative), axis=(2, 3, 4))    # l2sq((x0 - x1) - d/dt x_t), (n_points - 1, B) # x0 <- x at time 1, x1 <- x at time eps
#       straightness = jnp.mean(jnp.sum(batch_mul(dt, straightness_gap), axis=0))                  # (1,)
#       straightness_by_t = jnp.mean(straightness_gap, axis=1)                                     # (n_points - 1,)

#       # Calculate sequential straightness.
#       mid_images = jnp.transpose(mid_images, (5, 0, 1, 2, 3, 4))
#       mid_images = jnp.reshape(mid_images,
#         (mid_images.shape[0], mid_images.shape[1] * mid_images.shape[2]) + mid_images.shape[3:]) # (reflow_t + 1, B, H, W, C)
#       seq_diff = mid_images[1:] - mid_images[:-1]                                                # (reflow_t, B, H, W, C)
#       seq_straightness_gap = jnp.zeros((0, jax.local_device_count() * shape[0]))
#       current_index = 0
#       if isinstance(sde, sde_lib.RFSDE):
#         interval_resized = rescale_time(timestep_dict['interval'], 'rf')
#       else:
#         interval_resized = timestep_dict['interval']
#       for r in range(reflow_t):
#         next_index = current_index + n_points_per_seq[r]
#         mid_marginal_diff = seq_diff[r] / (interval_resized[r + 1] - interval_resized[r])       # (B, H, W, C)
#         mid_curv_derivative = curv_derivative[current_index:next_index]                         # (div[r], B, H, W, C)
#         part_gap = jnp.sum(
#           jnp.square(mid_marginal_diff - mid_curv_derivative), axis=(2, 3, 4))                  # (n_div[r], B)
#         seq_straightness_gap = jnp.concatenate([seq_straightness_gap, part_gap], axis=0)        # finally (n_points - 1, B)
#         current_index = next_index

#       seq_straightness = jnp.mean(jnp.sum(batch_mul(dt, seq_straightness_gap), axis=0))         # (1,)
#       seq_straightness_by_t = jnp.mean(seq_straightness_gap, axis=1)                            # (n_points,)

#       # For RK45 solver, NFE = (n_points - 1) * 6.
#       stats['straightness'] = straightness                   # (1,) (c.f. ODE solver: (n_tpu,) after passing pmap)
#       stats['straightness_by_t'] = straightness_by_t         # (n_points - 1,)
#       stats['nfe'] = nfe                                     # (1,)
#       stats['interval'] = t                                  # (n_points,)

#       # save trajectory if needed
#       if ('save_trajectory' in kwargs) and kwargs['save_trajectory']:
#         stats['trajectory'] = x_all

#       del curv_diff, dt, curv_derivative, marginal_diff, mid_images, seq_diff, seq_straightness_gap, part_gap, mid_marginal_diff, mid_curv_derivative, x_all

#     # x, initial_image: shape (n_tpu, B // n_tpu, H, W, C)
#     # times: shape (n_tpu, B // n_tpu)
#     return ((inverse_scaler(x), inverse_scaler(initial_image)), \
#       (jnp.ones((x.shape[0], x.shape[1])) * rescale_time(t[-1], "diffusion"), jnp.ones((x.shape[0], x.shape[1])) * rescale_time(t[0], "diffusion")), stats)

#   return ode_sampler

def sample_qm9(config, rng, n_samples, sde, dequantizer, dataset_info, nodes_per_sample=10):
  N = dataset_info['max_n_nodes']
  B = len(nodes_per_sample)

  # Set node mask
  node_mask = jnp.zeros((B, N))
  for i in range(B):
    node_mask = node_mask.at[i, 0:nodes_per_sample[i]].set(1)

  # Set edge mask
  edge_mask = jnp.expand_dims(node_mask, axis=1) * jnp.expand_dims(node_mask, axis=2)
  diag_mask = jnp.expand_dims(~jnp.eye(edge_mask.shape[2], dtype=bool), axis=0)
  edge_mask *= diag_mask
  edge_mask = jnp.reshape(edge_mask, (B * N * N, 1))
  node_mask = jnp.expand_dims(node_mask, axis=2)

  rng, step_rng = jax.random.split(rng)
  noise_x = sample_center_gravity_zero_gaussian_with_mask(
    step_rng,
    shape=(n_samples, N, 3),
    node_mask=node_mask,
  )
  rng, step_rng = jax.random.split(rng)
  noise_h = sample_gaussian_with_mask(
    step_rng,
    shape=(n_samples, N, 5),
    node_mask=node_mask,
  )
  noise = jnp.concatenate([noise_x, noise_h], axis=2)
  assert_mean_zero_with_mask(noise_x, node_mask)
  
  # sample p(x, h | z0 = noise).
  x, h = sample_fn(dequantizer, noise, node_mask) # x, h = self.sample_p_xh_given_z0(dequantizer, z_, node_mask)
  assert_mean_zero_with_mask(x, node_mask)
  
  # max_cog: check if x is biased.
  max_cog = jnp.max(jnp.abs(jnp.sum(x, axis=1, keepdims=True)))
  if max_cog > 5e-2:
    logging.info(f"Warning cog drift with error {max_cog:.3f}. Projecting the positions down.")
    x = remove_mean_with_mask(x, node_mask)
  
  # Now x, h are sampled.

  one_hot, charges = h['categorical'], h['integer']
  if config.data.include_charges:
    assert_correctly_masked(charges, node_mask)

  return {
    'one_hot'  : one_hot,
    'charges'  : charges,
    'x'        : x,
    'node_mask': node_mask,
  }
