import datetime
import time

from absl import app
from absl import flags
from absl import logging
import tensorflow as tf

import clevrtex_2.data as data_utils
import clevrtex_2.encoder_resnet as model_utils_encoder
import clevrtex_2.decoder_resnet as model_utils_decoder
import clevrtex_2.utils as utils
import numpy as np

FLAGS = flags.FLAGS
flags.DEFINE_string("model_dir", "model_dir",
                    "Where to save the checkpoints.")
flags.DEFINE_integer("seed", 51, "Random seed.")
flags.DEFINE_integer("batch_size", 32, "Batch size for the model.")
flags.DEFINE_integer("num_slots", 11, "Number of slots in Slot Attention.")
flags.DEFINE_integer("num_iterations", 3, "Number of attention iterations.")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate.")
flags.DEFINE_integer("num_train_steps", 250000, "Number of training steps.")
flags.DEFINE_integer("warmup_steps", 50000,
                     "Number of warmup steps for the learning rate.")
flags.DEFINE_float("decay_rate", 0.5, "Rate for the learning rate decay.")
flags.DEFINE_integer("decay_steps", 100000,
                     "Number of steps for the learning rate decay.")

#defines the train step. we only train on the reconstruction of the RGB image.
@tf.function
def train_step(batch, encoder, decoder, optimizer):
    with tf.GradientTape() as tape:
        slots, s_p, s_s = encoder(batch["image"], training=True)
        recon_combined, recons, masks, slots = decoder([slots, s_p, s_s], training=True)
        loss_value = utils.l2_loss(batch["image_rgbsv"], recon_combined)
        del recons, masks, slots 
    trainable_weights = encoder.trainable_weights + decoder.trainable_weights
    gradients = tape.gradient(loss_value, trainable_weights)
    optimizer.apply_gradients(zip(gradients, trainable_weights))
    return loss_value

#load hyperparameters. leave untouched for reproducing our experiments.
def main(argv):
  del argv
  batch_size = FLAGS.batch_size
  num_slots = FLAGS.num_slots
  num_iterations = FLAGS.num_iterations
  base_learning_rate = FLAGS.learning_rate
  num_train_steps = FLAGS.num_train_steps
  warmup_steps = FLAGS.warmup_steps
  decay_rate = FLAGS.decay_rate
  decay_steps = FLAGS.decay_steps
  tf.random.set_seed(FLAGS.seed)
  resolution = (128, 128)
  #load the train data of clevrtex
  data_iterator = data_utils.build_clevrtex_iterator(
      batch_size, split="train")
  optimizer = tf.keras.optimizers.legacy.Adam(base_learning_rate, epsilon=1e-08)

  #define encoder
  encoder = model_utils_encoder.build_model(resolution, batch_size, num_slots,
                                  num_iterations, model_type="object_discovery")
  
  #define RGB decoder
  decoder = model_utils_decoder.build_model(resolution, batch_size, num_slots,
                                  num_iterations, num_channels = 5,  model_type="object_discovery")
  
  #reload model if training was interrupted
  global_step = tf.Variable(
      0, trainable=False, name="global_step", dtype=tf.int64)
  ckpt_enc = tf.train.Checkpoint(
      network=encoder, optimizer=optimizer, global_step=global_step)
  ckpt_manager_enc = tf.train.CheckpointManager(
      checkpoint=ckpt_enc, directory=FLAGS.model_dir+"_enc", max_to_keep=200)
  ckpt_enc.restore(ckpt_manager_enc.latest_checkpoint)
  if ckpt_manager_enc.latest_checkpoint:
    logging.info("Restored from %s", ckpt_manager_enc.latest_checkpoint)
  else:
    logging.info("Initializing from scratch")

  ckpt_dec = tf.train.Checkpoint(
    network=decoder, optimizer=optimizer, global_step=global_step)
  ckpt_manager_dec = tf.train.CheckpointManager(
      checkpoint=ckpt_dec, directory=FLAGS.model_dir+"_dec", max_to_keep=200)
  ckpt_dec.restore(ckpt_manager_dec.latest_checkpoint)
  if ckpt_manager_dec.latest_checkpoint:
    logging.info("Restored from %s", ckpt_manager_dec.latest_checkpoint)
  else:
    logging.info("Initializing from scratch")

  start = time.time()
  #train untill num_train_steps is reached
  for _ in range(num_train_steps - global_step.numpy()):
    batch = next(data_iterator)
    if global_step < warmup_steps:
      learning_rate = base_learning_rate * tf.cast(
          global_step, tf.float32) / tf.cast(warmup_steps, tf.float32)
    else:
      learning_rate = base_learning_rate
    learning_rate = learning_rate * (decay_rate ** (
        tf.cast(global_step, tf.float32) / tf.cast(decay_steps, tf.float32)))
    optimizer.lr = learning_rate.numpy()
    loss_value = train_step(batch, encoder,decoder,optimizer)
    global_step.assign_add(1)
    if not global_step % 100:
      logging.info("Step: %s, Loss: %.6f, Time: %s",
                   global_step.numpy(), loss_value,
                   datetime.timedelta(seconds=time.time() - start))
      
      #log MSE.
    if not global_step  % 1000:
      saved_ckpt = ckpt_manager_enc.save()
      ckpt_manager_dec.save()
      logging.info("Saved checkpoint: %s", saved_ckpt)
      
if __name__ == "__main__":
  app.run(main)
