{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from ipywidgets import *\n",
    "import warnings\n",
    "\n",
    "import torch\n",
    "from torchvision.utils import make_grid\n",
    "\n",
    "from models.semisup_vae import REVAE\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "revae = REVAE()\n",
    "data = torch.load('./data/celeba.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# choose random 100 samples\n",
    "batch = data[np.random.choice(data.size(0), 100)]\n",
    "grid = make_grid(batch, nrow=10, padding=20, pad_value=1.)\n",
    "\n",
    "# plot images\n",
    "matplotlib.rc('axes', edgecolor=\"#ffffff\")\n",
    "plt.clf()\n",
    "fig = plt.figure(figsize=(14, 14))\n",
    "ax = fig.gca()\n",
    "ax.set_xticks(np.arange(0.48, 10, 1))\n",
    "ax.set_yticks(np.arange(0.48, 10, 1))\n",
    "ax.set_xticklabels(np.arange(0, 10, 1))\n",
    "ax.set_yticklabels(np.arange(9, -1, -1))\n",
    "ax.tick_params(axis='both', which='both', left=False, right=False, bottom=False, top=False, labelsize=14)\n",
    "plt.imshow(grid.permute(1, 2, 0), extent=[0, 10, 0, 10], origin='upper')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get image coordinate from user\n",
    "coords = input(\"Enter a comma seperated grid coordinate to obtain an image to intervene on, e.g 2,5: \")\n",
    "img = batch[10 * int(coords[0]) + int(coords[2])]\n",
    "\n",
    "# show selected image\n",
    "%matplotlib inline\n",
    "plt.clf()\n",
    "plt.figure(figsize=(2,2))\n",
    "plt.axis('off')\n",
    "print('Chosen Image:')\n",
    "plt.imshow(img.permute(1, 2, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "################################################################\n",
    "### You may need to run this cell twice before images appear ###\n",
    "################################################################\n",
    "\n",
    "\n",
    "# get the reconstruction for the selected image\n",
    "with torch.no_grad():\n",
    "    recon = revae.reconstruct_img(img.unsqueeze(0))[0]\n",
    "\n",
    "# setup interactive plot\n",
    "%matplotlib notebook\n",
    "fig = plt.figure(figsize = (10,4))\n",
    "ax1 = fig.add_subplot(1, 2, 1)\n",
    "ax2 = fig.add_subplot(1, 2, 2)\n",
    "ax1.axis('off')\n",
    "ax2.axis('off')\n",
    "ax1.imshow(img.permute(1, 2, 0))\n",
    "im = ax2.imshow(recon.permute(1, 2, 0))\n",
    "\n",
    "# get initial latent value\n",
    "z = revae._z_prior_fn(*revae.encoder_z(img.unsqueeze(0))).sample()\n",
    "\n",
    "# define the sliders\n",
    "spaced_vertical = Layout(display='flex', align_items='stretch', min_height='7ex', width='90%', margin='1ex')\n",
    "widgets = []\n",
    "for i in range(18):\n",
    "    slider = FloatSlider(min=revae.lims[i][0], max=revae.lims[i][1], step=1.0, value=z[0, i].item(),\n",
    "                        layout=spaced_vertical, style={'description_width': '20%'})\n",
    "    widgets.append(slider)\n",
    "\n",
    "\n",
    "def update(Arched_Eyebrows, Bags_Under_Eyes, Bangs,\n",
    "         Black_Hair, Blond_Hair, Brown_Hair,\n",
    "         Bushy_Eyebrows, Chubby, Eyeglasses,\n",
    "         Heavy_Makeup, Male, No_Beard,\n",
    "         Pale_Skin, Receding_Hairline, Smiling,\n",
    "         Wavy_Hair, Wearing_Necktie, Young):\n",
    "    \"\"\"\n",
    "    Update the latent values and then reconstruct.\n",
    "    \"\"\"\n",
    "    z[0, :18] = torch.tensor([[Arched_Eyebrows, Bags_Under_Eyes, Bangs,\n",
    "                             Black_Hair, Blond_Hair, Brown_Hair,\n",
    "                             Bushy_Eyebrows, Chubby, Eyeglasses,\n",
    "                             Heavy_Makeup, Male, No_Beard,\n",
    "                             Pale_Skin, Receding_Hairline, Smiling,\n",
    "                             Wavy_Hair, Wearing_Necktie, Young]])\n",
    "    with torch.no_grad():\n",
    "        img = revae.decoder(z).squeeze()\n",
    "    im.set_data(img.permute(1, 2, 0))\n",
    "    fig.canvas.draw()\n",
    "    fig.canvas.flush_events()\n",
    "\n",
    "\n",
    "interact(update, \n",
    "         Arched_Eyebrows=widgets[0], Bags_Under_Eyes=widgets[1], Bangs=widgets[2],\n",
    "         Black_Hair=widgets[3], Blond_Hair=widgets[4], Brown_Hair=widgets[5],\n",
    "         Bushy_Eyebrows=widgets[6], Chubby=widgets[7], Eyeglasses=widgets[8],\n",
    "         Heavy_Makeup=widgets[9], Male=widgets[10], No_Beard=widgets[11],\n",
    "         Pale_Skin=widgets[12], Receding_Hairline=widgets[13], Smiling=widgets[14],\n",
    "         Wavy_Hair=widgets[15], Wearing_Necktie=widgets[16], Young=widgets[17]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
