import datetime
import time
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
import clevrtex_2.data_inf as data_utils
import clevrtex_2.encoder_resnet as model_utils_encoder
import clevrtex_2.decoder_resnet as model_utils_decoder
import clevrtex_2.utils as utils
import numpy as np

FLAGS = flags.FLAGS
flags.DEFINE_string("model_dir", "rgb_rgb_1",
                    "Where to save the checkpoints.")
flags.DEFINE_integer("seed", 12, "Random seed.")
flags.DEFINE_integer("batch_size", 32, "Batch size for the model.")
flags.DEFINE_integer("num_slots", 11, "Number of slots in Slot Attention.")
flags.DEFINE_integer("num_iterations", 3, "Number of attention iterations.")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate.")
flags.DEFINE_integer("num_train_steps", 500000, "Number of training steps.")
flags.DEFINE_integer("warmup_steps", 10000,
                     "Number of warmup steps for the learning rate.")
flags.DEFINE_float("decay_rate", 0.5, "Rate for the learning rate decay.")
flags.DEFINE_integer("decay_steps", 100000,
                     "Number of steps for the learning rate decay.")




def main(argv):
  del argv
  #define similar to training file. we don't need the hyperparameters during inference, but we need them to reinstantiate the model
  batch_size = FLAGS.batch_size
  num_slots = FLAGS.num_slots
  num_iterations = FLAGS.num_iterations
  base_learning_rate = FLAGS.learning_rate
  num_train_steps = FLAGS.num_train_steps
  warmup_steps = FLAGS.warmup_steps
  decay_rate = FLAGS.decay_rate
  decay_steps = FLAGS.decay_steps
  tf.random.set_seed(FLAGS.seed)
  resolution = (128, 128)
  batch_size = 32
  model_dir = FLAGS.model_dir


  #load the test dataset
  data_iterator = data_utils.build_clevrtex_iterator(
      batch_size, split="testing")
  optimizer = tf.keras.optimizers.legacy.Adam(base_learning_rate, epsilon=1e-08)

  #define the slot attention encoder and decoder
  encoder = model_utils_encoder.build_model(resolution, batch_size, num_slots,
                                  num_iterations, model_type="object_discovery")
  decoder = model_utils_decoder.build_model(resolution, batch_size, num_slots,
                                  num_iterations, num_channels = 5,  model_type="object_discovery")
  global_step = tf.Variable(
      0, trainable=False, name="global_step", dtype=tf.int64)
  
  #load weights of encoder
  ckpt_enc = tf.train.Checkpoint(
      network=encoder, optimizer=optimizer, global_step=global_step)
  ckpt_manager_enc = tf.train.CheckpointManager(
      checkpoint=ckpt_enc, directory=model_dir+"_enc", max_to_keep=200)
  ckpt_enc.restore(ckpt_manager_enc.latest_checkpoint)

  if ckpt_manager_enc.latest_checkpoint:
    logging.info("Restored from %s", ckpt_manager_enc.latest_checkpoint)
  else:
    logging.info("Initializing from scratch")

  #load weights of decoder
  ckpt_dec = tf.train.Checkpoint(
    network=decoder, optimizer=optimizer, global_step=global_step)
  ckpt_manager_dec = tf.train.CheckpointManager(
      checkpoint=ckpt_dec, directory=model_dir+"_dec", max_to_keep=200)
  ckpt_dec.restore(ckpt_manager_dec.latest_checkpoint)

  if ckpt_manager_dec.latest_checkpoint:
    logging.info("Restored from %s", ckpt_manager_dec.latest_checkpoint)
  else:
    logging.info("Initializing from scratch")


  #start inference. get 32*156 samples which is appro 5k.
  all_pred_images = []
  all_gt_masks = []
  all_pred_masks = []
  all_slots = []
  all_gt_images = []

  for i in range(156):  
    batch = next(data_iterator)
    slots,s_p,s_s = encoder.predict(batch["image"])
    predictions, recons, masks, slots = decoder.predict([slots,s_p,s_s])
    predictions = np.array(predictions)
    slots = np.array(slots)
    masks = np.array(masks)
    all_gt_images.append(np.array(batch["image"]))
    all_gt_masks.append(np.array(batch["mask"]))
    all_pred_masks.append(masks)
    all_pred_images.append(predictions)
    all_slots.append(slots)
 
  #save samples at hard disc with numpy. can be further used for visualization of "evaluate.py"
  np.save("all_pred_images.npy",np.array(all_pred_images))
  np.save("all_gt_masks.npy",np.array(all_gt_masks))
  np.save("all_pred_masks.npy",np.array(all_pred_masks))
  np.save("all_slots.npy",np.array(all_slots))
  np.save("all_gt_images.npy",np.array(all_gt_images))

if __name__ == "__main__":
  app.run(main)
