import gc
import io
import os
import time

import numpy as np
import tensorflow as tf
import tensorflow_gan as tfgan
import logging
# Keep the import below for registering all model definitions
# from models import ddpm, ncsnv2, ncsnpp
import losses
import sampling
# from models import utils as mutils
# from models.ema import ExponentialMovingAverage
import datasets
import evaluation
import likelihood
import sde_lib
import torch
from torch.utils import tensorboard
from torchvision.utils import make_grid, save_image
from utils import save_checkpoint, restore_checkpoint
from tqdm import trange
import sys

def eval_intra_FID(config, eval_dir):
    fids = []
    incep_scores = []
    kids = []
    inception_model = evaluation.get_inception_model(inceptionv3=False)
    for i in trange(config.data.num_classes):
        all_logits = []
        all_pools = []
        stats = tf.io.gfile.glob(os.path.join(eval_dir, str(i), "statistics_*.npz"))
        for stat_file in stats:
            with tf.io.gfile.GFile(stat_file, "rb") as fin:
                stat = np.load(fin)
                all_logits.append(stat["logits"])
                all_pools.append(stat["pool_3"])
                
        all_logits = np.concatenate(all_logits, axis=0)[:config.eval.num_samples]
        all_pools = np.concatenate(all_pools, axis=0)[:config.eval.num_samples]

        data_stats = evaluation.load_dataset_stats_cond(config, i)
        data_pools = data_stats["pool_3"]

        inception_score = tfgan.eval.classifier_score_from_logits(all_logits)
        fid = tfgan.eval.frechet_classifier_distance_from_activations(data_pools, all_pools)
        # Hack to get tfgan KID work for eager execution.
        tf_data_pools = tf.convert_to_tensor(data_pools)
        tf_all_pools = tf.convert_to_tensor(all_pools)
        kid = tfgan.eval.kernel_classifier_distance_from_activations(tf_data_pools, tf_all_pools).numpy()
        del tf_data_pools, tf_all_pools
        if config.data.num_classes <= 10:
            print(f"Class {i}: Intra-FID = {fid:.6e}, Inception Score = {inception_score:.6e}, Intra-KID = {kid:.6e}")
        fids.append(float(fid))
        incep_scores.append(float(inception_score))
        kids.append(float(kid))

    print(f"Average: Intra-FID = {sum(fids) / len(fids):.6e}, Inception Score = {sum(incep_scores) / len(incep_scores):.6e}, Intra-KID = {sum(kids) / len(kids):.6e}")
