"""Training file for the TDVAE40n model in the paper
"""

import numpy as np
import tensorflow as tf

import training


def main(unused_argv):
  steps_per_epoch = 640000 // 128
  beta_y = np.concatenate((np.full(500 * steps_per_epoch, 0.1),
                           np.linspace(0.1, 1.0, 900 * steps_per_epoch)))

  train_eval_ops, test_eval_ops, sess, params, saver = \
      training.run_training(
          dataset='textures',
          dataset_params={
              'batch_size': 128,
              'test_batch_size': 128,
              'train_every': 2,
              'test_every': 1,
              'crop_dim': 40,
              'path': '../datasets/fakelabeled_natural_commonfiltered_640000_40px.pkl',
              'offset': 0.0},
          output_type='normal',
          output_sd=0.4,
          n_y=250,
          n_y_samples=1,
          n_y_samples_reconstr=1,
          beta_y_evo=beta_y,
          n_z=1800,
          beta_z_evo=1.0,
          lr_init=.05e-3,
          lr_factor=1.,
          lr_schedule=[1],
          n_steps=steps_per_epoch * 2250,
          report_interval=steps_per_epoch * 150,
          random_seed=None,
          encoder_kwargs={
              'encoder_type': 'mlp',
              'n_enc': [2000],
              'enc_strides': [1]
          },
          cluster_encoder_kwargs={
              'encoder_type': 'mlp',
              'n_enc': [1000, 500, 250]
          },
          latent_y_to_concat_encoder_kwargs={
              'y_to_concat_encoder_type': 'mlp',
              'y_to_concat_n_enc': [250, 500, 1000, 2000]
          },
          latent_concat_to_z_encoder_kwargs={
              'concat_to_z_encoder_type': 'mlp',
              'concat_to_z_n_enc': [2000]
          },
          latent_decoder_kwargs={
              'decoder_type': 'mlp',
              'n_dec': [2000]
          },
          decoder_kwargs={
              'decoder_type': 'mlp',
              'n_dec': [],
              'dec_up_strides': None
          },
          z1_distr_kwargs={
              'distr': 'normal',
              'sigma_nonlin': 'exp',
              'sigma_param': 'var'
          },
          z2_distr_kwargs={
              'distr': 'normal',
              'sigma_nonlin': 'exp',
              'sigma_param': 'var'
          },
          l2_lambda_w=0e-6,
          l2_lambda_b=0e-6,
          gradskip_threshold=1e10,
          gradclip_threshold=1e9,
          save_dir='log_TDVAE40n',
          restore_from=None,
          tb_dir=None,
          activation=tf.math.softplus
      )


if __name__ == '__main__':
  main([])
