"""Train score model in learned mirror space."""

import logging
import os
import time

from absl import app
from absl import flags
import flax
from flax.training.early_stopping import EarlyStopping
from flax.training import checkpoints
import jax
from ml_collections.config_flags import config_flags
import numpy as np
import tensorflow as tf
tf.config.experimental.set_visible_devices([], 'GPU')  # use CPU-only

import losses
import model_utils as lmm_mutils
from score_flow import datasets
from score_flow import losses as score_losses
from score_flow import sampling
from score_flow import utils
from score_flow.models import utils as score_mutils
from score_flow.models import ddpm, ncsnpp, ncsnv2  # pylint: disable=unused-import, g-multiple-import 


_SCORE_CONFIG = config_flags.DEFINE_config_file(
  'score_config', None, 'Score-model config.')
_LMM_CONFIG = config_flags.DEFINE_config_file(
  'lmm_config', None, 'Learned mirror map (LMM) config.')
_WORKDIR = flags.DEFINE_string(
  'workdir', None,
  'Base working directory that includes lmm/ subdirectory.')
_LMM_CKPT = flags.DEFINE_string(
  'lmm_ckpt', None,
  'Base name of LMM checkpoint path in lmm/checkpoints. '
  'If None, the latest checkpoint is used.')
_FINETUNE = flags.DEFINE_bool(
  'finetune', False, 'Whether this is a fine-tuning run.')
_MAX_SIGMA = flags.DEFINE_float('max_sigma', 0.1, 'Max sigma used for finetuning.')
_UPDATE_FWD = flags.DEFINE_bool('update_fwd', False, 'Whether to update forward mirror map during finetuning.')
_FINETUNE_FROM_CKPT = flags.DEFINE_string(
  'finetune_from_ckpt', '',
  'Checkpoint to start fine-tuning from.')


def main(_):
  score_config = _SCORE_CONFIG.value
  lmm_config = _LMM_CONFIG.value

  # Whether to convert Kolmogorov representation from volume to image.
  if ('kolmogorov' in lmm_config.constraint.type
      and lmm_config.data.kolmogorov_representation == 'image'):
    volume_to_image = True
    if utils.is_coordinator():
      logging.info('Using image representation of Kolmogorov data for forward mirror map')
  else:
    volume_to_image = False

  # Copy certain values from LMM config to score-model config.
  score_config.constraint = lmm_config.constraint
  score_config.data.num_kolmogorov_states = lmm_config.data.num_kolmogorov_states
  score_config.data.num_kolmogorov_states_per_row = lmm_config.data.num_kolmogorov_states_per_row

  # Copy certain values from score-model config to LMM config.
  lmm_config.data.height = score_config.data.height
  lmm_config.data.width = score_config.data.width
  lmm_config.data.num_channels = score_config.data.num_channels

  workdir = os.path.join(_WORKDIR.value, 'mdm')
  lmm_ckpt_path = os.path.join(_WORKDIR.value, 'lmm', 'checkpoints')
  if _LMM_CKPT.value is not None:
    lmm_ckpt_path = os.path.join(lmm_ckpt_path, _LMM_CKPT.value)

  # Create working directory and its subdirectories.
  ckpt_dir = os.path.join(workdir, 'checkpoints')
  progress_dir = os.path.join(workdir, 'progress')
  tf.io.gfile.makedirs(ckpt_dir)
  tf.io.gfile.makedirs(progress_dir)

  if utils.is_coordinator():
    # Save config.
    with tf.io.gfile.GFile(os.path.join(workdir, 'config.txt'), 'w') as f:
      f.write(str(score_config))

  # Get data.
  score_config.constraint = lmm_config.constraint
  train_ds, eval_ds, _ = datasets.get_dataset(
    score_config, additional_dim=None,
    uniform_dequantization=score_config.data.uniform_dequantization)

  # Get LMM model and checkpoint.
  lmm_model = lmm_mutils.get_model(lmm_config, volume_to_image=volume_to_image)
  lmm_state = lmm_mutils.init_state(lmm_config, lmm_model, volume_to_image=volume_to_image)
  lmm_state, _, _ = lmm_mutils.init_optimizer(lmm_config, lmm_state)
  lmm_state = checkpoints.restore_checkpoint(lmm_ckpt_path, lmm_state)
  if utils.is_coordinator():
    # Save LMM checkpoint.
    checkpoints.save_checkpoint(
      workdir,
      lmm_state,
      step=lmm_state.epoch,
      prefix='lmm_checkpoint_',
      overwrite=True)

  # Initialize score model and training state.
  rng = jax.random.PRNGKey(score_config.seed)
  rng, step_rng = jax.random.split(rng)
  score_model, score_init_model_state, score_init_params = score_mutils.init_model(
    step_rng, score_config)
  score_tx = score_losses.get_optimizer(score_config)
  score_opt_state = score_tx.init(score_init_params)
  score_state = score_mutils.State(
    step=0,
    epoch=0,
    model_state=score_init_model_state,
    opt_state=score_opt_state,
    ema_rate=score_config.model.ema_rate,
    params=score_init_params,
    params_ema=score_init_params,
    rng=rng)
  if _FINETUNE.value:
    # Load pretrained weights of score model.
    if _UPDATE_FWD.value:
      pretrained_ckpt_dir = ckpt_dir.replace(f'/finetune_maxsig={_MAX_SIGMA.value}', '')
    else:
      pretrained_ckpt_dir = ckpt_dir.replace(f'/finetune_no_fwd_update_maxsig={_MAX_SIGMA.value}', '')
    score_state = checkpoints.restore_checkpoint(
      os.path.join(pretrained_ckpt_dir, _FINETUNE_FROM_CKPT.value),
      score_state)
    if score_state.step == 0:
      raise ValueError('Pretrained MDM checkpoint not found')
    # score_state = score_state.replace(step=0, epoch=0)

  # Get SDE.
  sde, t0_eps = utils.get_sde(score_config)

  # Build sampling function.
  image_shape = (score_config.data.height, score_config.data.width, score_config.data.num_channels)
  input_shape = (score_config.training.batch_size // jax.local_device_count(), *image_shape)
  inverse_scaler = datasets.get_data_inverse_scaler(score_config)
  sde_sampling_fn = sampling.get_sampling_fn_without_pmap(
      score_config, sde, score_model, input_shape, inverse_scaler, t0_eps)
  
  # Build training function.
  score_optimize_fn = score_losses.optimization_manager(score_config)
  score_step_fn = losses.get_score_step_fn(
    sde, score_model, score_tx, lmm_model, lmm_state,
    train=True,
    optimize_fn=score_optimize_fn,
    reduce_mean=score_config.training.reduce_mean,
    continuous=score_config.training.continuous,
    likelihood_weighting=score_config.training.likelihood_weighting)
  score_pstep_fn = jax.pmap(score_step_fn, axis_name='batch', donate_argnums=1)

  # Get eval function.
  score_eval_fn = losses.get_score_step_fn(
    sde, score_model, score_tx, lmm_model, lmm_state,
    train=False,
    optimize_fn=score_optimize_fn,
    reduce_mean=score_config.training.reduce_mean,
    continuous=score_config.training.continuous,
    likelihood_weighting=score_config.training.likelihood_weighting)
  score_peval_fn = jax.pmap(score_eval_fn, axis_name='batch', donate_argnums=1)

  # Get sampling function.
  mdm_sampling_fn = losses.get_mdm_sampling_fn(
    lmm_model,
    sde_sampling_fn,
    apply_inverse=True)
  mdm_psampling_fn = jax.pmap(mdm_sampling_fn)

  # Load checkpoint.
  if len(os.listdir(ckpt_dir)) > 0:
    score_state = checkpoints.restore_checkpoint(ckpt_dir, score_state)
  logging.info('Starting training at epoch %d (step %d)', score_state.epoch, score_state.step)
  if os.path.exists(os.path.join(progress_dir, 'losses_score.npy')):
    epoch_times = list(np.load(os.path.join(progress_dir, 'epoch_times.npy')))
    losses_score = list(np.load(os.path.join(progress_dir, 'losses_score.npy')))
    losses_val = list(np.load(os.path.join(progress_dir, 'losses_score_val.npy')))
  else:
    epoch_times = []
    losses_score, losses_val = [], []

  # Replicate training state to run on multiple devices.
  score_pstate = flax.jax_utils.replicate(score_state)
  lmm_pstate = flax.jax_utils.replicate(lmm_state)

  # Check data constraint.
  image_shape = (score_config.data.height, score_config.data.width, score_config.data.num_channels)
  batch = next(iter(train_ds))['image']._numpy().reshape(-1, *image_shape)
  constraint_losses_fn = losses.get_constraint_losses_fn(lmm_config)
  constraint_losses = constraint_losses_fn(batch)
  logging.info('Constraint losses: %s', constraint_losses)
  if lmm_config.constraint.type != 'count':
    assert(np.allclose(constraint_losses, np.zeros_like(constraint_losses), atol=1e-2))

  # TODO: necessary to initialize `early_stop` based on checkpoint?
  early_stop = EarlyStopping(min_delta=1e-3, patience=2)
  saved_converged_ckpt = False

  # Create different random states for different processes in a
  # multi-host environment (e.g., TPU pods).
  rng = jax.random.fold_in(score_state.rng, jax.process_index())

  for epoch in range(score_state.epoch, score_config.training.n_epochs):
    # Training.
    epoch_losses = []
    epoch_time = 0
    for step, item in enumerate(train_ds):
      s = time.perf_counter()

      batch = item['image']._numpy()

      rng, step_rngs = utils.psplit(rng)
      (_, score_pstate), ploss, x_fwd = score_pstep_fn(
        (step_rngs, score_pstate), batch)
      
      loss = flax.jax_utils.unreplicate(ploss).mean()

      t = time.perf_counter() - s
      epoch_time += t

      epoch_losses.append(loss)

      if ((step + 1) % score_config.training.log_freq == 0) and utils.is_coordinator():
        logging.info('[epoch %03d, step %03d] %.3f sec; training loss: %.5e', epoch, step + 1, t, loss)

    # Update training curve.
    epoch_times.append(epoch_time)
    losses_score.append(np.mean(epoch_losses))

    # Update early stopping.
    _, early_stop = early_stop.update(np.mean(epoch_losses))

    # Validataion.
    epoch_val_losses = []
    for step, item in enumerate(eval_ds):
      s = time.perf_counter()

      val_batch = item['image']._numpy()

      rng, eval_rngs = utils.psplit(rng)
      peval_loss = score_peval_fn((eval_rngs, score_pstate), val_batch)

      eval_loss = flax.jax_utils.unreplicate(peval_loss).mean()
      epoch_val_losses.append(eval_loss)

      if ((step + 1) % score_config.training.log_freq == 0) and utils.is_coordinator():
        t = time.perf_counter() - s
        logging.info('[epoch %03d, step %03d] %.3f sec; val loss: %.5e', epoch, step + 1, t, eval_loss)

    # Update validation curve.
    losses_val.append(np.mean(epoch_val_losses))

    # Save progress snapshot.
    if ((epoch + 1) % score_config.training.snapshot_epoch_freq == 0
        and utils.is_coordinator()):
      score_state = flax.jax_utils.unreplicate(score_pstate)
      # Save stats.
      np.save(os.path.join(progress_dir, 'epoch_times.npy'), epoch_times)
      np.save(os.path.join(progress_dir, 'losses_score.npy'), losses_score)
      np.save(os.path.join(progress_dir, 'losses_score_val.npy'), losses_val)

    # Save checkpoint.
    if ((epoch + 1) % score_config.training.ckpt_epoch_freq == 0
        and utils.is_coordinator()):
      score_state = flax.jax_utils.unreplicate(score_pstate)
      score_state = score_state.replace(epoch=epoch + 1)
      checkpoints.save_checkpoint(
        ckpt_dir,
        score_state,
        step=epoch + 1,
        keep=np.inf)

    if early_stop.should_stop and not saved_converged_ckpt and utils.is_coordinator():
      logging.info('Met convergence criterion at epoch %d', epoch)
      score_state = flax.jax_utils.unreplicate(score_pstate)
      score_state = score_state.replace(epoch=epoch + 1)
      checkpoints.save_checkpoint(
        ckpt_dir,
        score_state,
        step=epoch + 1,
        keep=1,
        prefix='converged_checkpoint_')
      saved_converged_ckpt = True
      if score_config.training.early_stop:
        logging.info('Convergence criterion met. Exiting training loop')
        break


if __name__ == '__main__':
  app.run(main)