import jax
import jax.numpy as jnp
import vae
import utils as vae_utils
import ml_collections
from absl import app
from tensorboardX import SummaryWriter
import time
import os
import tensorflow as tf

# Disable TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Only show errors and fatal messages
tf.get_logger().setLevel('ERROR')  # Set TensorFlow logger to only show errors

# from absl import flags
# FLAGS = flags.FLAGS
# ml_collections.config_flags.DEFINE_config_file(
#     'config',
#     None,
#     'File path to the training hyperparameter configuration.',
#     lock_config=True,
# )
# should appear as FLAGS.config

def main(argv):
    config = ml_collections.ConfigDict()
    config.act_fn = "colu"
    config.dataset_name = 'binarized_mnist'
    config.learning_rate = 1e-3
    config.batch_size = 128
    config.variant = "soft"
    config.share_axis = True
    # config.latents = 2401
    config.num_groups = 1
    config.seed = 105
    config.num_epochs = 100
    for seed in [0,1,2,3,4]:
        for variant in ["soft","hard"]: #,"softmax","softapprox"
            for share_axis in [True,False]:
                for num_groups in [0,1,4,800,2400]: # S - 1 = inf (trivial), 2400, 1200, 240, 48, 6, 3, 2, 1 (pointwise)
                    config.seed = seed
                    config.num_epochs = 100
                    config.variant = variant
                    config.share_axis = share_axis
                    config.latents = 2401 if share_axis else 2400 + num_groups
                    config.num_groups = num_groups
                    # just_now = time.time()
                    loss = vae.train_and_evaluate(config)
    # for seed in [0,1,13579]:
    #     config.num_epochs = 500
    #     config.seed = seed
    #     loss = train.train_and_evaluate(config)

if __name__ == '__main__':
    app.run(main)