# coding=utf-8
# Copyright 2023.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""

Uncertainty Aware Information Bottlenck Wide ResNet 28-10 on CIFAR-10/100.

Hyperparameters differ slightly from the original paper's code
(https://github.com/szagoruyko/wide-residual-networks) as TensorFlow uses, for
example, l2 instead of weight decay, and a different parameterization for SGD's
momentum.
"""

import os
import time

from absl import app
from absl import flags
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from absl import logging
import robustness_metrics as rm

import tensorflow_datasets as tfds
import datasets

import models
import utils.ood_utils as ood_utils
import utils.calibration_utils as calibration_utils
import utils.uaib_utils as uaib_utils
import utils.run_utils as run_utils
import utils.schedules as schedules
from tensorboard.plugins.hparams import api as hp
import tensorflow_probability as tfp

FLAGS = flags.FLAGS


def main(argv):
    
    ####################     set-up directories      #################### 
    
    fmt = "[%(filename)s:%(lineno)s] %(message)s"
    formatter = logging.PythonFormatter(fmt)
    logging.get_absl_handler().setFormatter(formatter)
    del argv  # unused arg

    dir_path = os.path.dirname(os.path.realpath(__file__))
    tf.io.gfile.makedirs(dir_path)
    logging.info("Saving checkpoints at %s", dir_path)
    tf.random.set_seed(FLAGS.seed)

    summary_writer = tf.summary.create_file_writer(os.path.join(dir_path, f'summaries_uaib/{FLAGS.uaib_dim}_{FLAGS.codebook_size}_seed_{FLAGS.seed}'))
    #  tf.io.gfile.makedirs(FLAGS.output_dir)
    #  logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    #  tf.random.set_seed(FLAGS.seed)
    
    
    ##################     set-up distributed training      ################## 
    
    data_dir = FLAGS.data_dir
    if FLAGS.use_gpu:
        logging.info("Use GPU")
        strategy = tf.distribute.MirroredStrategy()
    # TODO: check TPU
    else:
        logging.info("Use TPU at %s", FLAGS.tpu if FLAGS.tpu is not None else "local")
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)
        
    ##################     set-up datasets                      ################## 

    ds_info = tfds.builder(FLAGS.dataset).info
    batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    train_dataset_size = ds_info.splits["train"].num_examples * FLAGS.train_proportion
    steps_per_epoch = int(train_dataset_size / batch_size)
    logging.info("Steps per epoch %s", steps_per_epoch)
    logging.info("Size of the dataset %s", ds_info.splits["train"].num_examples)
    logging.info("Train proportion %s", FLAGS.train_proportion)
    steps_per_eval = ds_info.splits["test"].num_examples // batch_size
    num_classes = ds_info.features["label"].num_classes


    # Note that stateless_{fold_in,split} may incur a performance cost, but a
    # quick side-by-side test seemed to imply this was minimal.
    seeds = tf.random.experimental.stateless_split([FLAGS.seed, FLAGS.seed + 1], 2)[:, 0]
    train_builder = datasets.get(
        FLAGS.dataset,
        data_dir=data_dir,
        download_data=FLAGS.download_data,
        split=tfds.Split.TRAIN,
        seed=seeds[0],
        shuffle_buffer_size=FLAGS.shuffle_buffer_size,
        validation_percent=0.0,
    )
    train_dataset = train_builder.load(batch_size=batch_size)
    
    clean_test_builder = datasets.get(
        FLAGS.dataset,
        split=tfds.Split.TEST,
        data_dir=data_dir,
        drop_remainder=FLAGS.drop_remainder_for_eval,
    )

    clean_test_dataset = clean_test_builder.load(batch_size=batch_size)
    test_datasets = {"clean": strategy.experimental_distribute_dataset(clean_test_dataset),}

    train_dataset = strategy.experimental_distribute_dataset(train_dataset)

    steps_per_epoch = train_builder.num_examples // batch_size
    steps_per_eval = clean_test_builder.num_examples // batch_size
    num_classes = 100 if FLAGS.dataset == "cifar100" else 10
    


    if FLAGS.eval_on_ood:
        ood_dataset_names = FLAGS.ood_dataset
        ood_ds, steps_per_ood = ood_utils.load_ood_datasets(
            ood_dataset_names,
            clean_test_builder,
            1.0 - FLAGS.train_proportion,
            batch_size,
            drop_remainder=FLAGS.drop_remainder_for_eval,
        )
        ood_datasets = {
            name: strategy.experimental_distribute_dataset(ds)
            for name, ds in ood_ds.items()
        }

    
    
    with strategy.scope():
        
        ##################     build model     ################## 
        
        logging.info("Building ResNet model")
        model = models.wide_resnet_uaib(
            input_shape=(32, 32, 3),
            depth=28,
            uaib_dim=FLAGS.uaib_dim,
            codebook_size=FLAGS.codebook_size,
            uaib_tau=FLAGS.uaib_tau,
            width_multiplier=10,
            num_classes=num_classes,
            l2=FLAGS.l2,
            hps=run_utils._extract_hyperparameter_dictionary(),
            seed=seeds[1],
        )
        logging.info("Model input shape: %s", model.input_shape)
        logging.info("Model output shape: %s", model.output_shape)
        logging.info("Model number of weights: %s", model.count_params())
        
        ##################     set-up optimizers    ################## 
        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * batch_size / 128
        
        train_epochs=FLAGS.train_epochs
        lr_decay_epochs = [
            (int(start_epoch_str) * train_epochs) // 200
            for start_epoch_str in FLAGS.lr_decay_epochs
        ]
        
        # optimizer of encoder/ decoder
        lr_schedule = schedules.WarmUpPiecewiseConstantSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs,
        )
        optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=1.0 - FLAGS.one_minus_momentum, nesterov=True)
        
        # optimizer of centroids
        optimizer_c = tf.keras.optimizers.Adam( FLAGS.cluster_base_learning_rate)
        
        ##################     create metrics    ################## 

        metrics = {
            "train/negative_log_likelihood": tf.keras.metrics.Mean(),
            "train/accuracy": tf.keras.metrics.SparseCategoricalAccuracy(),
            "train/loss": tf.keras.metrics.Mean(),
            "train/ece": rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            "test/negative_log_likelihood": tf.keras.metrics.Mean(),
            "test/accuracy": tf.keras.metrics.SparseCategoricalAccuracy(),
            "test/ece": rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
        }
        cluster_metrics={}
        if FLAGS.eval_clusters_true_label:
            cluster_metrics.update(uaib_utils.create_cluster_metrics(FLAGS.codebook_size, num_classes,"true_label"))
      
        if FLAGS.eval_clusters_predicted_label:
            cluster_metrics.update(uaib_utils.create_cluster_metrics(FLAGS.codebook_size, num_classes,"predicted_label"))
            
        if FLAGS.eval_on_ood:
            ood_scores={'cluster_distance','onempmax','entropy'}
            ood_metrics = ood_utils.create_ood_metrics(ood_dataset_names,ood_scores)
            metrics.update(ood_metrics)
            
        ## TODO: support method-specific metrics
        if FLAGS.eval_calibration:
            uncertainty_scores={'cluster_distance'}
            calibration_metrics = calibration_utils.create_calibration_metrics(uncertainty_scores,FLAGS.calibration_num_buckets)
            metrics.update(calibration_metrics)
            
        #TODO: save/ restore checkpoints for eval only

        #    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        #    latest_checkpoint = tf.train.latest_checkpoint(dir_path)
        #    if latest_checkpoint:
        #      # checkpoint.restore must be within a strategy.scope() so that optimizer
        #      # slot variables are mirrored.
        #      checkpoint.restore(latest_checkpoint)
        #      logging.info('Loaded checkpoint %s', latest_checkpoint)
        #      initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

        #    if FLAGS.saved_model_dir:
        #      logging.info('Saved model dir : %s', FLAGS.saved_model_dir)
        #      latest_checkpoint = tf.train.latest_checkpoint(FLAGS.saved_model_dir)
        #      checkpoint.restore(latest_checkpoint)
        #      logging.info('Loaded checkpoint %s', latest_checkpoint)
            
    ######################################################
    
    ##################          learning algorithm        ################## 
    
    ######################################################

    @tf.function
    def cluster_step(iterator):
        
        """ 
        Optimization of Variational Information Bottleneck Clustering.
        
        1. Train variational marginal posterior (centroids)
            * The means are learned by gradient descent.
            * The (full rank) covariance matrices are computed analytically.
        2. Learn centroid probabilities by Bayes rule.
        """
        
        def centroid_step_fn(inputs):
            """
             Train variational marginal posterior (centroids)
            """
            
            images = inputs["features"]

            with tf.GradientTape() as tape:
                outputs=model(images, training=True)
                
                _, uncertainty= tf.split(outputs, [num_classes, 1], axis=-1)
                

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = tf.reduce_mean(uncertainty) / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
                # train centroids
            grads_and_vars = []
            for grad, var in zip(grads, model.trainable_variables):
                      if ('centroid' in var.name ):
                        grads_and_vars.append((grad ,  var))
            optimizer_c.apply_gradients(grads_and_vars)
        
        def centroid_probs_step_fn(inputs):
            """
             Learn centroid assignment probabilities by Bayes rule.
             This is a single-forward pass (assignment probabilities are updated batch-wise within the layer in training mode).
            """
            images = inputs["features"]
            labels = inputs["labels"]
            model(images, training=True)
            
            cluster_distances=[l for l in model.losses if 'cluster_distances' in l.name][0]
            
            # get closest centroid
            clusters=tf.math.argmin(cluster_distances, axis=-1)
            
            return tf.cast(clusters, tf.int32),  tf.cast(labels, tf.int32)
            
        initialized_centroids_op =   model.get_layer('dense_uaib') .initialized.assign(tf.constant(True, dtype=tf.bool))
        model.get_layer('dense_uaib').add_update(initialized_centroids_op)
            
        # M-step: train centroids\\
        
     
        model.get_layer('dense_uaib').reset_centroids()
        for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
            strategy.run(centroid_step_fn, args=(next(iterator),))    
        model.get_layer('dense_uaib').set_centroids()
       
        # M-step: compute prior centroid probabilities 
        
        labels_all_ls = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
        clusters_all_ls = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
        
        model.get_layer('dense_uaib').reset_centroid_probs()
        
        
        for i in tf.range(tf.cast(steps_per_epoch, tf.int32)):
            clusters, labels=strategy.run(centroid_probs_step_fn, args=(next(iterator),))
            
            if (strategy.num_replicas_in_sync > 1): 
                    clusters = tf.concat((clusters.values)  ,  axis=0)
                    labels = tf.concat((labels.values) , axis=0)
                    
            clusters_all_ls = clusters_all_ls.write(i, clusters)
            labels_all_ls = labels_all_ls.write(i, labels)
        model.get_layer('dense_uaib').set_centroid_probs()
        
        clusters = clusters_all_ls.stack()
        clusters = tf.reshape(clusters, [-1])
        
        labels = labels_all_ls.stack()
        labels = tf.reshape(labels, [-1])
        
        cluster_classes= uaib_utils.cluster_majority_class(clusters, labels,FLAGS.codebook_size)
        
        return cluster_classes
        



    @tf.function
    def train_step(iterator):
        """
        Network Optimization:
        
        1. Train encoder p(z|x) by gradient descent.
        2. Train decoder p(y|z) by gradient descent.
        3. Perform E-step to computed the conditional centroid probabilities of datapoints (needed for the uaib regularization loss).
        """

        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs["features"]
            labels = inputs["labels"]


            with tf.GradientTape() as tape:
                outputs = model(images, training=True)

                logits, uncertainty= tf.split(outputs, [num_classes, 1], axis=-1)
              
                negative_log_likelihood = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True))

                l2_loss = sum([l for l in model.losses  if not "cluster_distances" in l.name])
                
                loss = negative_log_likelihood + l2_loss + FLAGS.beta * tf.reduce_mean(uncertainty)
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync
                
  

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            
            # train encoder/ decoder
            grads_and_vars = []
            for grad, var in zip(grads, model.trainable_variables):

                  if ('centroid' not in var.name ):
                    grads_and_vars.append((grad ,  var))
            optimizer.apply_gradients(grads_and_vars)
      

            probs = tf.nn.softmax(logits)
            metrics["train/ece"].add_batch(probs, label=labels)
            metrics["train/loss"].update_state(loss)
            metrics["train/negative_log_likelihood"].update_state(
                negative_log_likelihood
            )
            metrics["train/accuracy"].update_state(labels, logits)
            
            


        for i in tf.range(tf.cast(steps_per_epoch, tf.int32)):
            strategy.run(step_fn, args=(next(iterator),))
            
     
            

            
    ######################################################
    
    ##################          evaluation methods        ################## 
    
    ######################################################

    @tf.function(experimental_autograph_options=tf.autograph.experimental.Feature.LISTS)
    def test_ood_step(iterator, dataset_split, dataset_name, num_steps):
        """Out-Of Distribution Detection Evaluation StepFn."""
        
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs["features"]
            
            outputs = model(images, training=False)
            logits, cluster_distance = tf.split(outputs, [num_classes, 1], axis=-1)
        
            
            ood_scores_dict={}
            ood_scores_dict['ood_labels']=1 - inputs["is_in_distribution"]
            if 'entropy' in ood_scores:
                pred_dist = tfp.distributions.Categorical(logits=logits)
                entropy = tf.cast(tf.squeeze(pred_dist.entropy()), tf.float32)
                ood_scores_dict['entropy']=entropy
            
            if 'onempmax' in ood_scores:
                probs = tf.nn.softmax(logits)
                onempmax= tf.cast(tf.squeeze(1 - tf.reduce_max(probs, axis=-1)), tf.float32)
                ood_scores_dict['onempmax']=onempmax
                
            if 'cluster_distance' in ood_scores:
                ood_scores_dict['cluster_distance']=cluster_distance

            return ood_scores_dict
# =============================================================================
#         
# =============================================================================
        
        
        all_return_values_dict={}
        
        all_return_values_dict['ood_labels']=tf.TensorArray(tf.int32, size=num_steps, dynamic_size=False)
        for name in ood_scores:
            all_return_values_dict[name]=tf.TensorArray(tf.float32, size=num_steps, dynamic_size=False)

       
        for i in tf.range(tf.cast(num_steps, tf.int32)):
            
            batch_return_values_dict=strategy.run(step_fn, args=(next(iterator),))
        

            for name,val in batch_return_values_dict.items():
       
                if (strategy.num_replicas_in_sync > 1):
                    all_values=tf.concat((batch_return_values_dict[name].values) , axis=0)
                else:
                    all_values=batch_return_values_dict[name]
                    
                all_return_values_dict[name]=all_return_values_dict[name].write(i,all_values)


        for name,val in all_return_values_dict.items():

            all_return_values_dict[name]=all_return_values_dict[name].concat()

        return all_return_values_dict


    @tf.function(experimental_autograph_options=tf.autograph.experimental.Feature.LISTS)
    def test_step(iterator, dataset_split, dataset_name, num_steps,cluster_classes=None):
        """Evaluation StepFn."""

        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs["features"]
            labels = inputs["labels"]
            outputs = model(images, training=False)
            
            
            logits, cluster_distance= tf.split(outputs, [num_classes, 1], axis=-1)
            
            cluster_distance=tf.squeeze(cluster_distance)

            # update accuracy metrics
            probs = tf.nn.softmax(logits)
            negative_log_likelihood = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(labels, probs))
            metrics[f"{dataset_split}/negative_log_likelihood"].update_state(negative_log_likelihood)
            
            
            metrics[f"{dataset_split}/accuracy"].update_state(labels, probs)
            metrics[f"{dataset_split}/ece"].add_batch(probs, label=labels)
            
            # TODO: check later if squeezes and type casts are needed
            clusters=None
            return_values_dict={}
            
            
            
            
            # prepare outputs for calibration evaluation
            if FLAGS.eval_calibration:
                
                matches=tf.cond(tf.convert_to_tensor(FLAGS.eval_calibration, tf.bool),lambda: calibration_utils.sparse_categorical_matches(labels, logits), lambda: tf.cast(labels, tf.int32) )
                return_values_dict['matches']=matches
                
                # update uncertainty metrics
                uncertainty_scores_dict={}
        
                
                if 'cluster_distance' in uncertainty_scores:
                    uncertainty_scores_dict['cluster_distance']=cluster_distance
                

                return_values_dict.update(uncertainty_scores_dict)
                

            
            if not FLAGS.eval_clusters_true_label and not FLAGS.eval_clusters_predicted_label:
                return return_values_dict
            
            # prepare outputs for calibration evaluation
            
            # update cluster metrics
            if clusters is None:
                clusters=tf.math.argmin([l for l in model.losses if 'cluster_distances' in l.name][0], axis=-1)
            return_values_dict['clusters']=tf.cast(clusters, tf.int32)
                
            if FLAGS.eval_clusters_true_label:
                return_values_dict['labels']=tf.cast(labels, tf.int32)
            
            if FLAGS.eval_clusters_predicted_label:
                predicted_labels=tf.math.argmax(logits, axis=-1)
                return_values_dict['predicted_labels']=tf.cast(predicted_labels, tf.int32)
                            
            return return_values_dict
            
        
        
# =============================================================================
#         
# =============================================================================
        all_return_values_dict={}
        

        
        if FLAGS.eval_calibration:

            all_return_values_dict['matches']=tf.TensorArray(tf.int32, size=num_steps, dynamic_size=False)
            for name in uncertainty_scores:
                all_return_values_dict[name]=tf.TensorArray(tf.float32, size=num_steps, dynamic_size=False)
                
        if FLAGS.eval_clusters_true_label or FLAGS.eval_clusters_predicted_label:
            all_return_values_dict['clusters'] = tf.TensorArray(tf.int32, size=num_steps, dynamic_size=False)
            if FLAGS.eval_clusters_true_label:
                all_return_values_dict['labels'] = tf.TensorArray(tf.int32, size=num_steps, dynamic_size=False)
            if FLAGS.eval_clusters_predicted_label:
                all_return_values_dict['predicted_labels'] = tf.TensorArray(tf.int32, size=num_steps, dynamic_size=False)


        for i in tf.range(tf.cast(num_steps, tf.int32)):

            batch_return_values_dict =strategy.run(step_fn, args=(next(iterator),))
            
            
            # gather cluster assignments of test datapoints along with their label
     


            for name,val in batch_return_values_dict.items():
       
                if (strategy.num_replicas_in_sync > 1):
                    all_values=tf.concat((batch_return_values_dict[name].values) , axis=0)
                else:
                    all_values=batch_return_values_dict[name]
                    
                all_return_values_dict[name]=all_return_values_dict[name].write(i,all_values)


        for name,val in all_return_values_dict.items():

            all_return_values_dict[name]=all_return_values_dict[name].concat()
            
 

        return all_return_values_dict


# =============================================================================
        
       
    metrics.update({"test/ms_per_example": tf.keras.metrics.Mean()})
    metrics.update({"train/ms_per_example": tf.keras.metrics.Mean()})

    train_iterator = iter(train_dataset)
    start_time = time.time()
    
    
 
        
    # compute and update OOD metrics

    
    ######################################################
    
    ##################          training loop                 ################## 
    
    ######################################################
    
    for epoch in range(FLAGS.train_epochs):
       
        logging.info("Starting to run epoch: %s",epoch )
        if not FLAGS.eval_only:
                train_start_time = time.time()
                
                # train encoder/ decoder
                train_step(train_iterator)
                

                # train centroids
                logging.info("Starting to cluster epoch: %s", epoch)
                majority_classes=cluster_step(train_iterator)
                
                
                
                ms_per_example = (time.time() - train_start_time) * 1e6 / batch_size
                metrics["train/ms_per_example"].update_state(ms_per_example)

                current_step = (epoch + 1) * steps_per_epoch
                max_steps = steps_per_epoch * FLAGS.train_epochs
                time_elapsed = time.time() - start_time
                steps_per_sec = float(current_step) / time_elapsed
                eta_seconds = (max_steps - current_step) / steps_per_sec
                message = (
                    "{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. "
                    "ETA: {:.0f} min. Time elapsed: {:.0f} min".format(
                        current_step / max_steps,
                        epoch + 1,
                        FLAGS.train_epochs,
                        steps_per_sec,
                        eta_seconds / 60,
                        time_elapsed / 60,
                    )
                )
                logging.info(message)
                
                logging.info(
                "Train Loss: %.4f, Accuracy: %.2f%%",
                metrics["train/loss"].result(),
                metrics["train/accuracy"].result() * 100,
                )
                
        # evaluate on in-distribution, test data
        dataset_name, test_dataset="clean", test_datasets["clean"]
        test_iterator = iter(test_dataset)
        logging.info("Testing on dataset %s", dataset_name)
        logging.info("Starting to run eval at epoch!!!!!!1: %s", epoch)
        
        test_start_time = time.time()
        return_values=test_step(test_iterator, "test", dataset_name, steps_per_eval,majority_classes)
        ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size
        metrics["test/ms_per_example"].update_state(ms_per_example)
        logging.info(
                "Test NLL: %.4f, Accuracy from logits: %.2f%%",
                metrics["test/negative_log_likelihood"].result(),
                metrics["test/accuracy"].result() * 100,
            )
        
        if FLAGS.eval_calibration:
            
            uncertainty_dict = {uncertainty_score: return_values[uncertainty_score] for uncertainty_score in uncertainty_scores}
            
            calibration_utils.eval_calibration(strategy, calibration_metrics,return_values['matches'], uncertainty_dict)
   
        
        if FLAGS.eval_clusters_true_label:
            uaib_utils.eval_clusters(strategy,cluster_metrics,epoch,  return_values['clusters'], return_values['labels'],'true_label')
        if FLAGS.eval_clusters_predicted_label:
            uaib_utils.eval_clusters(strategy,cluster_metrics,epoch, return_values['clusters'], return_values['predicted_labels'],'predicted_label')
    
            
        logging.info("Done with testing on %s", dataset_name)

            
        # evaluate on out-of-distribution data
        if FLAGS.eval_on_ood:
            
            for ood_dataset_name, ood_dataset in ood_datasets.items():
            
                ood_iterator = iter(ood_dataset)

                return_values= test_ood_step(ood_iterator,"test",ood_dataset_name,steps_per_ood[ood_dataset_name],)
                ood_dict = {ood_score: return_values[ood_score] for ood_score in ood_scores}
                #ood_dict={'cluster_distance':rate,'onempmax':onempmax,'entropy':entropy}
                ood_utils.eval_on_ood(strategy, ood_metrics, return_values['ood_labels'], ood_dict,ood_dataset_name)
            
        # TODO: check eval only mode from restored weights
        if FLAGS.eval_only:
            break
            
        # update metrics
        # Metrics from Robustness Metrics (like ECE) will return a dict with a single key/value, instead of a scalar.
        total_results = {name: metric.result() for name, metric in metrics.items()}
        total_results = { k: (list(v.values())[0] if isinstance(v, dict) else v)   for k, v in total_results.items()  }
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)
            
        for metric in metrics.values():
            metric.reset_states()


    #    if (FLAGS.checkpoint_interval > 0 and
    #        (epoch + 1) % FLAGS.checkpoint_interval == 0):
    #      checkpoint_name = checkpoint.save(
    #          os.path.join(FLAGS.output_dir, 'checkpoint'))
    #      logging.info('Saved checkpoint to %s', checkpoint_name)

    #  final_checkpoint_name = checkpoint.save(
    #      os.path.join(FLAGS.output_dir, 'checkpoint'))
    #  logging.info('Saved last checkpoint to %s', final_checkpoint_name)
    with summary_writer.as_default():
        hp.hparams(
            {
                "base_learning_rate": FLAGS.base_learning_rate,
                "one_minus_momentum": FLAGS.one_minus_momentum,
                "l2": FLAGS.l2,
            }
        )


if __name__ == "__main__":
    app.run(main)
