# train two VAEs with initializations up to a rotation
import jax
import jax.numpy as jnp
import models
import train
import utils as vae_utils
import ml_collections
from absl import app
from absl import flags
from tensorboardX import SummaryWriter
import time
import os
import tensorflow as tf

# Disable TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Only show errors and fatal messages

# Alternatively, you can use:
tf.get_logger().setLevel('ERROR')  # Set TensorFlow logger to only show errors

# FLAGS = flags.FLAGS
# ml_collections.config_flags.DEFINE_config_file(
#     'config',
#     None,
#     'File path to the training hyperparameter configuration.',
#     lock_config=True,
# )
# should appear as FLAGS.config

import jax.numpy as jnp

def rotation_matrix(theta, n1, n2):
    dim = len(n1)
    assert len(n1) == len(n2)
    assert jnp.abs(jnp.dot(n1, n2)) < 1e-4
    return (jnp.eye(dim) +
            (jnp.outer(n2, n1) - jnp.outer(n1, n2)) * jnp.sin(theta) +
            (jnp.outer(n1, n1) + jnp.outer(n2, n2)) * (jnp.cos(theta) - 1))

def rotate_vector(theta, n1, n2, vec):
    assert len(n1) == len(n2)
    assert len(n1) == len(vec)
    assert jnp.abs(jnp.dot(n1, n2)) < 1e-4

    n1_dot_vec = jnp.dot(n1, vec)
    n2_dot_vec = jnp.dot(n2, vec)

    return (vec +
            (n2 * n1_dot_vec - n1 * n2_dot_vec) * jnp.sin(theta) +
            (n1 * n1_dot_vec + n2 * n2_dot_vec) * (jnp.cos(theta) - 1))

n1 = jnp.zeros((500,)).at[1].set(1)
n2 = jnp.zeros((500,)).at[2].set(1)
theta = jnp.pi / 360
p = rotation_matrix(theta, n1, n2)

def main(argv):
    config = ml_collections.ConfigDict()
    config.act_fn = "colu"
    # config.dataset_name = 'mnist'
    config.dataset_name = 'binarized_mnist'
    # config.model_name = "mnist"
    config.learning_rate = 1e-3
    config.batch_size = 128
    config.variant = "soft"
    config.share_axis = True
    # config.latents = 2401
    config.num_groups = 1
    config.seed = 105
    config.num_epochs = 100
    # each 100-epoch takes <3 min
    for seed in [0,1,2,3,4]:
        for variant in ["soft","hard"]: #,"softmax","softapprox"
            for share_axis in [True,False]:
                for num_groups in [0,1,4,800,2400]: # S - 1 = inf (trivial), 2400, 1200, 240, 48, 6, 3, 2, 1 (pointwise)
                    config.seed = seed
                    config.num_epochs = 100
                    config.variant = variant
                    config.share_axis = share_axis
                    config.latents = 2401 if share_axis else 2400 + num_groups
                    config.num_groups = num_groups
                    # just_now = time.time()
                    loss = train.train_and_evaluate(config)
    # for seed in [0,1,13579]:
    #     config.num_epochs = 500
    #     config.seed = seed
    #     loss = train.train_and_evaluate(config)

if __name__ == '__main__':
    app.run(main)