{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_hub as hub\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reload_encoder_decoder_dislib(result_path,image_shape, dataset, model_name, repetition):\n",
    "    \"\"\"\n",
    "    Reload pre-trained encoder-decoder models from disentanglement_lib \n",
    "    \"\"\"\n",
    "    model_path = os.path.join(result_path, dataset+\"_\"+model_name+\"_\"+str(repetition)+\"\\\\\"+model_name+\"\\\\\"+model_name+\"\\\\tfhub\")\n",
    "    # Define the encoder\n",
    "    input_layer = tf.keras.layers.Input(image_shape)\n",
    "    encoder_layer = hub.KerasLayer(model_path, signature=\"gaussian_encoder\",signature_outputs_as_dict=True)(input_layer)\n",
    "    encoder = tf.keras.models.Model(input_layer, encoder_layer[\"mean\"])\n",
    "    \n",
    "    # Define the decoder\n",
    "    latent_input_layer = tf.keras.layers.Input(encoder_layer[\"mean\"].shape[-1]) # get the latent variable shape\n",
    "    decoder_layer = hub.KerasLayer(model_path, signature=\"decoder\", signature_outputs_as_dict=True)(latent_input_layer)\n",
    "    decoder = tf.keras.models.Model(latent_input_layer, decoder_layer[\"images\"])\n",
    "    return encoder, decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder, decoder = reload_encoder_decoder_dislib(result_path, image_shape, dataset, model_name, repetition)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create some test images with all ones \n",
    "images_per_batch = 10\n",
    "test_images = np.ones([images_per_batch]+image_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make predictions\n",
    "latent_codes = encoder.predict(test_images)\n",
    "reconstructions = decoder.predict(latent_codes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x2814613e748>"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD7CAYAAACscuKmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAThUlEQVR4nO3df5BW1XkH8O83CP6E6IIQIiiiRKW2gt2ghvxA0AQN0TYNqU5siWGyaWNanSYlazrTTDLjlCS2k8ykdkKMhcYES9WItZ0o3YTppLEoRBAQAX8gEDasgkTzQyL49I/3bvPwwPtj3733vS+c72eGeZ97z9n3Pezus/fce+49h2YGETn2vansBohIayjZRRKhZBdJhJJdJBFKdpFEKNlFEjGoZCc5m+Rmks+Q7M6rUSKSPzY7zk5yCIAtAK4EsBPA4wCuN7On8mueiOTluEF87TQAz5jZcwBA8h4A1wKomuwkdQePSMHMjEfaP5hu/BkAdrjtndk+EWlDgzmyH+mvx2FHbpJdALoG8TkikoPBJPtOAOPd9jgAu2IlM1sEYBGgbrxImQbTjX8cwCSSZ5McBuA6AA/m0ywRyVvTR3YzO0DyUwAeBjAEwF1mtjG3lolIrpoeemvqw9SNFylcEVfjReQoomQXSYSSXSQRSnaRRCjZRRKhZBdJhJJdJBFKdpFEKNlFEqFkF0mEkl0kEUp2kUQo2UUSoWQXSYSSXSQRSnaRRCjZRRKhZBdJhJJdJBFKdpFEKNlFEqFkF0mEkl0kEUp2kUQo2UUSUTfZSd5Fso/kBrevg+QKkluz19OKbaaIDFYjR/bFAGaHfd0AesxsEoCebFtE2ljdZDez/wawN+y+FsCSLF4C4A9ybpeI5KzZc/YxZtYLANnr6PyaJCJFaHrJ5kaR7ALQVfTniEhtzR7Zd5McCwDZa1+1ima2yMw6zayzyc8SkRw0m+wPApiXxfMALM+nOSJSFJpZ7QrkUgAzAIwCsBvA5wE8AGAZgDMBbAcw18ziRbwjvVftDxORQTMzHml/3WTPk5JdpHjVkl130IkkQskukgglu0giCh9nlzbl/8z/YSj7uIsXuvj4UG+ti3fn0Sgpko7sIolQsoskQskukgiNs6fi7WF7rIvfX6Pusy5+JdQb5uJfh7ItLr69buskRxpnF0mckl0kEerGpyJOL3KNi/84lL3s4jNqvOcOF78Qys5y8Usu/lSo9+Ma7y9NUTdeJHFKdpFE6A66VDwQtn/fxfFK+osu3uPicaHecVViANjsYn91/39CvY+4+LuQAunILpIIJbtIIpTsIonQ0JsAnwjbN7j4+y6OT715c8P2f7jYD+XFtYNudPE5oWxfjc+TqjT0JpI4JbtIItSNl8PNd/GZLt4V6vnhtimh7Dcu3uniOMznVxGME2DMgzRB3XiRxCnZRRKhZBdJhM7ZpbbrXBwnubjQxXFI7c0u9pNcxOG03ir1gEOfiPuzag2UqOlzdpLjSf6Q5CaSG0nenO3vILmC5NbsNf64RaSNNNKNPwDg02Z2AYBLAdxEcjKAbgA9ZjYJQE+2LSJtasDdeJLLAXw9+zfDzHqzZZtXmtl5db5W3fij2fCwfZWLvxzK/DDdSBfHp+P2u/jZUHaii313/0+qNVCAnIbeSE4AMBXAKgBjzKw3e/NeAKMH10QRKVLDz7OTPAXAfQBuMbNXyCP+8TjS13UB6GqueSKSl4aO7CSHopLo3zGz+7Pdu7PuO7LXviN9rZktMrNOM+vMo8Ei0py65+ysHMKXANhrZre4/V8BsMfMFpLsBtBhZgvqvJfO2Y9VHwrb/jDib4M9P9TzT70tDmVPudjftnsw1Ftar3FpqXbO3kg3fjoql0TWk+xfyu9zqCz5t4zkfADbcfhDjiLSRuomu5n9CEC1E/RZ+TZHRIqiO+ikXH/q4hdDmb8K5J+WuyTU2+Di7aEswaWk9dSbSOKU7CKJ0LzxUry3ujhetX+bi38Tyt5w8cVV9sftiaHML221ploD06Aju0gilOwiiVCyiyRCQ2/SWqPC9jtcvC2U+fPtK1x8Uqj3vIsvDmV+yel3ufhHVdp3DNDQm0jilOwiiVA3Xsrlu+S/qlHvL1w8P5T9m4t/Gcr8acP7XPx2HLPUjRdJnJJdJBFKdpFE6JxdyuUnpYhPrHknu/i2UPaai+N6dL9wsT9nj5Nbfq7GZx9ldM4ukjglu0gi9NSblOs9Ln4hlA118btd/I5Q7ywXx8PXXhdvcfH7Qr0OFx+jS03pyC6SCCW7SCJ0Nb4Ib3Hxz0prRW3Xuzjedeb5qZ8/Fsp+4uK4OuujLv52jff3k1f8PJT5+Yp9dzx2s/2dd+eGMn8q4JeQiqsc+J/Z/aFsIY4quhovkjglu0gilOwiidDQ20BMdfG1Lv7fUM/f0fU7ocyfAz8eyh5rsl3VLHbxk6HMXz15ayjz5+Z3uPi+UO8UF18fyuL3pJotNcq+XmX/gbDtz1BfC2XTXeyXmF4d6vlz+LWh7FQXx2sTR5G6R3aSJ5B8jOQ6khtJfiHb30FyBcmt2etpxTdXRJrVSDd+P4CZZnYRgCkAZpO8FEA3gB4zmwSgJ9sWkTbVyFpvht8+TjA0+2eodGRnZPuXAFgJ4LO5t7BoF4TtOS6O84z7hzZOd/F7Qj2/jFGcL+0cF8cu/ngX+67pOaHeJhc/HMouc/FkF8d+lx8aez2U+dOLSS6Ow4gjXPxEKLvcxf+CfP06bPu54neGMv+z8D/PEaGeP+zFxcUnuPieeo1rX42uzz4kW8G1D8AKM1sFYIyZ9QJA9jq6uGaKyGA1lOxmdtDMpgAYB2AayQsb/QCSXSRXk4yXRESkhQY09GZm+1Dprs8GsJvkWADIXuM9Sf1fs8jMOs0sdo5EpIXq3i5L8nQAr5vZPpInAngEwJdQOVPdY2YLSXYD6DCzBXXeqz1ul/V/dr4Uyvx5dLx900+E4Idgzg/1/NBQXL/M39r5tlDmP88vNfxyqOeHq3aEMj/U5P8v42q08aVQ5n9KY1GdX2Mtnivvd/HdLv5ijfeLjnfxR138wVDP9zPj0Ji/ZuL/zxNCPf91cZnnV118TSiLy0y3gWq3yzYyzj4WwBKSQ1DpCSwzs4dIPgpgGcn5qMwxMrfWm4hIuRq5Gv8kDr2dpH//HgCzimiUiOQvzTvo/J+os0PZVhfHYTl/x5h/mur5UM8PBR0fyg66OHbBz6xSLw6N+S5zbL8/NfBfF7vZfujwhFDmu7S+i98R6vnubfxN8hNRxFOZRvm54t/v4reEev5nNj6U+f+b/7/EOejGuDgOl653cRt22xule+NFEqFkF0mEJq/4RNj2XdO44qi/xvkBF98V6vkHJ4aGsj9y8b2hbLiLh7n4w6He1S6Od5P5so+4+L9CPd/9HxnKZrrY36H3RqjnV0J9Vyhb6uIH0ZghYdt/73yX/qZQ799dHE95/K1e/q6+eCecH/HYGso2ungD2p4mrxBJnJJdJBFKdpFEpDn05n0jbPuhmwmhzE96uNLF60I9f+55MJT5c8qnQ5mfeMFfO4h/kv1PLT6J5q8f+HPxeDOzHxIcHsr8E2xvdvF/hnp3VImbFb9Xfttf35gc6vlrJLXuflvl4vg99d+ff63WwKObjuwiiVCyiyRCQ2+RH/Kq1a0sgr+rzd+gHO9ca2YChTg05k8F/jKU+a7wZhfH05UfNNGOWq4I2xe52A8VxqWb4lzxnv+N+6mLXw31/B11H8BRTUNvIolTsoskQskukgids0u5/MQfe0PZx13sJ92ME0g85+I42Ya/NvGKi9eHere7uNE579uUztlFEqdkF0mEuvFSrk+6OD496O8o/HMX/3Wo5+el/0Uo80OpV7r4Q6HenmoNPPqoGy+SOCW7SCL0IIyUy89Pd0ko80tULXdxXKLEP9QTu/F+UopvuvgY6rY3Skd2kUQo2UUSoWQXSYTO2aV4/rfs86Hsd10cB4z8nO9+QpC4lJV/WvBAKPMTcn6mWgPT0PCRPVu2+QmSD2XbHSRXkNyavcYVwEWkjQykG38zgE1uuxtAj5lNAtCTbYtIm2roDjqS4wAsAXAbgL8yszkkNwOYYWa92ZLNK83svDrvozvoUuGH0b7m4jjfnR822xLK/IQVJ7o4dtX98FrsX/rhtrhi7zFqsHfQfRXAAhy6TMAYM+vN3rwXh07HLyJtpm6yk5wDoM/M1jTzASS7SK4mubqZrxeRfDRyNX46gGtIXo3K9dERJO8GsJvkWNeNj5MVAwDMbBGARYC68SJlGtBTbyRnAPhMds7+FQB7zGwhyW4AHWa2oM7XK9mPVVeF7TNc7IfGzg/13uviO0PZyS72677FCScXu3h7KLsbySniqbeFAK4kuRWVhwcXDuK9RKRgA7qpxsxWIlsLxcz2AJiVf5NEpAi6g07ycUPY3u/iJ13cG+r5J9tGhLKXq8QnhXp+uaZnIVXo3niRRCjZRRKhOeikcR8N2/6KzQWhzE8w4eeS24/qYpmf2MJfZf9mqLe0xnsmSHPQiSROyS6SCCW7SCJ0zi61vdPF8Q43P3Abn0Tz598jXfxaqOfPLn8cyra52D+xFieVlEPonF0kcUp2kUToDjo5nJ9QosPFcXmmqS6eHcq+7eKfuzj+xs1x8R2h7FFIjnRkF0mEkl0kEUp2kURo6E0O5+d2P9fF8YmyMS4eFcr2ufinLj4Y6v3KxX/fUOukDg29iSROyS6SCA29CfDJsD3TxdNcHO9+O1ijzM8P7yel2BrqfbZu6yQnOrKLJELJLpIIXY1P1RddfGMo81fI97r4zFDP/zTjgzDPuNgvyfTBUO+Fag2UZulqvEjilOwiiVCyiyRCQ2+p+jsXHx/K/HLLfvmkU0K9E1x8bii73cVPD6xpUoyGkp3kNgCvojKyesDMOkl2oDI9/wRU5hT5sJm9XO09RKRcA+nGX25mU8ysM9vuBtBjZpMA9GTbItKmGhp6y47snWb2ktu3GcAMt2TzSjM7r877aOitHQ0P2x9zsR9Ciyukri+mOTI4gx16MwCPkFxDsivbN8bMerM37wUwevDNFJGiNHqBbrqZ7SI5GsAKkg1fcsn+OHTVrSgihWroyG5mu7LXPgDfQ+XxiN1Z9x3Za1+Vr11kZp3uXF9ESlD3nJ3kyQDeZGavZvEKVG62nAVgj5ktJNkNoMPMFtR5L52zixSs2jl7I8k+EZWjOVDp9n/XzG4jORLAMlTumN4OYK6Z7a3yNv3vpWQXKVjTyZ4nJbtI8fQgjEjilOwiiVCyiyRCyS6SCCW7SCKU7CKJULKLJELJLpIIJbtIIpTsIolQsoskQskukgglu0gilOwiiVCyiyRCyS6SCCW7SCKU7CKJULKLJELJLpIIJbtIIpTsIolQsoskQskukgglu0giGkp2kqeSvJfk0yQ3kbyMZAfJFSS3Zq+nFd1YEWleo0f2rwH4vpmdD+AiAJsAdAPoMbNJAHqybRFpU40s7DgCwDoAE81VJrkZwAwz682WbF5pZufVeS+t9SZSsMGs9TYRwIsA/pnkEyTvzJZuHmNmvdmb9wIYnVtrRSR3jST7cQAuBvBPZjYVwC8xgC47yS6Sq0mubrKNIpKDRpJ9J4CdZrYq274XleTfnXXfkb32HemLzWyRmXWaWWceDRaR5tRNdjP7GYAdJPvPx2cBeArAgwDmZfvmAVheSAtFJBd1L9ABAMkpAO4EMAzAcwBuROUPxTIAZwLYDmCume2t8z66QCdSsGoX6BpK9rwo2UWKN5ir8SJyDFCyiyRCyS6SCCW7SCKU7CKJULKLJELJLpKI41r8eS8BeAHAqCwum9pxKLXjUO3QjoG24axqBS29qeb/P5Rc3Q73yqsdake7tyPPNqgbL5IIJbtIIspK9kUlfW6kdhxK7ThUO7QjtzaUcs4uIq2nbrxIIlqa7CRnk9xM8hmSLZuNluRdJPtIbnD7Wj4VNsnxJH+YTce9keTNZbSF5AkkHyO5LmvHF8poh2vPkGx+w4fKagfJbSTXk1zbP4VaSe0obNr2liU7ySEA/hHAVQAmA7ie5OQWffxiALPDvjKmwj4A4NNmdgGASwHclH0PWt2W/QBmmtlFAKYAmE3y0hLa0e9mVKYn71dWOy43syluqKuMdhQ3bbuZteQfgMsAPOy2bwVwaws/fwKADW57M4CxWTwWwOZWtcW1YTmAK8tsC4CTAPwEwCVltAPAuOwXeCaAh8r62QDYBmBU2NfSdgAYAeB5ZNfS8m5HK7vxZwDY4bZ3ZvvKUupU2CQnAJgKYFUZbcm6zmtRmSh0hVUmFC3je/JVAAsAvOH2ldEOA/AIyTUku0pqR6HTtrcy2Y80VU6SQwEkTwFwH4BbzOyVMtpgZgfNbAoqR9ZpJC9sdRtIzgHQZ2ZrWv3ZRzDdzC5G5TTzJpLvLqENg5q2vZ5WJvtOAOPd9jgAu1r4+VFDU2HnjeRQVBL9O2Z2f5ltAQAz2wdgJSrXNFrdjukAriG5DcA9AGaSvLuEdsDMdmWvfQC+B2BaCe0Y1LTt9bQy2R8HMInk2SSHAbgOlemoy9LyqbBJEsC3AGwys38oqy0kTyd5ahafCOAKAE+3uh1mdquZjTOzCaj8PvzAzG5odTtInkxyeH8M4L0ANrS6HVb0tO1FX/gIFxquBrAFwLMA/qaFn7sUQC+A11H56zkfwEhULgxtzV47WtCOd6Jy6vIkgLXZv6tb3RYAvwfgiawdGwD8bba/5d8T16YZ+O0FulZ/Pyaisp7hOgAb+383S/odmQJgdfazeQDAaXm1Q3fQiSRCd9CJJELJLpIIJbtIIpTsIolQsoskQskukgglu0gilOwiifg/qRFJ1tPoJKAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Plot reconstructions\n",
    "plt.imshow(reconstructions[2])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
