# Assume vector (flat) inputs. Here the rate is bits per sample (multiply by np.log(2) to convert to nats)

import argparse
import os
import sys

import numpy as np
import tensorflow as tf
import tensorflow_compression as tfc
from absl import app
from absl.flags import argparse_flags

import tensorflow_probability as tfp

from rdvae.utils import softplus_inv_1

tfd = tfp.distributions
tfb = tfp.bijectors

# VAE implementation inspired by
# https://github.com/tensorflow/probability/blob/9d4fc05d16a0401aa2b4a9653320e37248892e8a/tensorflow_probability/examples/vae.py#L271

from rdvae.nn_models import get_activation, make_mlp
import rdvae.tfc_utils as tfc_utils
from common import ntc_sources

def af_transform(base_distribution, mades, permute=True, iaf=False):
  """
  Apply a cascade of autoregressive transforms to a base distribution. Default is MAF.
  """
  if permute:
    dims = np.arange(base_distribution.event_shape[0])
  dist = base_distribution
  for i, made in enumerate(mades):
    af = tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=made)
    if iaf:
      af = tfb.Invert(af)
    if permute:
      permutation_order = np.roll(dims, i)  # circular shift
      permutation = tfb.Permute(permutation_order)
      bij = tfb.Chain([permutation, af])
    else:
      bij = af
    dist = tfd.TransformedDistribution(distribution=dist, bijector=bij)
  return dist


def check_no_decoder(decoder_units):
  # When decoder_units = [] (default), the code still uses a decoder network mapping from latent_dim to
  # data_dim. In order to specify "no decoder network at all", we follow the convention of setting decoder_units=[0]
  return len(decoder_units) == 1 and decoder_units[0] <= 0


class Model(tf.keras.Model):
  """Loosely based on toy_sources.ntc.NTCModel"""

  # def __init__(self, lmbda, data_dim, latent_dim, encoder_units, decoder_units,
  #              encoder_activation, decoder_activation, prior_type='deep', posterior_type='gaussian',
  #              dtype='float32', ar_hidden_units=[10, 10], ar_activation='relu', maf_stacks=3, iaf_stacks=3,
  #              rpd=False):
  def __init__(self, **kwargs):
    super().__init__()
    self.__dict__.update(kwargs)
    # self.lmbda = lmbda
    # self.data_dim = data_dim
    # self.latent_dim = latent_dim
    # self.rpd = rpd
    # self.dtype = dtype
    dtype = self.dtype
    posterior_type = self.posterior_type
    latent_dim = self.latent_dim
    ar_activation = get_activation(self.ar_activation, dtype)

    # borrowed from get_ntc_mlp_model
    # data_dim, = source.event_shape
    if posterior_type in ('gaussian', 'iaf'):
      encoder_output_dim = latent_dim * 2  # currently IAF uses a base Gaussian distribution conditioned on x
      if posterior_type == 'iaf':
        self._iaf_mades = [
          tfb.AutoregressiveNetwork(params=2, activation=ar_activation, hidden_units=self.ar_hidden_units) for
          _ in range(self.iaf_stacks)]
    else:
      encoder_output_dim = latent_dim

    # We always require an encoder network in order to produce the variational distribution Q(Y|X).
    # encoder_units = [] gives the minimal network.
    encoder = make_mlp(
      units=self.encoder_units + [encoder_output_dim],
      activation=get_activation(self.encoder_activation, dtype),
      name="encoder",
      input_shape=[self.data_dim],
      dtype=dtype,
    )

    # However, a decoder network is optional when dim(Y) == dim(X).
    # When decoder_units = [] (default), the code still uses a decoder network mapping from latent_dim to
    # data_dim. In order to specify "no decoder network at all", we follow the convention of setting decoder_units=[0]
    if check_no_decoder(self.decoder_units):
      decoder = None  # no decoder
      assert self.data_dim == latent_dim
      print('Not using decoder')
    else:  # decoder_units = [] allowed
      decoder = make_mlp(
        self.decoder_units + [self.data_dim],
        get_activation(self.decoder_activation, dtype),
        "decoder",
        [latent_dim],
        dtype,
      )

    self.encoder = encoder
    self.decoder = decoder

    # self.prior_type = prior_type
    self._prior = None
    if self.prior_type == "deep":
      self._prior = tfc_utils.MyDeepFactorized(
        batch_shape=[self.latent_dim], dtype=self.dtype)
    elif self.prior_type == 'std_gaussian':  # use 'gmm_1' for gaussian prior with learned mean/scale
      self._prior = tfd.MultivariateNormalDiag(loc=tf.zeros([self.latent_dim], dtype=self.dtype),
                                               scale_diag=tf.ones([self.latent_dim], dtype=self.dtype))
    elif self.prior_type == 'maf':
      # see https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors/MaskedAutoregressiveFlow
      # and https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors/AutoregressiveNetwork
      # maf = tfd.TransformedDistribution(
      #     distribution=tfd.MultivariateNormalDiag(loc=tf.zeros([self.latent_dim], dtype=self.dtype),
      #                                             scale_diag=tf.ones([self.latent_dim], dtype=self.dtype)),
      #     bijector=tfb.MaskedAutoregressiveFlow(
      #         shift_and_log_scale_fn=tfb.AutoregressiveNetwork(
      #             params=2, hidden_units=ar_hidden_units)))
      # self._prior = maf
      self._maf_mades = [
        tfb.AutoregressiveNetwork(params=2, activation=ar_activation, hidden_units=self.ar_hidden_units) for _
        in range(self.maf_stacks)]
      base_distribution = tfd.MultivariateNormalDiag(loc=tf.zeros([self.latent_dim], dtype=self.dtype),
                                                     scale_diag=tf.ones([self.latent_dim], dtype=self.dtype))
      self._prior = af_transform(base_distribution, self._maf_mades, permute=True, iaf=False)
    elif self.prior_type[:4] in ("gsm_", "gmm_", "lsm_", "lmm_"):  # mixture prior; specified like 'gmm_2'
      # This only implements a scalar mixture for each dimension, and the dimensions themselves are
      # still fully factorized just like tfc.DeepFactorized
      components = int(self.prior_type[4:])
      shape = (self.latent_dim, components)
      self.logits = tf.Variable(tf.random.normal(shape, dtype=self.dtype))
      self.log_scale = tf.Variable(
        tf.random.normal(shape, mean=2., dtype=self.dtype))
      if "s" in self.prior_type:  # scale mixture
        self.loc = 0.
      else:
        self.loc = tf.Variable(tf.random.normal(shape, dtype=self.dtype))
    else:
      raise ValueError(f"Unknown prior_type: '{self.prior_type}'.")

    self.build([None, self.data_dim])

  def prior(self, conv_unoise=False):
    if self._prior is not None:
      prior = self._prior
    elif self.prior_type[:4] in ("gsm_", "gmm_", "lsm_", "lmm_"):
      cls = tfd.Normal if self.prior_type.startswith("g") else tfd.Logistic
      prior = tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(logits=self.logits),
        components_distribution=cls(
          # loc=self.loc, scale=tf.math.exp(self.log_scale)),
          loc=self.loc, scale=tf.math.softplus(self.log_scale)),
      )
    if conv_unoise:  # convolve with uniform noise for NTC compression model
      prior = tfc_utils.MyUniformNoiseAdapter(prior)
    return prior

  def call(self, x, training):
    """Given a batch of inputs, perform a full inference -> generative pass through the model."""
    if self.posterior_type in ('gaussian', 'iaf'):
      encoder_res = self.encoder(x)
      qy_loc = encoder_res[..., :self.latent_dim]
      qy_scale = tf.nn.softplus(encoder_res[..., self.latent_dim:] + softplus_inv_1)
      if self.scale_lb:
        qy_scale = qy_scale + self.scale_lb
      y_dist = tfd.MultivariateNormalDiag(loc=qy_loc, scale_diag=qy_scale, name="q_y")
      if self.posterior_type == 'iaf':
        # y_dist = tfd.TransformedDistribution(distribution=y_dist, bijector=self.iaf)
        y_dist = af_transform(y_dist, self._iaf_mades, permute=True, iaf=True)

      y_tilde = y_dist.sample()  # Y ~ Q(Y|X); batch_size by latent_dim; not using IWAE as it doesn't exactly correspond to our R-D objective here
      log_q_tilde = y_dist.log_prob(y_tilde)  # [batch_size]; should be 0 on avg for uniform distribution
      prior = self.prior(conv_unoise=False)
      # kls = tfd.kl_divergence(encoder_dist, self.prior_dist)  # only
      # for Gaussians
    elif self.posterior_type == 'uniform':
      encoder_res = self.encoder(x)
      prior = self.prior(conv_unoise=True)  # Balle VAE
      if not training:  # Hard quantization; do it the proper/fancy way (possibly with smart offset) using tfc.
        entropy_model = tfc.ContinuousBatchedEntropyModel(prior, coding_rank=1, compression=False)
        y_tilde = entropy_model.quantize(encoder_res)
      else:
        y_dist = tfd.Uniform(low=encoder_res - 0.5, high=encoder_res + 0.5, name="q_y")
        y_tilde = y_dist.sample()  # Y ~ Q(Y|X); batch_size by latent_dim
      log_q_tilde = 0.  # [batch_size]; should be 0 on avg for uniform distribution
    else:
      raise NotImplementedError(f'unknown posterior_type={self.posterior_type}')

    if self.prior_type == 'maf':
      log_prior = prior.log_prob(y_tilde)  # just [batch_size], one number per each x in the batch
    else:
      log_prior = tf.reduce_sum(prior.log_prob(y_tilde),
                                axis=-1)  # sum across latent_dim (since the prior is fully factorized)
    rates = log_q_tilde - log_prior

    if self.decoder:
      y_tilde = self.decoder(y_tilde)

    # Compute losses.
    mse = tf.reduce_mean(tf.math.squared_difference(x, y_tilde))
    if not self.nats:
      rates = (rates / tf.cast(tf.math.log(2.), self.dtype))  # convert to bits
    if self.rpd:  # normalize by number of data dimension
      rate = tf.reduce_mean(rates) / float(self.data_dim)
    else:
      rate = tf.reduce_mean(rates)
    loss = rate + self.lmbda * mse
    return dict(loss=loss, rate=rate, rates=rates, mse=mse, y_tilde=y_tilde)

  def train_step(self, x):
    with tf.GradientTape() as tape:
      res = self(x, training=True)
    variables = self.trainable_variables
    loss = res['loss']
    gradients = tape.gradient(loss, variables)
    self.optimizer.apply_gradients(zip(gradients, variables))
    for m in self.my_metrics:
      m.update_state(res[m.name])
    retval = {m.name: m.result() for m in self.my_metrics}
    retval['lr'] = self.optimizer.lr
    return retval

  def test_step(self, x):
    res = self(x, training=False)
    for m in self.my_metrics:
      m.update_state(res[m.name])
    return {m.name: m.result() for m in self.my_metrics}

  def predict_step(self, x):
    raise NotImplementedError("Prediction API is not supported.")

  def compile(self, **kwargs):
    super().compile(
      loss=None,
      metrics=None,
      loss_weights=None,
      weighted_metrics=None,
      **kwargs,
    )
    self.metric_names = ('loss', 'rate', 'mse')
    self.my_metrics = [tf.keras.metrics.Mean(name=name) for name in self.metric_names]  # can't use self.metrics

  @classmethod
  def create_model(cls, args):
    return cls(**vars(args))

  def fit(self, *args, **kwargs):
    retval = super().fit(*args, **kwargs)
    return retval

  def sample(self, num_samples):
    """
    Draw samples from the compression model.
    :param num_samples: int
    :return: a [num_samples, data_shape] tensor.
    """
    if self.prior_type == 'deep' and self.posterior_type == 'uniform':
      # sample from the discretized prior (as would be in actual entropy coding)
      prior = self.prior(
        conv_unoise=True)  # the prior is not actually convolved with unoise; this is just to get quantized samples
      samples = prior.sample(num_samples, quantized=True)
    else:
      samples = self.prior(conv_unoise=False).sample(num_samples)
    if self.decoder:
      samples = self.decoder(samples)
    return samples


def get_runname(args):
  from rdvae.utils import config_dict_to_str
  model_name = os.path.splitext(os.path.basename(__file__))[0]
  runname = config_dict_to_str(vars(args),
                               record_keys=('data_dim', 'latent_dim', 'lmbda', 'encoder_units', 'decoder_units',
                                            'prior_type', 'posterior_type',
                                            'maf_stacks', 'iaf_stacks'), prefix=model_name)
  return runname


def gen_dataset(dataset_spec: str, data_dim: int, batchsize: int, dtype='float32', **kwargs):
  """
  This returns an 'infinite' batched dataset for training.
  If only one batch of data is desired, you can create a Python iterator like 'batched_iterator = iter(dataset)'
   and then call 'batch = next(batch_iterator)'; see https://github.com/tensorflow/tensorflow/issues/40285
  :param dataset: a string specifying the dataset
  :param data_dim:
  :param batchsize:
  :return:
  """
  if dataset_spec == 'gaussian':
    if kwargs.get('gparams_path', None):
      gparams = np.load(kwargs['gparams_path'])
      loc = gparams['loc'].astype(dtype)
      scale = gparams['scale'].astype(dtype)
    else:
      loc = np.zeros(data_dim, dtype=dtype)
      scale = np.ones(data_dim, dtype=dtype)
    source = tfd.Normal(loc=loc, scale=scale)
    map_sample_fun = lambda _: source.sample(batchsize)
  elif dataset_spec == 'banana':
    if kwargs.get('gparams_path', None):
      print('Did you mean to run with --dataset gaussian instead?')
    source = ntc_sources.get_banana()  # a tfp.distributions.TransformedDistribution object
    if data_dim == 2:
      map_sample_fun = lambda _: source.sample(batchsize)
    else:
      from common.ntc_sources import get_nd_banana
      map_sample_fun, _ = get_nd_banana(data_dim, kwargs['embedder_depth'], batchsize, kwargs.get('seed', 0))
  else:
    raise NotImplementedError
  dataset = tf.data.Dataset.from_tensors(
    [])  # got this trick from https://github.com/tensorflow/compression/blob/66228f0faf9f500ffba9a99d5f3ad97689595ef8/models/toy_sources/compression_model.py#L121
  dataset = dataset.repeat()
  dataset = dataset.map(map_sample_fun)
  return dataset


def get_lr_scheduler(learning_rate, epochs, decay_factor=0.1, warmup_epochs=0):
  """Returns a learning rate scheduler function for the given configuration."""

  def scheduler(epoch, lr):
    del lr  # unused
    if epoch < warmup_epochs:
      return learning_rate * 10. ** (epoch - warmup_epochs)
    if epoch < 1 / 2 * epochs:
      return learning_rate
    if epoch < 3 / 4 * epochs:
      return learning_rate * decay_factor ** 1
    if epoch < 7 / 8 * epochs:
      return learning_rate * decay_factor ** 2
    return learning_rate * decay_factor ** 3

  return scheduler


def train(args):
  """Instantiates and trains the model."""
  if args.check_numerics:
    tf.debugging.enable_check_numerics()

  model = Model.create_model(args)
  model.compile(
    run_eagerly=args.eager,
    optimizer=tf.keras.optimizers.Adam(learning_rate=args.lr),
  )

  if args.dataset.endswith('.npy') or args.dataset.endswith('.npz'):
    from rdvae.utils import get_np_datasets
    train_dataset, validation_dataset = get_np_datasets(args.dataset, args.batchsize)
  else:
    train_dataset = gen_dataset(dataset_spec=args.dataset, data_dim=args.data_dim, batchsize=args.batchsize,
                                gparams_path=args.gparams_path, embedder_depth=args.embedder_depth)
    validation_dataset = gen_dataset(dataset_spec=args.dataset, data_dim=args.data_dim, batchsize=args.batchsize,
                                     gparams_path=args.gparams_path, embedder_depth=args.embedder_depth)

  validation_dataset = validation_dataset.take(
    args.max_validation_steps)  # keras crashes without this (would be using an infinite validation set)

  ##################### BEGIN: Good old bookkeeping #########################
  runname = get_runname(args)
  save_dir = os.path.join(args.checkpoint_dir, runname)
  if not os.path.exists(save_dir):
    os.makedirs(save_dir)
  import json
  from common.common_utils import get_time_str
  time_str = get_time_str()
  args_file_name = f'args-{time_str}.json'
  import sys
  args.cmdline = " ".join(sys.argv)  # attempt to reconstruct the original cmdline; not reliable (e.g., loses quotes)
  with open(os.path.join(save_dir, args_file_name), 'w') as f:  # will overwrite existing
    json.dump(vars(args), f, indent=4, sort_keys=True)
  # save a copy of the source code
  from shutil import copy2
  script_name = os.path.splitext(os.path.basename(__file__))[0]
  copied_path = copy2(script_name + '.py', os.path.join(save_dir, f'{script_name}-{time_str}.py'))
  if args.verbose:
    print('Saved a copy of %s.py to %s' % (script_name, copied_path))

  # log to file during training
  log_file_path = os.path.join(save_dir, f'record-{time_str}.jsonl')
  from rdvae.utils import get_json_logging_callback
  file_log_callback = get_json_logging_callback(log_file_path)
  print(f'Logging to {log_file_path}')
  ##################### END: Good old bookkeeping #########################

  #### BEGIN: boilerplate for periodic checkpointing and optionally resuming training ####
  initial_epoch = 0
  if args.cont is not None:
    if args.cont == '':  # if flag specified but no value given, use the latest ckpt in save_dir
      restore_ckpt_path = tf.train.latest_checkpoint(save_dir)
      if not restore_ckpt_path:
        print(f'No checkpoints found in {save_dir}; training from scratch!')
    else:  # then the supplied ckpt path had better be valid
      if os.path.isdir(args.cont):
        ckpt_dir = args.cont
        restore_ckpt_path = tf.train.latest_checkpoint(ckpt_dir)
        assert restore_ckpt_path is not None, f'No checkpoints found in {ckpt_dir}'
      else:  # assuming this is a checkpoint name
        restore_ckpt_path = args.cont
    if restore_ckpt_path:
      load_status = model.load_weights(restore_ckpt_path).expect_partial()
      # load_status.assert_consumed()
      print('Loaded model weights from', restore_ckpt_path)
      # grab epoch number from checkpoint name
      from common.common_utils import parse_runname
      initial_epoch = int(parse_runname(restore_ckpt_path)['epoch'])
  from rdvae.keras_utils import MyModelCheckpointCallback
  model_checkpoint_callback = MyModelCheckpointCallback(
    filepath=save_dir + '/ckpt-lmbda=%g-epoch={epoch}-loss={loss:.3f}' % args.lmbda,
    max_to_keep=1,
    verbose=0,
    save_weights_only=True,
    save_best_only=False,
    save_freq='epoch'
  )
  #### END: boilerplate for periodic checkpointing and optionally resuming training ####
  tmp_save_dir = os.path.join('/tmp/tfc-models', save_dir)
  lr_scheduler = get_lr_scheduler(args.lr, args.epochs, decay_factor=0.2)
  hist = model.fit(
    train_dataset.prefetch(tf.data.AUTOTUNE),
    epochs=args.epochs,
    steps_per_epoch=args.steps_per_epoch,
    validation_data=validation_dataset.cache(),
    validation_freq=1,
    verbose=int(args.verbose),
    callbacks=[
      tf.keras.callbacks.TerminateOnNaN(),
      # tf.keras.callbacks.TensorBoard(
      #     log_dir=tmp_save_dir,
      #     histogram_freq=1, update_freq="epoch"),
      # tf.keras.callbacks.experimental.BackupAndRestore(tmp_save_dir),
      model_checkpoint_callback,
      file_log_callback,
      tf.keras.callbacks.LearningRateScheduler(lr_scheduler),
    ],
    initial_epoch=initial_epoch
  )
  # model.save_weights(os.path.join(save_dir, f'ckpt-lmbda={args.lmbda}-epoch={args.epochs}'))
  return hist


def parse_args(argv):
  """Parses command line arguments."""
  parser = argparse_flags.ArgumentParser(
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)

  # High-level options.
  parser.add_argument(
    "--verbose", "-V", action="store_true",
    help="Report progress and metrics when training or compressing.")
  parser.add_argument("--seed", type=int, default=0)
  parser.add_argument(
    "--check_numerics", action="store_true",
    help="Enable TF support for catching NaN and Inf in tensors.")
  parser.add_argument(
    "--eager", default=False, action="store_true",
    help="Whether to run eagerly (helpful for debugging).")
  parser.add_argument(
    "--checkpoint_dir", default="./checkpoints",
    help="Directory where to save/load model checkpoints.")

  # Specifying dataset
  parser.add_argument("--data_dim", type=int, default=None, help="Data dimensionality.")
  parser.add_argument("--dataset", type=str, default="banana", help="Dataset name/specifier. ")
  parser.add_argument("--gparams_path", type=str, default=None, help="Path to Gaussian loc/scale params. ")
  parser.add_argument("--embedder_depth", type=int, default=1, help="Depth of MLP used to embed banana source"
                                                                    "in higher dimensions (only has effect when"
                                                                    "data_dim != 2 for banana soruce). ")

  # Model specific args
  parser.add_argument(
    "--scale_lb", type=float, default=None,
    help="Cap q(y|x) scale to be at least this number to avoid KL blowup.")
  parser.add_argument(
    "--encoder_units", type=lambda s: [int(i) for i in s.split(',')], default=[],
    help="A comma delimited list, specifying the number of units per hidden layer in the encoder.")
  parser.add_argument(
    "--decoder_units", type=lambda s: [int(i) for i in s.split(',')], default=[],
    help="A comma delimited list, specifying the number of units per hidden layer in the decoder;"
         "set to 0 to not use decoder (for quantization experiments).")
  parser.add_argument("--encoder_activation", type=str, default="softplus", help="Activation in encoder MLP")
  parser.add_argument("--decoder_activation", type=str, default="softplus", help="Activation in decoder MLP")

  parser.add_argument("--latent_dim", type=int, help="Latent space dimensionality."
                                                     "Will be automatically set to be the same as data_dim if decoder_units=0.")
  parser.add_argument(
    "--posterior_type", type=str, default='gaussian', help="Posterior type.")
  parser.add_argument(
    "--prior_type", type=str, default='deep', help="Prior type.")
  parser.add_argument(
    "--ar_hidden_units", type=lambda s: [int(i) for i in s.split(',')], default=[10, 10],
    help="A comma delimited list, specifying the number of hidden units per MLP layer in the AutoregressiveNetworks"
         "for normalizing flow.")
  parser.add_argument(
    "--ar_activation", type=str, default=None,
    help="Activation function to use in the AutoregressiveNetworks"
         "for normalizing flow. No need to worry about output activation as tfb.MaskedAutoregressiveFlow operates on"
         "log_scale outputted by the AutoregressiveNetworks.")
  parser.add_argument(
    "--maf_stacks", type=int, default=0, help="Number of stacks of transforms to use for MAF prior.")
  parser.add_argument(
    "--iaf_stacks", type=int, default=0, help="Number of stacks of transforms to use for IAF posterior.")
  parser.add_argument(
    "--lambda", type=float, default=0.01, dest="lmbda",
    help="Lambda for rate-distortion tradeoff.")
  parser.add_argument(
    "--rpd", default=False, action='store_true',
    help="Whether to normalize the rate (per sample) by the number of data dimensions; default is False, i.e., bits/nats per sample.")
  parser.add_argument(
    "--nats", default=False, action='store_true',
    help="Whether to compute rate in terms of nats (instead of bits)")

  subparsers = parser.add_subparsers(
    title="commands", dest="command",
    help="What to do: 'train' loads training data and trains (or continues "
         "to train) a new model. 'compress' reads an image file (lossless "
         "PNG format) and writes a compressed binary file. 'decompress' "
         "reads a binary file and reconstructs the image (in PNG format). "
         "input and output filenames need to be provided for the latter "
         "two options. Invoke '<command> -h' for more information.")

  # 'train' subcommand.
  train_cmd = subparsers.add_parser(
    "train",
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    description="Trains (or continues to train) a new model. Note that this "
                "model trains on a continuous stream of patches drawn from "
                "the training image dataset. An epoch is always defined as "
                "the same number of batches given by --steps_per_epoch. "
                "The purpose of validation is mostly to evaluate the "
                "rate-distortion performance of the model using actual "
                "quantization rather than the differentiable proxy loss. "
                "Note that when using custom training images, the validation "
                "set is simply a random sampling of patches from the "
                "training set.")

  train_cmd.add_argument(
    "--batchsize", type=int, default=1024,
    help="Batch size for training and validation.")
  train_cmd.add_argument(
    "--lr", type=float, default=1e-3,
    help="Learning rate.")
  # train_cmd.add_argument(
  #     "--patchsize", type=int, default=256,
  #     help="Size of image patches for training and validation.")
  train_cmd.add_argument(
    "--epochs", type=int, default=100,
    help="Train up to this number of epochs. (One epoch is here defined as "
         "the number of steps given by --steps_per_epoch, not iterations "
         "over the full training dataset.)")
  train_cmd.add_argument(
    "--steps_per_epoch", type=int, default=1000,
    help="Perform validation and produce logs after this many batches.")
  train_cmd.add_argument(
    "--max_validation_steps", type=int, default=10,
    help="Maximum number of batches to use for validation.")
  # train_cmd.add_argument(
  #     "--preprocess_threads", type=int, default=16,
  #     help="Number of CPU threads to use for parallel decoding of training "
  #          "images.")
  train_cmd.add_argument(
    "--cont", nargs='?', default=None, const='',  # see https://docs.python.org/3/library/argparse.html#nargs
    help="Path to the checkpoint (either the directory containing the checkpoint (will use the latest), or"
         "full checkpoint name (should not have the .index extension)) to continue training from;"
         "if no path is given, will try to use the latest ckpt in the run dir.")

  # # 'decompress' subcommand.
  # decompress_cmd = subparsers.add_parser(
  #     "decompress",
  #     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  #     description="Reads a TFCI file, reconstructs the image, and writes back "
  #                 "a PNG file.")

  # # Arguments for both 'compress' and 'decompress'.
  # for cmd, ext in ((compress_cmd, ".tfci"), (decompress_cmd, ".png")):
  #     cmd.add_argument(
  #         "input_file",
  #         help="Input filename.")
  #     cmd.add_argument(
  #         "output_file", nargs="?",
  #         help=f"Output filename (optional). If not provided, appends '{ext}' to "
  #              f"the input filename.")

  # Parse arguments.
  args = parser.parse_args(argv[1:])
  if args.command is None:
    parser.print_usage()
    sys.exit(2)
  return args


def main(args):
  # Invoke subcommand.
  if check_no_decoder(args.decoder_units):
    print(f'Using Z=Y; resetting latent_dim={args.latent_dim} to data_dim={args.data_dim}')
    args.latent_dim = args.data_dim

  seed = args.seed
  np.random.seed(seed)
  tf.random.set_seed(seed)
  if args.command == "train":
    train(args)


if __name__ == "__main__":
  app.run(main, flags_parser=parse_args)
