from absl import app, flags
from ml_collections.config_flags import config_flags
import logging
import os
import io
import tensorflow as tf
import gc
import numpy as np
import evaluation
from eval_intra_FID import eval_intra_FID
from eval_FID import eval_FID
from eval_acc import eval_acc
from eval_PRDC import eval_PRDC
from eval_CAS import eval_CAS
from eval_calib_error import eval_calib_error
from tqdm import trange


FLAGS = flags.FLAGS

config_flags.DEFINE_config_file(
  "config", None, "Training configuration.", lock_config=True)
flags.DEFINE_string("eval_folder", "eval", "The folder name storing the evaluation results.")
flags.DEFINE_string("eval_metrics", "intra-FID", "Evaluation metrics separated using comma. (FID, intra-FID, Acc, PRDC, CAS, Calib-Error, or all)")
flags.DEFINE_bool("run_inception", False, "Specify 'True' to run inception model.")
flags.mark_flags_as_required(["config", "eval_folder"])


def main(argv):
    eval_dir = FLAGS.eval_folder
    metrics = FLAGS.eval_metrics.split(',')

    if FLAGS.run_inception and ('FID' in metrics or 'intra-FID' in metrics or 'PRDC' in metrics or 'all' in metrics):
        inception_model = evaluation.get_inception_model(inceptionv3=False)
        print(f'Getting latent of inception model')
        for i in trange(FLAGS.config.data.num_classes):
            stats = tf.io.gfile.glob(os.path.join(eval_dir, str(i), "sample_*.npz"))
            qq = False
            if len(stats) == 0:
                qq = True
                stats = tf.io.gfile.glob(os.path.join(eval_dir, str(i), "samples_*.npz"))
            for r, stat_file in enumerate(stats):
                if qq:
                    samples = np.load(os.path.join(eval_dir, str(i), f"samples_{r}.npz"))['samples']
                else:
                    samples = np.load(os.path.join(eval_dir, str(i), f"sample_{r}.npz"))['samples']
                
                latents = evaluation.run_inception_distributed(samples, inception_model, inceptionv3=False)
                gc.collect()
                with tf.io.gfile.GFile(os.path.join(eval_dir, str(i), f"statistics_{r}.npz"), "wb") as fout:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer, pool_3=latents["pool_3"], logits=latents["logits"])
                    fout.write(io_buffer.getvalue())
            
    if 'FID' in metrics or 'all' in metrics:
        eval_FID(FLAGS.config, eval_dir)
    if 'intra-FID' in metrics or 'all' in metrics:
        eval_intra_FID(FLAGS.config, eval_dir)
    if 'Acc' in metrics or 'all' in metrics:
        if FLAGS.config.data.dataset == 'CIFAR10' or 'CIFAR10_SEMI' in FLAGS.config.data.dataset:
            eval_acc(FLAGS.config, eval_dir)
        else:
            raise ValueError(f'Accuracy evaluation not supported on dataset {config.data.dataset}.')
    if 'PRDC' in metrics or 'all' in metrics:
        eval_PRDC(FLAGS.config, eval_dir)
    if 'CAS' in metrics or 'all' in metrics:
        if FLAGS.config.data.dataset == 'CIFAR10' or 'CIFAR10_SEMI' in FLAGS.config.data.dataset:
            eval_CAS(FLAGS.config, eval_dir)
        else:
            raise ValueError(f'CAS evaluation not supported on dataset {config.data.dataset}.')
    if 'Calib-Error' in metrics or 'all' in metrics:
        eval_calib_error(FLAGS.config, eval_dir)


if __name__ == "__main__":
    app.run(main)
