import gc
import io
import os
import time

import numpy as np
import tensorflow as tf
import tensorflow_gan as tfgan
import logging
# Keep the import below for registering all model definitions
from models import ddpm, ncsnv2, ncsnpp
import losses
import sampling
from models import utils as mutils
from models.ema import ExponentialMovingAverage
import datasets
import evaluation
import likelihood
import sde_lib
from absl import flags
import torch
from torch.utils import tensorboard
from torchvision.utils import make_grid, save_image
from utils import save_checkpoint, restore_checkpoint
from absl import app
from absl import flags
from ml_collections.config_flags import config_flags
import logging
import os
import tensorflow as tf
from tqdm import tqdm, trange
import tensorflow_datasets as tfds

FLAGS = flags.FLAGS

config_flags.DEFINE_config_file(
  "config", None, "Training configuration.", lock_config=False)

def main(argv):
  config = FLAGS.config
  config.training.batch_size = 1
  config.eval.batch_size = 1
  if config.data.dataset == 'CIFAR10':
    dataset_builder = tfds.builder('cifar10')
    train_split_name = 'train'
    eval_split_name = 'test'
    ds_name = 'cifar10'

    def resize_op(img):
      img = tf.image.convert_image_dtype(img, tf.float32)
      return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)
    def preprocess_fn(d):
      img = resize_op(d['image'])
      return dict(image=img, label=d.get('label', d['label']))
  elif config.data.dataset == 'CIFAR100':
    dataset_builder = tfds.builder('cifar100')
    train_split_name = 'train'
    eval_split_name = 'test'
    ds_name = 'cifar100'

    def resize_op(img):
      img = tf.image.convert_image_dtype(img, tf.float32)
      return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)
    def preprocess_fn(d):
      img = resize_op(d['image'])
      return dict(image=img, label=d.get('label', d['label']))
  else:
    exit(0)
  prefetch_size = tf.data.experimental.AUTOTUNE
  bs = 512
  arr = [None for _ in range(config.data.num_classes)]
  def create_dataset(dataset_builder, split):
    dataset_options = tf.data.Options()
    dataset_options.experimental_optimization.map_parallelization = True
    dataset_options.experimental_threading.private_threadpool_size = 48
    dataset_options.experimental_threading.max_intra_op_parallelism = 1
    read_config = tfds.ReadConfig(options=dataset_options)
    if isinstance(dataset_builder, tfds.core.DatasetBuilder):
      dataset_builder.download_and_prepare()
      ds = dataset_builder.as_dataset(
        split=split, shuffle_files=True, read_config=read_config)
    else:
      ds = dataset_builder.with_options(dataset_options)
    ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    for i in range(config.data.num_classes):
      arr[i] = ds.filter(lambda y: y['label'] == i).batch(bs, drop_remainder=False)
  create_dataset(dataset_builder, train_split_name)

  result_pool3 = [[] for _ in range(config.data.num_classes)]
  result_logits = [[] for _ in range(config.data.num_classes)]
  inceptionv3 = config.data.image_size >= 256
  inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3)
  for i in trange(config.data.num_classes):
    iter_arr = iter(arr[i])
    for j in range(0, config.data.per_class, bs):
      samples = next(iter_arr)['image']._numpy()

      # nrow = int(bs ** 0.5)
      # image_grid = make_grid(torch.from_numpy(samples).permute(0, 3, 1, 2), nrow, padding=2)
      # with tf.io.gfile.GFile(os.path.join('test', 'cifar100', f'class_{i}.png'), "wb") as fout:
      #     save_image(image_grid, fout)
      samples = np.clip(samples * 255., 0, 255).astype(np.uint8)
      latents = evaluation.run_inception_distributed(samples, inception_model, inceptionv3=inceptionv3)
      result_pool3[i].append(latents["pool_3"])
      result_logits[i].append(latents["logits"])

    result_pool3[i] = np.concatenate(result_pool3[i])
    result_logits[i] = np.concatenate(result_logits[i])
    with open(f"assets/stats/{ds_name}_class{i}_stats.npz", "wb") as fout:
      io_buffer = io.BytesIO()
      np.savez_compressed(io_buffer, pool_3=result_pool3[i], logits=result_logits[i])
      fout.write(io_buffer.getvalue())
  with open(f"assets/stats/{ds_name}_stats.npz", "wb") as fout:
    io_buffer = io.BytesIO()
    np.savez_compressed(io_buffer, pool_3=np.concatenate(result_pool3), logits=np.concatenate(result_logits))
    fout.write(io_buffer.getvalue())
    

if __name__ == "__main__":
  app.run(main)
