# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
"""Training and evaluation for score-based generative models. """

import gc
import io
import os
import time
from typing import Any
import copy
import timeit

import flax
import flax.jax_utils as flax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
import tensorflow_gan as tfgan
import logging
import functools
import scipy

from flax.training import orbax_utils
import orbax.checkpoint
# Keep the import below for registering all model definitions
from models import ddpm, ncsnv2, ncsnpp, iddpm
import losses
import sampling
import utils
from models import utils as mutils
import datasets
import evaluation
# import likelihood
import sde_lib
from absl import flags
import datetime
import wandb
import matplotlib.pyplot as plt
import jax_smi

from flax.traverse_util import flatten_dict, unflatten_dict
from flax.training import orbax_utils, checkpoints
from configs.datasets_config import get_dataset_info, get_args
import ml_collections
from qm9.dataset import retrieve_dataloaders
import utils_qm9

FLAGS = flags.FLAGS


def train(config, workdir, log_name):
  """Runs the training pipeline.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  # ====================================================================================================== #
  # Get logger
  jax_smi.initialise_tracking()

  # wandb_dir: Directory of wandb summaries
  current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
  if log_name is None:
    wandb.init(project="anonymous-repo", name=f"{config.model.name}-{current_time}", entity="anonymous", resume="allow")
  else:
    wandb.init(project="anonymous-repo", name=log_name, entity="anonymous", resume="allow")
  wandb_dir = os.path.join(workdir, "wandb")
  tf.io.gfile.makedirs(wandb_dir)
  wandb.config = config

  # Create directories for experimental logs
  sample_dir = os.path.join(workdir, "samples")
  tf.io.gfile.makedirs(sample_dir)
  rng = jax.random.PRNGKey(config.seed)
  # ====================================================================================================== #
  # Initialize model.
  rng, step_rng = jax.random.split(rng)
  if config.model.name != 'egnn':
    state = mutils.init_train_state(step_rng, config)
  else:
    # QM9: Dataset information
    dataset_info = get_dataset_info(config.data.dataset, config.data.remove_h)
    args = get_args()
    dataloaders, charge_scale = retrieve_dataloaders(args, config, evaluation=False, jit=False)
    dummy_loader = next(iter(dataloaders['train']))
    dummy_input_batch, _ = utils_qm9.preprocess_batch(dummy_loader, config, evaluation=False, jit=False)
    dummy_input_batch = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],) + x.shape[2:]), dummy_input_batch) # un-pmap
    rng, step_rng = jax.random.split(rng)
    all_transformed_data = utils_qm9.transform_data(step_rng, config, dummy_input_batch)

    # jax.tree_util.tree_map(lambda x: print(x.shape), all_transformed_data['egnn'])
    # edge_mask: (373248, 1)
    # edges:     [(373248,), (373248,)]
    # h:         (13824, 7)
    # node_mask: (13824, 1)
    # x:         (13824, 3)
    shape1, shape2 = all_transformed_data['egnn']['h'].shape
    all_transformed_data['egnn']['h'] = jnp.concatenate([all_transformed_data['egnn']['h'], jnp.zeros((shape1, config.model.aug_dim))], axis=-1) # consider aug_dim

    state = mutils.init_train_state(step_rng, config, dummy_batch=all_transformed_data['egnn'])

  checkpoint_dir = os.path.join(workdir, "checkpoints")
  checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta")
  state_dict = {
    'model': state
  }

  # get manager options, and restore checkpoints.
  mgr_options = orbax.checkpoint.CheckpointManagerOptions(
    save_interval_steps=config.training.snapshot_freq,
    create=True)
  ckpt_mgr = orbax.checkpoint.CheckpointManager(
    checkpoint_dir,
    orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options)

  # Resume training when intermediate checkpoints are detected
  mgr_meta_options = orbax.checkpoint.CheckpointManagerOptions(
    save_interval_steps=config.training.snapshot_freq_for_preemption,
    max_to_keep=1,
    create=True)
  ckpt_meta_mgr = orbax.checkpoint.CheckpointManager(
    checkpoint_meta_dir,
    orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_meta_options)
  if ckpt_meta_mgr.latest_step() is not None:
    logging.info(f"Restore checkpoint-meta from step {ckpt_meta_mgr.latest_step()}.")
    state_dict = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(checkpoint_meta_dir, f"{ckpt_meta_mgr.latest_step()}", "default"), target=state_dict)
    state = state_dict['model']

  # `state.step` is JAX integer on the GPU/TPU devices
  initial_step = int(state.step)

  # Define dense_state: a model with single linear layer for random projection of auxiliary variables.
  if config.model.name != 'egnn':
    Cin, Cout = config.data.num_channels, config.model.aug_dim
  else:
    # EGNN case
    Cin, Cout = 3 + len(dataset_info['atom_decoder']) + int(config.model.include_charges), config.model.aug_dim

  if os.path.exists(os.path.join(checkpoint_dir, "dense.npy")):
    kernel_arr = np.load(os.path.join(checkpoint_dir, "dense.npy"))
  else:
    rng, step_rng = jax.random.split(rng)
    kernel_arr = jax.random.normal(step_rng, shape=(Cin, Cout))
<<<<<<< HEAD
    kernel_arr /= scipy.linalg.norm(kernel_arr, 'fro') # normalize w.r.t. Frobenius norm
=======
    kernel_arr /= jnp.sqrt(Cout)
<<<<<<< HEAD
>>>>>>> 2e63c4f66606012b35834534aec6512406a9c67a
=======
>>>>>>> 3e09f769e4c44ae1dedd5f0ddf773fd064c1ebc6
>>>>>>> c7ef253909942a6ffee30d856d0bff6b4997b796
    np.save(os.path.join(checkpoint_dir, "dense.npy"), kernel_arr)
  assert kernel_arr.shape == (Cin, Cout)
<<<<<<< HEAD
  if Cout > 0:
    dense_state = nn.Dense(features=Cout, use_bias=False)
    dense_state.init(step_rng, jnp.ones(config.data.data_shape[:-1] + (config.data.data_shape[-1] + config.model.aug_dim,)))
    dense_fn = functools.partial(dense_state.apply, {'params': {'kernel': kernel_arr}}) # usage: y1 = lambda1 * y0 + lambda2 * dense_fn(x0 + x1)
  else:
    dense_fn = None
=======
  dense_state = nn.Dense(features=Cout, use_bias=False)

  rng, step_rng = jax.random.split(rng)
  if config.model.name != 'egnn':
    dense_state.init(step_rng, jnp.ones(config.data.data_shape[:-1] + (config.data.data_shape[-1] + config.model.aug_dim,)))
    dense_fn = functools.partial(dense_state.apply, {'params': {'kernel': kernel_arr}}) # usage: y1 = lambda1 * y0 + lambda2 * dense_fn(x0 + x1)
  else:
    dummy_h, dummy_x = all_transformed_data['egnn']['h'], all_transformed_data['egnn']['x']
    dummy_xh = jnp.concatenate([dummy_x, dummy_h], axis=-1)
    dense_state.init(step_rng, dummy_xh)
    dense_fn = functools.partial(dense_state.apply, {'params': {'kernel': kernel_arr}})
>>>>>>> 36e81044525741a90d4dfb59822b4d823641a2e0
  # ====================================================================================================== #
  # Build data iterators
  if config.model.name != 'egnn':
    train_ds, eval_ds, _ = datasets.get_dataset(config,
                                                additional_dim=config.training.n_jitted_steps,
                                                uniform_dequantization=config.data.uniform_dequantization)
    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
    eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
    # Create data normalizer and its inverse
    scaler = datasets.get_data_scaler(config)
    inverse_scaler = datasets.get_data_inverse_scaler(config)
  else:
    dataloaders, charge_scale = retrieve_dataloaders(args, config, evaluation=False, jit=True)
    dummy_loader = next(iter(dataloaders['train']))
    # dummy_input_batch, _ = utils_qm9.preprocess_batch(dummy_loader, config, evaluation=False, jit=False)
    train_ds = dataloaders['train']
    eval_ds = dataloaders['valid']
    test_ds = dataloaders['test']
    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)
  # ====================================================================================================== #
  # Setup SDEs
  if config.training.sde.lower() == 'rfsde':
    sde = sde_lib.RFSDE(N=config.eval.num_scales)
    sampling_eps = 1e-3
  elif config.training.sde.lower() == 'vpsde':
    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.eval.num_scales)
    sampling_eps = 1e-3
  else:
    raise NotImplementedError(f"SDE {config.training.sde} unknown.")
  # ====================================================================================================== #
  # Build one-step training and evaluation functions
  optimize_fn = losses.optimization_manager(config)
  train_step_fn = losses.get_step_fn(config, sde, state, train=True, optimize_fn=optimize_fn, eps=sampling_eps, dense_fn=dense_fn)
  # Pmap (and jit-compile) multiple training steps together for faster running
  p_train_step = jax.pmap(functools.partial(jax.lax.scan, train_step_fn), axis_name='batch', donate_argnums=1)

  eval_step_fn = losses.get_step_fn(config, sde, state, train=False, optimize_fn=optimize_fn, eps=sampling_eps, dense_fn=dense_fn)
  # Pmap (and jit-compile) multiple evaluation steps together for faster running
  p_eval_step = jax.pmap(functools.partial(jax.lax.scan, eval_step_fn), axis_name='batch', donate_argnums=1)

  # Building sampling functions
  if config.model.name != 'egnn':
    if config.training.snapshot_sampling:
      sampling_shape = (config.eval.batch_size // jax.local_device_count(), config.data.image_size,
                        config.data.image_size, config.data.num_channels)
      sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)
  
    else:
      logging.info("Sampling function is not required.")
  
  else:
    pass
    # if config.training.snapshot_sampling:
    #   sampling_fn = sampling.get_sampling_qm9_fn(config, sde, sampling_eps)

  pstate = flax_utils.replicate(state)
  num_train_steps = config.training.n_iters
  # ====================================================================================================== #
  # In case there are multiple hosts (e.g., TPU pods), only log to host 0
  if jax.process_index() == 0:
    logging.info("Starting training loop at step %d." % (initial_step,))
  rng = jax.random.fold_in(rng, jax.process_index())

  # JIT multiple training steps together for faster training
  n_jitted_steps = config.training.n_jitted_steps
  # Must be divisible by the number of steps jitted together
  assert config.training.log_freq % n_jitted_steps == 0 and \
         config.training.snapshot_freq_for_preemption % n_jitted_steps == 0 and \
         config.training.eval_freq % n_jitted_steps == 0 and \
         config.training.snapshot_freq % n_jitted_steps == 0, "Missing logs or checkpoints!"
<<<<<<< HEAD
  # ====================================================================================================== #
  # Main training or generation part
  for step in range(initial_step, num_train_steps + 1, config.training.n_jitted_steps):
    # Use batch
    batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter))

    if (step % config.training.snapshot_freq == 0 or step == num_train_steps):
      if (not config.training.zero_snapshot) and step == 0:
        pass
      else:
        # Generate and save one batch of samples
        if config.training.snapshot_sampling:
          rng, *sample_rng = jax.random.split(rng, jax.local_device_count() + 1)
          sample_rng = jnp.asarray(sample_rng)
          (sample, init_noise), _ = sampling_fn(sample_rng, pstate)
          image_grid = jnp.reshape(sample, (-1, *sample.shape[-3:]))

          # Draw snapshot figure
          this_sample_dir = os.path.join(sample_dir)
          tf.io.gfile.makedirs(this_sample_dir)
          utils.draw_figure_grid(image_grid[0:64], this_sample_dir, f"sample_{step}")
        
          # Get statistics
          stats = utils.get_samples_and_statistics(config, rng, sampling_fn, pstate, this_sample_dir, sampling_shape, mode='train', current_step=step)
          logging.info(f"FID = {stats['fid']}")
          logging.info(f"KID = {stats['kid']}")
          logging.info(f"Inception_score = {stats['is']}")
          logging.info(f"NFE (Number of function evaluations) = {stats['nfe']}")
          logging.info(f"straightness = {stats['straightness']}")
          logging.info(f"straightness_x = {stats['x_straightness']}")
          wandb_statistics_dict = {
            'fid': float(stats['fid']),
            'kid': float(stats['kid']),
            'inception_score': float(stats['is']),
            'nfe': float(stats['nfe']),
            'step': int(step),
            'n_data': int(config.training.snapshot_fid_sample),
            'x_straightness': float(stats['x_straightness']),
            'straighteness': float(stats['straightness'])
          }
          if config.model.aug_dim > 0:
            logging.info(f"straightness_y = {stats['y_straightness']}")
            wandb_statistics_dict['y_straightness'] = float(stats['y_straightness'])
          wandb.log(wandb_statistics_dict, step=step)
    # ====================================================================================================== #
    rng, step_rng = jax.random.split(rng)
    batch_noise = jax.random.normal(step_rng, batch['image'].shape)
=======
>>>>>>> 36e81044525741a90d4dfb59822b4d823641a2e0

  if config.model.name != 'egnn':
    # ====================================================================================================== #
    # Main training or generation part
    for step in range(initial_step, num_train_steps + 1, config.training.n_jitted_steps):
      # Use batch
      batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter))

      if (step % config.training.snapshot_freq == 0 or step == num_train_steps):
        if (not config.training.zero_snapshot) and step == 0:
          pass
        else:
          # Generate and save one batch of samples
          if config.training.snapshot_sampling:
            rng, *sample_rng = jax.random.split(rng, jax.local_device_count() + 1)
            sample_rng = jnp.asarray(sample_rng)
            (sample, init_noise), _ = sampling_fn(sample_rng, pstate)
            image_grid = jnp.reshape(sample, (-1, *sample.shape[-3:]))

            # Draw snapshot figure
            this_sample_dir = os.path.join(sample_dir)
            tf.io.gfile.makedirs(this_sample_dir)
            utils.draw_figure_grid(image_grid[0:64], this_sample_dir, f"sample_{step}")
          
            # Get statistics
            stats = utils.get_samples_and_statistics(config, rng, sampling_fn, pstate, this_sample_dir, sampling_shape, mode='train', current_step=step)
            logging.info(f"FID = {stats['fid']}")
            logging.info(f"KID = {stats['kid']}")
            logging.info(f"Inception_score = {stats['is']}")
            logging.info(f"NFE (Number of function evaluations) = {stats['nfe']}")
            logging.info(f"straightness = {stats['straightness']}")
            logging.info(f"straightness_x = {stats['x_straightness']}")
            logging.info(f"straightness_y = {stats['y_straightness']}")
            wandb_statistics_dict = {
              'fid': float(stats['fid']),
              'kid': float(stats['kid']),
              'inception_score': float(stats['is']),
              'nfe': float(stats['nfe']),
              'step': int(step),
              'n_data': int(config.training.snapshot_fid_sample),
              'x_straightness': float(stats['x_straightness']),
              'y_straightness': float(stats['y_straightness']),
              'straighteness': float(stats['straightness'])
            }
            wandb.log(wandb_statistics_dict, step=step)
      # ====================================================================================================== #
      rng, step_rng = jax.random.split(rng)
      batch_noise = jax.random.normal(step_rng, batch['image'].shape)

      rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
      next_rng = jnp.asarray(next_rng)
      (_, pstate), ploss = p_train_step((next_rng, pstate), (batch_noise, batch['image']))

      # Calculate loss and save
      loss = flax.jax_utils.unreplicate(ploss).mean()
      wandb_log_dict = {'train/loss': float(loss)}
      # Log to console, file and tensorboard on host 0
      if jax.process_index() == 0 and step % config.training.log_freq == 0:
        logging.info("step: %d, training_loss: %.5e" % (step, loss))
        wandb.log(wandb_log_dict, step=step)

      if step % config.training.snapshot_freq_for_preemption == 0:
        # Save a temporary checkpoint to resume training after pre-emption periodically
        saved_state = flax_utils.unreplicate(pstate)
        state_dict = {
          'model': copy.deepcopy(saved_state),
        }
        save_args = flax.training.orbax_utils.save_args_from_target(state_dict)
        ckpt_meta_mgr.save(step, state_dict, save_kwargs={'save_args': save_args})
        del state_dict

      if step % config.training.snapshot_freq == 0:
        # Save a temporary checkpoint to resume training after pre-emption periodically
        saved_state = flax_utils.unreplicate(pstate)
        state_dict = {
          'model': copy.deepcopy(saved_state),
        }
        save_args = flax.training.orbax_utils.save_args_from_target(state_dict)
        ckpt_mgr.save(step, state_dict, save_kwargs={'save_args': save_args})
        del state_dict
      
      # ====================================================================================================== #
      # Report the loss on an evaluation dataset periodically only in train_baseline case
      if step % config.training.eval_freq == 0:
        eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), next(eval_iter))  # pylint: disable=protected-access
        rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
        next_rng = jnp.asarray(next_rng)
        
        # Eval loss at the baseline.
        rng, step_rng = jax.random.split(rng)
        eval_batch_noise = jax.random.normal(step_rng, eval_batch['image'].shape)
        (_, _), peval_loss = p_eval_step((next_rng, pstate), (eval_batch_noise, eval_batch['image']))

        eval_loss = flax.jax_utils.unreplicate(peval_loss).mean()
        wandb_log_dict = {'eval/loss': float(eval_loss)}
        if jax.process_index() == 0:
          logging.info("step: %d, eval_loss: %.5e" % (step, eval_loss))
          wandb.log(wandb_log_dict, step=step)
      # ====================================================================================================== #

  else: # EGNN
    for n_epochs in range(1, config.training.n_epochs + 1):
      logging.info(f'Epoch {n_epochs}')
      dataloaders, charge_scale = retrieve_dataloaders(args, config, evaluation=False, jit=True)
      step_per_epoch = len(dataloaders['train'])
      for step, batch_loader in enumerate(dataloaders['train']):
        # Get batch
        start = timeit.default_timer()

        if jnp.array(batch_loader['positions']).shape[0] < config.training.n_jitted_steps * config.training.batch_size:
          continue

        batch, _ = utils_qm9.preprocess_batch(batch_loader, config, evaluation=False, jit=True)
        n_devices, n_jit, B, N, D = batch['x'].shape
        batch = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (n_devices * n_jit * B,) + x.shape[3:]), batch) # un-pmap

        rng, step_rng = jax.random.split(rng)
        batch = utils_qm9.transform_data(step_rng, config, batch)

        # pmap and jit all of them.
        batch = batch['egnn']
        batch['edge_mask'] = jnp.reshape(batch['edge_mask'], (n_devices, n_jit, B * N * N, 1))
        batch['node_mask'] = jnp.reshape(batch['node_mask'], (n_devices, n_jit, B * N, -1))
        batch['edges'] = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (n_devices, n_jit, B * N * N)), batch['edges'])
        batch['h'] = jnp.reshape(batch['h'], (n_devices, n_jit, B * N, -1)) # h
        batch['x'] = jnp.reshape(batch['x'], (n_devices, n_jit, B * N, -1)) # x
        batch['noise_h'] = jnp.reshape(batch['noise_h'], (n_devices, n_jit, B * N, -1)) # h-noise
        batch['noise_x'] = jnp.reshape(batch['noise_x'], (n_devices, n_jit, B * N, -1)) # x-noise
        batch['t'] = jnp.reshape(batch['t'], (n_devices, n_jit, B * N, -1)) # t
        
        # Set the train step
        rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
        next_rng = jnp.asarray(next_rng)
        (_, pstate), ploss = p_train_step((next_rng, pstate), batch)

        # Calculate loss and save
        loss = flax.jax_utils.unreplicate(ploss).mean()
        wandb_log_dict = {'train/loss': float(loss)}
        # Log to console, file and tensorboard on host 0
        if jax.process_index() == 0:
          logging.info("Epoch %d\tstep %d\ttraining_loss %.5e\tTime %.3f" % (n_epochs, step + 1, loss, timeit.default_timer() - start))
          wandb.log(wandb_log_dict, step=(n_epochs - 1) * step_per_epoch + (step + 1))

      # After each epoch, save the checkpoint.
      # Save a temporary checkpoint to resume training after pre-emption periodically
      saved_state = flax_utils.unreplicate(pstate)
      state_dict = {
        'model': copy.deepcopy(saved_state),
      }
      save_args = flax.training.orbax_utils.save_args_from_target(state_dict)
      ckpt_meta_mgr.save(n_epochs, state_dict, save_kwargs={'save_args': save_args})
      del state_dict

      # Save a temporary checkpoint to resume training after pre-emption periodically
      saved_state = flax_utils.unreplicate(pstate)
      state_dict = {
        'model': copy.deepcopy(saved_state),
      }
      save_args = flax.training.orbax_utils.save_args_from_target(state_dict)
      ckpt_mgr.save(n_epochs, state_dict, save_kwargs={'save_args': save_args})
      del state_dict

      # Sample molecules and evaluate.



# Simplified CIFAR-10 sampling code
def evaluate(config, workdir, log_name, eval_folder="eval"):
  """Evaluate trained models.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints.
    eval_folder: The subfolder for storing evaluation results. Default to
      "eval".
  """
  # ====================================================================================================== #
  # Get logger
  jax_smi.initialise_tracking()

  # wandb_dir: Directory of wandb summaries
  current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
  if log_name is None:
    wandb.init(project="anonymous-repo", name=f"{config.model.name}-{current_time}", entity="anonymous", resume="allow")
  else:
    wandb.init(project="anonymous-repo", name=log_name, entity="anonymous", resume="allow")
  wandb_dir = os.path.join(workdir, "wandb")
  tf.io.gfile.makedirs(wandb_dir)
  wandb.config = config

  # Create directory to eval_folder
  eval_dir = os.path.join(workdir, eval_folder)
  tf.io.gfile.makedirs(eval_dir)
  rng = jax.random.PRNGKey(config.seed)

  # Create data normalizer and its inverse
  scaler = datasets.get_data_scaler(config)
  inverse_scaler = datasets.get_data_inverse_scaler(config)
  # ====================================================================================================== #
  # Initialize model
  rng, step_rng = jax.random.split(rng)
  state = mutils.init_train_state(step_rng, config)

  # Generate placeholder.
  state_dict = {
    'model': copy.deepcopy(state),
  }
  checkpoint_dir = os.path.join(workdir, "checkpoints")
  # ====================================================================================================== #
  # Setup SDEs
  if config.training.sde.lower() == 'rfsde':
    sde = sde_lib.RFSDE(N=config.eval.num_scales)
    sampling_eps = 1e-3 # Not used.
  else:
    raise NotImplementedError(f"SDE {config.training.sde} unknown.")
  # ====================================================================================================== #
  # Build the sampling function when sampling is enabled
  if config.eval.enable_sampling:
    sampling_shape = (config.eval.batch_size // jax.local_device_count(),
                      config.data.image_size, config.data.image_size,
                      config.data.num_channels)
    sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)
  # ====================================================================================================== #
  # Add additional task for evaluation (for example, get gradient statistics) here.
  # ====================================================================================================== #
  # Create different random states for different hosts in a multi-host environment (e.g., TPU pods)
  state_dict = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(checkpoint_dir, f"{config.eval.begin_step}", "default"), target=state_dict)
  state = state_dict['model']
  rng = jax.random.fold_in(rng, jax.process_index())
  pstate = flax.jax_utils.replicate(state)
  # ====================================================================================================== #
  # Generate samples and compute IS/FID/KID when enabled
  state = jax.device_put(state)
  # Run sample generation for multiple rounds to create enough samples
  # Designed to be pre-emption safe. Automatically resumes when interrupted
  if jax.process_index() == 0:
    logging.info("Sampling -- baseline")
  this_sample_dir = os.path.join(eval_dir, f"baseline_host_{jax.process_index()}")
  stats = utils.get_samples_and_statistics(config, rng, sampling_fn, pstate, this_sample_dir, sampling_shape, mode='eval')
  logging.info(f"FID = {stats['fid']}")
  logging.info(f"KID = {stats['kid']}")
  logging.info(f"Inception_score = {stats['is']}")
  # wandb_statistics_dict = {
  #   'fid': float(stats['fid']),
  #   'kid': float(stats['kid']),
  #   'inception_score': float(stats['is']),
  #   'sample': wandb.Image(os.path.join(this_sample_dir, "sample.png")),
  #   'nfe': stats['nfe']
  # }

  # logging.info(f"straightness = {stats['straightness']['straightness']}")
  # logging.info(f"sequential straightness = {stats['straightness']['seq_straightness']}")

  # wandb.log(wandb_statistics_dict, step=0)


# TODO
# def evaluate(config, workdir, log_name, eval_folder="eval"):
#   """Evaluate trained models.

#   Args:
#     config: Configuration to use.
#     workdir: Working directory for checkpoints.
#     eval_folder: The subfolder for storing evaluation results. Default to
#       "eval".
#   """
#   # ====================================================================================================== #
#   # Get logger
#   jax_smi.initialise_tracking()

#   # wandb_dir: Directory of wandb summaries
#   current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
#   if log_name is None:
#     wandb.init(project="seq-rf", name=f"{config.model.name}-{current_time}", entity="anonymous", resume="allow")
#   else:
#     wandb.init(project="seq-rf", name=log_name, entity="anonymous", resume="allow")
#   wandb_dir = os.path.join(workdir, "wandb")
#   tf.io.gfile.makedirs(wandb_dir)
#   wandb.config = config

#   # Create directory to eval_folder
#   eval_dir = os.path.join(workdir, eval_folder)
#   tf.io.gfile.makedirs(eval_dir)
#   rng = jax.random.PRNGKey(config.seed)

#   # Create data normalizer and its inverse
#   scaler = datasets.get_data_scaler(config)
#   inverse_scaler = datasets.get_data_inverse_scaler(config)
#   # ====================================================================================================== #
#   # Get modes
#   """
#     eval_mode
#       'eval_baseline': Evaluate baseline model. (Original RF)
#       'eval_reflow': Evaluate reflow (n-SeqRF) model.
#       'eval_distill': Evaluate distill (n-SeqRF-distill) model.
#       'eval_rf_distill': distill from baseline reflow model.
#   """
#   if config.model.rf_phase == 1:
#     eval_mode = 'eval_baseline'
#   else:
#     eval_mode = config.eval.reflow_mode
#     assert eval_mode in ['eval_reflow', 'eval_distill', 'eval_rf_distill']
#     if eval_mode == 'eval_rf_distill':
#       assert config.eval.num_scales % config.training.reflow_t == 0
#   # ====================================================================================================== #
#   # Initialize model
#   rng, step_rng = jax.random.split(rng)
#   state = mutils.init_train_state(step_rng, config)

#   # Generate placeholder.
#   state_dict = {
#     'model': copy.deepcopy(state),
#   }

#   if eval_mode == 'eval_reflow':
#     checkpoint_dir = os.path.join(workdir, f"{config.model.rf_phase}_rf", "checkpoints")
#   elif eval_mode == 'eval_distill':
#     checkpoint_dir = os.path.join(workdir, f"{config.model.rf_phase}_distill", "checkpoints")
#   elif eval_mode == 'eval_baseline':
#     # checkpoint_dir will not be used
#     checkpoint_dir = None
#   elif eval_mode == 'eval_rf_distill':
#     checkpoint_dir = os.path.join(workdir, f"{config.model.rf_phase}_{config.training.reflow_t}rf_{config.eval.num_scales}distill", "checkpoints")
#   else:
#     raise NotImplementedError()
#   # ====================================================================================================== #
#   # Setup SDEs
#   if config.training.sde.lower() == 'rfsde':
#     sde = sde_lib.RFSDE(N=config.eval.num_scales)
#     sampling_eps = 1e-3 # Not used.
#   else:
#     raise NotImplementedError(f"SDE {config.training.sde} unknown.")
#   # ====================================================================================================== #
#   # Build the sampling function when sampling is enabled
#   if config.eval.enable_sampling:
#     sampling_shape = (config.eval.batch_size // jax.local_device_count(),
#                       config.data.image_size, config.data.image_size,
#                       config.data.num_channels)
#     sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)
#   # ====================================================================================================== #
#   # Add additional task for evaluation (for example, get gradient statistics) here.
#   # ====================================================================================================== #
#   # Create different random states for different hosts in a multi-host environment (e.g., TPU pods)
#   rng = jax.random.fold_in(rng, jax.process_index())

#   # TODO: Custom experiment. (Playground)
#   if config.eval.custom:
#     rng, *sample_rng = jax.random.split(rng, jax.local_device_count() + 1)
#     sample_rng = jnp.asarray(sample_rng) # use the same seed
#     sampling_shape = (config.eval.batch_size // jax.local_device_count(),
#                       config.data.image_size, config.data.image_size,
#                       config.data.num_channels)
#     if eval_mode == 'eval_baseline':
#       logging.info("Import from pretrained pytorch model.")
#       assert config.training.import_torch != "none"
#       state = mutils.torch_to_flax_ckpt(config.training.import_torch, state, config.training.reflow_t,
#                                         config.model.initial_count, config.model.embedding_type)
#     else:
#       state_dict = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(checkpoint_dir, f"100000", "default"), target=state_dict)
#       state = state_dict['model']
#     pstate = flax.jax_utils.replicate(state)
#     logging.info("Evaluate global truncation error, compared to 480-step Euler solver.")
#     config.eval.save_trajectory = True

#     nfe_set = [480, 240, 120, 48, 24, 12, 8, 6, 4, 2]
#     for nfe in nfe_set:
#       max_nfe = nfe_set[0]
#       # for nfe in [12]:
#       if nfe % config.training.reflow_t != 0:
#         continue

#       logging.info(f"Run for NFE={nfe}.")
#       sde = sde_lib.RFSDE(N=nfe)
#       config.eval.num_scales = nfe
#       sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)
#       (samples, z), _, straightness = sampling_fn(sample_rng, pstate) # fix sample_rng
#       samples = samples.reshape((-1, *samples.shape[-3:]))
#       utils.draw_figure_grid(samples[0:64], os.path.join("test"), f"{nfe}")
#       np.save(f"test/{nfe}.npy", np.array(straightness))
#       logging.info(f"{nfe} done.")
#       # exit()
#       # sample_dict[nfe] = samples
#       # trajectory_dict[nfe] = straightness['all_trajectory']

#       # sample_dict[nfe] = np.array(jnp.reshape(sample_dict[nfe], (128, 32, 32, 3)))
#       # trajectory_dict[nfe] = np.array(jnp.reshape(jnp.transpose(trajectory_dict[nfe], (1, 0, 2, 3, 4, 5)), (nfe, 128, 32, 32, 3)))
#       # del samples, straightness

#       # if nfe < max_nfe:
#       #   total_diff = sample_dict[nfe] - sample_dict[max_nfe]
#       #   assert trajectory_dict[nfe].shape[0] == nfe
#       #   assert trajectory_dict[max_nfe].shape[0] == max_nfe

#       #   current_diff = []
#       #   total_gap = np.zeros((nfe,))
#       #   for i in range(nfe):
#       #     current_gap = trajectory_dict[nfe][i + 1 - 1] - trajectory_dict[max_nfe][int(max_nfe // nfe) * (i + 1) - 1] # (128, 32, 32, 3)
#       #     current_gap = np.mean(np.sqrt(np.mean(current_gap ** 2, (1, 2, 3))))
#       #     total_gap[i] = current_gap
#       #   print(total_gap)
#       #   np.save(f"test/{nfe}.npy", total_gap)
      
#       #   current_diff = jnp.concatenate([jnp.expand_dims(c, 0) for c in current_diff], axis=0)
#       #   print(current_diff.shape)
#       #   exit()
      
#   elif config.eval.custom_two:
#     """
#     Two experiments
#     (1) Verify variance reduction effect.
#     (2) Lipschitz constant effect.

#     How to do?
#     (1) for each t in [0, 1], sample (x_t, t) randomly from the training set.
#     (2) Run the neural network by value_and_grad.
#     (3) Calculate the square norm of the gradient.
#     (4) Calculate the average flow matching loss.
#      --> Multiply (3) and (4) to obtain the upper bound of variance.
#     (5) Save the flow function f(x_t, t).
     
#     (5) slightly (and randomly) perturb x_t and get (x_t', t)
#     (6) Obtain the flow function f(x_t', t) of the perturbed data.
#      --> Get average, 99% highest and maximum of M(t) w.r.t. time t, and save this.
#     """
#     if eval_mode == 'eval_baseline':
#       # cast torch checkpoint if necessary.
#       logging.info("Import from pretrained pytorch model.")
#       assert config.training.import_torch != "none"
#       state = mutils.torch_to_flax_ckpt(config.training.import_torch, state, config.training.reflow_t,
#                                         config.model.initial_count, config.model.embedding_type)
#       pstate = flax.jax_utils.replicate(state)
#     elif eval_mode == 'eval_reflow': # we always use 100000 step in this playground.
#       logging.info(f"Restore checkpoint from step 100000.")
#       state_dict = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(checkpoint_dir, f"100000", "default"), target=state_dict)
#       state = state_dict['model']
#       pstate = flax.jax_utils.replicate(state)
#     else:
#       raise ValueError("eval_mode should be eval_baseline or eval_reflow.")


#     _, eval_ds, _ = datasets.get_dataset(config,
#                                          additional_dim=config.training.n_jitted_steps,
#                                          uniform_dequantization=False,
#                                          gen_reflow=False)
#     eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types

#     # Now we imported state.
#     t_ls = np.linspace(0.05, 0.95, 19)
#     n_batch = 1
#     loss_jcfm = np.zeros_like(t_ls)
#     grad_sq_norm_arr = np.zeros_like(t_ls)
#     avg_lipschitzness = np.zeros_like(t_ls)
#     top_99_lipschitzness = np.zeros_like(t_ls)
#     max_lipschitzness = np.zeros_like(t_ls)


    
#     # Eval loss at the baseline.

#     for i in range(len(t_ls)):
#       t=t_ls[i]
#       eval_step_fn = losses.get_step_fn_playground(sde, state, train=False, reflow_t=config.training.reflow_t, t=t)
#       p_step = jax.pmap(functools.partial(jax.lax.scan, eval_step_fn), axis_name='batch', donate_argnums=1)
#       grad_sq_norm_sum = 0
#       lipschitz_list = []
#       for idx in range(n_batch):
#         eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), next(eval_iter))
#         rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
#         next_rng = jnp.asarray(next_rng)
        
#         batch = eval_batch['image']
#         # (_, _), peval_loss = p_eval_step((next_rng, pstate), eval_batch)
#         (_, _), (lipschitz, grad_sq_norm) = p_step((next_rng, pstate), batch)
#         # print(lipschitz.shape) # (8, 5, 16)
#         # print(grad_sq_norm.shape) # (8, 5)
#         grad_sq_norm_sum += jnp.sum(grad_sq_norm)
#         lipschitz_list.append(jnp.expand_dims(lipschitz, 0))

#         logging.info(f"{idx + 1} step done.")
#       lipschitz_list = np.reshape(jnp.concatenate(lipschitz_list, 0), (-1,))
#       grad_sq_norm_sum /= n_batch
#       grad_sq_norm_arr[i] = np.array(grad_sq_norm_sum)
#       avg_lipschitzness[i] = np.sum(lipschitz_list) / (40 * n_batch)
#       max_lipschitzness[i] = np.max(lipschitz_list)
#       print(t, grad_sq_norm_arr[i], avg_lipschitzness[i], max_lipschitzness[i], flush=True)
#     exit()

#   elif eval_mode == 'eval_baseline':
#     # cast torch checkpoint if necessary.
#     logging.info("Import from pretrained pytorch model.")
#     assert config.training.import_torch != "none"
#     state = mutils.torch_to_flax_ckpt(config.training.import_torch, state, config.training.reflow_t,
#                                       config.model.initial_count, config.model.embedding_type)
#     pstate = flax.jax_utils.replicate(state)
#     # ====================================================================================================== #
#     # Generate samples and compute IS/FID/KID when enabled
#     if config.eval.enable_sampling:
#       state = jax.device_put(state)
#       # Run sample generation for multiple rounds to create enough samples
#       # Designed to be pre-emption safe. Automatically resumes when interrupted
#       if jax.process_index() == 0:
#         logging.info("Sampling -- baseline")
#       this_sample_dir = os.path.join(eval_dir, f"baseline_host_{jax.process_index()}")
#       stats = utils.get_samples_and_statistics(config, rng, sampling_fn, pstate, this_sample_dir, sampling_shape, mode='eval', current_step=0)
#       logging.info(f"FID = {stats['fid']}")
#       logging.info(f"KID = {stats['kid']}")
#       logging.info(f"Inception_score = {stats['is']}")
#       wandb_statistics_dict = {
#         'fid': float(stats['fid']),
#         'kid': float(stats['kid']),
#         'inception_score': float(stats['is']),
#         'sample': wandb.Image(os.path.join(this_sample_dir, "sample.png")),
#         'nfe': stats['nfe']
#       }

#       logging.info(f"straightness = {stats['straightness']['straightness']}")
#       logging.info(f"sequential straightness = {stats['straightness']['seq_straightness']}")

#       wandb.log(wandb_statistics_dict, step=0)
#     # ====================================================================================================== #
#   elif eval_mode in ['eval_distill', 'eval_reflow', 'eval_rf_distill']:
#     for ckpt in range(config.eval.begin_step, config.eval.end_step + 1, config.eval.interval_step):
#       logging.info(f"Restore checkpoint from step {ckpt}.")
#       state_dict = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(checkpoint_dir, f"{ckpt}", "default"), target=state_dict)
#       state = state_dict['model']
#       pstate = flax.jax_utils.replicate(state)
#       # ====================================================================================================== #
#       # Generate samples and compute IS/FID/KID when enabled
#       if config.eval.enable_sampling:
#         state = jax.device_put(state)
#         # Run sample generation for multiple rounds to create enough samples
#         # Designed to be pre-emption safe. Automatically resumes when interrupted
#         if jax.process_index() == 0:
#           logging.info("Sampling -- checkpoint step: %d" % (ckpt,))
#         this_sample_dir = os.path.join(
#           eval_dir, f"step_{ckpt}_host_{jax.process_index()}")
        
#         stats = utils.get_samples_and_statistics(config, rng, sampling_fn, pstate, this_sample_dir, sampling_shape,
#                                                  mode='eval', current_step=ckpt)
#         straightness_dir = os.path.join(this_sample_dir, "straightness")
#         logging.info(f"FID = {stats['fid']}")
#         logging.info(f"KID = {stats['kid']}")
#         logging.info(f"Inception_score = {stats['is']}")
#         logging.info(f"NFE (Number of function evaluations) = {stats['nfe']}")
#         logging.info(f"straightness = {stats['straightness']['straightness']}")
#         logging.info(f"sequential straightness = {stats['straightness']['seq_straightness']}")
#         wandb_statistics_dict = {
#           'fid': float(stats['fid']),
#           'kid': float(stats['kid']),
#           'inception_score': float(stats['is']),
#           'sample': wandb.Image(os.path.join(this_sample_dir, "sample.png")),
#           'nfe': float(stats['nfe']),
#           'straightness': float(stats['straightness']['straightness']),
#           'seq_straightness': float(stats['straightness']['seq_straightness']),
#           'str_fig': wandb.Image(os.path.join(straightness_dir, f'straightness_{ckpt}.png')),
#           'seq_str_fig': wandb.Image(os.path.join(straightness_dir, f'seq_straightness_{ckpt}.png')),
#           'step': int(ckpt),
#           'n_data': int(config.eval.num_samples),
#         }
#         wandb.log(wandb_statistics_dict, step=ckpt)
#       # ====================================================================================================== #
#   else:
#     raise NotImplementedError("TODO: eval_ada_distill, eval_ada_reflow, rf_distill")
#   # ====================================================================================================== #
