"""Train a score model on a known Gaussian distribution."""

import functools
import logging
import os
import time

from absl import app
from absl import flags
import flax
from flax.metrics import tensorboard
from flax.training import checkpoints
import jax
from ml_collections.config_flags import config_flags
import numpy as np
from score_flow import sampling
from sklearn.decomposition import PCA
import scipy.stats
import tensorflow as tf

import datasets
import utils
import losses

_CONFIG = config_flags.DEFINE_config_file('config', None, 'Score-model config.')
_WORKDIR = flags.DEFINE_string(
    'workdir', 'score_checkpoints/', 'Base working directory.')


def save_checkpoint(ckpt_dir, pstate, rng, **kwargs):
  """Unreplicate `pstate` and save it as a checkpoint."""
  saved_state = flax.jax_utils.unreplicate(pstate)
  saved_state = saved_state.replace(rng=rng)
  path = checkpoints.save_checkpoint(ckpt_dir, saved_state, **kwargs)
  return path


def get_batch_dims(
    config, evaluation=False, additional_dim=None, device_batch=True
):
  """Return batch dimensions for one set of training data."""
  # Compute batch size for this worker.
  batch_size = (
      config.training.batch_size if not evaluation else config.eval.batch_size
  )
  if batch_size % jax.device_count() != 0:
    raise ValueError(
        f'Batch sizes ({batch_size} must be divided by '
        f'the number of devices ({jax.device_count()})'
    )

  per_device_batch_size = batch_size // jax.device_count()

  # Create additional data dimension when jitting multiple steps together
  if not device_batch:
    batch_dims = [batch_size]
  elif additional_dim is None:
    batch_dims = [jax.local_device_count(), per_device_batch_size]
  else:
    batch_dims = [
        jax.local_device_count(), additional_dim, per_device_batch_size
    ]
  return batch_dims


def main(_):
  config = _CONFIG.value
  workdir = _WORKDIR.value

  # Create working directory and its subdirectories.
  workdir = _WORKDIR.value
  ckpt_dir = os.path.join(workdir, 'checkpoints')
  tb_dir = os.path.join(workdir, 'tensorboard')
  sample_dir = os.path.join(workdir, 'samples')
  tf.io.gfile.makedirs(ckpt_dir)
  tf.io.gfile.makedirs(tb_dir)
  tf.io.gfile.makedirs(sample_dir)

  # Create TensorBoard writer.
  writer = tensorboard.SummaryWriter(tb_dir)

  if utils.is_coordinator():
    # Save config.
    with tf.io.gfile.GFile(os.path.join(workdir, 'config.txt'), 'w') as f:
      f.write(str(config))

  # Get data.
  gauss_dir = 'celeba_gauss/16x16_components=256'
  mean = np.load(os.path.join(gauss_dir, 'mean.npy'))
  components = np.load(os.path.join(gauss_dir, 'components.npy'))
  explained_variance = np.load(os.path.join(gauss_dir, 'explained_variance.npy'))
  noise_variance = np.load(os.path.join(gauss_dir, 'noise_variance.npy'))

  pca = PCA(n_components=256)
  pca.n_components_ = 256
  pca.components_ = components
  pca.explained_variance_ = explained_variance
  pca.noise_variance_ = noise_variance
  pca.mean_ = mean

  mean = pca.mean_
  cov = pca.get_covariance()
  # Pre-condition covariance matrix.
  eigvals, eigvecs = np.linalg.eigh(cov)
  eigvals += 0.01
  cov = eigvecs @ np.diag(eigvals) @ eigvecs.T

  gaussian = scipy.stats.multivariate_normal(
    mean=pca.mean_, cov=cov, seed=config.seed)

  batch_dims = get_batch_dims(
      config, additional_dim=config.training.n_jitted_steps)
  batch_shape = batch_dims + [
      config.data.image_size, config.data.image_size, config.data.num_channels]
  logging.info('[INFO] batch shape: %s', str(batch_shape))

  # Get data scalers.
  scaler = datasets.get_data_scaler(config)
  inverse_scaler = datasets.get_data_inverse_scaler(config)

  # Initialize training state.
  state, score_model, tx = utils.initialize_training_state(config)

  # Load checkpoint.
  state = checkpoints.restore_checkpoint(ckpt_dir, state)
  if utils.is_coordinator():
    logging.info('Starting training at step %d', state.step)

  # Get SDE.
  sde, t0_eps = utils.get_sde(config)

  # Build training and eval functions.
  optimize_fn = losses.optimization_manager(config)
  train_step_fn = losses.get_step_fn(
      sde,
      score_model,
      optimizer=tx,
      train=True,
      optimize_fn=optimize_fn,
      reduce_mean=config.training.reduce_mean,
      continuous=config.training.continuous,
      likelihood_weighting=config.training.likelihood_weighting)
  eval_step_fn = losses.get_step_fn(
      sde,
      score_model,
      optimizer=tx,
      train=False,
      optimize_fn=optimize_fn,
      reduce_mean=config.training.reduce_mean,
      continuous=config.training.continuous,
      likelihood_weighting=config.training.likelihood_weighting)

  # Build sampling function.
  sampling_shape = (
      int(config.training.batch_size // jax.device_count()),
      config.data.image_size, config.data.image_size,
      config.data.num_channels)
  logging.info('[INFO] sampling shape: %s', str(sampling_shape))
  sampling_fn = sampling.get_sampling_fn(
      config, sde, score_model, sampling_shape, inverse_scaler, t0_eps)

  # Pmap and JIT multiple training/eval steps together for faster running.
  p_train_step = jax.pmap(
      functools.partial(jax.lax.scan, train_step_fn), axis_name='batch',
      donate_argnums=1)
  p_eval_step = jax.pmap(
      functools.partial(jax.lax.scan, eval_step_fn), axis_name='batch',
      donate_argnums=1)

  # Replicate training state to run on multiple devices.
  pstate = flax.jax_utils.replicate(state)

  # Keep track of best eval loss.
  best_eval_loss = np.inf
  best_eval_ckpt_path = None

  init_step = state.step
  n_steps = config.training.n_iters
  n_jitted_steps = config.training.n_jitted_steps
  # Create different random states for different processes in a
  # multi-host environment (e.g., TPU pods).
  rng = jax.random.fold_in(state.rng, jax.process_index())
  # Always evaluate on the same eval batch.
  eval_rng = jax.random.PRNGKey(config.seed + 1)
  for step in range(init_step, n_steps + 1, n_jitted_steps):
    step_start_time = time.perf_counter()
    # Sample next batch of data.
    rng, sample_rng = jax.random.split(rng)
    data = gaussian.rvs(size=batch_dims)
    data = scaler(data).reshape(batch_shape)
    batch = {'image': data}

    # Train step.
    rng, step_rngs = utils.psplit(rng)
    (_, pstate), ploss = p_train_step((step_rngs, pstate), batch)
    loss = flax.jax_utils.unreplicate(ploss).mean()

    if step % config.training.log_freq == 0 and utils.is_coordinator():
      step_time = time.perf_counter() - step_start_time
      # Log training loss.
      logging.info('[step %d] training loss: %.5e', step, loss)
      logging.info('[step %d] step time: %.3f', step, step_time)
      writer.scalar('training_loss', loss, step)

    if ((step != 0 and step % config.training.snapshot_freq == 0 or
         step == n_steps) and utils.is_coordinator()):
      # Save model checkpoint.
      save_checkpoint(ckpt_dir, pstate, rng, step=step, keep=np.inf)

      # Save samples.
      rng, sample_rngs = utils.psplit(rng)
      sample, _ = sampling_fn(sample_rngs, pstate)
      image_grid = sample.reshape(-1, *sample.shape[2:])
      nrow = int(np.sqrt(image_grid.shape[0]))
      with tf.io.gfile.GFile(
          os.path.join(sample_dir, f'samples_{step}.png'), 'wb') as fout:
        utils.save_image_grid(image_grid, fout, nrow=nrow, padding=2)

    if step % config.training.eval_freq == 0:
      # Eval step. We always evaluate on the same batch of samples.
      eval_data = gaussian.rvs(batch_dims)
      eval_data = scaler(eval_data).reshape(batch_shape)
      eval_batch = {'image': eval_data}

      rng, next_rngs = utils.psplit(rng)
      (_, _), peval_loss = p_eval_step((next_rngs, pstate), eval_batch)

      eval_loss = flax.jax_utils.unreplicate(peval_loss).mean()
      if utils.is_coordinator():
        logging.info('[step %d] eval loss: %.5e', step, eval_loss)
        writer.scalar('eval_loss', eval_loss, step)

      # if eval_loss < best_eval_loss and utils.is_coordinator():
      #   if best_eval_ckpt_path is not None:
      #     # We want to keep only the best checkpoint,
      #     # so delete the previous one.
      #     gfile.Remove(best_eval_ckpt_path)

      #   best_eval_loss = eval_loss
      #   best_eval_ckpt_path = save_checkpoint(
      #       ckpt_dir, pstate, rng,
      #       step=step, prefix=f'best_eval_{eval_loss:.3f}_', overwrite=True)

if __name__ == '__main__':
  app.run(main)