import json
import pickle
import functools
from datetime import datetime
from time import time
from typing import Tuple, Any, Mapping, Iterable, Callable

import flax.typing
import jax
import jax.numpy as jnp
import numpy as np
import optax
import orbax.checkpoint
from absl import app, flags, logging
from flax.training import orbax_utils
from flax.training.train_state import TrainState
from flax.training.common_utils import shard, shard_prng_key
from flax.jax_utils import replicate, unreplicate
from matplotlib import pyplot as plt

from ol import models
from ol.dataset.dataset import Dataset, Batch, Stats
from ol.experiments import DIR_EXPERIMENTS
from ol.graph.graphbuilder import GraphBuilder
from ol.metrics import BatchMetrics, Metrics, EvalMetrics
from ol.metrics import rel_lp_loss
from ol.metrics import mse_error, rel_lp_error_mean, recall_tol_mean, chamfer_mean
from ol.models.common import AbstractOperator, Inputs
from ol.stepping import OutputStepper
from ol.utils import disable_logging, Array, Tree
from ol.test import RECALL_TOLERANCE, EXCLUDE_PERCENTILE


NUM_DEVICES = jax.local_device_count()
FLAGS = flags.FLAGS
EXCLUDE_PERCENTILE_STATS = 0.4
EXCLUDE_PERCENTILE_LOSS = 0.2

def define_flags():
  # FLAGS::general
  flags.DEFINE_string(name='exp', default='000', required=False,
    help='Name of the experiment'
  )
  flags.DEFINE_string(name='datetime', default=None, required=False,
    help='A string representing the current datetime'
  )
  flags.DEFINE_string(name='datadir', default=None, required=True,
    help='Path of the folder containing the datasets'
  )
  flags.DEFINE_string(name='datapath', default=None, required=True,
    help='Relative path inside the data directory'
  )
  flags.DEFINE_string(name='params', default=None, required=False,
    help='Path of the previous experiment containing the initial parameters'
  )
  flags.DEFINE_integer(name='seed', default=45, required=False,
    help='Seed for random number generator'
  )
  flags.DEFINE_float(name='space_downsample_factor', default=1.0, required=False,
    help='Factor for downsampling the space resolution (including the boundaries)'
  )
  flags.DEFINE_float(name='boundary_downsample_factor', default=1.0, required=False,
    help='Factor for downsampling the resolution of the boundary functions'
  )
  flags.DEFINE_boolean(name='use_extensions', default=False, required=False,
    help='Wether to use pre-computed extensions as input domain functions'
  )

  # FLAGS::training
  flags.DEFINE_integer(name='batch_size', default=2, required=False,
    help='Size of a batch of training samples per device and per optimization step (input batch size)'
  )
  flags.DEFINE_integer(name='superbatch_size', default=1, required=False,
    help='Number of batches to load and process together'
  )
  flags.DEFINE_integer(name='superbatch_repeats', default=1, required=False,
    help='Number of repetitions of each super batch before moving on to the next'
  )
  flags.DEFINE_integer(name='epochs', default=20, required=False,
    help='Number of training epochs'
  )
  flags.DEFINE_integer(name='long_epoch_gap', default=5, required=False,
    help='Number of training epochs to wait before each evaluation'
  )
  flags.DEFINE_float(name='lr_init', default=1e-05, required=False,
    help='Initial learning rate in the onecycle scheduler'
  )
  flags.DEFINE_float(name='lr_peak', default=2e-04, required=False,
    help='Peak learning rate in the onecycle scheduler'
  )
  flags.DEFINE_float(name='lr_base', default=1e-05, required=False,
    help='Final learning rate in the onecycle scheduler'
  )
  flags.DEFINE_float(name='lr_lowr', default=1e-06, required=False,
    help='Final learning rate in the exponential decay'
  )
  flags.DEFINE_float(name='gclip', default=5e-01, required=False,
    help='Norm for adaptive gradient clipping'
  )
  flags.DEFINE_float(name='weight_decay', default=1e-04, required=False,
    help='weight decay for the optimizer'
  )
  flags.DEFINE_integer(name='n_train', default=(2**4), required=False,
    help='Number of training samples'
  )
  flags.DEFINE_integer(name='n_valid', default=(2**4), required=False,
    help='Number of validation samples'
  )

  # FLAGS::core
  flags.DEFINE_string(name='core_name', default='XRIGNO', required=False,
    help='Name of the core neural operator architecture'
  )
  flags.DEFINE_float(name='pmesh_subsample_factor', default=16.0, required=False,
    help='Factor for random subsampling of the physical mesh (ignored if rmesh_gridres is not None)'
  )
  flags.DEFINE_integer(name='rmesh_gridres', default=0, required=False,
    help='Grid resolution of rmesh (overrides pmesh_subsample_factor)'
  )
  flags.DEFINE_float(name='rmesh_subsample_factor', default=1.2, required=False,
    help='Factor for random subsampling of hierarchical regional meshes'
  )
  flags.DEFINE_float(name='overlap_factor_p2r', default=1.0, required=False,
    help='Overlap factor for p2r edges (encoder)'
  )
  flags.DEFINE_float(name='overlap_factor_r2p', default=1.5, required=False,
    help='Overlap factor for r2p edges (decoder)'
  )
  flags.DEFINE_integer(name='rmesh_levels', default=40, required=False,
    help='Number of multimesh connection levels (processor)'
  )
  flags.DEFINE_integer(name='node_coordinate_freqs', default=4, required=False,
    help='Number of frequencies for encoding periodic node coordinates'
  )
  flags.DEFINE_integer(name='node_latent_size', default=128, required=False,
    help='Size of latent node features'
  )
  flags.DEFINE_integer(name='edge_latent_size', default=128, required=False,
    help='Size of latent edge features'
  )
  flags.DEFINE_integer(name='processor_steps', default=12, required=False,
    help='Number of message-passing steps in the processor'
  )
  flags.DEFINE_float(name='p_edge_masking', default=0.5, required=False,
    help='Probability for random masking of the edges'
  )

  # FLAGS::model::CrossAttentionExtender
  flags.DEFINE_boolean(name='extender', default=True, required=False,
    help='Wether to use the extender module or not'
  )
  flags.DEFINE_boolean(name='independent', default=False, required=False,
    help='Wether to extend the boundary functions independently or not'
  )
  flags.DEFINE_integer(name='ext_depth', default=4, required=False,
    help='Number of attention blocks in the extender'
  )
  flags.DEFINE_integer(name='ext_vars', default=16, required=False,
    help='Number of output domain functions by the extender'
  )
  flags.DEFINE_integer(name='ext_latent_size', default=128, required=False,
    help='Dimension of the latent features after the initial extender feed-forward blocks'
  )
  flags.DEFINE_integer(name='ext_heads', default=4, required=False,
    help='Number of attention heads in each block of the extender'
  )
  flags.DEFINE_float(name='ext_p_masking', default=0.3, required=False,
    help='Probabilty for random masking of the boundary nodes'
  )
  flags.DEFINE_float(name='ext_dropout', default=0.0, required=False,
    help='Dropout ratios in the MLPs of the extender'
  )

def train(
  rngkey: flax.typing.PRNGKey,
  model: AbstractOperator,
  state: TrainState,
  dataset: Dataset,
  graph_builder: GraphBuilder,
  epochs: int,
  loss_fn: Callable,
  epochs_before: int = 0,
) -> TrainState:
  """Trains a model and returns the state."""

  # Set constants
  num_samples_trn = (dataset.splits[0][1] - dataset.splits[0][0])
  num_times = dataset.metadata.shape[1]
  num_pnodes = dataset.metadata.shape[2]
  big_batch_size = FLAGS.superbatch_size * NUM_DEVICES * FLAGS.batch_size
  assert num_samples_trn % big_batch_size == 0

  # Store the initial time
  time_int_pre = time()

  # Set stepper
  stepper = OutputStepper(operator=model)

  # Build input and output variable statistics
  _stats = dataset.stats
  stats = {
    'x': _stats['x'],
    't': _stats['t'],
    'geo': Stats(
      mean=jnp.concatenate([_stats[key].mean for key in dataset.metadata.geo], axis=-1),
      std=jnp.concatenate([_stats[key].std for key in dataset.metadata.geo], axis=-1),
    ),
    'dom': Stats(
      mean=jnp.concatenate([_stats[key].mean for key in dataset.metadata.dom], axis=-1),
      std=jnp.concatenate([_stats[key].std for key in dataset.metadata.dom], axis=-1),
    ),
    'seg': {key: _stats[key] for key in dataset.metadata.seg},
    'out': Stats(
      mean=jnp.concatenate([_stats[key].mean for key in dataset.metadata.out], axis=-1),
      std=jnp.concatenate([_stats[key].std for key in dataset.metadata.out], axis=-1),
    ),
  }
  # Replicate state, stats, and graphs
  # NOTE: Internally uses jax.device_put_replicate
  state = replicate(state)
  stats = replicate(stats)

  @functools.partial(jax.pmap, axis_name='device')
  def _train_one_batch(rngkey: flax.typing.PRNGKey, state: TrainState, stats: dict, batch: Batch) -> Tuple[TrainState, Array, Array]:
    """Loads a batch, normalizes it, updates the state based on it, and returns it."""

    # Split the rngkey
    subrngkeys = jax.random.split(rngkey, num=3)

    # Define helper function for updating the state per sub-batch
    # NOTE: Each batch of trajectories (time-dependent) can result in multiple sub-batches of input-output pairs
    def _update_state_per_subbatch(state: TrainState,
      q: Array, u: Array, h: Array, m: Array, x: Array, t: Array, tau: Array, r: Array, g,
    ) -> Tuple[TrainState, Array, Tree]:
      # NOTE: INPUT SHAPES [superbatch_size*batch_size, ...]

      # Get random number generators
      subsubrngkey = jax.random.fold_in(subrngkeys[0], data=state.step)
      subsubsubrngkeys = jax.random.split(subsubrngkey, num=3)
      rngs = {'dropout': subsubsubrngkeys[0], 'masking': subsubsubrngkeys[1], 'other': subsubsubrngkeys[2]}

      def _compute_loss(params: flax.typing.Collection,
        q: Array, u: Array, h: Array, m: Array, x: Array, t: Array, tau: Array, r: Array, g,
      ) -> Array:
        """Computes the prediction of the model and returns its loss."""
        inputs = Inputs(s=q, a=u, q=h, m=m, x_inp=x, x_out=x, t=t, tau=tau)
        _loss_inputs = stepper.get_loss_inputs(
          variables={'params': params},
          stats=stats,
          u_tgt=r,
          inputs=inputs,
          graphs=graph_builder.build_graphs(g),
          deterministic=False,
          rngs=rngs,
        )
        return loss_fn(*_loss_inputs)

      # Get loss and gradients w.r.t. the parameters
      _loss, _grads = jax.value_and_grad(_compute_loss)(state.params, q, u, h, m, x, t, tau, r, g)
      # Synchronize loss and gradients
      loss = jax.lax.pmean(_loss, axis_name='device')
      grads = jax.lax.pmean(_grads, axis_name='device')
      # Apply gradients
      state = state.apply_gradients(grads=grads)

      return state, loss, grads

    # Shuffle along the space axis
    batch = batch.shuffled(rngkey=subrngkeys[1])

    # Prepare input-output pairs
    # -> [1, superbatch_size*batch_size, ...]
    num_valid_pairs = 1  # NOTE: Relevant for all2all training of time-dependent datasets
    s_batch = jnp.concatenate([batch.functions[key].values[None] for key in dataset.metadata.geo], axis=-1)
    a_batch = jnp.concatenate([batch.functions[key].values[None] for key in dataset.metadata.dom], axis=-1)
    q_batch = {key: batch.functions[key].values[None] for key in dataset.metadata.seg}
    m_batch = {key: batch.functions[key].mask[None] for key in dataset.metadata.seg}
    x_batch = batch.x[None]
    r_batch = jnp.concatenate([batch.functions[key].values[None] for key in dataset.metadata.out], axis=-1)
    t_batch = None
    tau_batch = None
    g_batch = jax.tree.map(lambda arr: arr[None], batch.g)

    # Split a big batch into small batches
    # -> [superbatch_size, batch_size, ...]
    # NOTE: No need to do so with all2all training
    num_valid_pairs *= FLAGS.superbatch_size
    split_big_batch = lambda t: jax.tree.map(lambda arr: jnp.reshape(arr, shape=(FLAGS.superbatch_size, FLAGS.batch_size, *arr.shape[2:])), t)
    s_batch = split_big_batch(s_batch)
    a_batch = split_big_batch(a_batch)
    q_batch = split_big_batch(q_batch)
    m_batch = split_big_batch(m_batch)
    x_batch = split_big_batch(x_batch)
    r_batch = split_big_batch(r_batch)
    g_batch = split_big_batch(g_batch)

    # Repeat the super batch in the same epoch for faster training
    # NOTE: It is vital to repeat them in tiles or randomly (so that the SGD step sees different samples every time)
    num_valid_pairs *= FLAGS.superbatch_repeats
    repeat_batches = lambda t: jax.tree.map(lambda arr: jnp.tile(arr, reps=(FLAGS.superbatch_repeats, *([1]*(arr.ndim-1)))), t)
    s_batch = repeat_batches(s_batch)
    a_batch = repeat_batches(a_batch)
    q_batch = repeat_batches(q_batch)
    m_batch = repeat_batches(m_batch)
    x_batch = repeat_batches(x_batch)
    r_batch = repeat_batches(r_batch)
    g_batch = repeat_batches(g_batch)
    # Shuffle the repeated big batch
    shuffle_batches = lambda t: jax.tree.map(lambda arr: jax.random.permutation(subrngkeys[2], arr.reshape(-1, *arr.shape[2:])).reshape(*arr.shape), t)
    s_batch = shuffle_batches(s_batch)
    a_batch = shuffle_batches(a_batch)
    q_batch = shuffle_batches(q_batch)
    m_batch = shuffle_batches(m_batch)
    x_batch = shuffle_batches(x_batch)
    r_batch = shuffle_batches(r_batch)
    g_batch = shuffle_batches(g_batch)

    # Define helper function for updating the state
    def _update_state(i, carry):
      # Update state, loss, and gradients
      _state, _loss_carried, _grads_carried = carry
      _state, _loss_subbatch, _grads_subbatch = _update_state_per_subbatch(
        state=_state,
        q=s_batch[i],
        u=a_batch[i],
        h=jax.tree.map(lambda h: h[i], q_batch),
        m=jax.tree.map(lambda m: m[i], m_batch),
        x=x_batch[i],
        t=(t_batch[i] if (t_batch is not None) else None),
        tau=(tau_batch[i] if (tau_batch is not None) else None),
        r=r_batch[i],
        g=jax.tree.map(lambda arr: arr[i], g_batch)
      )
      # Update the carried loss and gradients of the subbatch
      _loss_updated = _loss_carried + _loss_subbatch / num_valid_pairs
      _grads_updated = jax.tree.map(
        lambda g_old, g_new: (g_old + g_new / num_valid_pairs),
        _grads_carried, _grads_subbatch,
      )

      return _state, _loss_updated, _grads_updated

    # Loop over the pairs
    _init_state = state
    _init_loss = 0.
    _init_grads = jax.tree.map(lambda p: jnp.zeros_like(p), state.params)
    state, loss, grads = jax.lax.fori_loop(
      lower=0,
      upper=num_valid_pairs,
      body_fun=_update_state,
      init_val=(_init_state, _init_loss, _init_grads)
    )

    return state, loss, grads

  def train_one_epoch(rngkey: flax.typing.PRNGKey, state: TrainState, batches: Iterable[Batch]) -> Tuple[TrainState, Array, Array]:
    """Updates the state based on accumulated losses and gradients."""

    # Split the rngkey
    subrngkeys = jax.random.split(rngkey, num=1)

    # Loop over the batches
    loss_epoch = 0.
    grad_epoch = 0.
    for i, batch in enumerate(batches):
      # Split the batch between devices
      # [superbatch_size*NUM_DEVICES*batch_size, ...] -> [NUM_DEVICES, superbatch_size*batch_size, ...]
      batch = Batch(
        x=shard(batch.x),
        t=shard(batch.t),
        g=shard(batch.g),
        functions=shard(batch.functions),
      )

      # Get loss and updated state
      subsubrngkey = jax.random.fold_in(subrngkeys[0], data=i)
      subsubrngkey = shard_prng_key(subsubrngkey)
      state, loss, grads = _train_one_batch(subsubrngkey, state, stats, batch)
      # NOTE: Using the first element of replicated loss and grads
      num_big_batches = num_samples_trn / big_batch_size
      loss_epoch += loss[0] / num_big_batches
      grad_epoch += np.mean(jax.tree.flatten(
        jax.tree.map(jnp.mean, jax.tree.map(lambda g: jnp.abs(g[0]), grads)))[0]) / num_big_batches

    return state, loss_epoch, grad_epoch

  @jax.pmap
  def _evaluate_one_batch(state: TrainState, stats: Mapping[str, Stats], batch: Batch) -> Mapping:

    u_tgt = jnp.concatenate([batch.functions[key].values[:, [0]] for key in dataset.metadata.out], axis=-1)
    u_prd = stepper.apply(
      variables={'params': state.params},
      stats=stats,
      inputs=Inputs(
        s=jnp.concatenate([batch.functions[key].values[:, [0]] for key in dataset.metadata.geo], axis=-1),
        a=jnp.concatenate([batch.functions[key].values[:, [0]] for key in dataset.metadata.dom], axis=-1),
        q={key: batch.functions[key].values[:, [0]] for key in dataset.metadata.seg},
        m={key: batch.functions[key].mask[:, [0]] for key in dataset.metadata.seg},
        x_inp=batch.x,
        x_out=batch.x,
      ),
      graphs=graph_builder.build_graphs(batch.g),
      deterministic=True,
    )

    # Calculate the errors
    bbox_x = jnp.array(dataset.metadata.bbox_x)
    char_length = jnp.linalg.norm(bbox_x[1, :]-bbox_x[0, :])
    tol = RECALL_TOLERANCE * char_length
    batch_metrics = BatchMetrics(
      mse=mse_error(u_tgt, u_prd),
      l1=rel_lp_error_mean(u_tgt, u_prd, p=1, exclude_percentile=EXCLUDE_PERCENTILE),
      l2=rel_lp_error_mean(u_tgt, u_prd, p=2, exclude_percentile=EXCLUDE_PERCENTILE),
      recall=recall_tol_mean(batch.x, u_tgt, u_prd, q=EXCLUDE_PERCENTILE, tol=tol),
      chamfer=chamfer_mean(batch.x, u_tgt, u_prd, q=EXCLUDE_PERCENTILE),
    )

    return batch_metrics.__dict__

  def evaluate(state: TrainState, batches: Iterable[Batch]) -> EvalMetrics:
    """Evaluates the model on a dataset based on multiple trajectory lengths."""

    metrics: list[BatchMetrics] = []

    for batch in batches:
      # Split the batch between devices
      # [NUM_DEVICES*batch_size, ...] -> [NUM_DEVICES, batch_size, ...]
      batch = Batch(
        x=shard(batch.x),
        t=shard(batch.t),
        g=shard(batch.g),
        functions=shard(batch.functions),
      )
      # Get evaluation metrics
      batch_metrics = _evaluate_one_batch(state, stats, batch)
      batch_metrics = BatchMetrics(**batch_metrics)
      # Re-arrange the sub-batches gotten from each device
      batch_metrics.reshape(shape=(FLAGS.batch_size * NUM_DEVICES, 1))
      # Append the errors to the list
      metrics.append(batch_metrics)

    # Aggregate over the batch dimension and compute norm per variable
    metrics_med = Metrics(
      mse=jnp.median(jnp.concatenate([m.mse for m in metrics]), axis=0).item(),
      l1=jnp.median(jnp.concatenate([m.l1 for m in metrics]), axis=0).item(),
      l2=jnp.median(jnp.concatenate([m.l2 for m in metrics]), axis=0).item(),
      recall=jnp.median(jnp.concatenate([m.recall for m in metrics]), axis=0).item(),
      chamfer=jnp.median(jnp.concatenate([m.chamfer for m in metrics]), axis=0).item(),
    )
    metrics_std = Metrics(
      mse=jnp.std(jnp.concatenate([m.mse for m in metrics]), axis=0).item(),
      l1=jnp.std(jnp.concatenate([m.l1 for m in metrics]), axis=0).item(),
      l2=jnp.std(jnp.concatenate([m.l2 for m in metrics]), axis=0).item(),
      recall=jnp.std(jnp.concatenate([m.recall for m in metrics]), axis=0).item(),
      chamfer=jnp.std(jnp.concatenate([m.chamfer for m in metrics]), axis=0).item(),
    )
    metrics_max = Metrics(
      mse=jnp.max(jnp.concatenate([m.mse for m in metrics]), axis=0).item(),
      l1=jnp.max(jnp.concatenate([m.l1 for m in metrics]), axis=0).item(),
      l2=jnp.max(jnp.concatenate([m.l2 for m in metrics]), axis=0).item(),
      recall=jnp.max(jnp.concatenate([m.recall for m in metrics]), axis=0).item(),
      chamfer=jnp.max(jnp.concatenate([m.chamfer for m in metrics]), axis=0).item(),
    )

    # Build the metrics object
    metrics = EvalMetrics(median=metrics_med, std=metrics_std, maximum=metrics_max)

    return metrics

  # Evaluate before training
  metrics_trn = evaluate(state=state, batches=dataset.batches(split=0, batch_size=(FLAGS.batch_size * NUM_DEVICES)))
  metrics_val = evaluate(state=state, batches=dataset.batches(split=1, batch_size=(FLAGS.batch_size * NUM_DEVICES)))
  # Report the initial evaluations
  time_tot_pre = time() - time_int_pre
  lr = state.opt_state[-1].hyperparams['learning_rate'][0].item()
  logging.info('\t'.join([
    f'EPCH: {epochs_before : 04d}/{FLAGS.epochs : 04d}',
    f'LR: {lr : .2e}',
    f'TIME: {time_tot_pre : 06.1f}s',
    f'GRAD: {0. : .2e}',
    f'LOSS: {0. : .2e}',
    f'ERR-VAL (ERR-TRN): {metrics_val.median.l2 : .2%} ({metrics_trn.median.l2 : .2%})',
    f'RCL-VAL (RCL-TRN): {metrics_val.median.recall : .2%} ({metrics_trn.median.recall : .2%})',
    f'CHD-VAL (CHD-TRN): {metrics_val.median.chamfer : .2f} ({metrics_trn.median.chamfer : .2f})',
  ]))

  # Set the checkpoint manager up
  DIR = DIR_EXPERIMENTS / f'E{FLAGS.exp}' / FLAGS.datapath / FLAGS.datetime
  with disable_logging(level=logging.FATAL):
    (DIR / 'metrics').mkdir(exist_ok=True)
    (DIR / 'metrics/plots').mkdir(exist_ok=True)
    checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    checkpointer_options = orbax.checkpoint.CheckpointManagerOptions(
      max_to_keep=1,
      keep_period=None,
      best_fn=(lambda metrics: metrics['valid']['median']['l2']),
      best_mode='min',
      create=True,)
    checkpointer_save_args = orbax_utils.save_args_from_target(target={'state': state})
    checkpoint_manager = orbax.checkpoint.CheckpointManager(
      (DIR / 'checkpoints'), checkpointer, checkpointer_options)

  # Loop over epochs
  checkpointed_metrics = []
  _epochs = epochs // FLAGS.superbatch_repeats
  for _epoch in range(1, _epochs+1):
    # Get rngkey of the epoch
    subrngkeys = jax.random.split(jax.random.fold_in(rngkey, data=_epoch), num=3)

    # Store the initial time
    epoch = _epoch * FLAGS.superbatch_repeats
    time_int = time()

    # Re-construct the graphs with a new PRNG key
    # NOTE: In order to prevent training with the same regional nodes
    if dataset.metadata.fix and ((epoch % FLAGS.long_epoch_gap) == 0):
      dataset.build_graphs(builder=graph_builder, batch_size=FLAGS.batch_size, rngkey=subrngkeys[0])

    # Train one epoch
    state, loss, grad = train_one_epoch(
      rngkey=subrngkeys[1],
      state=state,
      batches=dataset.batches(split=0, batch_size=FLAGS.superbatch_size*NUM_DEVICES*FLAGS.batch_size, rngkey=subrngkeys[2]),
    )

    # Evaluate during training
    if ((epoch % FLAGS.long_epoch_gap) == 0) or (_epoch == _epochs):
      # Evaluate on training and validation datasets
      metrics_trn = evaluate(state=state, batches=dataset.batches(split=0, batch_size=(FLAGS.batch_size * NUM_DEVICES)))
      metrics_val = evaluate(state=state, batches=dataset.batches(split=1, batch_size=(FLAGS.batch_size * NUM_DEVICES)))
      # Log the results
      time_tot = (time() - time_int) / FLAGS.superbatch_repeats
      lr = state.opt_state[-1].hyperparams['learning_rate'][0].item()
      logging.info('\t'.join([
        f'EPCH: {epochs_before + epoch : 04d}/{epochs : 04d}',
        f'LR: {lr : .2e}',
        f'TIME: {time_tot : 06.1f}s',
        f'GRAD: {grad.item() : .2e}',
        f'LOSS: {loss.item() : .2e}',
        f'ERR-VAL (ERR-TRN): {metrics_val.median.l2 : .2%} ({metrics_trn.median.l2 : .2%})',
        f'RCL-VAL (RCL-TRN): {metrics_val.median.recall : .2%} ({metrics_trn.median.recall : .2%})',
        f'CHD-VAL (CHD-TRN): {metrics_val.median.chamfer : .2f} ({metrics_trn.median.chamfer : .2f})',
      ]))
      # Checkpoint
      with disable_logging(level=logging.FATAL):
        step = epochs_before + epoch
        checkpoint_metrics = {
          'step': step,
          'grad': grad.item(),
          'loss': loss.item(),
          'lr': lr,
          'train': metrics_trn.to_dict(),
          'valid': metrics_val.to_dict(),
        }
        checkpointed_metrics.append(checkpoint_metrics)
        # Store the state and the metrics
        checkpoint_manager.save(
          step=step,
          items={'state': jax.device_get(unreplicate(state)),},
          metrics=checkpoint_metrics,
          save_kwargs={'save_args': checkpointer_save_args}
        )
        with open(DIR / 'metrics' / f'{str(step)}.json', 'w') as f:
          json.dump(checkpoint_metrics, f)
        # Plot the history of the metrics
        metrics_to_plot = {
          'optimization': (
            {'label': 'Training gradients', 'values': lambda m: m['grad']},
            {'label': 'Training loss', 'values': lambda m: m['loss']},
          ),
          'error': (
            {'label': 'Training error [%]', 'values': lambda m: m['train']['median']['l2'] * 100},
            {'label': 'Validation error [%]', 'values': lambda m: m['valid']['median']['l2'] * 100}
          ),
        }
        steps = [m['step'] for m in checkpointed_metrics]
        for filename, mtp in metrics_to_plot.items():
          fig, axs = plt.subplots(
            ncols=2,
            figsize=(10, 3),
            sharex=True,
            sharey=(filename != 'optimization'),
            tight_layout=True,
          )
          for i, item in enumerate(mtp):
            values = [item['values'](m) for m in checkpointed_metrics]
            ax: plt.Axes = axs[i]
            ax.scatter(steps, values, s=10, color='black', zorder=3)
            ax.set(ylabel=item['label'], yscale='log')
            ax.grid(which='both')
          file = DIR / 'metrics/plots' / f'{filename}.pdf'
          fig.savefig(file, dpi=100, bbox_inches='tight')
          plt.close(fig)

    # Or just report the loss
    else:
      time_tot = (time() - time_int) / FLAGS.superbatch_repeats
      logging.info('\t'.join([
        f'EPCH: {epochs_before + epoch : 04d}/{epochs : 04d}',
        f'LR: {state.opt_state[-1].hyperparams["learning_rate"][0].item() : .2e}',
        f'TIME: {time_tot : 06.1f}s',
        f'GRAD: {grad.item() : .2e}',
        f'LOSS: {loss.item() : .2e}',
      ]))

  return unreplicate(state)

def get_operator(model_configs: Mapping[str, Any], dataset: Dataset) -> AbstractOperator:
  """Build the model based on the given configurations."""

  # Set model kwargs
  if model_configs is None:
    if FLAGS.core_name == 'XRIGNO':
      configs_core = dict(
        num_outputs=sum([dataset.sample.functions[key].values.shape[-1] for key in dataset.metadata.out]),
        processor_steps=FLAGS.processor_steps,
        node_latent_size=FLAGS.node_latent_size,
        edge_latent_size=FLAGS.edge_latent_size,
        mlp_hidden_layers=1,
        p_edge_masking=FLAGS.p_edge_masking,
        tdep=dataset.time_dependent,
      )
    elif FLAGS.core_name == 'XGAOT':
      configs_core = dict(
        num_outputs=sum([dataset.sample.functions[key].values.shape[-1] for key in dataset.metadata.out]),
        gridres=FLAGS.rmesh_gridres,
        patch_size=2,
        transformer_hidden_size=FLAGS.node_latent_size,
        processor_steps=FLAGS.processor_steps,
        processor_attn_heads=4,
        latent_size=FLAGS.node_latent_size,
        mlp_hidden_layers=1,
        p_edge_masking=FLAGS.p_edge_masking,
        tdep=dataset.time_dependent,
      )

    model_configs = {
      'configs_core': configs_core,
      'configs_extender': dict(
        depth=FLAGS.ext_depth,
        out_dim=FLAGS.ext_vars,
        latent_dim=FLAGS.ext_latent_size,
        n_heads=FLAGS.ext_heads,
        ff_mult=1,
        p_masking=FLAGS.ext_p_masking,
        attn_dropout=FLAGS.ext_dropout,
        ff_dropout=FLAGS.ext_dropout,
      ),
      'use_extender': FLAGS.extender,
      'independent': FLAGS.independent,
      'boundary_size': dataset.metadata.boundary_size,
      'unify': True,
      'even': False,
    }

  model = models.__dict__[FLAGS.core_name](**model_configs)

  return model

def main(argv):
  # Check the number of arguments
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  # Check the available devices
  with disable_logging():
    process_index = jax.process_index()
    process_count = jax.process_count()
    local_devices = jax.local_devices()
  logging.info('JAX host: %d / %d', process_index, process_count)
  logging.info('JAX local devices: %r', local_devices)
  # We only support single-host training
  assert process_count == 1
  # Check the inputs
  if not FLAGS.datetime:
    FLAGS.datetime = datetime.now().strftime('%Y%m%d-%H%M%S')
  assert (FLAGS.epochs % FLAGS.superbatch_repeats) == 0
  if FLAGS.core_name == 'XRIGNO':
    assert FLAGS.rmesh_gridres == 0
    FLAGS.rmesh_gridres = None
  elif FLAGS.core_name == 'XGAOT':
    assert FLAGS.rmesh_gridres > 0
    assert FLAGS.rmesh_gridres % 2 == 0
    FLAGS.rmesh_gridres = (FLAGS.rmesh_gridres, FLAGS.rmesh_gridres)
  else:
    raise ValueError(FLAGS.core_name)

  # Initialize the random keys
  rngkey = jax.random.key(FLAGS.seed)
  subrngkeys = jax.random.split(rngkey, num=3)

  # Read the dataset
  dataset = Dataset(
    dir=FLAGS.datadir,
    name=FLAGS.datapath,
    file='train.nc',
    space_downsample_factor=FLAGS.space_downsample_factor,
    boundary_downsample_factor=FLAGS.boundary_downsample_factor,
    splits=[(0, FLAGS.n_train), (FLAGS.n_train, FLAGS.n_train + FLAGS.n_valid)],
    replace_nan='maxabs',
    rngkey=subrngkeys[0],
  )
  # Add geometric features to the input domain functions
  dataset.metadata.dom += dataset.metadata.geo
  # Add pre-computed extensions to the input domain functions
  if FLAGS.use_extensions:
    dataset.metadata.dom += dataset.metadata.ext
  # Compute dataset statistics
  dataset.compute_stats(split=0, exclude_percentile=EXCLUDE_PERCENTILE_STATS, batch_size=FLAGS.batch_size)

  # Read the checkpoint
  if FLAGS.params:
    DIR_OLD_EXPERIMENT = DIR_EXPERIMENTS / FLAGS.params
    orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    step = orbax.checkpoint.CheckpointManager(DIR_OLD_EXPERIMENT / 'checkpoints', orbax_checkpointer).latest_step()
    ckpt = orbax_checkpointer.restore(directory=(DIR_OLD_EXPERIMENT / 'checkpoints' / str(step) / 'default'))
    state = ckpt['state']
    params = state['params']
    with open(DIR_OLD_EXPERIMENT / 'configs.json', 'rb') as f:
      old_configs = json.load(f)
      model_configs = old_configs['model_configs']
  else:
    params = None
    model_configs = None
  # Get the model
  model = get_operator(model_configs, dataset)

  # Store the configurations
  DIR = DIR_EXPERIMENTS / f'E{FLAGS.exp}' / FLAGS.datapath / FLAGS.datetime
  DIR.mkdir(parents=True)
  logging.info(f'Experiment stored in {DIR.relative_to(DIR_EXPERIMENTS).as_posix()}')
  flags = {f: FLAGS.get_flag_value(f, default=None) for f in FLAGS}
  with open(DIR / 'configs.json', 'w') as f:
    json.dump(fp=f,
      obj={'flags': flags, 'model_configs': model.configs},
      indent=2,
    )
  # Store the statistics
  with open(DIR / 'stats.pkl', 'wb') as f:
    pickle.dump(file=f, obj=dataset.stats)

  # Construct the graphs
  logging.info('Constructing the graphs for all dataset samples...')
  graph_builder = GraphBuilder(
    pmesh_subsample_factor=FLAGS.pmesh_subsample_factor,
    overlap_factor_p2r=FLAGS.overlap_factor_p2r,
    overlap_factor_r2p=FLAGS.overlap_factor_r2p,
    rmesh_levels=FLAGS.rmesh_levels,
    rmesh_subsample_factor=FLAGS.rmesh_subsample_factor,
    periodic=dataset.metadata.periodic,
    node_coordinate_freqs=FLAGS.node_coordinate_freqs,
    gridres=FLAGS.rmesh_gridres,
  )
  dataset.build_graphs(builder=graph_builder, batch_size=FLAGS.batch_size)
  num_edges_p2r = dataset.rigs.p2r_edge_indices.shape[1]
  num_edges_r2r = dataset.rigs.r2r_edge_indices.shape[1]
  num_edges_r2p = (dataset.rigs.r2p_edge_indices.shape[1]
    if (dataset.rigs.r2p_edge_indices is not None) else dataset.rigs.p2r_edge_indices.shape[1])
  logging.info(f'Constructed {len(dataset.rigs)} graph(s).')
  logging.info(f'Each graph has [ {num_edges_p2r / 1000 : .1f}k + '
    f'{num_edges_r2r / 1000 : .1f}k + {num_edges_r2p / 1000 : .1f}k ] edges')

  # Initialzize the model or use the loaded parameters
  if params is None:
    dummy_graph_builder = GraphBuilder(
      pmesh_subsample_factor=16,
      overlap_factor_p2r=.01,
      overlap_factor_r2p=.01,
      rmesh_levels=1,
      rmesh_subsample_factor=4,
      periodic=dataset.metadata.periodic,
      node_coordinate_freqs=FLAGS.node_coordinate_freqs,
      gridres=FLAGS.rmesh_gridres,
    )
    dummy_graphs = dummy_graph_builder.build_graphs(
        dummy_graph_builder.build_metadata(
        x_inp=dataset.sample.x[0, 0],
        x_out=dataset.sample.x[0, 0],
        x_bnd=dataset.sample.x[0, 0, :100],
        z_bnd=dataset.sample.functions['sdfgrad'].values[0, 0, :100],
        bbox=np.array(dataset.metadata.bbox_x),
      )
    )
    dummy_graphs = jax.tree.map(lambda v: jnp.repeat(v, repeats=FLAGS.batch_size, axis=0), dummy_graphs)
    dummy_inputs = Inputs(
      s=jnp.ones(shape=(FLAGS.batch_size, 1, dataset.sample.x.shape[2], sum([dataset.sample.functions[key].values.shape[-1] for key in dataset.metadata.geo]))),
      a=jnp.ones(shape=(FLAGS.batch_size, 1, dataset.sample.x.shape[2], sum([dataset.sample.functions[key].values.shape[-1] for key in dataset.metadata.dom]))),
      q={key: jnp.ones(shape=(FLAGS.batch_size, 1, dataset.sample.x.shape[2], sum([dataset.sample.functions[key].values.shape[-1]]))) for key in dataset.metadata.seg},
      m={key: jnp.ones(shape=(FLAGS.batch_size, 1, dataset.sample.x.shape[2]), dtype=bool) for key in dataset.metadata.seg},
      x_inp=jnp.tile(dataset.sample.x, reps=(FLAGS.batch_size, 1, 1, 1)),
      x_out=jnp.tile(dataset.sample.x, reps=(FLAGS.batch_size, 1, 1, 1)),
      t=None,
      tau=None,
    )
    variables = model.init(rngs=subrngkeys[1], inputs=dummy_inputs, graphs=dummy_graphs, deterministic=True)
    params = variables['params']

  # Report the total number of parameters
  n_model_parameters = sum([arr.size for arr in jax.tree.leaves(params)])
  logging.info(f'Training a {model.__class__.__name__} with {n_model_parameters} parameters')

  # Set optimizer transition steps
  num_big_batches = (dataset.splits[0][1] - dataset.splits[0][0]) // (FLAGS.batch_size * NUM_DEVICES * FLAGS.superbatch_size)
  transition_steps = FLAGS.epochs * num_big_batches * FLAGS.superbatch_size
  # Set learning rate and optimizer
  pct_start = .05  # Warmup cosine onecycle
  pct_final = .0   # Final exponential decay
  lr = optax.join_schedules(
    schedules=[
      optax.cosine_onecycle_schedule(
        transition_steps=((1 - pct_final) * transition_steps),
        peak_value=FLAGS.lr_peak,
        pct_start=(pct_start / (1 - pct_final)),
        div_factor=(FLAGS.lr_peak / FLAGS.lr_init),
        final_div_factor=(FLAGS.lr_init / FLAGS.lr_base),
      ),
      optax.exponential_decay(
        transition_steps=(pct_final * transition_steps),
        init_value=FLAGS.lr_base,
        decay_rate=(FLAGS.lr_lowr / FLAGS.lr_base) if FLAGS.lr_lowr else 1,
      ),
    ],
    boundaries=[int((1 - pct_final) * transition_steps),],
  )
  tx = optax.chain(
    optax.adaptive_grad_clip(clipping=FLAGS.gclip),
    optax.inject_hyperparams(optax.adamw)(learning_rate=lr, weight_decay=FLAGS.weight_decay)
  )
  state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

  # Train the model
  loss_fn = lambda gtr, prd: rel_lp_loss(gtr, prd, p=2, q=EXCLUDE_PERCENTILE_LOSS)
  state = train(
    rngkey=subrngkeys[2],
    model=model,
    state=state,
    dataset=dataset,
    graph_builder=graph_builder,
    epochs=FLAGS.epochs,
    epochs_before=0,
    loss_fn=loss_fn,
  )

if __name__ == '__main__':
  logging.set_verbosity('info')
  define_flags()
  app.run(main)
