"""Pretrain RealNVP on a prior dataset."""
import datetime
import functools
import logging
import os
import time

from absl import app
from absl import flags
import flax
from flax.training import checkpoints
import jax
import jax.numpy as jnp
from ml_collections.config_flags import config_flags
import numpy as np
import optax
from score_flow import utils as sutils
import tensorflow as tf

import datasets
import utils
from posterior_sampling.losses import clip_grad
from posterior_sampling import model_utils as mutils

# DPI config.
_CONFIG = config_flags.DEFINE_config_file('config', None, 'DPI config.')
# Working directory.
_WORKDIR = flags.DEFINE_string(
    'workdir', None,
    'Directory to write checkpoints, samples, and TensorBoard logs to.',
    required=True)


def save_configs():
  """Save grad ascent and score-model configs."""
  workdir = _WORKDIR.value
  config = _CONFIG.value
  with tf.io.gfile.GFile(os.path.join(workdir, 'config.txt'), 'w') as f:
    f.write(str(config))


def get_step_fn(model, optimizer, mode='train'):
  """Train step function."""
  grad_clip = _CONFIG.value.optim.grad_clip

  @functools.partial(jax.value_and_grad, has_aux=True)
  def loss_fn(params, model_state, x):
    variables = {'params': params, **model_state}
    # Apply model.
    (z, logdet), new_model_state = model.apply(
        variables, x, reverse=False, mutable=list(model_state.keys()))
    # Compute loss.
    log_prob_loss = -logdet + 0.5 * jnp.sum(
        jnp.square(z) + jnp.log(2 * jnp.pi), -1)
    log_prob_loss = jnp.mean(log_prob_loss)
    return log_prob_loss, new_model_state

  def train_step_fn(state, x):
    (loss, new_model_state), grad = loss_fn(state.params, state.model_state, x)
    grad = jax.lax.pmean(grad, axis_name='batch')
    loss = jax.lax.pmean(loss, axis_name='batch')

    if grad_clip != -1:
      grad = clip_grad(grad, grad_clip=_CONFIG.value.optim.grad_clip)

    # Apply updates.
    updates, new_opt_state = optimizer.update(
        grad, state.opt_state, state.params)
    new_params = optax.apply_updates(state.params, updates)

    step = state.step + 1
    new_state = state.replace(
        step=step,
        opt_state=new_opt_state,
        params=new_params,
        model_state=new_model_state)

    return new_state, loss

  def eval_step_fn(state, x):
    (loss, _), _ = loss_fn(state.params, state.model_state, x)
    loss = jax.lax.pmean(loss, axis_name='batch')
    return loss

  if mode == 'train':
    return train_step_fn
  elif mode == 'eval':
    return eval_step_fn


def main(_) -> None:
  config = _CONFIG.value
  workdir = _WORKDIR.value
  ckpt_dir = os.path.join(workdir, 'checkpoints')
  sample_dir = os.path.join(_WORKDIR.value, 'samples')

  image_size = config.data.image_size
  image_shape = (image_size, image_size, config.data.num_channels)
  image_dim = np.prod(image_shape)

  if utils.is_coordinator():
    tf.io.gfile.makedirs(ckpt_dir)
    tf.io.gfile.makedirs(sample_dir)
    save_configs()
    logging.info('[INFO] local device count = %d', jax.local_device_count())
    logging.info('[INFO] device count = %d', jax.device_count())
    logging.info('[INFO] process count = %d', jax.process_count())
    # Create summary writer.
    now = datetime.datetime.now()
    tb_dir = os.path.join(
        _WORKDIR.value, 'tensorboard', now.strftime('%Y%m%d-%H%M%S'))
    writer = tf.summary.create_file_writer(tb_dir)

  # Dataset.
  train_ds, eval_ds, _ = datasets.get_dataset(config, additional_dim=None)
  train_iter = iter(train_ds)
  eval_iter = iter(eval_ds)

  # Initialize model.
  model, model_state, params = mutils.get_model_and_init_params(
      config, train=True)
  optimizer = optax.adam(learning_rate=config.optim.learning_rate)

  # Initialize state.
  opt_state = optimizer.init(params)
  state = mutils.State(
      step=0,
      opt_state=opt_state,
      params=params,
      model_state=model_state,
      rng=jax.random.PRNGKey(config.seed + 1))

  # Load checkpoint.
  state = checkpoints.restore_checkpoint(ckpt_dir, state)
  init_step = state.step
  pstate = flax.jax_utils.replicate(state)

  # Training function.
  step_fn = get_step_fn(model, optimizer, mode='train')
  p_train_step = jax.pmap(jax.jit(step_fn), axis_name='batch', donate_argnums=0)

  # Eval function.
  eval_fn = get_step_fn(model, optimizer, mode='eval')
  p_eval_step = jax.pmap(jax.jit(eval_fn), axis_name='batch')

  # Keep track of best eval loss.
  best_eval_loss = np.inf
  best_eval_ckpt_path = None

  for step in range(init_step, config.training.n_iters + 1):
    # Track the amount of time it takes per step.
    step_start_time = time.perf_counter()

    batch = next(train_iter)['image'].numpy()
    batch = batch.reshape(batch.shape[0], batch.shape[1], image_dim)

    pstate, ploss = p_train_step(pstate, batch)
    loss = flax.jax_utils.unreplicate(ploss).item()

    if (((step + 1) % config.training.log_freq == 0 or step == init_step) and
        utils.is_coordinator()):
      step_time = time.perf_counter() - step_start_time
      logging.info('step %d: %.0f seconds', step + 1, step_time)
      with writer.as_default(step=step + 1):
        tf.summary.scalar('total', loss)

    if ((step + 1) % config.training.snapshot_freq == 0 and
        utils.is_coordinator()):
      # Sample images.
      sampling_fn = mutils.get_sampling_fn(
          model,
          flax.jax_utils.unreplicate(pstate.params),
          flax.jax_utils.unreplicate(pstate.model_state),
          train=True)
      (samples, _), _ = sampling_fn(jax.random.PRNGKey((0)), (25, *image_shape))
      samples = samples.reshape(-1, *image_shape)

      # Save checkpoint.
      checkpoints.save_checkpoint(
          ckpt_dir,
          flax.jax_utils.unreplicate(pstate),
          step=step + 1,
          keep=np.inf,
      )

      # Save .npy file of all samples.
      with tf.io.gfile.GFile(
          os.path.join(sample_dir, f'x_{step + 1:06}.npy'), 'wb'
      ) as f:
        np.save(f, samples)
      # Save .png file of all samples.
      with tf.io.gfile.GFile(
          os.path.join(sample_dir, f'x_{step + 1:06}.png'), 'wb'
      ) as f:
        utils.save_image_grid(samples[:9], f, nrow=3, padding=2)

    if ((step + 1) % config.training.eval_freq == 0 or step == init_step):
      # Eval step.
      eval_batch = next(eval_iter)['image'].numpy()
      eval_batch = eval_batch.reshape(
          eval_batch.shape[0], eval_batch.shape[1], image_dim)

      # Eval step.
      peval_loss = p_eval_step(pstate, eval_batch)
      eval_loss = flax.jax_utils.unreplicate(peval_loss).item()
      if utils.is_coordinator():
        logging.info('[step %d] eval loss: %.5e', step + 1, eval_loss)
        with writer.as_default(step=step + 1):
          tf.summary.scalar('eval_loss', eval_loss)

      # if eval_loss < best_eval_loss and utils.is_coordinator():
      #   if best_eval_ckpt_path is not None:
      #     # We want to keep only the best checkpoint,
      #     # so delete the previous one.
      #     tf.io.gfile.remove(best_eval_ckpt_path)

      #   best_eval_loss = eval_loss
      #   best_eval_ckpt_path = checkpoints.save_checkpoint(
      #       ckpt_dir, flax.jax_utils.unreplicate(pstate), step=step + 1,
      #       prefix=f'best_eval_{eval_loss:.3f}_', overwrite=True)


if __name__ == '__main__':
  app.run(main)