{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.2.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os, sys\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
    "# os.environ[\"TF_XLA_FLAGS\"] =\"--tf_xla_auto_jit=2\"\n",
    "sys.path.append(\"/home/Developer/NCSN-TF2.0/\")\n",
    "\n",
    "import PIL\n",
    "import utils, configs\n",
    "import argparse\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "import tensorflow_probability as tfp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from tensorflow.compat.v1 import ConfigProto\n",
    "from tensorflow.compat.v1 import InteractiveSession\n",
    "from helper import plot_confusion_matrix, metrics\n",
    "\n",
    "from datasets.dataset_loader import  *\n",
    "from tqdm import tqdm\n",
    "from sklearn.metrics import classification_report, average_precision_score\n",
    "from sklearn.metrics import roc_auc_score, precision_recall_curve, auc\n",
    "\n",
    "from sklearn.mixture import GaussianMixture\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.pipeline import Pipeline\n",
    "\n",
    "from PIL import Image\n",
    "from IPython.display import display\n",
    "from matplotlib.pyplot import imshow\n",
    "from datetime import datetime\n",
    "\n",
    "import seaborn as sns\n",
    "import plotly as py\n",
    "import plotly.graph_objs as go\n",
    "\n",
    "\n",
    "# example of calculating the frechet inception distance in Keras\n",
    "import numpy\n",
    "from numpy import cov\n",
    "from numpy import trace\n",
    "from numpy import iscomplexobj\n",
    "from numpy import asarray\n",
    "from numpy.random import randint\n",
    "from scipy import linalg\n",
    "from scipy.linalg import sqrtm\n",
    "\n",
    "# from keras.applications.inception_v3 import InceptionV3\n",
    "# from keras.applications.inception_v3 import preprocess_input\n",
    "# from keras.datasets.mnist import load_data\n",
    "\n",
    "from tensorflow.keras.applications.inception_v3 import InceptionV3\n",
    "from tensorflow.keras.applications.inception_v3 import preprocess_input\n",
    "from tensorflow.keras.datasets.mnist import load_data\n",
    "\n",
    "from skimage.transform import resize\n",
    "\n",
    "sns.set(style=\"darkgrid\")\n",
    "plt.rcParams['axes.titlesize'] = 18\n",
    "plt.rcParams['axes.labelsize'] = 16\n",
    "\n",
    "seed=42\n",
    "tf.random.set_seed(seed)\n",
    "np.random.seed(seed)\n",
    "\n",
    "tf.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num GPUs Available:  1\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"Num GPUs Available: \", len(tf.config.experimental.list_physical_devices('GPU')))\n",
    "tf.config.experimental.list_physical_devices('GPU')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.compat.v1 import ConfigProto\n",
    "from tensorflow.compat.v1 import InteractiveSession\n",
    "\n",
    "config = ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "session = InteractiveSession(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# scale an array of images to a new size\n",
    "def scale_images(images, new_shape):\n",
    "    images_list = list()\n",
    "    for image in images:\n",
    "        # resize with nearest neighbor interpolation\n",
    "        new_image = resize(image, new_shape, 0)\n",
    "        # store\n",
    "        images_list.append(new_image)\n",
    "    return asarray(images_list)\n",
    "\n",
    "# calculate frechet inception distance\n",
    "def calculate_fid(model, images1, images2):\n",
    "    # calculate activations\n",
    "    act1 = model.predict(images1)\n",
    "    act2 = model.predict(images2)\n",
    "    # calculate mean and covariance statistics\n",
    "    mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)\n",
    "    mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)\n",
    "    # calculate sum squared difference between means\n",
    "    ssdiff = numpy.sum((mu1 - mu2)**2.0)\n",
    "    # calculate sqrt of product between cov\n",
    "    covmean = sqrtm(sigma1.dot(sigma2))\n",
    "    # check and correct imaginary numbers from sqrt\n",
    "    if iscomplexobj(covmean):\n",
    "        covmean = covmean.real\n",
    "    # calculate score\n",
    "    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)\n",
    "    return fid\n",
    "\n",
    "def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):\n",
    "    \"\"\"Numpy implementation of the Frechet Distance.\n",
    "    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)\n",
    "    and X_2 ~ N(mu_2, C_2) is\n",
    "            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).\n",
    "\n",
    "    Stable version by Dougal J. Sutherland.\n",
    "\n",
    "    Params:\n",
    "    -- mu1 : Numpy array containing the activations of the pool_3 layer of the\n",
    "             inception net ( like returned by the function 'get_predictions')\n",
    "             for generated samples.\n",
    "    -- mu2   : The sample mean over activations of the pool_3 layer, precalcualted\n",
    "               on an representive data set.\n",
    "    -- sigma1: The covariance matrix over activations of the pool_3 layer for\n",
    "               generated samples.\n",
    "    -- sigma2: The covariance matrix over activations of the pool_3 layer,\n",
    "               precalcualted on an representive data set.\n",
    "\n",
    "    Returns:\n",
    "    --   : The Frechet Distance.\n",
    "    \"\"\"\n",
    "\n",
    "    mu1 = np.atleast_1d(mu1)\n",
    "    mu2 = np.atleast_1d(mu2)\n",
    "\n",
    "    sigma1 = np.atleast_2d(sigma1)\n",
    "    sigma2 = np.atleast_2d(sigma2)\n",
    "\n",
    "    assert mu1.shape == mu2.shape, \"Training and test mean vectors have different lengths\"\n",
    "    assert sigma1.shape == sigma2.shape, \"Training and test covariances have different dimensions\"\n",
    "\n",
    "    diff = mu1 - mu2\n",
    "\n",
    "    # product might be almost singular\n",
    "    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n",
    "    if not np.isfinite(covmean).all():\n",
    "        msg = \"fid calculation produces singular product; adding %s to diagonal of cov estimates\" % eps\n",
    "        warnings.warn(msg)\n",
    "        offset = np.eye(sigma1.shape[0]) * eps\n",
    "        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n",
    "\n",
    "    # numerical error might give slight imaginary component\n",
    "    if np.iscomplexobj(covmean):\n",
    "        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n",
    "            m = np.max(np.abs(covmean.imag))\n",
    "            raise ValueError(\"Imaginary component {}\".format(m))\n",
    "        covmean = covmean.real\n",
    "\n",
    "    tr_covmean = np.trace(covmean)\n",
    "\n",
    "    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean\n",
    "\n",
    "# calculate frechet inception distance\n",
    "# expects a list of image batches\n",
    "def calculate_fid_stable(model, images1, images2):\n",
    "    # calculate activations\n",
    "    act1 = np.concatenate([model.predict(batch) for batch in images1])\n",
    "    act2 = np.concatenate([model.predict(batch) for batch in images2])\n",
    "\n",
    "    # calculate mean and covariance statistics\n",
    "    mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)\n",
    "    mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)\n",
    "   \n",
    "    return calculate_frechet_distance(mu1, sigma1, mu2, sigma2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # @tf.function\n",
    "def preprocess_images_tf(images, shape, batch_sz=256):    \n",
    "    imgs = []\n",
    "    \n",
    "    # Scale to shape + Inception preprocessing\n",
    "    for i in range(0, images.shape[0], batch_sz):\n",
    "        _img = tf.image.resize(images[i:i+batch_sz], shape)\n",
    "        _img = preprocess_input(_img)\n",
    "        imgs.append(_img)\n",
    "    return imgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare the inception v3 model\n",
    "model = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prepared (100, 32, 32, 3) (100, 32, 32, 3)\n"
     ]
    }
   ],
   "source": [
    "# define two fake collections of images\n",
    "images1 = randint(0, 255, 100*32*32*3)\n",
    "images1 = images1.reshape((100,32,32,3))\n",
    "images2 = randint(0, 255, 100*32*32*3)\n",
    "images2 = images2.reshape((100,32,32,3))\n",
    "print('Prepared', images1.shape, images2.shape)\n",
    "# convert integer to floating point values\n",
    "images1 = images1.astype('float32')\n",
    "images2 = images2.astype('float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FID (same): -0.000\n",
      "Stable FID (different): 19.305\n",
      "CPU times: user 6min 36s, sys: 9min 37s, total: 16min 14s\n",
      "Wall time: 36.1 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "# resize + pre-process images\n",
    "images1 = preprocess_images_tf(images1, (299,299), 10)\n",
    "images2 = preprocess_images_tf(images2, (299,299), 10)\n",
    "\n",
    "# fid between images1 and images1\n",
    "fid = calculate_fid_stable(model, images1, images1)\n",
    "print('FID (same): %.3f' % fid)\n",
    "\n",
    "\n",
    "# fid between images1 and images2\n",
    "fid = calculate_fid_stable(model, images1, images2)\n",
    "print('Stable FID (different): %.3f' % fid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded (50000, 32, 32, 3) (10000, 32, 32, 3)\n",
      "CPU times: user 2min 28s, sys: 9min, total: 11min 29s\n",
      "Wall time: 1min 59s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "from numpy.random import shuffle\n",
    "from keras.datasets import cifar10\n",
    "\n",
    "# load cifar10 images\n",
    "with tf.device(\"cpu\"):\n",
    "    (images1, _), (images2, _) = cifar10.load_data()\n",
    "    # shuffle(images1)\n",
    "    # images1 = images1[:10000]\n",
    "    print('Loaded', images1.shape, images2.shape)\n",
    "    # convert integer to floating point values\n",
    "    images1 = images1.astype('float32')\n",
    "    images2 = images2.astype('float32')\n",
    "\n",
    "    # resize images\n",
    "    images1 = preprocess_images_tf(images1, (299,299), batch_sz=256)\n",
    "    images2 = preprocess_images_tf(images2, (299,299), batch_sz=256)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculate Activations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2min 36s, sys: 8min 23s, total: 10min 59s\n",
      "Wall time: 1min 49s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "with tf.device(\"gpu\"):\n",
    "    act1 = np.concatenate([model.predict(batch) for batch in images1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 38.3 s, sys: 1min 44s, total: 2min 23s\n",
      "Wall time: 25.4 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "with tf.device(\"gpu\"):\n",
    "    act2 = np.concatenate([model.predict(batch) for batch in images2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 38.2 s, sys: 23.6 s, total: 1min 1s\n",
      "Wall time: 2.9 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# calculate mean and covariance statistics\n",
    "mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)\n",
    "mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FID: 3.235\n",
      "CPU times: user 4min 34s, sys: 7min 37s, total: 12min 11s\n",
      "Wall time: 19.1 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# calculate fid\n",
    "fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)\n",
    "print('FID: %.3f' % fid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FID (same): -0.000\n",
      "CPU times: user 4min 13s, sys: 7min 11s, total: 11min 24s\n",
      "Wall time: 17.9 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "fid = calculate_frechet_distance(mu1, sigma1, mu1, sigma1)\n",
    "print('FID (same): %.3f' % fid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"./statistics/new_fid_stats_cifar10_train.npz\"\n",
    "# np.savez(path, mu=mu1, sigma=sigma1)\n",
    "\n",
    "with np.load(path) as f:\n",
    "    m, s = f['mu'][:], f['sigma'][:]\n",
    "\n",
    "# assert np.allclose(m, mu1)\n",
    "# assert np.allclose(s, sigma1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Getting files for a checkpoint\n",
    "> ### Use evaluation experiment to generate samples first"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================\n",
      "Parameters: \n",
      "\n",
      "experiment: train\n",
      "dataset: cifar10\n",
      "model: refinenet\n",
      "filters: 128\n",
      "num_L: 10\n",
      "sigma_low: 0.01\n",
      "sigma_high: 1.0\n",
      "sigma_sequence: geometric\n",
      "steps: 200000\n",
      "learning_rate: 0.001\n",
      "batch_size: 128\n",
      "samples_dir: ./samples/\n",
      "checkpoint_dir: longleaf_models/\n",
      "checkpoint_freq: 5000\n",
      "resume: True\n",
      "resume_from: -1\n",
      "init_samples: \n",
      "k: 10\n",
      "eval_setting: sample\n",
      "ocnn: False\n",
      "y_cond: False\n",
      "max_to_keep: 2\n",
      "split: ['95', '5']\n",
      "====================\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('longleaf_models/refinenet128_cifar10_L10_SH1e+00_SL1e-02/train_95_5/',\n",
       " 'refinenet128_cifar10_L10_SH1e+00_SL1e-02/train_95_5')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import utils \n",
    "\n",
    "def get_command_line_args(_args):\n",
    "    parser = utils._build_parser()\n",
    "\n",
    "    parser = parser.parse_args(_args)\n",
    "\n",
    "    utils.check_args_validity(parser)\n",
    "\n",
    "    print(\"=\" * 20 + \"\\nParameters: \\n\")\n",
    "    for key in parser.__dict__:\n",
    "        print(key + ': ' + str(parser.__dict__[key]))\n",
    "    print(\"=\" * 20 + \"\\n\")\n",
    "    return parser\n",
    "\n",
    "# Make sure these params match the model you want to evaluate\n",
    "args = get_command_line_args([\"--checkpoint_dir=longleaf_models/\",\n",
    "                              \"--filters=128\",\n",
    "                              \"--dataset=cifar10\",\n",
    "                              \"--sigma_low=0.01\",\n",
    "                              \"--sigma_high=1\",\n",
    "                              \"--split=95,5\"\n",
    "                             ])\n",
    "configs.config_values = args\n",
    "dir_statistics = './statistics'\n",
    "save_dir, complete_model_name = utils.get_savemodel_dir()\n",
    "save_dir, complete_model_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::   5%|▌         | 1/20 [00:29<09:22, 29.61s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-1: 82.658\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  10%|█         | 2/20 [00:47<07:52, 26.24s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-2: 136.799\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  15%|█▌        | 3/20 [01:05<06:41, 23.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-3: 134.650\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  20%|██        | 4/20 [01:21<05:42, 21.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-4: 77.444\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  25%|██▌       | 5/20 [01:39<05:04, 20.33s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-5: 70.214\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  30%|███       | 6/20 [01:56<04:31, 19.36s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-6: 78.043\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  35%|███▌      | 7/20 [02:12<03:59, 18.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-7: 71.017\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  40%|████      | 8/20 [02:28<03:30, 17.54s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-8: 70.920\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  45%|████▌     | 9/20 [02:45<03:12, 17.52s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-9: 99.364\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  50%|█████     | 10/20 [03:03<02:54, 17.50s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-10: 94.916\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  55%|█████▌    | 11/20 [03:20<02:36, 17.36s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-11: 87.868\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  60%|██████    | 12/20 [03:37<02:18, 17.28s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-12: 120.422\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  65%|██████▌   | 13/20 [03:54<02:00, 17.23s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-13: 84.668\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  70%|███████   | 14/20 [04:11<01:42, 17.13s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-14: 111.054\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  75%|███████▌  | 15/20 [04:27<01:24, 16.96s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-15: 111.211\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  80%|████████  | 16/20 [04:45<01:09, 17.26s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-16: 101.450\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  85%|████████▌ | 17/20 [05:03<00:52, 17.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-17: 87.838\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  90%|█████████ | 18/20 [05:21<00:34, 17.49s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-18: 113.988\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt::  95%|█████████▌| 19/20 [05:39<00:17, 17.72s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-19: 113.356\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ckpt:: 100%|██████████| 20/20 [05:57<00:00, 17.85s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ckpt-20: 115.152\n",
      "CPU times: user 1h 6min 2s, sys: 1h 32min 45s, total: 2h 38min 47s\n",
      "Wall time: 5min 57s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "import pathlib\n",
    "from imageio import imread\n",
    "\n",
    "multiple = 10000\n",
    "i = 1\n",
    "step_ckpt = 0\n",
    "fids = []\n",
    "for i in tqdm(range(1, 21), desc=\"Ckpt:\"):\n",
    "    step_ckpt = i * multiple\n",
    "    save_directory = '{}/{}/step{}/samples/'.format(dir_statistics, complete_model_name, step_ckpt)\n",
    "    \n",
    "    path = pathlib.Path(save_directory)\n",
    "    files = list(path.glob('*.jpg')) + list(path.glob('*.png'))\n",
    "    x = np.array([imread(str(fn)).astype(np.float32) for fn in files])\n",
    "\n",
    "    with tf.device(\"cpu\"):\n",
    "        generated_images = preprocess_images_tf(x, (299,299), batch_sz=256)\n",
    "\n",
    "    with tf.device(\"gpu\"):\n",
    "        sample_acts = np.concatenate([model.predict(batch) for batch in generated_images])\n",
    "\n",
    "    mus, sigmas = sample_acts.mean(axis=0), cov(sample_acts, rowvar=False)\n",
    "\n",
    "    # calculate fid\n",
    "    fid = calculate_frechet_distance(m, s, mus, sigmas)\n",
    "    fids.append(fid)\n",
    "    \n",
    "    print('Ckpt-%d: %.3f' % (i, fid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 70.21384937160167)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "min_idx = np.argmin(fids)\n",
    "min_idx+1, fids[min_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
