import gc
import io
import os
import time

import numpy as np
import tensorflow as tf
import tensorflow_gan as tfgan
import logging
# Keep the import below for registering all model definitions
# from models import ddpm, ncsnv2, ncsnpp
import losses
import sampling
# from models import utils as mutils
# from models.ema import ExponentialMovingAverage
import datasets
import evaluation
import likelihood
import sde_lib
import torch
from torch.utils import tensorboard
from torchvision.utils import make_grid, save_image
from utils import save_checkpoint, restore_checkpoint
from tqdm import trange
import sys

def eval_FID(config, eval_dir):
    all_logits = []
    all_pools = []
    for i in range(config.data.num_classes):
        stats = tf.io.gfile.glob(os.path.join(eval_dir, str(i), f"statistics_*.npz"))
        total = 0
        for stat_file in stats:
            with tf.io.gfile.GFile(stat_file, "rb") as fin:
                stat = np.load(fin)
                all_logits.append(stat["logits"])
                all_pools.append(stat["pool_3"])
                total += len(all_pools[-1])
        if total == config.eval.num_samples:
            continue
        elif total > config.eval.num_samples:
            all_pools[-1] = all_pools[-1][:-(total - config.eval.num_samples)]
            all_logits[-1] = all_logits[-1][:-(total - config.eval.num_samples)]
        else:
            raise ValueError(f'Not enough samples for class {i}.')

    all_logits = np.concatenate(all_logits, axis=0)
    all_pools = np.concatenate(all_pools, axis=0)
    data_stats = evaluation.load_dataset_stats(config)
    data_pools = data_stats["pool_3"]

    inception_score = tfgan.eval.classifier_score_from_logits(all_logits)
    fid = tfgan.eval.frechet_classifier_distance_from_activations(data_pools, all_pools)
    # Hack to get tfgan KID work for eager execution.
    tf_data_pools = tf.convert_to_tensor(data_pools)
    tf_all_pools = tf.convert_to_tensor(all_pools)
    kid = tfgan.eval.kernel_classifier_distance_from_activations(tf_data_pools, tf_all_pools).numpy()
    del tf_data_pools, tf_all_pools

    print(f"FID = {fid:.6e}, Inception Score = {inception_score:.6e}, KID = {kid:.6e}")
