{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-09-25 20:13:07.250520: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# choose a particular GPU\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow.keras.layers import Dense, Flatten\n",
    "from tensorflow.keras import Model\n",
    "from tensorflow.keras.utils import plot_model\n",
    "\n",
    "import tensorflow_probability as tfp\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import random\n",
    "\n",
    "seed = 1234\n",
    "tf.random.set_seed(seed)\n",
    "os.environ['TF_DETERMINISTIC_OPS'] = 'true'\n",
    "os.environ['PYTHONHASHSEED'] = f'{seed}'\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)\n",
    "\n",
    "physical_devices = tf.config.list_physical_devices('GPU') \n",
    "tf.config.experimental.set_memory_growth(physical_devices[0], True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "stop = tf.stop_gradient\n",
    "log1mexp = tfp.math.log1mexp\n",
    "\n",
    "@tf.function\n",
    "def log_sigmoid(logits):\n",
    "    return tf.clip_by_value(tf.math.log_sigmoid(logits), clip_value_max=-1e-7, clip_value_min=-float('inf'))\n",
    "\n",
    "@tf.function\n",
    "def logaddexp(x1, x2):\n",
    "    delta = tf.where(x1 == x2, 0., x1 - x2)\n",
    "    return tf.math.maximum(x1, x2) + tf.math.softplus(-tf.math.abs(delta))\n",
    "\n",
    "@tf.function\n",
    "def topk_marginals(logprobs, k):\n",
    "    batch_size = logprobs.shape[0]\n",
    "    n = logprobs.shape[1]\n",
    "    \n",
    "    state = np.ones((batch_size, k+2)) * -float('inf')\n",
    "    state[:, 1] = 0\n",
    "    state = tf.convert_to_tensor(state, dtype=tf.float32)\n",
    "\n",
    "    a = tf.TensorArray(tf.float32, size=n+1)\n",
    "    a = a.write(0, state)\n",
    "    \n",
    "    for i in range(1, n+1):\n",
    "        \n",
    "        state = tf.concat([\n",
    "            tf.ones([batch_size, 1]) * -float('inf'), \n",
    "            logaddexp(\n",
    "                state[:, :-1] + logprobs[:, i-1:i], \n",
    "                state[:, 1:] + log1mexp(stop(logprobs[:, i-1:i]))\n",
    "            )\n",
    "        ], 1)\n",
    "        \n",
    "        a = a.write(i, state)\n",
    "    a = tf.transpose(a.stack(), perm=[1, 0, 2])\n",
    "    marginals = tf.gradients(a[:, n, k+1:k+2], logprobs)\n",
    "    \n",
    "    return marginals[0], a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def sample(a, probs):\n",
    "    \n",
    "    n = a.shape[-2] - 1\n",
    "    k = a.shape[-1] - 1\n",
    "    bsz = a.shape[0]\n",
    "\n",
    "    j = tf.fill((bsz,), k)\n",
    "    samples = tf.TensorArray(tf.int32, size=n, clear_after_read=False)\n",
    "    for i in tf.range(n, 0, -1):\n",
    "        \n",
    "        # Unnormalized probabilities of Xi and -Xi\n",
    "        full = tf.fill((bsz,), i-1)\n",
    "        p_idx = tf.stack([full, j-1], axis=1)\n",
    "        z_idx = tf.stack([full + 1, j], axis=1)\n",
    "        \n",
    "        p = tf.gather_nd(batch_dims=1, indices=p_idx, params=a)\n",
    "        z = tf.gather_nd(batch_dims=1, indices=z_idx, params=a)\n",
    "        \n",
    "        p = (p + probs[:, i-1]) - z\n",
    "        q = log1mexp(p)\n",
    "\n",
    "        # Sample according to normalized dist.\n",
    "        X = tfp.distributions.Bernoulli(logits=(p-q)).sample()\n",
    "\n",
    "        # Pick next state based on value of sample\n",
    "        j = tf.where(X>0, j - 1, j)\n",
    "\n",
    "        # Concatenate to samples\n",
    "        samples = samples.write(i-1, X)\n",
    "        #probs_ = probs_.write(i-1, p)\n",
    "        \n",
    "    samples = tf.transpose(samples.stack(), perm=[1, 0])\n",
    "    #probs_ = probs_.stack()\n",
    "    \n",
    "    # Our samples should always satisfy the constraint\n",
    "    tf.debugging.assert_equal(tf.math.reduce_sum(samples, axis=-1), k-1)\n",
    "    \n",
    "    return tf.cast(samples, tf.float32)#, probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def xexpx(x):\n",
    "    expx = tf.exp(x)\n",
    "    return tf.where(expx == 0, expx, x*expx)\n",
    "\n",
    "@tf.function\n",
    "def xexpy(x,y):\n",
    "    expy = tf.exp(y)\n",
    "    return tf.where(expy == 0, expy, x*expy)\n",
    "\n",
    "@tf.function\n",
    "def entropy(a, logprobs):\n",
    "    entropy = tf.zeros((a.shape[0], a.shape[-1]))\n",
    "    for i in range(10, a.shape[-2]):\n",
    "        \n",
    "        p_left = (a[:, i-1, :-1] + logprobs[:, i-1:i]) - a[:, i, 1:]\n",
    "        p_right = (a[:, i-1, 1:] + log1mexp(logprobs[:, i-1:i])) - a[:, i, 1:]\n",
    "        \n",
    "        entropy = tf.concat([tf.zeros((a.shape[0], 1)),\n",
    "                             xexpx(p_left) + xexpx(p_right) +\\\n",
    "                             xexpy(entropy[:, :-1], p_left) + xexpy(entropy[:, 1:], p_right)\n",
    "                            ], 1)\n",
    "    return tf.clip_by_value(-entropy[:, -1], clip_value_max=float('inf'), clip_value_min=0)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class IMLESubsetkLayer(tf.keras.layers.Layer):\n",
    "  \n",
    "    def __init__(self, _k=1, _tau=10.0, _lambda=10.0):\n",
    "        super(IMLESubsetkLayer, self).__init__()\n",
    "        \n",
    "        self.k = _k\n",
    "        self._tau = _tau\n",
    "        self._lambda = _lambda\n",
    "        self.samples = None\n",
    "        self.gumbel_dist = tfp.distributions.Gumbel(loc=0.0, scale=1.0)\n",
    "        \n",
    "    @tf.function\n",
    "    def sample_gumbel(self, shape, eps=1e-20):\n",
    "        return self.gumbel_dist.sample(shape)\n",
    "    \n",
    "    @tf.function\n",
    "    def sample_gumbel_k(self, shape):\n",
    "        \n",
    "        s = tf.map_fn(fn=lambda t: tf.random.gamma(shape, 1.0/self.k,  t/self.k), \n",
    "                  elems=tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]))\n",
    "        # now add the samples\n",
    "        s = tf.reduce_sum(s, 0)\n",
    "        # the log(m) term\n",
    "        s = s - tf.math.log(10.0)\n",
    "        # divide by k --> each s[c] has k samples whose sum is distributed as Gumbel(0, 1)\n",
    "        s = self._tau * (s / self.k)\n",
    "\n",
    "        return s\n",
    "       \n",
    "    def imle_layer(self, logits, hard=False, evaluate=False):        \n",
    "        return gumbel_softmax(logits)\n",
    "\n",
    "    def call(self, logits, hard=True, evaluate=False):\n",
    "        return self.imle_layer(logits, hard, evaluate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "PARAMS = {\n",
    "    \"batch_size\": 100,\n",
    "    \"data_dim\": 784,\n",
    "    \"M\": 20,\n",
    "    \"N\": 20,\n",
    "    \"nb_epoch\": 100, \n",
    "    \"epsilon_std\": 0.01,\n",
    "    \"anneal_rate\": 0.0003,\n",
    "    \"init_temperature\": 1.0,\n",
    "    \"min_temperature\": 0.5,\n",
    "    \"learning_rate\": 1e-3,\n",
    "    \"hard\": False,\n",
    "}\n",
    "\n",
    "class DiscreteVAE(tf.keras.Model):\n",
    "    \n",
    "    def __init__(self, params):\n",
    "        super(DiscreteVAE, self).__init__()\n",
    "        \n",
    "        self.params = params\n",
    "                \n",
    "        # encoder\n",
    "        self.enc_dense1 = tf.keras.layers.Dense(512, activation='relu')\n",
    "        self.enc_dense2 = tf.keras.layers.Dense(256, activation='relu')\n",
    "        self.enc_dense3 = tf.keras.layers.Dense(params[\"N\"]*params[\"M\"])\n",
    "        \n",
    "        # this is our new Gumbel layer\n",
    "        self.imleLayer = IMLESubsetkLayer(_k=1, _tau=10.0, _lambda=10.0)\n",
    "\n",
    "        # decoder\n",
    "        self.flatten = Flatten()\n",
    "        self.dec_dense1 = tf.keras.layers.Dense(256, activation='relu')\n",
    "        self.dec_dense2 = tf.keras.layers.Dense(512, activation='relu')\n",
    "        self.dec_dense3 = tf.keras.layers.Dense(params[\"data_dim\"])\n",
    "\n",
    "\n",
    "    def sample_gumbel(self, shape, eps=1e-20): \n",
    "        \"\"\"Sample from Gumbel(0, 1)\"\"\" \n",
    "        U = tf.random.uniform(shape, minval=0, maxval=1)\n",
    "        return -tf.math.log(-tf.math.log(U + eps) + eps)\n",
    "    \n",
    "    def gumbel_softmax_sample(self, logits, temperature): \n",
    "        \"\"\" Draw a sample from the Gumbel-Softmax distribution\"\"\"\n",
    "        # logits: [batch_size, n_class] unnormalized log-probs\n",
    "        y = logits + self.sample_gumbel(tf.shape(logits))\n",
    "        return tf.nn.softmax(y / temperature)  \n",
    "\n",
    "    def gumbel_softmax(self, logits, temperature, hard=True):\n",
    "        \"\"\"\n",
    "        logits: [batch_size, n_class] unnormalized log-probs\n",
    "        temperature: non-negative scalar\n",
    "        hard: if True, take argmax, but differentiate w.r.t. soft sample y\n",
    "        \"\"\"\n",
    "        y = self.gumbel_softmax_sample(logits, temperature)\n",
    "        if hard: \n",
    "            # \n",
    "            y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keepdims=True)),y.dtype)\n",
    "            y = tf.stop_gradient(y_hard - y) + y\n",
    "        return y\n",
    "    \n",
    "    def SIMPLE(self, logits, temperature, hard=True):\n",
    "        \"\"\"\n",
    "        input: [*, n_class]\n",
    "        return: [*, n_class] an one-hot vector\n",
    "        \"\"\"\n",
    "        y = tf.nn.softmax(logits, dim=-1)\n",
    "        y_perturbed = tf.nn.softmax(logits + self.sample_gumbel(tf.shape(logits)))\n",
    "        y_hard = tf.cast(tf.equal(y_perturbed, tf.reduce_max(y_perturbed, 1, keepdims=True)),y_perturbed.dtype)\n",
    "        y = tf.stop_gradient(y_hard - y) + y\n",
    "        return y\n",
    "    \n",
    "    \n",
    "    def decoder(self, x):\n",
    "        # decoder\n",
    "        h = self.flatten(x)\n",
    "        h = self.dec_dense1(h)\n",
    "        h = self.dec_dense2(h)\n",
    "        h = self.dec_dense3(h)\n",
    "        return h\n",
    "\n",
    "    def call(self, x, tau, hard=False, evaluate=False):\n",
    "        N = self.params[\"N\"]\n",
    "        M = self.params[\"M\"]\n",
    "\n",
    "        # encoder\n",
    "        x = self.enc_dense1(x)\n",
    "        x = self.enc_dense2(x)\n",
    "        x = self.enc_dense3(x)   # (batch, N*M)\n",
    "        logits_y = tf.reshape(x, [-1, M])   # (batch*N, M)\n",
    "\n",
    "        ###################################################################\n",
    "        ## here we toggle between methods #################################\n",
    "        # here we can switch between traditional and our method\n",
    "        # \"traditional\" Gumbel Softmax trick\n",
    "        y = self.gumbel_softmax(logits=logits_y, temperature=1.0, hard=True)\n",
    "#         y = self.SIMPLE(logits=logits_y, temperature=tau, hard=True)\n",
    "        # IMLE approach -- note: we don't anneal so set temperature once at init\n",
    "#         y = self.imleLayer(logits=logits_y, hard=True, evaluate=evaluate)\n",
    "        ###################################################################\n",
    "        \n",
    "        assert y.shape == (self.params[\"batch_size\"]*N, M)\n",
    "        y = tf.reshape(y, [-1, N, M])\n",
    "        self.sample_y = y\n",
    "\n",
    "        # decoder\n",
    "        logits_x = self.decoder(y)\n",
    "        return logits_y, logits_x\n",
    "\n",
    "\n",
    "def gumbel_loss(model, x, tau, hard=True, evaluate=False):\n",
    "    M = 20\n",
    "    N = 20\n",
    "    data_dim = PARAMS['data_dim']\n",
    "    logits_y, logits_x = model(x, tau, hard, evaluate)\n",
    "    \n",
    "    # cross-entropy\n",
    "    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=logits_x)\n",
    "    cross_ent = tf.math.reduce_sum(cross_ent, 1)\n",
    "    cross_ent = tf.math.reduce_mean(cross_ent, 0)\n",
    "    \n",
    "    if evaluate:\n",
    "        # KL loss\n",
    "        logprobs_q = log_sigmoid(logits_y)\n",
    "        marginals_q, a_q = topk_marginals(logprobs_q, 10)\n",
    "        a_q = tf.where(a_q == -float('inf'), -1000., a_q)\n",
    "        q_entropy = entropy(a_q, logprobs_q)\n",
    "        kl = tf.math.log(184756.) - tf.reshape(q_entropy, [-1,N])\n",
    "        kl = tf.math.reduce_sum(kl, 1)\n",
    "        KL_mean = tf.math.reduce_mean(kl)\n",
    "    else: \n",
    "        # KL loss\n",
    "        q_y = tf.nn.softmax(logits_y)   # (batshsize*N, M)  softmax\n",
    "        log_q_y = tf.math.log(q_y + 1e-20)   # (batshsize*N, M)  \n",
    "        kl_tmp = tf.reshape(q_y*(log_q_y-tf.math.log(1.0/M)), [-1,N,M])  # (batch_size,N,K)\n",
    "        KL = tf.math.reduce_sum(kl_tmp, [1, 2])    # shape=(batch_size, 1)\n",
    "\n",
    "        KL_mean = tf.math.reduce_mean(KL)\n",
    "    return cross_ent + KL_mean\n",
    "\n",
    "\n",
    "def compute_gradients(model, x, tau, hard):\n",
    "    with tf.GradientTape() as tape:\n",
    "        loss = gumbel_loss(model, x, tau, hard)\n",
    "    return tape.gradient(loss, model.trainable_variables), loss\n",
    "\n",
    "\n",
    "def apply_gradients(optimizer, gradients, variables):\n",
    "    optimizer.apply_gradients(zip(gradients, variables))\n",
    "\n",
    "\n",
    "def get_learning_rate(step, init=PARAMS[\"learning_rate\"]):\n",
    "    return tf.convert_to_tensor(init * pow(0.95, (step / 1000.)), dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-09-25 20:13:10.278269: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2022-09-25 20:13:10.869974: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 16539 MB memory:  -> device: 0, name: NVIDIA TITAN RTX, pci bus id: 0000:5e:00.0, compute capability: 7.5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 , TRAIN loss: 167.57878 , Temperature: 0.9997000449955004\n",
      "Eval Loss: 275.96844 \n",
      "\n",
      "Epoch: 2 , TRAIN loss: 148.87381 , Temperature: 0.9991004048785274\n",
      "Eval Loss: 268.0066 \n",
      "\n",
      "Epoch: 3 , TRAIN loss: 141.92842 , Temperature: 0.9982016190284373\n",
      "Eval Loss: 263.31644 \n",
      "\n",
      "Epoch: 4 , TRAIN loss: 139.38332 , Temperature: 0.997004495503373\n",
      "Eval Loss: 260.47855 \n",
      "\n",
      "Epoch: 5 , TRAIN loss: 132.81133 , Temperature: 0.9955101098295706\n",
      "Eval Loss: 257.94675 \n",
      "\n",
      "Epoch: 6 , TRAIN loss: 131.08289 , Temperature: 0.9937198033910547\n",
      "Eval Loss: 256.41223 \n",
      "\n",
      "Epoch: 7 , TRAIN loss: 133.6204 , Temperature: 0.9916351814230984\n",
      "Eval Loss: 254.60652 \n",
      "\n",
      "Epoch: 8 , TRAIN loss: 128.53612 , Temperature: 0.9892581106136482\n",
      "Eval Loss: 254.405 \n",
      "\n",
      "Epoch: 9 , TRAIN loss: 125.53764 , Temperature: 0.9865907163177327\n",
      "Eval Loss: 254.13934 \n",
      "\n",
      "Epoch: 10 , TRAIN loss: 120.20479 , Temperature: 0.9836353793906725\n",
      "Eval Loss: 253.12982 \n",
      "\n",
      "Epoch: 11 , TRAIN loss: 125.92363 , Temperature: 0.9803947326466972\n",
      "Eval Loss: 253.01585 \n",
      "\n",
      "Epoch: 12 , TRAIN loss: 119.68844 , Temperature: 0.9768716569503434\n",
      "Eval Loss: 252.76945 \n",
      "\n",
      "Epoch: 13 , TRAIN loss: 120.12447 , Temperature: 0.9730692769487556\n",
      "Eval Loss: 252.32515 \n",
      "\n",
      "Epoch: 14 , TRAIN loss: 120.463806 , Temperature: 0.9689909564537397\n",
      "Eval Loss: 252.26645 \n",
      "\n",
      "Epoch: 15 , TRAIN loss: 118.15303 , Temperature: 0.964640293483123\n",
      "Eval Loss: 252.63042 \n",
      "\n",
      "Epoch: 16 , TRAIN loss: 118.21389 , Temperature: 0.9600211149716509\n",
      "Eval Loss: 251.00201 \n",
      "\n",
      "Epoch: 17 , TRAIN loss: 114.58075 , Temperature: 0.9551374711623026\n",
      "Eval Loss: 251.6128 \n",
      "\n",
      "Epoch: 18 , TRAIN loss: 113.67122 , Temperature: 0.9499936296895314\n",
      "Eval Loss: 250.36156 \n",
      "\n",
      "Epoch: 19 , TRAIN loss: 119.891754 , Temperature: 0.9445940693665233\n",
      "Eval Loss: 250.33348 \n",
      "\n",
      "Epoch: 20 , TRAIN loss: 113.923996 , Temperature: 0.9389434736891332\n",
      "Eval Loss: 249.9696 \n",
      "\n",
      "Epoch: 21 , TRAIN loss: 113.672 , Temperature: 0.9330467240696794\n",
      "Eval Loss: 249.1542 \n",
      "\n",
      "Epoch: 22 , TRAIN loss: 123.18997 , Temperature: 0.9269088928142736\n",
      "Eval Loss: 248.89816 \n",
      "\n",
      "Epoch: 23 , TRAIN loss: 121.415115 , Temperature: 0.9205352358578188\n",
      "Eval Loss: 248.8967 \n",
      "\n",
      "Epoch: 24 , TRAIN loss: 115.64214 , Temperature: 0.9139311852712282\n",
      "Eval Loss: 248.58832 \n",
      "\n",
      "Epoch: 25 , TRAIN loss: 111.44382 , Temperature: 0.9071023415558017\n",
      "Eval Loss: 247.68117 \n",
      "\n",
      "Epoch: 26 , TRAIN loss: 118.77994 , Temperature: 0.9000544657400421\n",
      "Eval Loss: 247.86798 \n",
      "\n",
      "Epoch: 27 , TRAIN loss: 115.846725 , Temperature: 0.8927934712944959\n",
      "Eval Loss: 248.46979 \n",
      "\n",
      "Epoch: 28 , TRAIN loss: 113.15054 , Temperature: 0.8853254158804752\n",
      "Eval Loss: 247.64012 \n",
      "\n",
      "Epoch: 29 , TRAIN loss: 112.81449 , Temperature: 0.8776564929487385\n",
      "Eval Loss: 247.53032 \n",
      "\n",
      "Epoch: 30 , TRAIN loss: 110.62661 , Temperature: 0.8697930232043984\n",
      "Eval Loss: 247.43657 \n",
      "\n",
      "Epoch: 31 , TRAIN loss: 119.77542 , Temperature: 0.8617414459544691\n",
      "Eval Loss: 247.37262 \n",
      "\n",
      "Epoch: 32 , TRAIN loss: 115.396675 , Temperature: 0.8535083103545701\n",
      "Eval Loss: 247.04443 \n",
      "\n",
      "Epoch: 33 , TRAIN loss: 113.031044 , Temperature: 0.8451002665713722\n",
      "Eval Loss: 246.84024 \n",
      "\n",
      "Epoch: 34 , TRAIN loss: 107.69496 , Temperature: 0.8365240568773926\n",
      "Eval Loss: 246.95438 \n",
      "\n",
      "Epoch: 35 , TRAIN loss: 111.42847 , Temperature: 0.8277865066947337\n",
      "Eval Loss: 246.73276 \n",
      "\n",
      "Epoch: 36 , TRAIN loss: 110.345024 , Temperature: 0.8188945156043044\n",
      "Eval Loss: 246.71602 \n",
      "\n",
      "Epoch: 37 , TRAIN loss: 112.8344 , Temperature: 0.8098550483369699\n",
      "Eval Loss: 246.2143 \n",
      "\n",
      "Epoch: 38 , TRAIN loss: 113.39542 , Temperature: 0.8006751257629464\n",
      "Eval Loss: 246.6559 \n",
      "\n",
      "Epoch: 39 , TRAIN loss: 112.12604 , Temperature: 0.791361815895584\n",
      "Eval Loss: 246.22522 \n",
      "\n",
      "Epoch: 40 , TRAIN loss: 110.71779 , Temperature: 0.7819222249254774\n",
      "Eval Loss: 245.8085 \n",
      "\n",
      "Epoch: 41 , TRAIN loss: 111.44801 , Temperature: 0.7723634883006051\n",
      "Eval Loss: 246.61232 \n",
      "\n",
      "Epoch: 42 , TRAIN loss: 109.40027 , Temperature: 0.7626927618679156\n",
      "Eval Loss: 246.31946 \n",
      "\n",
      "Epoch: 43 , TRAIN loss: 112.24536 , Temperature: 0.7529172130914742\n",
      "Eval Loss: 246.01541 \n",
      "\n",
      "Epoch: 44 , TRAIN loss: 116.43471 , Temperature: 0.74304401236194\n",
      "Eval Loss: 245.80774 \n",
      "\n",
      "Epoch: 45 , TRAIN loss: 109.36332 , Temperature: 0.7330803244117685\n",
      "Eval Loss: 246.08221 \n",
      "\n",
      "Epoch: 46 , TRAIN loss: 110.24271 , Temperature: 0.7230332998501351\n",
      "Eval Loss: 246.0701 \n",
      "\n",
      "Epoch: 47 , TRAIN loss: 109.77754 , Temperature: 0.7129100668311394\n",
      "Eval Loss: 246.16573 \n",
      "\n",
      "Epoch: 48 , TRAIN loss: 102.02682 , Temperature: 0.7027177228683977\n",
      "Eval Loss: 246.06557 \n",
      "\n",
      "Epoch: 49 , TRAIN loss: 104.21529 , Temperature: 0.6924633268086435\n",
      "Eval Loss: 245.67769 \n",
      "\n",
      "Epoch: 50 , TRAIN loss: 108.01788 , Temperature: 0.6821538909764523\n",
      "Eval Loss: 245.79924 \n",
      "\n",
      "Epoch: 51 , TRAIN loss: 108.547516 , Temperature: 0.6717963735016784\n",
      "Eval Loss: 246.0104 \n",
      "\n",
      "Epoch: 52 , TRAIN loss: 111.35535 , Temperature: 0.6613976708406429\n",
      "Eval Loss: 245.2125 \n",
      "\n",
      "Epoch: 53 , TRAIN loss: 102.299576 , Temperature: 0.6509646105015452\n",
      "Eval Loss: 245.44588 \n",
      "\n",
      "Epoch: 54 , TRAIN loss: 108.09459 , Temperature: 0.6405039439839885\n",
      "Eval Loss: 245.12508 \n",
      "\n",
      "Epoch: 55 , TRAIN loss: 104.77628 , Temperature: 0.6300223399419125\n",
      "Eval Loss: 245.69858 \n",
      "\n",
      "Epoch: 56 , TRAIN loss: 105.351685 , Temperature: 0.6195263775786136\n",
      "Eval Loss: 245.62045 \n",
      "\n",
      "Epoch: 57 , TRAIN loss: 105.846634 , Temperature: 0.609022540281914\n",
      "Eval Loss: 245.96759 \n",
      "\n",
      "Epoch: 58 , TRAIN loss: 107.42676 , Temperature: 0.5985172095069092\n",
      "Eval Loss: 245.26697 \n",
      "\n",
      "Epoch: 59 , TRAIN loss: 103.14381 , Temperature: 0.5880166589130854\n",
      "Eval Loss: 245.44359 \n",
      "\n",
      "Epoch: 60 , TRAIN loss: 111.09474 , Temperature: 0.5775270487619548\n",
      "Eval Loss: 245.33264 \n",
      "\n",
      "Epoch: 61 , TRAIN loss: 105.67869 , Temperature: 0.5670544205807092\n",
      "Eval Loss: 245.30548 \n",
      "\n",
      "Epoch: 62 , TRAIN loss: 105.16727 , Temperature: 0.556604692096744\n",
      "Eval Loss: 244.94482 \n",
      "\n",
      "Epoch: 63 , TRAIN loss: 104.64652 , Temperature: 0.5461836524472542\n",
      "Eval Loss: 245.26328 \n",
      "\n",
      "Epoch: 64 , TRAIN loss: 107.73546 , Temperature: 0.5357969576674562\n",
      "Eval Loss: 245.57875 \n",
      "\n",
      "Epoch: 65 , TRAIN loss: 106.4305 , Temperature: 0.5254501264603462\n",
      "Eval Loss: 244.74532 \n",
      "\n",
      "Epoch: 66 , TRAIN loss: 104.412476 , Temperature: 0.5151485362502642\n",
      "Eval Loss: 245.97902 \n",
      "\n",
      "Epoch: 67 , TRAIN loss: 112.7644 , Temperature: 0.5048974195219025\n",
      "Eval Loss: 245.34612 \n",
      "\n",
      "Epoch: 68 , TRAIN loss: 103.975876 , Temperature: 0.5\n",
      "Eval Loss: 245.55264 \n",
      "\n",
      "Epoch: 69 , TRAIN loss: 107.00073 , Temperature: 0.5\n",
      "Eval Loss: 245.55916 \n",
      "\n",
      "Epoch: 70 , TRAIN loss: 108.850266 , Temperature: 0.5\n",
      "Eval Loss: 245.6999 \n",
      "\n",
      "Epoch: 71 , TRAIN loss: 108.56678 , Temperature: 0.5\n",
      "Eval Loss: 245.30513 \n",
      "\n",
      "Epoch: 72 , TRAIN loss: 106.48784 , Temperature: 0.5\n",
      "Eval Loss: 245.422 \n",
      "\n",
      "Epoch: 73 , TRAIN loss: 108.406746 , Temperature: 0.5\n",
      "Eval Loss: 245.57533 \n",
      "\n",
      "Epoch: 74 , TRAIN loss: 104.590164 , Temperature: 0.5\n",
      "Eval Loss: 245.91678 \n",
      "\n",
      "Epoch: 75 , TRAIN loss: 112.19656 , Temperature: 0.5\n",
      "Eval Loss: 245.37482 \n",
      "\n",
      "Epoch: 76 , TRAIN loss: 105.566444 , Temperature: 0.5\n",
      "Eval Loss: 245.3584 \n",
      "\n",
      "Epoch: 77 , TRAIN loss: 102.401085 , Temperature: 0.5\n",
      "Eval Loss: 245.53618 \n",
      "\n",
      "Epoch: 78 , TRAIN loss: 108.70701 , Temperature: 0.5\n",
      "Eval Loss: 245.55301 \n",
      "\n",
      "Epoch: 79 , TRAIN loss: 108.226715 , Temperature: 0.5\n",
      "Eval Loss: 245.92863 \n",
      "\n",
      "Epoch: 80 , TRAIN loss: 104.48426 , Temperature: 0.5\n",
      "Eval Loss: 245.50975 \n",
      "\n",
      "Epoch: 81 , TRAIN loss: 104.50169 , Temperature: 0.5\n",
      "Eval Loss: 244.88603 \n",
      "\n",
      "Epoch: 82 , TRAIN loss: 110.760376 , Temperature: 0.5\n",
      "Eval Loss: 245.74776 \n",
      "\n",
      "Epoch: 83 , TRAIN loss: 106.4659 , Temperature: 0.5\n",
      "Eval Loss: 245.95004 \n",
      "\n",
      "Epoch: 84 , TRAIN loss: 104.75502 , Temperature: 0.5\n",
      "Eval Loss: 245.39891 \n",
      "\n",
      "Epoch: 85 , TRAIN loss: 101.44387 , Temperature: 0.5\n",
      "Eval Loss: 245.32735 \n",
      "\n",
      "Epoch: 86 , TRAIN loss: 105.50177 , Temperature: 0.5\n",
      "Eval Loss: 245.17278 \n",
      "\n",
      "Epoch: 87 , TRAIN loss: 104.82187 , Temperature: 0.5\n",
      "Eval Loss: 245.18301 \n",
      "\n",
      "Epoch: 88 , TRAIN loss: 104.77376 , Temperature: 0.5\n",
      "Eval Loss: 245.85846 \n",
      "\n",
      "Epoch: 89 , TRAIN loss: 104.11487 , Temperature: 0.5\n",
      "Eval Loss: 245.27525 \n",
      "\n",
      "Epoch: 90 , TRAIN loss: 102.89894 , Temperature: 0.5\n",
      "Eval Loss: 245.07867 \n",
      "\n",
      "Epoch: 91 , TRAIN loss: 105.77638 , Temperature: 0.5\n",
      "Eval Loss: 245.05559 \n",
      "\n",
      "Epoch: 92 , TRAIN loss: 106.55687 , Temperature: 0.5\n",
      "Eval Loss: 245.19058 \n",
      "\n",
      "Epoch: 93 , TRAIN loss: 108.10065 , Temperature: 0.5\n",
      "Eval Loss: 245.2606 \n",
      "\n",
      "Epoch: 94 , TRAIN loss: 109.73303 , Temperature: 0.5\n",
      "Eval Loss: 246.13712 \n",
      "\n",
      "Epoch: 95 , TRAIN loss: 107.79889 , Temperature: 0.5\n",
      "Eval Loss: 244.88623 \n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 96 , TRAIN loss: 107.25028 , Temperature: 0.5\n",
      "Eval Loss: 245.30112 \n",
      "\n",
      "Epoch: 97 , TRAIN loss: 104.02638 , Temperature: 0.5\n",
      "Eval Loss: 244.64043 \n",
      "\n",
      "Epoch: 98 , TRAIN loss: 101.71076 , Temperature: 0.5\n",
      "Eval Loss: 245.45157 \n",
      "\n",
      "Epoch: 99 , TRAIN loss: 104.31462 , Temperature: 0.5\n",
      "Eval Loss: 244.75468 \n",
      "\n",
      "Epoch: 100 , TRAIN loss: 102.49172 , Temperature: 0.5\n",
      "Eval Loss: 244.9499 \n",
      "\n",
      "CPU times: user 20min 13s, sys: 15.7 s, total: 20min 29s\n",
      "Wall time: 20min 10s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model = DiscreteVAE(PARAMS)\n",
    "learning_rate = tf.Variable(PARAMS[\"learning_rate\"], trainable=False, name=\"LR\")\n",
    "\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
    "\n",
    "# data\n",
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "x_train = x_train.astype('float32') / 255.\n",
    "x_test = x_test.astype('float32') / 255.\n",
    "x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))\n",
    "x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))\n",
    "\n",
    "TRAIN_BUF = 60000\n",
    "BATCH_SIZE = 100\n",
    "TEST_BUF = 10000\n",
    "\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(TRAIN_BUF).batch(BATCH_SIZE)\n",
    "test_dataset = tf.data.Dataset.from_tensor_slices(x_test).shuffle(TEST_BUF).batch(BATCH_SIZE)\n",
    "\n",
    "# temperature\n",
    "tau = PARAMS[\"init_temperature\"]\n",
    "anneal_rate = PARAMS[\"anneal_rate\"]\n",
    "min_temperature = PARAMS[\"min_temperature\"]\n",
    "\n",
    "results = []\n",
    "\n",
    "# Train\n",
    "for epoch in range(1, PARAMS[\"nb_epoch\"] + 1):\n",
    "    \n",
    "    # this is only needed for the standard Gumbel softmax trick\n",
    "    tau = np.maximum(tau * np.exp(-anneal_rate*epoch), min_temperature)\n",
    "\n",
    "    for train_x in train_dataset:\n",
    "        gradients, loss = compute_gradients(model, train_x, tau, hard=PARAMS[\"hard\"])\n",
    "        apply_gradients(optimizer, gradients, model.trainable_variables)\n",
    "\n",
    "    print(\"Epoch:\", epoch, \", TRAIN loss:\", loss.numpy(), \", Temperature:\", tau)\n",
    "\n",
    "    if epoch % 1 == 0:\n",
    "        losses = []\n",
    "        for test_x in test_dataset:\n",
    "            losses.append(gumbel_loss(model, test_x, tau, hard=True, evaluate=True))\n",
    "        eval_loss = np.mean(losses)\n",
    "        results.append(eval_loss)\n",
    "        print(\"Eval Loss:\", eval_loss, \"\\n\")\n",
    "\n",
    "    if PARAMS['hard'] == True:\n",
    "        model.save_weights(\"model.h5\")\n",
    "    else:\n",
    "        model.save_weights(\"model_hard.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "SIMPLE_res = [263.7302 ,251.18088 ,246.33345 ,243.87569 ,242.23654 ,240.63268 ,240.3769 ,238.67422, 238.54333 ,237.74467 ,237.18701 ,236.66383 ,237.05296 ,236.54428 ,236.6843 ,235.88643 ,235.71437 ,236.06479 ,235.67657 ,235.67119 ,235.66957 ,235.24947 ,235.27081 ,235.0676 ,235.25371 ,235.26814 ,234.78458 ,234.38338 ,234.84767 ,234.6623 ,234.63838 ,234.79012 ,234.6961 ,234.56339 ,234.71466 ,234.62645 ,234.55602 ,234.53557 ,234.48553 ,234.41893 ,234.4857 ,234.5761 ,234.43562 ,234.36523 ,234.36906 ,234.41887 ,234.56465 ,234.54298 ,234.51425 ,234.19354 ,234.50745 ,234.51003 ,234.2202 ,234.34398 ,234.38408 ,234.62442 ,234.2286 ,234.39273 ,234.55411 ,234.48737 ,234.60356 ,234.669 ,234.60332 ,234.96918 ,235.08226 ,234.69551 ,234.54205 ,234.5958 ,234.86174 ,234.7789 ,235.04431 ,234.75188 ,234.86469 ,234.9746 ,234.62251 ,234.81293 ,234.79703 ,235.23468 ,234.97758 ,234.77356 ,235.05292 ,234.9518 ,234.73944 ,235.00218 ,234.86794 ,235.16379 ,234.95787 ,234.67206 ,234.95192 ,235.25476 ,234.97649 ,235.29564 ,235.581 ,234.95891 ,235.65594 ,235.47629 ,235.41869 ,235.53271 ,235.37158 ,235.3255]\n",
    "gumbel_softmax = [275.96844 ,268.0066 ,263.31644 ,260.47855 ,257.94675 ,256.41223 ,254.60652 ,254.405 ,254.13934 ,253.12982 ,253.01585 ,252.76945 ,252.32515 ,252.26645 ,252.63042 ,251.00201 ,251.6128 ,250.36156 ,250.33348 ,249.9696 ,249.1542 ,248.89816 ,248.8967 ,248.58832 ,247.68117 ,247.86798 ,248.46979 ,247.64012 ,247.53032 ,247.43657 ,247.37262 ,247.04443 ,246.84024 ,246.95438 ,246.73276 ,246.71602 ,246.2143 ,246.6559 ,246.22522 ,245.8085 ,246.61232 ,246.31946 ,246.01541 ,245.80774 ,246.08221 ,246.0701 ,246.16573 ,246.06557 ,245.67769 ,245.79924 ,246.0104 ,245.2125 ,245.44588 ,245.12508 ,245.69858 ,245.62045 ,245.96759 ,245.26697 ,245.44359 ,245.33264 ,245.30548 ,244.94482 ,245.26328 ,245.57875 ,244.74532 ,245.97902 ,245.34612 ,245.55264 ,245.55916 ,245.6999 ,245.30513 ,245.422 ,245.57533 ,245.91678 ,245.37482 ,245.3584 ,245.53618 ,245.55301 ,245.92863 ,245.50975 ,244.88603 ,245.74776 ,245.95004 ,245.39891 ,245.32735 ,245.17278 ,245.18301 ,245.85846 ,245.27525 ,245.07867 ,245.05559 ,245.19058 ,245.2606 ,246.13712 ,244.88623 ,245.30112 ,244.64043 ,245.45157 ,244.75468 ,244.9499]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAACvCAYAAADE6tqMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA6nElEQVR4nO2dd3hUZdqH7+mTSTLpFUhCMSGA1ACKoBJQUAnFVUAW4VMXdUWx7i7Ltx8oCIplURE3uouuri7uohAQVIqggvQmvSUkJKT3TCbTz/fHwCEhFQiZQN77unJdc+a855zfOZn5zVufRyFJkoRAIBC0MEpPCxAIBG0TYT4CgcAjCPMRCAQeQZiPQCDwCMJ8BAKBRxDmIxAIPIIwn+uQ2bNns2TJEk/LaJVkZWURFxeHw+HwtBRBIwjzaWUkJibSs2dP+vTpQ0JCAhMnTmTZsmW4XC65zNy5c5k+fXqLadq5cye33377FR9vs9mYMWMGiYmJxMXFsXPnzkaPSU5OJjExkT59+nD77bfz3HPPXfH1m4OHH36Y5cuX17nParWSkJDA9u3ba+1bsGABM2bMkLcXL15MXFwcBw8erFFuxYoVxMfH06dPnxp/eXl5zXsjrQi1pwUIapOcnMygQYOoqKhg165dzJ8/n4MHD/Laa69dk+s5HA7U6mv7Uejbty9TpkxpkomsXLmSVatW8c9//pOoqCgKCgrYtGnTNdV3Neh0Ou69915WrVrFrbfeKr/vdDpZu3Yt8+bNA0CSJFatWoW/vz8rV66kZ8+eNc7Tu3dvli1b1qLaPYmo+bRifH19GTZsGO+88w4rV67k5MmTAMycOZNFixYBUFxczBNPPEFCQgIDBgxg0qRJci0pJyeHp59+mltuuYWBAwcyd+5cwP0rO3HiRBYsWMCAAQNYvHgxNpuNhQsXcueddzJo0CBmz56NxWLBbDYzbdo08vPza/wau1wuPvroI4YPH87AgQN59tlnKS0trfM+tFot//M//0NCQgJKZeMfuUOHDjF48GCioqIACAkJYcKECfL+xMREtm3bJm8vXryYl156qcY5vv76awYPHszgwYP5+OOP5fcPHjzI/fffT9++fRk0aFANQz9w4AATJ04kISGB0aNHyzW0RYsWsWfPHubOnUufPn3k51idsWPHsm7dOqqqquT3tm7disvlkmuNe/bsIT8/n1mzZvHtt99is9kafRY3MsJ8rgN69uxJeHg4e/bsqbXvk08+ISwsjO3bt/PLL7/wwgsvoFAocDqdPPHEE0RGRrJp0yZ+/vln7r33Xvm4gwcP0qFDB7Zt28bvf/973nzzTc6cOUNKSgrr168nPz+fJUuWYDAY+Pvf/05oaCj79+9n//79hIWF8dlnn7Fx40Y+//xztmzZgp+fX51fyiuhV69erFq1in/84x8cOnQIp9N52efYuXMn69evZ+nSpXz00UeyWc2fP58pU6awb98+NmzYwD333ANAXl4eTzzxBL///e/ZtWsXf/rTn5gxYwbFxcU8//zzJCQkMHv2bPbv38/s2bNrXa9v376Ehoayfv16+b1Vq1YxatQouVa5cuVKhg4dKv8ffvzxx8u+rxsJYT7XCaGhoZSVldV6X61WU1BQQHZ2NhqNhoSEBBQKBQcPHiQ/P58//vGPGAwGdDodCQkJNc738MMPo1ar0el0LF++nFmzZuHv74+Pjw9PPPEEa9eurVfPf/7zH55//nnCw8PRarU8/fTTrFu3rlk6eseMGcNf/vIXtm7dysMPP8ygQYP46KOPLusc06dPx2AwEBcXx/3338+aNWsA9/M6e/YsxcXFeHt707t3b8BtFLfffjt33HEHSqWS2267jR49evDTTz9dlu5Vq1YBYDKZ+OGHHxg3bhwAVVVVfP/99yQlJaHRaBgxYgQrV66scfyvv/5KQkKC/Dd8+PDLuufrDdHnc52Ql5eHn59frfcfe+wx3n//fR599FEAJkyYwOOPP05OTg6RkZH19uWEh4fLr4uLi6mqquL++++X35MkqUYn96VkZ2czffr0Gs0opVJJUVERYWFhTb6v7Oxs7rvvPnl7//79AIwePZrRo0djt9vZuHEjf/jDH4iPj2fIkCFNOm9ERIT8ul27dnKTdf78+bz33nvcc889tG/fnqeffpqhQ4eSnZ3N999/z+bNm+XjHA4HAwcObPK9jB07liVLlpCXl8eWLVuIioqiW7duAGzYsAG1Wi03wZKSknjkkUcoLi4mMDAQcNf42lKfjzCf64CDBw+Sl5dHv379au3z8fFh5syZzJw5k1OnTjFlyhRuvvlmIiIiyMnJqbczWaFQyK8DAgLQ6/WsXbu2TuOoXvYC4eHhLFiwoE5Nl0NkZKRsOHWh0Wi45557+Pvf/86pU6cYMmQIXl5eNfpWCgoKah2Xk5ND586dAbfBhYaGAhATE8Nf//pXXC4X69evZ8aMGezcuZOIiAjGjBnDq6++elX30q9fP1avXs3PP//M2LFj5X0pKSmYzWaGDh0KuM3dbrezZs0apkyZcsXXvJ4Rza5WjMlkYvPmzbzwwguMHj2auLi4WmU2b95MRkYGkiTh4+ODSqVCqVTSs2dPQkJCePvttzGbzVitVvbu3VvndZRKJQ8++CALFiygqKgIQP71BggKCqK0tJSKigr5mIceeoh33nmHc+fOAe7a08aNG+u9F5vNhtVqBcBut2O1WqkvmsuKFSv48ccfMZlMuFwufvrpJ06fPi2PDnXt2pVvv/0Wu93OoUOHWLduXa1zfPDBB1RVVXHq1ClWrFgh97OsWrWK4uJilEolRqMRAJVKxejRo9m8eTNbtmzB6XRitVrZuXMnubm5AAQHB5OZmVnv/V1g3LhxfPHFF+zfv5+kpCT5WW7fvp3k5GRSUlJISUlh1apVTJs2jZSUlEbPeaMiaj6tkCeffFI2kS5duvDII48wceLEOstmZGQwb948iouLMRqNPPTQQ3JTITk5mVdffVX+tU1KSqq3pvKHP/yBJUuWMH78eEpKSggLC+Ohhx5iyJAhdO7cmfvuu4/hw4fLw8dTpkxBkiQeffRR8vPzCQoK4t577623n2LkyJGyUT322GMA/PDDD7Rv375WWR8fH5KTk0lNTcXpdNKuXTtefvlluc/queee44UXXmDAgAH079+fpKSkWiNtAwYM4K677pI1Dh48GIAtW7bw+uuvY7FYiIyMZNGiReh0OiIiIvjggw948803efHFF2UDf/nllwGYMmUKM2fOZNmyZXKfVF2MGDGCuXPncsstt8i1rVWrVhEfHy9ruMDDDz/MJ598IjcJDxw4QJ8+fWqU+fTTT2sNyd8oKEQwMYFA4AlEs0sgEHgEYT4CgcAjCPMRCAQeQZiPQCDwCDfsaJfL5aKyshKNRlPnPBWBQHBtuDCHydvbu8G1fDes+VRWVspDmAKBoOWJjY3F19e33v03rPloNBrA/QC0Wm2DZQ8fPkyPHj1aQtYVIzQ2D0Jj89CQRpvNxsmTJ+XvYH3csOZzoaml1WrR6XSNlm9KGU8jNDYPQmPz0JjGxro7bljzaQqW4lLytu3Gy2FH6uNC0YRYMwKBoHlo09+28tR0bGXlaCqrMGXleFqOQNCmaNM1H4XyYrXQYTJ5UMn1g91uJysrC4vFUmufWq3m2LFjHlDVdITG5kGlUnHmzBnat2/faN9OfbRp81EbvOTXDnNVAyUFF8jKysLX15eYmJhabfrKykq8vb09pKxpCI3Ng8lkwmKxkJWVRceOHa/oHG262aX2umg+dmE+TcJisRAUFCTmTrVxFAoFQUFBddaAm0rbNh9R87kihPEI4Oo/B8J8ziPMRyBoWdp2n0+1ZpfTYkFyieH26w273U5ycjJr1qxBrVajVquJjo5mxowZdOnSpdmuk5iYSHJyMrGxsZd1XFxcHPv27avVh2Oz2ViwYAG7d+9GqVQiSRJPPPGEHP2wPkpKSvj9739PVVUVSUlJqNVqkpKSCAoKuux78jRt2nwUKiUqvQ6nxQoSOCwWNAaDp2UJLoM///nPWCwWli9fjtFoRJIkvv/+e1JTU5vVfJqbzz77jNLSUlavXo1KpaKysrLOWNSXsn37doxGI19++SXgNsVBgwYJ87keURu83OaDu+klzOf6IT09nY0bN/LTTz/J8ZgVCoWciwvcCRZ79OjB5MmTAZgzZw69e/dm8uTJzJw5E61WS3p6OpmZmdx1110MHTqUxYsXk5uby9SpU5k6dap8rm+++YZ9+/aRn5/P1KlT5XOmpaWxYMECSkpKsNvtTJ06ld/85jcNas/NzSU4OBiVSgWAt7e3XDtyOp0sXLhQjqE9ZMgQXnrpJXbv3s0bb7yByWRizJgx3HXXXeTn5zNjxgx0Oh1vv/023333HWlpaZhMJtLT0+nevTuPP/44r7/+OtnZ2dx111386U9/AuDjjz9m7dq1OJ1OdDodL7/8MvHx8aSmpvLoo4/y73//m3bt2rF48WLS0tLkRJXNhTAfgxfW4lJA9PtcCSXHTlJ08BhSM+TruhSFWk1Qz3gC4utu6hw9epTo6Og6Uwo1lVOnTvHpp5/idDpJTEykoqKCzz//nIKCAkaOHMkDDzwgm0JhYSFffPEFhYWFjB07loSEBLp06cJLL73Em2++SefOnTGZTPzmN7+hd+/ecvaMunjwwQd57LHH2Llzp5yP/kL86xUrVnDs2DFWrFgBwLRp0/jPf/7DpEmTmDFjBj/++CPvvfeeXPa9996r0Rw8cuQIX3/9NQaDgXHjxvH222/zj3/8A4fDwbBhw5gwYQIxMTGMHTtWTrm0bds25syZw3//+186d+7M888/z/PPP8+MGTNYs2YNX3/99RU/4/oQ5iM6na+KkmOnronxAEgOByXHTtVrPpdy+vRpXnzxRSwWC0OGDKk3yHt1hg8fLi887tixo5w0MCwsDKPRSG5urmwiDzzwAODOZHHnnXeya9cu1Go1qampvPDCC/I57XY7aWlpDZpPXFwcP/zwA7t372bv3r3MmzePn3/+mblz57Jz507GjRsn67r//vvZuHEjkyZNatJzGDx4sLyaPC4ujq5du6LVatFqtXTs2JGzZ88SExPD4cOH+fDDDykrK0OhUJCeni6fY+zYsezYsYPp06fzxRdf4OPj06RrXw7CfKo1s4T5XD4B8Tdd05pPQPxN9e7v1q0bGRkZlJeXYzQa6dKlC6tWreLzzz/n8OHDgHsmbvXkhxfS91yg+uJIlUpVa7u+VM2SJKFQKJAkiYCAADlT6eWg0+nkfPJ33nknjzzyCHPnzpXPXZ3LGdZuyj3ZbDaeffZZPv/8c7p3705eXp6c0BDcHeKnTp3C19dXTqfU3Ajz8dLLr4X5XD4B8bFyzaSlZ+bGxMQwbNgw/vKXvzB//nz5195sNstloqKiOHToEAD5+fns2bPnsrKQVmflypX069eP4uJifv75Z6ZMmULHjh3R6/WkpKTISQJTU1MJCwtrsLawZ88eYmJiCA4OBtxNpQtphG655RZWrlwp912lpKRw991313keb2/vGvnUmorNZsPhcMiZXf/973/X2P/GG2/QvXt3Xn/9daZNm8aXX35ZI8ttcyDMRzS7rmtee+01PvjgAx544AHUajVGo5HQ0FAef/xxAMaPH8+MGTMYPXo0MTExVxUnJyIigkmTJlFQUMATTzwhJ3FMTk5mwYIFLF26FJfLRVBQEO+8806D58rKyuLVV1/FbrejVCoJCgrizTffBNzNrNzcXDnP++DBgxk/fnyd55kyZQqzZs1Cr9fz9ttvN/lefHx8mDFjBg888AARERE1aj0bN25k165dLF++HJ1Ox/Tp03nhhRf47LPP6k2/fSXcsHm7rFarHPCoobgjtvIKMr5ZD4Da20DHsffUW9aT7N2796pTEzcHx44dIz4+vs5918OaJKGxebigsa7PQ1O/e21+Rl2Nmk9VVb0pfAUCQfPS5s1HqVbjujCr2SXJc34EAsG1pc2bD4CkUcmvRb+PQNAyCPMBXNU60YT5CAQtgzAfatZ87CKioUDQIgjzAZy6i6l1rCVlHlQiELQdhPkArurmUyrMRyBoCYT5ULPmYysrR3K6GigtaG189913jB07ljFjxjBy5EhefPFFwB1u4kLW2pkzZxIXF0dqaqp8XGZmJl27dmXGjBmAe+Jft27dGDNmDElJSUyYMEEO5L548WIWLlxY69orVqwgISGBMWPGyH9vvfXWtb7lG4I2P8MZAJUStY8Bh8kMLglbeTm6AH9PqxI0gfz8fF555RVWrlxJREQEkiRx/PjxOst269aNb775hp49ewLu5RLdunWrUcbX11dep/Xpp58ya9YsVq5c2aCGQYMGyavMBU1HmM95dP7+bvPB3e8jzKdpfP/vjaz6eC1Wc/PPj9IZdIx59D5GThpeb5nCwkLUajX+/v6AewFmfTOw77nnHv773/8yc+ZMlEol3377LePHj+fAgQN1lr/tttuaPYaN4CLCfM6jC/CjMisbAGtJKRDtUT3XC+u+/OGaGA+A1Wxl3Zc/NGg+Xbt2pWfPntx5550MHDiQvn37MmbMGAICAmqVNRgM9OzZk61bt6LT6YiNjZVNqy6+//77eo2sOtu2bWPMmDHy9uTJk3nwwQcbPa6tI8znPLqAiwGpxIhX0xkxcdg1rfmMmDiswTJKpZIPPviAkydPsnv3bjZu3MjSpUv55ptv6iyflJRESkoKWq2WcePGUVJSUmN/RUUFY8aMQZIkOnTowOuvv96oTtHsujLavPlcWMtVvZllLS2rM6aKoDYjJw2XayaeXBAZGxtLbGwsv/3tb7n33nvZtWtXneX69+/PG2+8gcPhYP78+axevbrG/up9PoJrS5s2n7Sj6fztL/9A76vj5o9uRqlR47I7cFltOKssNRadCloneXl5ZGdn06dPH8AdG7m4uFiOjXMpCoWCP//5zzgcjmYNDyG4fNr009+6ZhtFucWQCwe3HSHM3w9LgTtqm7WkVJjPdYDD4WDx4sWcO3cOvV6Py+XiueeeqzWKVZ3qsWsuhy+//JK1a9fK20899RRarbZWn0+PHj2YP3/+FV2jLdGm4/l89uYyflzpzhDw2xfGc3O0P2Un3fNAArrFEdznygNPNTcink/zIDQ2DyKez1Xi7XsxfnNlhRlDROjF7XM5npAkELQZ2rb5GKuZT7kZQ3goCpX7kdjKyrFViEWmAsG1oo2bz8WqbWV5JUq1GkN42MX3skTtpy5u0Ja64DK52s9BmzYfQ7Vml7nCPbvZu32E/J5oetVGr9dTVFQkDKiNI0kSRUVF6PX6xgvXQ4uMdpWUlPDHP/6Rs2fPotVqiY6OZu7cuQQGBmK1WlmwYAHbt29Hp9PRu3dv5s2bB8CZM2eYOXMmpaWl+Pv7s3DhQmJiYppNl0+1Zpep/Lz5tLtoPlX5hTitNlTVFp62ddq3b09WVladecVtNpuc6K61IjQ2D1arFT8/v3qnNDSFFjEfhULB7373Ozlf0sKFC3nrrbdYsGABb775JjqdjnXr1qFQKCgsLJSPmzNnDpMmTWLMmDGsWrWK2bNn89lnnzWbrkubXeDO46UPCsRSVAySRGV2LsaOUc12zesdjUZDx44d69y3d+9eevXq1cKKLg+hsXnYu3dvvZ+DptIizS5/f/8aidp69+5NdnY2lZWVpKSk8Oyzz8qziS8kUSsqKuLo0aOMGjUKgFGjRnH06FGKi4ubTVddzS6o2fQync1qtusJBIKLNKnmU1FRwT//+U9++eUXSkpKCAgIYNCgQUydOhU/P7/GT1ANl8vFsmXLSExMJDMzE39/f95//3127tyJt7c3zz77LAkJCeTk5BAWFoZK5Q5xqlKpCA0NJScnh8DAwCZf70La3Lqw2y6m+K0oM7Fnzx4UCgUKmx3f8++bsnLYu2sXqFR1n6QF2bt3r6clNIrQ2Dy0BY2Nmk9eXh4PPfQQ3t7e3H333YSEhFBQUMD69etJSUlh2bJlhIWFNXYamXnz5mEwGJg8eTJHjx4lMzOTbt268ac//Ylff/2VJ598kg0bNlzVTVWnoYlOkiTxsfYLHDYHLoeLm7vfjM7LXfZs6Q9Yi0tRADcFhWLsHNNsmq6E1jLJsCGExubhetd4YZJhYzTa7Hrrrbe45ZZbWL16Nc888wwTJ07kmWeeYfXq1QwcOJA33nijyYIXLlxIRkYG77zzDkqlksjISNRqtdy06tWrFwEBAZw5c4aIiAjy8vJwOp0AOJ1O8vPz5dzSzYFCoag10fACvtEd5NcVGaLpJRA0N42az5YtW3jxxRdrrfBWKBS88MILbN26tUkXWrRoEYcPH2bJkiVyT35gYCADBw7kl19+AdyjW0VFRURHRxMUFER8fDxr1qwBYM2aNcTHx19Wk6sp1NXpDOATfbEX35ybL5IJCgTNTKPNLqvViq+vb537jEYjNput0YucOnWK5ORkYmJimDhxIuAesl2yZAmvvPIKs2bNYuHChajVat544w2MRiMAL7/8MjNnzuSDDz7AaDTWGUP3aqk+y7l6p7PG24A+OAhLYRFIEhXpmfh37dLs1xcI2iqNmk+XLl3YsGED9913X619GzdupHPnzo1e5KabbuLEiRN17uvQoQP/+te/6tzXuXNnli9f3uj5rwZv34s1nwtzfS7gG9PebT5A8ZHjGDtHo9RorqkegaCt0Giz66mnnmL27Nl8/PHHnDt3DpvNxrlz51i6dClz5szhqaeeagmd14ya67sqa+wzdo6Rw2o4LVaKj9RtoAKB4PJptOYzdOhQ5s+fzxtvvMGbb74pvx8WFsbcuXNJTEy8pgKvNd71zPUBUKrVBPXuQd623QCUHjuFX5eOaHxad7gDgeB6oEnzfEaOHMnIkSNJTU2ltLSUgIAAOnXqdK21tQg1O5zNtfb7xnSg9PhprMUlSC4XxYeOEXZrQktKFAhuSC5rhnPnzp3p168fDoeDTz75pMkjXa0ZQwPNLnCP6oX0vVneLj9zFrupdjmBQHB5NGo+J0+eZOzYsdx88808/PDD7Nixg8mTJ7N27VqeeeaZZl1r5Qnqm+dTHa+wELxC3cs+kCRKjp1sCWkCwQ1No+Yzb948EhMTSUlJoW/fvjzzzDN88MEHfPXVV3z88cd88cUXLaHzmuHTSLPrAgE9usqvy0+nYykuEWElBIKroNE+n+PHj/Ppp5+iVCqZPn06S5cuJSHB3efRp0+fOkMrXE801uySy4WHogv0x1pciuRykfndJpQ6Ld6R4fh0aId3+wiRakcguAwarfk4nU6USncxrVaLwWBo5IjrixpD7fU0u8Dd9xN4c81A2S6rjYozZ8n5eTt52/eImpBAcBk0WvNxOBx8/fXX8hfLbrfz1VdfyfsvrL26XmlstKs6Pu0jCRvUH9PZc1gK3IHGLlBx5iwaH2+CetafskUgEFykUfPp1asXKSkp8naPHj1qZHTs2bPnNRHWUhi8L+bmqjJV4XK6UKrqrxAaO0Zh7BiFJElYi0spOXYS0/mFp8WHjlF5LgeNrw9eoSF4R4aJOUECQT00aj71LX24gMvlajYxnkCpUqLz0mKtctdizCYzPn4+jR6nUCjQBwUQPqg/2TY75pw8AKzFpViLSzFlZFEAGCLDCOrZHV2gPy6bjfLUdCxFJRg7RdcI2SoQtDWuKoyqzWajV69eHDt2rLn0eASdt042n8ryppnPBRRKJeFDBpK7ZadsQNUxZ+dhzs4D5fnOaJe7+WrKzCY66W60vk2/lkBwI3HVMZxvhE5WdwCxCgBM5ZU0PTSaG5VGQ7vEwTiqLNhNlViLSqjMzq1pRq5LnpMkUXzwKKED+1GVX4AuwB+115VnAhAIrjeu2nxuhOFlvffFTAGXru+6HNReetReerxCgvDv2gVrWbm7HygrB+l8x7zWz4itrByAivRMqgqLcJjMqPR6ou5NRO0l8sML2gYtkr2itaMzXAyzaiprvqUTOj8jEYPdgfNdDieSy4VKqyH7p21yQkKHyW12TouFvB37iLxz0A1h6AJBYzRqPpMmTar3y3C9dzZfwOB7sbZRlNt82TGqo1SrAHcQ+qBe3evMhmrOzqVo/2F0QQHoA/3RiP4gwQ1Mo+bz4IMPNrh//PjxzSbGUwRGBMivs06fu+bX0/n7EdijK8WHj+MVFoLa4EXFmbMANdaNGSLCCOgWiyE89JprEghamkbN58iRI/zlL3+Rt5cvX17DkJ555hnGjRt3bdS1EEHtLsaFzky99uYD7tpPQI+uKFUqXA4n1uISbGUVNcqYc/Iw5+QR0r93i2gSCFqSRpdXrFixosZ29YBigBz8/XomMCJAblrmZeZjt9pb5LrK87nAlGoVkUMH49/1Jnw7RuEVHlKjXMHuA+jyS7BXXnlnuEDQ2mi05nPpUHpj29cjWp2GkHbB5GcV4HK6yE7PJTquQ+MHNiMabwMh/S7OFrdXmMjdthtLobsPSldcRnrKdyi1mvPlvfGJbo9vxw5obrD1doK2QaM1n7pS5jS0fb3SvnOk/DqrhZpeDaHx9aFd4hC8wmrWglw2Oy6bHWtJKUUHDpOxeh2mzIt6beUVZP+4jczvN1NVUNTSsgWCJtNozcfpdLJjxw65huNwOGps3ygjXu07t2PfT78CrcN8AJQaNe2GDqYsNZ2co8fQVNmQLnnektNF7tZdhA7og6PKQvHh4/KconM/bCHyzkGiw1rQKmnUfIKCgpg1a5a87e/vX2O7uZP4eYoOXdrJr1tixKupKFRK/GM7kVpRQp/evZHOzxcy5+RRfOgYdlMlkstF3o7aebMlp5Nzm7ei8/ND7WNAbfBC7eWFSqd1T4YMC0GpFlO9BJ6h0U/epk2bWkKHx2lfzXwyU7M9qKR+lCoVnO+kNnaKxisshKz1P+Ew1+yI1voZcdntOMxV4JKwlpRiLSmtfT61Gp/o9hgiw9AF+OO0WnFarOj8/cRqfME1R/zsnSckMhitXovNYqO8uJzy4gqMgXVnam0taLwNtBs+hLxtu3FYLOgDA/AKC8HYOQZnlYWcLTuxFpfUe7zL4aA8NZ3y1PRa+9TeBlQ6HUq1Cu92EfjFdcZWUkZFeibaAD+MMVEozocecZ8nA02ZCafVhkqnrXU+geBShPmcR6lU0q5TJGeOpgOQlXaOboFdGz6oFaD19aHDiKG13lf6eNNh5FCcFit2UyWOSjOOqiocVRZcVhtVhUXYy031ntdRacZxfmi/Kr+Q4qMncFULnlZ88BgB3ePwaR9B9k/bsRaX4AWkrViDITwM35gO+LSPRKlRI0kStvIKnOYqUChwOZzu2ppLQm3wQuPni9boW2PwQnJJSC6nx5qFkiTdMIMpdeG0WDHnFWCICEOl9UwWXmE+1ejQuZ1sPqcPptEtofWbT0MoFAp5sSshQTX2SZKEpbCYynM5VOUXYa+oQO2lR6nRYCkqkTutL1DdeAAcZjMFu/dTsHt/zYu6JMzZuZizc8lTKNAHBuC0WhtNN6QPCsS3YwecFiuW4lIsBUW4HHYC4mMJ6t2jlhE4bXaq8gtQKJRIkkTFmQzMOfkYIsIIHdgHlbbptS9JkkCSUCiV2E2VZG/+BafdTkjfnvjGuKdcSC4XkiThstqwlVcguVx4hQbXMEdJkrAUFbsTDBQV4xPVnoBusfJ8rurlHFVVlJ1Mw5R5Dq2fkcBuceiDAy8UoOz0Gcw5eVgKi1Hp9YT063kxg8oluBwOSo6cwG6uIqDrTegC/Bq8X1tZOVkbf8ZpsaL1N9JhROL55T8tizCfanTrH8fP37gnTe7csJukR+65YX/9FAoFXiFBeF1iSuAeQbOVVyA5nViKiik6eBSXzT3x0rtdBJbCYpxWa63jnDoNquoTNM9/GZuCpai4zrIlR09iN5lReelwVJrRBweh9tJTuP8QTkttDaazWVhLSvFuH4GttByFWoXGxxuVVotCpURVWSWbjTk3n4ozZzFlZqPUagi/bQBFBw5jK3fPNM/9ZRelJ9OwV5hwWiy1rqXSafGNicLlcGArN2ErK5OfE0Bx6VEqzpzFEBGK5JKwlZRhLS2rZez2chOVmdl4twsnJKE3hqx88k9kyPsd5iqyNvyET4d2aHy9URu80Pr7oTF44bTZyduxF1tpGQAV6WcJurkbvp2iasz/ctnt2E1mHGYzeTv2yc/OVlpO0a+HCYiPddeEzi/3qcoroOTEabxCgvGP64JC2fzfA2E+1eg1uCc6Lx3WKis5GXlknsoiKrZlJxu2BhQqpfzrqQ8OxDcmisrsXPSB/u7ObIeDslNplBw7hbPKAkoF4bf252RRPj3julKRnklFRpb8hQB357Yu0P/8+VWoDV4olArslWaqcgtqTSGojulslvy6rgW5l2KvMFF67FSd+7yBjG/W4XI43drP46xycm7jz7XKWwoK672O02qj9MTpRrWUVdTfvK1O5blcKrO/R13PvN3q87nqxSVR9OsRin49gkqvQ6FWITldNe71UkqPn6b0ZCq4JJRqNX6xnSg5fgpcEpWZ2Zhz8/CNicJeXoE+NBjviMuNeFU3wnyqodNr6Xt7L7av2wXAjvW726T5XIpKp8XYMUreVqrVBMTH4hfbmaq8AjS+Pu6IjEX5aHy8CezRlcAeXXFarVgKi1GoVOhDgmo1Py7gqLJQfvoMtnITam8vtL4+6EOCKDl2ivLTZ+rX5aWXR/a8ggPRGH0p3HeoVs3iUuwVjYdNqR53CQCFAoVCgUKtQmv0xWG21BplBFBqNedrKD6UHDmOy+6o+wJKBV7BQRg7x2DOyaMiPdP9fjXj8Y/rgk9UO0qOnqDyXG6DehVKJRpfnxqa66oZyuVVKrT+RqxF5wckzge7czkclBytmRRTjsZ5XnfM6JENamkqwnwuYeDd/WXz2blhDw88NVZOHSSoiVKlwjsyvN79Kp2uSXGq1V76WmmJAEIH9EFr9KUyOwd9QAAaX28qzmZhLSrFJ7odwX1urtW34xUaTOmJVFQ6LbpAfySnC0elGZfdgdNqpexMBorzXzSVXodvTAe8wkIp2L3fPTUB0IcE0X74HZhz8nCYzeiDA9H6+dVoekguF6bMbKwlpai99Gh8fND6G901uvNNdWPnaMzZebjsdpBAY/RBFxiASqet0Zw3dorG2DmGvG17cFS5NQTeHC9nQtGHBGEpLMZWWobTasNWYcJW4n6NErQ+PgT17o4uwJ+yU2mYMrOxFBbXNGGlAo23t7sP0NtAQNebUHnpObt2g5yFRaFWITkuHqPS62oZmEKpQqlpHtsQ5nMJ3fp3xcffB1OpiZKCUk4eOE3XvrGeltUmUSgUBMTfRED8TfJ7fjd1avAYnb8fYQP71rs/S+Wia2R7FCoVXqHBKM7/sOj8jeTv2ockSYTdmoBCqcC7Xf3GqlAq8Y1uj290+3rLqPV6jJ2iG9R7AUN4KFH3DaP8zFnO5uVyU7UUTA31z12Kf1wX/OO6ILlcOKosSC4XCtxTJxR1/Ih2GJlIRUYWXsGBaAP8yN+5D9PZc+hDgoi4/Vaq8gspOXIcpVaLV0gQxo5Rl9WZ3xDCfC5BrVbRP7Evm1e42//ffr5emM+NhFJZZ21M4+NNu8QhHhB0EZVOR0DXm0irLG+8cCMolEo03o0vONb4eBPYPU7ejhhyC06bHaVGjUKhwDeqHb5R7Ro4w5Uj2hN1kPibO+Qq9uEdRzmx/2QjRwgENw4qraZFRnmF+dRBu44RDBo5UN7+6m+rbojQIQJBa0KYTz2Meew+1Oc71lIPn+HA1kMeViQQ3FgI86mH4Igg7hx3sQ/g6+RVuJw3RvgQgaA1IMynAUZNHYne4E7kl30mh1++2+lhRQLBjUOLmE9JSQnTpk1jxIgRJCUl8fTTT1NcXHMq/fvvv09cXBwnT17s3D1z5gwTJkxgxIgRTJgwgfT09JaQK2MM8GXkb4fL2yn/+OaqkgoKBIKLtIj5KBQKfve737Fu3Tq++eYbOnTowFtvvSXvP3LkCAcOHCAyMrLGcXPmzGHSpEmsW7eOSZMmMXv27JaQW4O7JyRiDDQCUJJfyrP3/YlFLy6hMEeEKBUIroYWMR9/f38GDrw4etS7d2+ys90Bu2w2G3PnzmXOnDk1hveKioo4evQoo0aNAmDUqFEcPXq0Vo3pWqM36Bn7u/vkbafDyaHtR/jXW1+2qA6B4Eajxft8XC4Xy5YtIzExEYB3332X0aNH06FDzTVUOTk5hIWFoTq/HkilUhEaGkpOTuMLC5ubO8YMZsofJxF108XZrIe2H+H0obQW1yIQ3Ci0+AznefPmYTAYmDx5Mvv37+fQoUO89NJL1+x6hw8fblK5vXtrx0Cujm8HL5JmjGDDpz9yck8qAP9a9G9GT7/nqjU2lcY0tgaExuahLWhsUfNZuHAhGRkZJCcno1Qq2b17N2lpaQwbNgyA3NxcHnvsMV577TXi4+PJy8vD6XSiUqlwOp3k5+cTEdH4QsXq9OjRA51O12CZvXv30q9fvyadr11Ie/530lwkl0Tm8WyO/HCCzjd3YsCwfmiuYUS4y9HoKYTG5uF612i1Wpv0o99iza5FixZx+PBhlixZgvb8wrTHH3+crVu3smnTJjZt2kR4eDhLly5l8ODBBAUFER8fz5o1awBYs2YN8fHxHs+WER4VVmP2848pW1k67zPeePodTOWNh2oQCARuWsR8Tp06RXJyMvn5+UycOJExY8Ywffr0Ro97+eWX+fzzzxkxYgSff/45r7zySguobZwxj92Hwderxnuph8+w8KlFFOW2bIe4QHC90iLNrptuuokTJ040Wu7SND2dO3dm+fLl10rWFRMcEcSCL1/m1K+nOX0ojXXLfgDgXFo2/zf5VR58aix3jB0s4gAJBA0gQmpcIcYAX/rd2Yd+d/YhKrYDH7/6GU6nC4vZwr/e+pJt3+/k3sl3s/enA6QeOUPCnX24b8oI9Ab9DZ8ZQSBoCsJ8moFbRwwgKDyQT177nLyz+YC7GbZ45odymbWfrWPLN9tQaVSUFpRxy939+Z+Zv0Wj80zaEoHA04h2QTMR26sLcz/9X0ZNHYmqnjQk5SUVlOSXIkkS29ftYtGLS6g6n00h7Wg6Kz5cza6Ne3A4Go5BLBDcCIiaTzOi0Wm4/4nRDLpnIF+++xUnDpym56DudLm5M2s/W0d5cc0Idcf3nWTGPX/EP9ivRkd1QKg/sb26uAOWKxSgAIUe4uO6YfDxuvSyAkGTkCSJ9GMZZGfk0mtQD3z8fEg7ms7ODXsIDg8kJj6agBB/9AYdpYXllBSU4B/sR0R0uBxepjkR5nMNCI8K47m3p9fo27l99G1knsrC2+jNvp/283XyasC9XOPSEbKS/FJ2bthT67z7Nx6iW0JXXC4XxgBfuvTshFKpJO1IOuUl5bicLvyD/blr/FBC24dc+xu9gTBXmCnMKcLb6E1QeNOmc9htdqxVVryN3pfVh3fmWAaHth/B22ig1203ExxxMTZzSUEpptJKXE4XStXFhonL5QIJlColVouNo7uOYamy0ql7DKHtQmpcP/9cIbt/2ENpYTmxvbsQ2i6YXRv3suuHvfJnzRhoZNA9A9jw5SacjYSKUalVGAN88fLxYuDwBJIeaZ6JtQrpBg3Rd2GiU3NPMmwutq/bxeqPvyUv091HpNFq6DmoO6d+TaW8pOKqzq1Sq0j8zR3cO/kutHodP6ZsISc9l4iYcKJuao/OS4fdZudcWjalhWX4+vsSGBZAWPsQAkIDyM8qoCC7EIOvAf8gI8ZAIz7+PqgbyWq5d+9eevXqjVKpaHSkT5IkJJdU4wt2JbhcLirLzXgbDQ1eU5IkKkpN7Ni6g+j2MajUSvQGPft+/pUtq7dRlHfxB6B7/64Mvf92OnaLwS/IiMVs5cyxDH7depDCnGJ8A3wwV1RxeNdRrGYrEdFh9LilO5XllRTmFFGYU0SVqYq4PrEMumcg5cXlnEvLRpIkCnOLObzjaA1toe1D6NQthrOnssg+414+pFIpCY8OJ7ZXZ0qLyjmy6xhOh5PQdsEU55dirbqYVSIgxJ++d/TCP9iPfT/9ypljGVxLXvvvy2TlZTY6ybCx754wHzw7o7Sy3ExBdiGh7YIx+BqwW+0c23sCs6lKDl5WWWHmu3+vpzS/rJGz1UStUaM36DCVNc/kR5VKiUqjRu+lw8vHC2+jAR+jN3abg4oyE0W5RVSZLGh1GsKjwwmJDMLX3wetXockSdhtdqpMVRTlFnPuTA5VpioUSgUajRqNVoNPgA99b+9Nhy7t2LN5HxnHM3G5XGi0GuIT4uh5a3dM5WaK84qxmC2UFJRxbM9xKkpNBIUHMvzBoag1Ks6eyiLz/BdZpVbhbfSmosyE1Vx/Hqv6UCgUN1QIXYOvF0qVClPpxUSG0XEdiIyJIDP1HJVllVjMFoyBRgJC/CjMKa4RwSGyYwQvfzqLX389IMynPq4X82kqu3fvxlvhS0WJCZVKSc7ZPFIPnwFJIiY+mvAodxbJTSt+5vTBVA+rvf7QaDUEhrtrfZKr6V8JpUp52REuFQoFfe/ojd1q49jek9irpVjW6jSoNCqqTPVnGAWIiA4jKCKI1ENpVFXWLKtSKelxSzfadYrk4LbDlBaV07VvLINGDqDHwG6YK6r411tfcmjHEW4dMYCHnnsAra7+dDhVlRbMFWbsNjuh7UNQKpVNWl7R2HdP9PlcJyiVSrr169pouYF3JXBgy0HW/msdaUfSAQgKC+T2MbdRkF1IflYBdpsDhQIiosMJjgjCVGaiMKeI3Mx8SgvKCA4PJDw6jKpKC6WFZVSUVGAqr2zSl9ITNQWFUtEkbXqDHi9fHUGhQTidLirLKwkKC2RI0iAShvZBrVFTkF3Ijyu3cOpgKlmp2VjMFnQGHf7Bftw8sBude3TCbDLjckl07RtLcEQQh7Yf4VxaNn5BRoIjggiJDMbpdPLz6m2kHkojMDyAjvHR6PQ6UEBcn5uIjHGvUbTb7GScOEvGySz8g/3oMSCew0cP0z2+O6lH0kk7fAadl5abb+2Bf7Af+Vn5aPU6wqNCUSgUOB1OTh44zb6fD1BVaSE+IY7et/XE2+hOm/PA78fWeg7GQF+mL5iGy+Vq0kRYL289Xt76y/unNAFR8+H6qPlcrkZJkkg7kk5JQSm9BvW46vlEkiThdLpw2OxYzFbMJjOV5WZMZZVotGp8/HzIOJfO4Dtuo6qyiuz0XLdxlZqwW+0oFArUWjUGbz2+Ab606xSJf7AfLqcLh92BzWon9XAa29ftorSwjK59Y+mf2BeDr4HC7CJ2b9pH5uks/IP8CGkXjLfRgN6gp2N8DOFRofzy7Q4Obj+Ct6+BqNgORN3UnvZd2qFUKjGVm/D2Nbg7+/fta/JzlCQJl9NV79SJa8X1/nkUNZ82jkKhoHOPjs16PrVahVqtQm/Q4x/sV6tMUWUBKrUKHz8fYnt1adJ5VWoVKrUKnZeO3oN70ntwz1plAkMDiO3d8PkSf3MHib+5o859F2oBl4tCoWhx42lLiEmGAoHAI9ywNZ8LrUmbzdak8lbr5Y+EtDRCY/MgNDYP9Wm88J1rrEfnhu3zqaioqJEJQyAQtCyxsbH4+vrWu/+GNR+Xy0VlZSUaTcvknRYIBG4kScJut+Pt7d3gaNoNaz4CgaB1IzqcBQKBRxDmIxAIPIIwH4FA4BGE+QgEAo8gzEcgEHgEYT4CgcAjCPMRCAQeoU2bz5kzZ5gwYQIjRoxgwoQJpKene1oSJSUlTJs2jREjRpCUlMTTTz9NcbE7yl5r0/v+++8TFxcnzyRvbfqsVitz5szh7rvvJikpif/7v/9rdTo3b97M2LFjGTNmDElJSaxfv97jGhcuXEhiYmKN/21jmq5Ir9SGefjhh6WUlBRJkiQpJSVFevjhhz2sSJJKSkqkHTt2yNuvv/669Oc//1mSpNal9/Dhw9Jjjz0m3XnnndKJEydanT5JkqR58+ZJ8+fPl1wulyRJklRQUCBJUuvR6XK5pISEBPn5HTt2TOrdu7fkdDo9qnH37t1Sdna2NHToUFmbJDX83K5Eb5s1n8LCQqlfv36Sw+GQJEmSHA6H1K9fP6moqMjDymry/fffS1OnTm1Veq1WqzR+/Hjp7Nmz8ge0NemTJEkymUxSv379JJPJVOP91qTT5XJJAwYMkPbs2SNJkiTt2rVLuvvuu1uNxurm05CmK9V7w65qb4ycnBzCwsJQqdzxWlQqFaGhoeTk5BAY2LTsBdcal8vFsmXLSExMbFV63333XUaPHk2HDh3k91qTPoDMzEz8/f15//332blzJ97e3jz77LPo9fpWo1OhUPDOO+/w1FNPYTAYqKys5MMPP2x1zxIa/v9KknRFett0n09rZ968eRgMBiZPnuxpKTL79+/n0KFDTJo0ydNSGsThcJCZmUm3bt1YsWIFL730Es888wxms9nT0mQcDgcffvghH3zwAZs3b+Zvf/sbzz//fKvSeC1ps+YTERFBXl4eTqc7O6jT6SQ/P5+IiAgPK3OzcOFCMjIyeOedd1Aqla1G7+7du0lLS2PYsGEkJiaSm5vLY489xtmzZ1uFvgtERkaiVqsZNWoUAL169SIgIAC9Xt9qdB47doz8/Hw5HGm/fv3w8vJCp9O1Go0XaOjzd6WfzTZrPkFBQcTHx7NmzRoA1qxZQ3x8fKtoci1atIjDhw+zZMkStFp3VoHWovfxxx9n69atbNq0iU2bNhEeHs7SpUu59957W4W+CwQGBjJw4EB++eUXwD0aU1RURExMTKvRGR4eTm5uLmlpaQCkpqZSWFhIdHR0q9F4gYY+f1f62WzTITVSU1OZOXMm5eXlGI1GFi5cSKdOnTyq6dSpU4waNYqYmBj0enfGgPbt27NkyZJWqTcxMZHk5GRiY2Nbnb7MzExmzZpFaWkparWa5557jjvuuKNV6Vy9ejV///vf5ZhTM2bMYPjw4R7V+Oqrr7J+/XoKCwsJCAjA39+ftWvXNqjpSvS2afMRCASeo802uwQCgWcR5iMQCDyCMB+BQOARhPkIBAKPIMxHIBB4BGE+guuSuLg4MjIyPC1DcBW02bVdguYlMTGRwsJCeX0PwLhx45g9e7YHVQlaM8J8BM1GcnIygwYN8rQMwXWCaHYJrikrVqxg4sSJzJs3j379+jFy5Ei2b98u78/Ly+PJJ59kwIAB3HXXXfz3v/+V9zmdTpKTkxk+fDh9+vTh/vvvJycnR96/bds27r77bvr3788rr7wi5wbPyMhg8uTJ9OvXj4EDB/Lcc8+12P0Kmo6o+QiuOQcPHmTkyJHs2LGDDRs28PTTT/PDDz/g7+/Piy++SJcuXdiyZQtpaWk88sgjdOjQgVtvvZVPPvmEtWvX8tFHH9GxY0dOnDghLzkB+PHHH/nqq68wmUzcf//9DB06lNtvv513332X2267jc8++wy73c6hQ4c8ePeC+hA1H0GzMX36dBISEuS/C7WYwMBApk6dikaj4d5776Vjx478+OOP5OTksHfvXl566SV0Oh3x8fE8+OCDrFq1CoDly5fz7LPP0qlTJxQKBV27diUgIEC+3rRp0zAajURGRjJw4ECOHz8OgFqtJjs7m/z8fHQ6HQkJCS3/MASNIsxH0GwsWbKEPXv2yH/jx48HICwsTF44Ce5wF/n5+eTn5+Pn54ePj0+NfXl5eQDk5uYSFRVV7/VCQkLk115eXlRWVgLwhz/8AUmSeOCBB7jvvvv46quvmvU+Bc2DaHYJrjl5eXlIkiQbUE5ODomJiYSGhlJWVobJZJIN6ELEPHCHnDh79iyxsbGXdb2QkBBeffVVAPbs2cMjjzxC//79iY6Obsa7ElwtouYjuOYUFxfL/S/fffcdqamp3HHHHURERNCnTx/++te/YrVaOX78OF999RVJSUkAPPjgg7z77rukp6cjSRLHjx+npKSk0et999135ObmAuDn54dCoUCpFB/11oao+QiajSeffLLGPJ9BgwYxbNgwevbsSUZGBrfccgvBwcG89957ct/NX//6V+bMmcOQIUMwGo0888wz3HbbbQA88sgj2Gw2Hn30UUpKSujUqRNLlixpVMehQ4dYsGABJpOJoKAg/vd//7dGvGlB60DE8xFcU1asWMHy5ctZtmyZp6UIWhmiLioQCDyCMB+BQOARRLNLIBB4BFHzEQgEHkGYj0Ag8AjCfAQCgUcQ5iMQCDyCMB+BQOARhPkIBAKP8P+SeAk6hf6PegAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 288x144 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import seaborn as sns\n",
    "# sns.set_theme(style=\"white\")\n",
    "plt.figure(figsize=(4, 2))\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "colors = [sns.cubehelix_palette(as_cmap=False)[i] for i in range(6)]\n",
    "# colors = sns.cubehelix_palette(reverse=True)\n",
    "\n",
    "sns.set_theme(style=\"whitegrid\", palette='pastel')#palette=sns.cubehelix_palette(light=0.7, n_colors=2))#palette=sns.color_palette(\"husl\", 2))\n",
    "ax = sns.lineplot(gumbel_softmax, c=colors[1], label='Gumbel Softmax', lw=3)\n",
    "sns.lineplot(SIMPLE_res, c=colors[4], label='SIMPLE', lw=3)\n",
    "ax.set(xlabel='Epochs', ylabel='ELBO', title='Discrete 1-Subset VAE')\n",
    "plt.savefig('SIMPLE-GS.pdf',bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
