{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Notes\n",
    "\n",
    "#This file contains the code used for the simulations in Figure 1 and in Appendix E.2.\n",
    "\n",
    "#(This file contains the code for our algorithm and for GDA, for the CIFAR dataset.  For our algorithm with acceptance rate 1/2, set the \"rate\" paramter to \"rate = 2\".  For GDA, set the \"rate\" parameter to rate = 1.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/__init__.py:1473: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# example of calculating the frechet inception distance in Keras for cifar10\n",
    "import numpy\n",
    "from numpy import cov\n",
    "from numpy import trace\n",
    "from numpy import iscomplexobj\n",
    "from numpy import asarray\n",
    "from numpy.random import shuffle\n",
    "from scipy.linalg import sqrtm\n",
    "from keras.applications.inception_v3 import InceptionV3\n",
    "from keras.applications.inception_v3 import preprocess_input\n",
    "from keras.datasets.mnist import load_data\n",
    "from skimage.transform import resize\n",
    "from keras.datasets import cifar10\n",
    "import time\n",
    "\n",
    "\n",
    "# %load_ext line_profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:186: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:190: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:199: The name tf.is_variable_initialized is deprecated. Please use tf.compat.v1.is_variable_initialized instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:206: The name tf.variables_initializer is deprecated. Please use tf.compat.v1.variables_initializer instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:1834: The name tf.nn.fused_batch_norm is deprecated. Please use tf.compat.v1.nn.fused_batch_norm instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:133: The name tf.placeholder_with_default is deprecated. Please use tf.compat.v1.placeholder_with_default instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3976: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3980: The name tf.nn.avg_pool is deprecated. Please use tf.nn.avg_pool2d instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# prepare the inception v3 model\n",
    "model = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))\n",
    "# load cifar10 images\n",
    "(train_images, _), (test_images, _) = cifar10.load_data()\n",
    "shuffle(train_images)\n",
    "shuffle(test_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# scale an array of images to a new size\n",
    "def scale_images(images, new_shape):\n",
    "\timages_list = list()\n",
    "\tfor image in images:\n",
    "\t\t# resize with nearest neighbor interpolation\n",
    "\t\tnew_image = resize(image, new_shape, 0)\n",
    "\t\t# store\n",
    "# \t\tprint(new_image.shape)\n",
    "# \t\tprint(new_image)\n",
    "\t\timages_list.append(new_image)\n",
    "\treturn asarray(images_list)\n",
    "\n",
    "# calculate frechet inception distance\n",
    "def calculate_fid(model, images1, images2):\n",
    "\t# calculate activations\n",
    "\tact1 = model.predict(images1)\n",
    "\tact2 = model.predict(images2)\n",
    "\t# calculate mean and covariance statistics\n",
    "\tmu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)\n",
    "\tmu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)\n",
    "    \n",
    "\t# calculate sum squared difference between means\n",
    "\tssdiff = numpy.sum((mu1 - mu2)**2.0)\n",
    "\t# calculate sqrt of product between cov\n",
    "\tcovmean = sqrtm(sigma1.dot(sigma2))\n",
    "\t# check and correct imaginary numbers from sqrt\n",
    "\tif iscomplexobj(covmean):\n",
    "\t\tcovmean = covmean.real\n",
    "\t# calculate score\n",
    "\tfid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)\n",
    "\treturn mu1, mu2, sigma1, sigma2, fid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def scale_and_calculate_FID(model, images1, images2, new_shape=(299,299,3)):\n",
    "    \n",
    "    images1 = images1.astype('float32')\n",
    "    images2 = images2.astype('float32')\n",
    "    \n",
    "    mu1 = numpy.zeros((1, 2048))\n",
    "    sigma1 = numpy.zeros((2048, 2048))\n",
    "    for image in images1:\n",
    "        images_list = list()\n",
    "        image = resize(image, new_shape, 0)\n",
    "        image = preprocess_input(image)\n",
    "        images_list.append(image)\n",
    "        act = model.predict(asarray(images_list))\n",
    "        \n",
    "        mu1 += act\n",
    "        sigma1 += numpy.outer(act, act)\n",
    "    n1 = float(images1.shape[0])\n",
    "    mu1 /= n1\n",
    "    sigma1 -= n1*numpy.outer(mu1, mu1)\n",
    "    sigma1 /= (n1-1)\n",
    "\n",
    "    mu2 = numpy.zeros((1,2048))\n",
    "    sigma2 = numpy.zeros((2048, 2048))\n",
    "    for image in images2:\n",
    "        images_list = list()\n",
    "        image = resize(image, new_shape, 0)\n",
    "        image = preprocess_input(image)\n",
    "        images_list.append(image)\n",
    "        act = model.predict(asarray(images_list))\n",
    "        \n",
    "        mu2 += act\n",
    "        sigma2 += numpy.outer(act, act)\n",
    "    n2 = float(images2.shape[0])\n",
    "    mu2 /= n2\n",
    "    sigma2 -= n2*numpy.outer(mu2, mu2)\n",
    "    sigma2 /= (n2-1)\n",
    "    \n",
    "    # calculate sum squared difference between means\n",
    "    ssdiff = numpy.sum((mu1 - mu2)**2.0)\n",
    "    # calculate sqrt of product between cov\n",
    "    covmean = sqrtm(sigma1.dot(sigma2))\n",
    "    # check and correct imaginary numbers from sqrt\n",
    "    if iscomplexobj(covmean):\n",
    "        covmean = covmean.real\n",
    "    # calculate score\n",
    "    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)\n",
    "\n",
    "    return fid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "# %load_ext line_profiler\n",
    "\n",
    "import random\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from keras.optimizers import Adam#, SGD\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "from utils_CIFAR import *       # utils file has the filler code and helper functions\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "from functools import partial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "NOISE_SIZE = 100\n",
    "IMAGE_SHAPE = (32,32,3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#set filter = False to include the entire CIFAR-10 dataset\n",
    "X, _, _, _ = load_data(filter=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "adam_optimizer = keras.optimizers.Adam(lr=0.0002, beta_1=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "take_discriminator_steps_2 = partial(take_discriminator_steps, X_train=X, k=1)\n",
    "getLoss2 = partial(getLoss, X_train=X)\n",
    "create_gan2 = partial(create_gan, opt=adam_optimizer)\n",
    "\n",
    "def create_GAN_player():\n",
    "    ganPlayer = Players(create_generator(IMAGE_SHAPE, opt=adam_optimizer), \n",
    "                        create_discriminator(INPUT_SHAPE=IMAGE_SHAPE, opt=adam_optimizer), \n",
    "                        create_gan2, \n",
    "                        take_generator_steps, \n",
    "                        take_discriminator_steps_2, \n",
    "                        change_network, \n",
    "                        change_network, \n",
    "                        perturb_generator)    \n",
    "    return ganPlayer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def training_gd(create_player_function, create_player_function2):\n",
    "    #Number of iterations\n",
    "    T= 50010\n",
    "\n",
    "    # this will create a Players object, with two players\n",
    "    player = create_player_function()\n",
    "    player2 = create_player_function2()\n",
    "    \n",
    "    Loss=[]\n",
    "    pairwise_squared_distances=[]\n",
    "    \n",
    "    FID_scores=[]\n",
    "    \n",
    "    old_loss = 100\n",
    "    player.update_y()\n",
    "    \n",
    "    #how often to not accept/reject\n",
    "    rate = 2\n",
    "    \n",
    "    for j in tqdm(range(T)):\n",
    "        \n",
    "        \n",
    "        if j%rate != 0:\n",
    "            print(\"\\nIteration \", j)\n",
    "            \n",
    "            #save generator weights and the old loss    \n",
    "            player2.change_x(player.get_x())        \n",
    "            player2.change_y(player.get_y())\n",
    "\n",
    "            \n",
    "        if j>0:\n",
    "            loss_old = player.value(getLoss2)  \n",
    "            print(\"Old Loss: \", loss_old)        \n",
    "            Loss.append(loss_old)\n",
    "        \n",
    "        \n",
    "        # perform one gradient update for the generator and k gradient updates for the discriminator (we only use \"k=1\" discriminator gradient steps for CIFAR)\n",
    "        player.update_x()\n",
    "        k = 1\n",
    "        for s in range(k):\n",
    "            player.update_y()        \n",
    "\n",
    "        #Accept/reject Step\n",
    "        if j%rate != 0:\n",
    "            loss_new = player.value(getLoss2)            \n",
    "        \n",
    "            if loss_new > loss_old:\n",
    "                print(\"Reject\")\n",
    "                player.change_x(player2.get_x())        \n",
    "                player.change_y(player2.get_y())\n",
    "            else:\n",
    "                print(\"Accept\")\n",
    "                \n",
    "\n",
    "        folder_name ='results_supplementary'\n",
    "        \n",
    "        if (j%100 == 0 and j<3001) or j%1000==0:\n",
    "            loss = player.value(getLoss2)\n",
    "            print(\"Ending Loss:\",  loss)\n",
    "            filename = '/results'\n",
    "            #plot the genererated images ()\n",
    "            plot_generated_images(j, \n",
    "                                  player.get_x(),\n",
    "                                  folder=folder_name, \n",
    "                                  save=True, \n",
    "                                  image_shape=IMAGE_SHAPE, \n",
    "                                  name=filename+' %d.png')\n",
    "            \n",
    "        if j>0 and j%10==0:\n",
    "            plt.plot(Loss)\n",
    "            np.save(folder_name + filename + '_loss_values', Loss)\n",
    "    \n",
    "\n",
    "    #compute FID scores\n",
    "    \n",
    "        if j>0 and j%2500==0:\n",
    "            FID_sample_size = 10000\n",
    "\n",
    "            images1_a = train_images[np.random.randint(49999, size=FID_sample_size)]\n",
    "\n",
    "            FakeImages = generate_fake_FID_image_input(generator=player.get_x(),examples=FID_sample_size, image_shape=IMAGE_SHAPE)\n",
    "            images2_a = 255*FakeImages\n",
    "                \n",
    "            start = time.time()\n",
    "            fid_a = scale_and_calculate_FID(model, images1_a, images2_a)\n",
    "            end = time.time()\n",
    "            print(end - start)\n",
    "            FID_scores.append(fid_a)\n",
    "            print('FID_scores')\n",
    "            print(FID_scores)\n",
    "            np.save(folder_name + filename + '_FID_scores', FID_scores)\n",
    "    \n",
    "    \n",
    "    return player"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3376: The name tf.log is deprecated. Please use tf.math.log instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/ops/nn_impl.py:183: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:986: The name tf.assign_add is deprecated. Please use tf.compat.v1.assign_add instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:973: The name tf.assign is deprecated. Please use tf.compat.v1.assign instead.\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "998b0a7cd26e4bf08f13123c9885dc06",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=50010.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ending Loss: -1.2932832\n",
      "\n",
      "Iteration  1\n",
      "Old Loss:  -1.2905293\n",
      "Reject\n",
      "Old Loss:  -1.2895181\n",
      "\n",
      "Iteration  3\n",
      "Old Loss:  -1.2325437\n",
      "Reject\n",
      "Old Loss:  -1.2278309\n",
      "\n",
      "Iteration  5\n",
      "Old Loss:  -1.1441116\n",
      "Reject\n",
      "Old Loss:  -1.1467028\n",
      "\n",
      "Iteration  7\n",
      "Old Loss:  -1.0965085\n",
      "Accept\n",
      "Old Loss:  -1.1214663\n",
      "\n",
      "Iteration  9\n",
      "Old Loss:  -1.0599692\n",
      "Reject\n",
      "Old Loss:  -1.0832694\n",
      "\n",
      "Iteration  11\n",
      "Old Loss:  -0.9607671\n",
      "Reject\n",
      "Old Loss:  -0.93141186\n",
      "\n",
      "Iteration  13\n",
      "Old Loss:  -0.79709804\n",
      "Reject\n",
      "Old Loss:  -0.7663938\n",
      "\n",
      "Iteration  15\n",
      "Old Loss:  -0.77199113\n",
      "Reject\n",
      "Old Loss:  -0.7431572\n",
      "\n",
      "Iteration  17\n",
      "Old Loss:  -0.74580073\n",
      "Accept\n",
      "Old Loss:  -0.76233447\n",
      "\n",
      "Iteration  19\n",
      "Old Loss:  -0.78657377\n",
      "Accept\n",
      "Old Loss:  -0.81262165\n",
      "\n",
      "Iteration  21\n",
      "Old Loss:  -0.9446621\n",
      "Accept\n",
      "Old Loss:  -1.0516528\n",
      "\n",
      "Iteration  23\n",
      "Old Loss:  -0.99243045\n",
      "Reject\n",
      "Old Loss:  -0.95154834\n",
      "\n",
      "Iteration  25\n",
      "Old Loss:  -0.81944823\n",
      "Reject\n",
      "Old Loss:  -0.8163549\n",
      "\n",
      "Iteration  27\n",
      "Old Loss:  -0.70802903\n",
      "Reject\n",
      "Old Loss:  -0.7330277\n",
      "\n",
      "Iteration  29\n",
      "Old Loss:  -0.6686849\n",
      "Accept\n",
      "Old Loss:  -0.66517204\n",
      "\n",
      "Iteration  31\n",
      "Old Loss:  -0.7659133\n",
      "Accept\n",
      "Old Loss:  -0.84482\n",
      "\n",
      "Iteration  33\n",
      "Old Loss:  -0.8903661\n",
      "Reject\n",
      "Old Loss:  -0.9280906\n",
      "\n",
      "Iteration  35\n",
      "Old Loss:  -0.83057594\n",
      "Reject\n",
      "Old Loss:  -0.9515212\n",
      "\n",
      "Iteration  37\n",
      "Old Loss:  -0.8720163\n",
      "Accept\n",
      "Old Loss:  -0.93919945\n",
      "\n",
      "Iteration  39\n",
      "Old Loss:  -1.0037209\n",
      "Accept\n",
      "Old Loss:  -1.0683174\n",
      "\n",
      "Iteration  41\n",
      "Old Loss:  -1.0882444\n"
     ]
    }
   ],
   "source": [
    "# %lprun -f training_gd \n",
    "training_gd(create_GAN_player, create_GAN_player)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_tensorflow_p36)",
   "language": "python",
   "name": "conda_tensorflow_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
