import gc
import io
import os
import time

import numpy as np
import tensorflow as tf
import tensorflow_gan as tfgan
import logging
# Keep the import below for registering all model definitions
# from models import ddpm, ncsnv2, ncsnpp
import losses
import sampling
# from models import utils as mutils
# from models.ema import ExponentialMovingAverage
import datasets
import evaluation
import likelihood
import sde_lib
from absl import flags
import torch
from torch.utils import tensorboard
from torchvision.utils import make_grid, save_image
from utils import save_checkpoint, restore_checkpoint
from tqdm import trange
import sys
from prdc import compute_prdc

def eval_cond_prdc(config, eval_dir):
    precision, recall, density, coverage = [], [], [], []

    for i in trange(config.data.num_classes):
        gc.collect()
        all_logits = []
        all_pools = []
        stats = tf.io.gfile.glob(os.path.join(eval_dir, str(i), f"statistics_*.npz"))
        for stat_file in stats:
            with tf.io.gfile.GFile(stat_file, "rb") as fin:
                stat = np.load(fin)
                all_logits.append(stat["logits"])
                all_pools.append(stat["pool_3"])
        all_logits = np.concatenate(all_logits, axis=0)[:config.eval.num_samples]
        all_pools = np.concatenate(all_pools, axis=0)[:config.eval.num_samples]
        
        data_stats = evaluation.load_dataset_stats_cond(config, i)
        data_pools = data_stats["pool_3"]

        prdc = compute_prdc(real_features=data_pools, fake_features=all_pools, nearest_k=3)

        precision.append(prdc['precision'])
        recall.append(prdc['recall'])
        density.append(prdc['density'])
        coverage.append(prdc['coverage'])
        if config.data.num_classes <= 10:
            print(f"Class {i}: Precision = {precision[-1]:.6e}, Recall = {recall[-1]:.6e}, Density = {density[-1]:.6e}, Coverage = {coverage[-1]:.6e}")

    print(f"Average: Precision = {sum(precision) / len(precision):.6e}, Recall = {sum(recall) / len(recall):.6e}, Density = {sum(density) / len(density):.6e}, Coverage = {sum(coverage) / len(coverage):.6e}")

def eval_uncond_prdc(config, eval_dir):
    all_logits = []
    all_pools = []
    for i in range(config.data.num_classes):
        stats = tf.io.gfile.glob(os.path.join(eval_dir, str(i), f"statistics_*.npz"))
        total = 0
        for stat_file in stats:
            with tf.io.gfile.GFile(stat_file, "rb") as fin:
                stat = np.load(fin)
                all_logits.append(stat["logits"])
                all_pools.append(stat["pool_3"])
                total += len(all_pools[-1])
        if total == config.eval.num_samples:
            continue
        elif total > config.eval.num_samples:
            all_pools[-1] = all_pools[-1][:-(total - config.eval.num_samples)]
            all_logits[-1] = all_logits[-1][:-(total - config.eval.num_samples)]
        else:
            raise ValueError(f'Not enough samples for class {i}.')

    all_logits = np.concatenate(all_logits, axis=0)
    all_pools = np.concatenate(all_pools, axis=0)

    data_stats = evaluation.load_dataset_stats(config)
    data_pools = data_stats["pool_3"]

    prdc = compute_prdc(real_features=data_pools, fake_features=all_pools, nearest_k=3)

    print(f"Precision = {prdc['precision']:.6e}, Recall = {prdc['recall']:.6e}, Density = {prdc['density']:.6e}, Coverage = {prdc['coverage']:.6e}")


def eval_PRDC(config, eval_dir):
    eval_cond_prdc(config, eval_dir)
    eval_uncond_prdc(config, eval_dir)