{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-09-25 20:14:42.766366: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# choose a particular GPU\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow.keras.layers import Dense, Flatten\n",
    "from tensorflow.keras import Model\n",
    "from tensorflow.keras.utils import plot_model\n",
    "\n",
    "import tensorflow_probability as tfp\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import random\n",
    "\n",
    "seed = 1234\n",
    "tf.random.set_seed(seed)\n",
    "os.environ['TF_DETERMINISTIC_OPS'] = 'true'\n",
    "os.environ['PYTHONHASHSEED'] = f'{seed}'\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)\n",
    "\n",
    "physical_devices = tf.config.list_physical_devices('GPU') \n",
    "tf.config.experimental.set_memory_growth(physical_devices[0], True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "stop = tf.stop_gradient\n",
    "log1mexp = tfp.math.log1mexp\n",
    "\n",
    "@tf.function\n",
    "def log_sigmoid(logits):\n",
    "    return tf.clip_by_value(tf.math.log_sigmoid(logits), clip_value_max=-1e-7, clip_value_min=-float('inf'))\n",
    "\n",
    "@tf.function\n",
    "def logaddexp(x1, x2):\n",
    "    delta = tf.where(x1 == x2, 0., x1 - x2)\n",
    "    return tf.math.maximum(x1, x2) + tf.math.softplus(-tf.math.abs(delta))\n",
    "\n",
    "@tf.function\n",
    "def topk_marginals(logprobs, k):\n",
    "    batch_size = logprobs.shape[0]\n",
    "    n = logprobs.shape[1]\n",
    "    \n",
    "    state = np.ones((batch_size, k+2)) * -float('inf')\n",
    "    state[:, 1] = 0\n",
    "    state = tf.convert_to_tensor(state, dtype=tf.float32)\n",
    "\n",
    "    a = tf.TensorArray(tf.float32, size=n+1)\n",
    "    a = a.write(0, state)\n",
    "    \n",
    "    for i in range(1, n+1):\n",
    "        \n",
    "        state = tf.concat([\n",
    "            tf.ones([batch_size, 1]) * -float('inf'), \n",
    "            logaddexp(\n",
    "                state[:, :-1] + logprobs[:, i-1:i], \n",
    "                state[:, 1:] + log1mexp(stop(logprobs[:, i-1:i]))\n",
    "            )\n",
    "        ], 1)\n",
    "        \n",
    "        a = a.write(i, state)\n",
    "    a = tf.transpose(a.stack(), perm=[1, 0, 2])\n",
    "    marginals = tf.gradients(a[:, n, k+1:k+2], logprobs)\n",
    "    \n",
    "    return marginals[0], a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def sample(a, probs):\n",
    "    \n",
    "    n = a.shape[-2] - 1\n",
    "    k = a.shape[-1] - 1\n",
    "    bsz = a.shape[0]\n",
    "\n",
    "    j = tf.fill((bsz,), k)\n",
    "    samples = tf.TensorArray(tf.int32, size=n, clear_after_read=False)\n",
    "    #probs_ = tf.TensorArray(tf.float32, size=n, clear_after_read=False)\n",
    "    for i in tf.range(n, 0, -1):\n",
    "        \n",
    "        # Unnormalized probabilities of Xi and -Xi\n",
    "        full = tf.fill((bsz,), i-1)\n",
    "        p_idx = tf.stack([full, j-1], axis=1)\n",
    "        z_idx = tf.stack([full + 1, j], axis=1)\n",
    "        \n",
    "        p = tf.gather_nd(batch_dims=1, indices=p_idx, params=a)\n",
    "        z = tf.gather_nd(batch_dims=1, indices=z_idx, params=a)\n",
    "        \n",
    "        p = (p + probs[:, i-1]) - z\n",
    "        q = log1mexp(p)\n",
    "\n",
    "        # Sample according to normalized dist.\n",
    "        X = tfp.distributions.Bernoulli(logits=(p-q)).sample()\n",
    "\n",
    "        # Pick next state based on value of sample\n",
    "        j = tf.where(X>0, j - 1, j)\n",
    "\n",
    "        # Concatenate to samples\n",
    "        samples = samples.write(i-1, X)\n",
    "        #probs_ = probs_.write(i-1, p)\n",
    "        \n",
    "    samples = tf.transpose(samples.stack(), perm=[1, 0])\n",
    "    #probs_ = probs_.stack()\n",
    "    \n",
    "    # Our samples should always satisfy the constraint\n",
    "    tf.debugging.assert_equal(tf.math.reduce_sum(samples, axis=-1), k-1)\n",
    "    \n",
    "    return tf.cast(samples, tf.float32)#, probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def xexpx(x):\n",
    "    expx = tf.exp(x)\n",
    "    return tf.where(expx == 0, expx, x*expx)\n",
    "\n",
    "@tf.function\n",
    "def xexpy(x,y):\n",
    "    expy = tf.exp(y)\n",
    "    return tf.where(expy == 0, expy, x*expy)\n",
    "\n",
    "@tf.function\n",
    "def entropy(a, logprobs):\n",
    "    entropy = tf.zeros((a.shape[0], a.shape[-1]))\n",
    "    for i in range(10, a.shape[-2]):\n",
    "        \n",
    "        p_left = (a[:, i-1, :-1] + logprobs[:, i-1:i]) - a[:, i, 1:]\n",
    "        p_right = (a[:, i-1, 1:] + log1mexp(logprobs[:, i-1:i])) - a[:, i, 1:]\n",
    "        \n",
    "        entropy = tf.concat([tf.zeros((a.shape[0], 1)),\n",
    "                             xexpx(p_left) + xexpx(p_right) +\\\n",
    "                             xexpy(entropy[:, :-1], p_left) + xexpy(entropy[:, 1:], p_right)\n",
    "                            ], 1)\n",
    "    return tf.clip_by_value(-entropy[:, -1], clip_value_max=float('inf'), clip_value_min=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class IMLESubsetkLayer(tf.keras.layers.Layer):\n",
    "  \n",
    "    def __init__(self, _k=1, _tau=10.0, _lambda=10.0):\n",
    "        super(IMLESubsetkLayer, self).__init__()\n",
    "        \n",
    "        self.k = _k\n",
    "        self._tau = _tau\n",
    "        self._lambda = _lambda\n",
    "        self.samples = None\n",
    "        self.gumbel_dist = tfp.distributions.Gumbel(loc=0.0, scale=1.0)\n",
    "        \n",
    "    @tf.function\n",
    "    def sample_gumbel(self, shape, eps=1e-20):\n",
    "        return self.gumbel_dist.sample(shape)\n",
    "    \n",
    "    @tf.function\n",
    "    def sample_gumbel_k(self, shape):\n",
    "        \n",
    "        s = tf.map_fn(fn=lambda t: tf.random.gamma(shape, 1.0/self.k,  t/self.k), \n",
    "                  elems=tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]))\n",
    "        # now add the samples\n",
    "        s = tf.reduce_sum(s, 0)\n",
    "        # the log(m) term\n",
    "        s = s - tf.math.log(10.0)\n",
    "        # divide by k --> each s[c] has k samples whose sum is distributed as Gumbel(0, 1)\n",
    "        s = self._tau * (s / self.k)\n",
    "\n",
    "        return s\n",
    "       \n",
    "    def imle_layer(self, logits, hard=False, evaluate=False):\n",
    "        return gumbel_softmax(logits)\n",
    "\n",
    "    def call(self, logits, hard=True, evaluate=False):\n",
    "        return self.imle_layer(logits, hard, evaluate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "PARAMS = {\n",
    "    \"batch_size\": 100,\n",
    "    \"data_dim\": 784,\n",
    "    \"M\": 20,\n",
    "    \"N\": 20,\n",
    "    \"nb_epoch\": 100, \n",
    "    \"epsilon_std\": 0.01,\n",
    "    \"anneal_rate\": 0.0003,\n",
    "    \"init_temperature\": 1.0,\n",
    "    \"min_temperature\": 0.5,\n",
    "    \"learning_rate\": 1e-3,\n",
    "    \"hard\": False,\n",
    "}\n",
    "\n",
    "class DiscreteVAE(tf.keras.Model):\n",
    "    \n",
    "    def __init__(self, params):\n",
    "        super(DiscreteVAE, self).__init__()\n",
    "        \n",
    "        self.params = params\n",
    "                \n",
    "        # encoder\n",
    "        self.enc_dense1 = tf.keras.layers.Dense(512, activation='relu')\n",
    "        self.enc_dense2 = tf.keras.layers.Dense(256, activation='relu')\n",
    "        self.enc_dense3 = tf.keras.layers.Dense(params[\"N\"]*params[\"M\"])\n",
    "        \n",
    "        # this is our new Gumbel layer\n",
    "        self.imleLayer = IMLESubsetkLayer(_k=1, _tau=10.0, _lambda=10.0)\n",
    "\n",
    "        # decoder\n",
    "        self.flatten = Flatten()\n",
    "        self.dec_dense1 = tf.keras.layers.Dense(256, activation='relu')\n",
    "        self.dec_dense2 = tf.keras.layers.Dense(512, activation='relu')\n",
    "        self.dec_dense3 = tf.keras.layers.Dense(params[\"data_dim\"])\n",
    "\n",
    "\n",
    "    def sample_gumbel(self, shape, eps=1e-20): \n",
    "        \"\"\"Sample from Gumbel(0, 1)\"\"\" \n",
    "        U = tf.random.uniform(shape, minval=0, maxval=1)\n",
    "        return -tf.math.log(-tf.math.log(U + eps) + eps)\n",
    "    \n",
    "    def gumbel_softmax_sample(self, logits, temperature): \n",
    "        \"\"\" Draw a sample from the Gumbel-Softmax distribution\"\"\"\n",
    "        # logits: [batch_size, n_class] unnormalized log-probs\n",
    "        y = logits + self.sample_gumbel(tf.shape(logits))\n",
    "        return tf.nn.softmax(y / temperature)  \n",
    "\n",
    "    def gumbel_softmax(self, logits, temperature, hard=True):\n",
    "        \"\"\"\n",
    "        logits: [batch_size, n_class] unnormalized log-probs\n",
    "        temperature: non-negative scalar\n",
    "        hard: if True, take argmax, but differentiate w.r.t. soft sample y\n",
    "        \"\"\"\n",
    "        y = self.gumbel_softmax_sample(logits, temperature)\n",
    "        if hard: \n",
    "            # \n",
    "            y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keepdims=True)),y.dtype)\n",
    "            y = tf.stop_gradient(y_hard - y) + y\n",
    "        return y\n",
    "    \n",
    "    def SIMPLE(self, logits, temperature, hard=True):\n",
    "        \"\"\"\n",
    "        input: [*, n_class]\n",
    "        return: [*, n_class] an one-hot vector\n",
    "        \"\"\"\n",
    "        y = tf.nn.softmax(logits)\n",
    "        y_perturbed = tf.nn.softmax(logits + self.sample_gumbel(tf.shape(logits)))\n",
    "        y_hard = tf.cast(tf.equal(y_perturbed, tf.reduce_max(y_perturbed, 1, keepdims=True)),y_perturbed.dtype)\n",
    "        y = tf.stop_gradient(y_hard - y) + y\n",
    "        return y\n",
    "    \n",
    "    def decoder(self, x):\n",
    "        # decoder\n",
    "        h = self.flatten(x)\n",
    "        h = self.dec_dense1(h)\n",
    "        h = self.dec_dense2(h)\n",
    "        h = self.dec_dense3(h)\n",
    "        return h\n",
    "\n",
    "    def call(self, x, tau, hard=False, evaluate=False):\n",
    "        N = self.params[\"N\"]\n",
    "        M = self.params[\"M\"]\n",
    "\n",
    "        # encoder\n",
    "        x = self.enc_dense1(x)\n",
    "        x = self.enc_dense2(x)\n",
    "        x = self.enc_dense3(x)   # (batch, N*M)\n",
    "        logits_y = tf.reshape(x, [-1, M])   # (batch*N, M)\n",
    "\n",
    "        ###################################################################\n",
    "        ## here we toggle between methods #################################\n",
    "        # here we can switch between traditional and our method\n",
    "        # \"traditional\" Gumbel Softmax trick\n",
    "#         y = self.gumbel_softmax(logits=logits_y, temperature=tau, hard=True)\n",
    "        y = self.SIMPLE(logits=logits_y, temperature=tau, hard=True)\n",
    "        # IMLE approach -- note: we don't anneal so set temperature once at init\n",
    "#         y = self.imleLayer(logits=logits_y, hard=True, evaluate=evaluate)\n",
    "        ###################################################################\n",
    "        \n",
    "        assert y.shape == (self.params[\"batch_size\"]*N, M)\n",
    "        y = tf.reshape(y, [-1, N, M])\n",
    "        self.sample_y = y\n",
    "\n",
    "        # decoder\n",
    "        logits_x = self.decoder(y)\n",
    "        return logits_y, logits_x\n",
    "\n",
    "\n",
    "def gumbel_loss(model, x, tau, hard=True, evaluate=False):\n",
    "    M = 20\n",
    "    N = 20\n",
    "    data_dim = PARAMS['data_dim']\n",
    "    logits_y, logits_x = model(x, tau, hard, evaluate)\n",
    "    \n",
    "    # cross-entropy\n",
    "    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=logits_x)\n",
    "    cross_ent = tf.math.reduce_sum(cross_ent, 1)\n",
    "    cross_ent = tf.math.reduce_mean(cross_ent, 0)\n",
    "    \n",
    "    if evaluate:\n",
    "        # KL loss\n",
    "        logprobs_q = log_sigmoid(logits_y)\n",
    "        marginals_q, a_q = topk_marginals(logprobs_q, 10)\n",
    "        a_q = tf.where(a_q == -float('inf'), -1000., a_q)\n",
    "        q_entropy = entropy(a_q, logprobs_q)\n",
    "        kl = tf.math.log(184756.) - tf.reshape(q_entropy, [-1,N])\n",
    "        kl = tf.math.reduce_sum(kl, 1)\n",
    "        KL_mean = tf.math.reduce_mean(kl)\n",
    "    else: \n",
    "        # KL loss\n",
    "        q_y = tf.nn.softmax(logits_y)   # (batshsize*N, M)  softmax\n",
    "        log_q_y = tf.math.log(q_y + 1e-20)   # (batshsize*N, M)  \n",
    "        kl_tmp = tf.reshape(q_y*(log_q_y-tf.math.log(1.0/M)), [-1,N,M])  # (batch_size,N,K)\n",
    "        KL = tf.math.reduce_sum(kl_tmp, [1, 2])    # shape=(batch_size, 1)\n",
    "\n",
    "        KL_mean = tf.math.reduce_mean(KL)\n",
    "    return cross_ent + KL_mean\n",
    "\n",
    "\n",
    "def compute_gradients(model, x, tau, hard):\n",
    "    with tf.GradientTape() as tape:\n",
    "        loss = gumbel_loss(model, x, tau, hard)\n",
    "    return tape.gradient(loss, model.trainable_variables), loss\n",
    "\n",
    "\n",
    "def apply_gradients(optimizer, gradients, variables):\n",
    "    optimizer.apply_gradients(zip(gradients, variables))\n",
    "\n",
    "\n",
    "def get_learning_rate(step, init=PARAMS[\"learning_rate\"]):\n",
    "    return tf.convert_to_tensor(init * pow(0.95, (step / 1000.)), dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-09-25 20:14:45.922203: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2022-09-25 20:14:46.583663: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15492 MB memory:  -> device: 0, name: NVIDIA TITAN RTX, pci bus id: 0000:5e:00.0, compute capability: 7.5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 , TRAIN loss: 161.66273 , Temperature: 0.9997000449955004\n",
      "Eval Loss: 263.7302 \n",
      "\n",
      "Epoch: 2 , TRAIN loss: 147.49097 , Temperature: 0.9991004048785274\n",
      "Eval Loss: 251.18088 \n",
      "\n",
      "Epoch: 3 , TRAIN loss: 143.87224 , Temperature: 0.9982016190284373\n",
      "Eval Loss: 246.33345 \n",
      "\n",
      "Epoch: 4 , TRAIN loss: 144.10754 , Temperature: 0.997004495503373\n",
      "Eval Loss: 243.87569 \n",
      "\n",
      "Epoch: 5 , TRAIN loss: 132.15274 , Temperature: 0.9955101098295706\n",
      "Eval Loss: 242.23654 \n",
      "\n",
      "Epoch: 6 , TRAIN loss: 133.83615 , Temperature: 0.9937198033910547\n",
      "Eval Loss: 240.63268 \n",
      "\n",
      "Epoch: 7 , TRAIN loss: 137.77516 , Temperature: 0.9916351814230984\n",
      "Eval Loss: 240.3769 \n",
      "\n",
      "Epoch: 8 , TRAIN loss: 131.7594 , Temperature: 0.9892581106136482\n",
      "Eval Loss: 238.67422 \n",
      "\n",
      "Epoch: 9 , TRAIN loss: 128.65381 , Temperature: 0.9865907163177327\n",
      "Eval Loss: 238.54333 \n",
      "\n",
      "Epoch: 10 , TRAIN loss: 124.486336 , Temperature: 0.9836353793906725\n",
      "Eval Loss: 237.74467 \n",
      "\n",
      "Epoch: 11 , TRAIN loss: 131.06973 , Temperature: 0.9803947326466972\n",
      "Eval Loss: 237.18701 \n",
      "\n",
      "Epoch: 12 , TRAIN loss: 122.16926 , Temperature: 0.9768716569503434\n",
      "Eval Loss: 236.66383 \n",
      "\n",
      "Epoch: 13 , TRAIN loss: 124.377426 , Temperature: 0.9730692769487556\n",
      "Eval Loss: 237.05296 \n",
      "\n",
      "Epoch: 14 , TRAIN loss: 125.13043 , Temperature: 0.9689909564537397\n",
      "Eval Loss: 236.54428 \n",
      "\n",
      "Epoch: 15 , TRAIN loss: 124.72078 , Temperature: 0.964640293483123\n",
      "Eval Loss: 236.6843 \n",
      "\n",
      "Epoch: 16 , TRAIN loss: 126.43543 , Temperature: 0.9600211149716509\n",
      "Eval Loss: 235.88643 \n",
      "\n",
      "Epoch: 17 , TRAIN loss: 121.166626 , Temperature: 0.9551374711623026\n",
      "Eval Loss: 235.71437 \n",
      "\n",
      "Epoch: 18 , TRAIN loss: 121.2038 , Temperature: 0.9499936296895314\n",
      "Eval Loss: 236.06479 \n",
      "\n",
      "Epoch: 19 , TRAIN loss: 125.93919 , Temperature: 0.9445940693665233\n",
      "Eval Loss: 235.67657 \n",
      "\n",
      "Epoch: 20 , TRAIN loss: 121.90171 , Temperature: 0.9389434736891332\n",
      "Eval Loss: 235.67119 \n",
      "\n",
      "Epoch: 21 , TRAIN loss: 122.79502 , Temperature: 0.9330467240696794\n",
      "Eval Loss: 235.66957 \n",
      "\n",
      "Epoch: 22 , TRAIN loss: 129.34578 , Temperature: 0.9269088928142736\n",
      "Eval Loss: 235.24947 \n",
      "\n",
      "Epoch: 23 , TRAIN loss: 131.70184 , Temperature: 0.9205352358578188\n",
      "Eval Loss: 235.27081 \n",
      "\n",
      "Epoch: 24 , TRAIN loss: 123.24499 , Temperature: 0.9139311852712282\n",
      "Eval Loss: 235.0676 \n",
      "\n",
      "Epoch: 25 , TRAIN loss: 122.731575 , Temperature: 0.9071023415558017\n",
      "Eval Loss: 235.25371 \n",
      "\n",
      "Epoch: 26 , TRAIN loss: 131.13188 , Temperature: 0.9000544657400421\n",
      "Eval Loss: 235.26814 \n",
      "\n",
      "Epoch: 27 , TRAIN loss: 126.959595 , Temperature: 0.8927934712944959\n",
      "Eval Loss: 234.78458 \n",
      "\n",
      "Epoch: 28 , TRAIN loss: 120.213745 , Temperature: 0.8853254158804752\n",
      "Eval Loss: 234.38338 \n",
      "\n",
      "Epoch: 29 , TRAIN loss: 123.54049 , Temperature: 0.8776564929487385\n",
      "Eval Loss: 234.84767 \n",
      "\n",
      "Epoch: 30 , TRAIN loss: 119.91982 , Temperature: 0.8697930232043984\n",
      "Eval Loss: 234.6623 \n",
      "\n",
      "Epoch: 31 , TRAIN loss: 128.00171 , Temperature: 0.8617414459544691\n",
      "Eval Loss: 234.63838 \n",
      "\n",
      "Epoch: 32 , TRAIN loss: 127.167274 , Temperature: 0.8535083103545701\n",
      "Eval Loss: 234.79012 \n",
      "\n",
      "Epoch: 33 , TRAIN loss: 125.618286 , Temperature: 0.8451002665713722\n",
      "Eval Loss: 234.6961 \n",
      "\n",
      "Epoch: 34 , TRAIN loss: 116.33229 , Temperature: 0.8365240568773926\n",
      "Eval Loss: 234.56339 \n",
      "\n",
      "Epoch: 35 , TRAIN loss: 122.542496 , Temperature: 0.8277865066947337\n",
      "Eval Loss: 234.71466 \n",
      "\n",
      "Epoch: 36 , TRAIN loss: 126.26574 , Temperature: 0.8188945156043044\n",
      "Eval Loss: 234.62645 \n",
      "\n",
      "Epoch: 37 , TRAIN loss: 126.539474 , Temperature: 0.8098550483369699\n",
      "Eval Loss: 234.55602 \n",
      "\n",
      "Epoch: 38 , TRAIN loss: 126.58678 , Temperature: 0.8006751257629464\n",
      "Eval Loss: 234.53557 \n",
      "\n",
      "Epoch: 39 , TRAIN loss: 121.77014 , Temperature: 0.791361815895584\n",
      "Eval Loss: 234.48553 \n",
      "\n",
      "Epoch: 40 , TRAIN loss: 119.51841 , Temperature: 0.7819222249254774\n",
      "Eval Loss: 234.41893 \n",
      "\n",
      "Epoch: 41 , TRAIN loss: 127.748276 , Temperature: 0.7723634883006051\n",
      "Eval Loss: 234.4857 \n",
      "\n",
      "Epoch: 42 , TRAIN loss: 122.05701 , Temperature: 0.7626927618679156\n",
      "Eval Loss: 234.5761 \n",
      "\n",
      "Epoch: 43 , TRAIN loss: 125.222015 , Temperature: 0.7529172130914742\n",
      "Eval Loss: 234.43562 \n",
      "\n",
      "Epoch: 44 , TRAIN loss: 130.77121 , Temperature: 0.74304401236194\n",
      "Eval Loss: 234.36523 \n",
      "\n",
      "Epoch: 45 , TRAIN loss: 118.390915 , Temperature: 0.7330803244117685\n",
      "Eval Loss: 234.36906 \n",
      "\n",
      "Epoch: 46 , TRAIN loss: 122.875015 , Temperature: 0.7230332998501351\n",
      "Eval Loss: 234.41887 \n",
      "\n",
      "Epoch: 47 , TRAIN loss: 123.31408 , Temperature: 0.7129100668311394\n",
      "Eval Loss: 234.56465 \n",
      "\n",
      "Epoch: 48 , TRAIN loss: 114.99292 , Temperature: 0.7027177228683977\n",
      "Eval Loss: 234.54298 \n",
      "\n",
      "Epoch: 49 , TRAIN loss: 117.94737 , Temperature: 0.6924633268086435\n",
      "Eval Loss: 234.51425 \n",
      "\n",
      "Epoch: 50 , TRAIN loss: 119.18484 , Temperature: 0.6821538909764523\n",
      "Eval Loss: 234.19354 \n",
      "\n",
      "Epoch: 51 , TRAIN loss: 124.506454 , Temperature: 0.6717963735016784\n",
      "Eval Loss: 234.50745 \n",
      "\n",
      "Epoch: 52 , TRAIN loss: 125.58152 , Temperature: 0.6613976708406429\n",
      "Eval Loss: 234.51003 \n",
      "\n",
      "Epoch: 53 , TRAIN loss: 114.84353 , Temperature: 0.6509646105015452\n",
      "Eval Loss: 234.2202 \n",
      "\n",
      "Epoch: 54 , TRAIN loss: 120.87513 , Temperature: 0.6405039439839885\n",
      "Eval Loss: 234.34398 \n",
      "\n",
      "Epoch: 55 , TRAIN loss: 119.25325 , Temperature: 0.6300223399419125\n",
      "Eval Loss: 234.38408 \n",
      "\n",
      "Epoch: 56 , TRAIN loss: 122.872116 , Temperature: 0.6195263775786136\n",
      "Eval Loss: 234.62442 \n",
      "\n",
      "Epoch: 57 , TRAIN loss: 122.53606 , Temperature: 0.609022540281914\n",
      "Eval Loss: 234.2286 \n",
      "\n",
      "Epoch: 58 , TRAIN loss: 122.69198 , Temperature: 0.5985172095069092\n",
      "Eval Loss: 234.39273 \n",
      "\n",
      "Epoch: 59 , TRAIN loss: 115.22655 , Temperature: 0.5880166589130854\n",
      "Eval Loss: 234.55411 \n",
      "\n",
      "Epoch: 60 , TRAIN loss: 125.73427 , Temperature: 0.5775270487619548\n",
      "Eval Loss: 234.48737 \n",
      "\n",
      "Epoch: 61 , TRAIN loss: 121.66902 , Temperature: 0.5670544205807092\n",
      "Eval Loss: 234.60356 \n",
      "\n",
      "Epoch: 62 , TRAIN loss: 117.01724 , Temperature: 0.556604692096744\n",
      "Eval Loss: 234.669 \n",
      "\n",
      "Epoch: 63 , TRAIN loss: 117.442825 , Temperature: 0.5461836524472542\n",
      "Eval Loss: 234.60332 \n",
      "\n",
      "Epoch: 64 , TRAIN loss: 122.17252 , Temperature: 0.5357969576674562\n",
      "Eval Loss: 234.96918 \n",
      "\n",
      "Epoch: 65 , TRAIN loss: 123.33334 , Temperature: 0.5254501264603462\n",
      "Eval Loss: 235.08226 \n",
      "\n",
      "Epoch: 66 , TRAIN loss: 117.605 , Temperature: 0.5151485362502642\n",
      "Eval Loss: 234.69551 \n",
      "\n",
      "Epoch: 67 , TRAIN loss: 128.5358 , Temperature: 0.5048974195219025\n",
      "Eval Loss: 234.54205 \n",
      "\n",
      "Epoch: 68 , TRAIN loss: 118.623314 , Temperature: 0.5\n",
      "Eval Loss: 234.5958 \n",
      "\n",
      "Epoch: 69 , TRAIN loss: 120.091446 , Temperature: 0.5\n",
      "Eval Loss: 234.86174 \n",
      "\n",
      "Epoch: 70 , TRAIN loss: 125.51135 , Temperature: 0.5\n",
      "Eval Loss: 234.7789 \n",
      "\n",
      "Epoch: 71 , TRAIN loss: 122.129074 , Temperature: 0.5\n",
      "Eval Loss: 235.04431 \n",
      "\n",
      "Epoch: 72 , TRAIN loss: 117.69849 , Temperature: 0.5\n",
      "Eval Loss: 234.75188 \n",
      "\n",
      "Epoch: 73 , TRAIN loss: 123.43442 , Temperature: 0.5\n",
      "Eval Loss: 234.86469 \n",
      "\n",
      "Epoch: 74 , TRAIN loss: 121.836105 , Temperature: 0.5\n",
      "Eval Loss: 234.9746 \n",
      "\n",
      "Epoch: 75 , TRAIN loss: 125.69815 , Temperature: 0.5\n",
      "Eval Loss: 234.62251 \n",
      "\n",
      "Epoch: 76 , TRAIN loss: 118.741936 , Temperature: 0.5\n",
      "Eval Loss: 234.81293 \n",
      "\n",
      "Epoch: 77 , TRAIN loss: 116.130585 , Temperature: 0.5\n",
      "Eval Loss: 234.79703 \n",
      "\n",
      "Epoch: 78 , TRAIN loss: 125.80801 , Temperature: 0.5\n",
      "Eval Loss: 235.23468 \n",
      "\n",
      "Epoch: 79 , TRAIN loss: 125.10715 , Temperature: 0.5\n",
      "Eval Loss: 234.97758 \n",
      "\n",
      "Epoch: 80 , TRAIN loss: 121.005806 , Temperature: 0.5\n",
      "Eval Loss: 234.77356 \n",
      "\n",
      "Epoch: 81 , TRAIN loss: 120.50459 , Temperature: 0.5\n",
      "Eval Loss: 235.05292 \n",
      "\n",
      "Epoch: 82 , TRAIN loss: 125.01185 , Temperature: 0.5\n",
      "Eval Loss: 234.9518 \n",
      "\n",
      "Epoch: 83 , TRAIN loss: 119.228096 , Temperature: 0.5\n",
      "Eval Loss: 234.73944 \n",
      "\n",
      "Epoch: 84 , TRAIN loss: 119.98285 , Temperature: 0.5\n",
      "Eval Loss: 235.00218 \n",
      "\n",
      "Epoch: 85 , TRAIN loss: 117.47981 , Temperature: 0.5\n",
      "Eval Loss: 234.86794 \n",
      "\n",
      "Epoch: 86 , TRAIN loss: 120.17398 , Temperature: 0.5\n",
      "Eval Loss: 235.16379 \n",
      "\n",
      "Epoch: 87 , TRAIN loss: 118.850204 , Temperature: 0.5\n",
      "Eval Loss: 234.95787 \n",
      "\n",
      "Epoch: 88 , TRAIN loss: 121.57574 , Temperature: 0.5\n",
      "Eval Loss: 234.67206 \n",
      "\n",
      "Epoch: 89 , TRAIN loss: 120.33585 , Temperature: 0.5\n",
      "Eval Loss: 234.95192 \n",
      "\n",
      "Epoch: 90 , TRAIN loss: 115.28362 , Temperature: 0.5\n",
      "Eval Loss: 235.25476 \n",
      "\n",
      "Epoch: 91 , TRAIN loss: 118.52444 , Temperature: 0.5\n",
      "Eval Loss: 234.97649 \n",
      "\n",
      "Epoch: 92 , TRAIN loss: 120.14989 , Temperature: 0.5\n",
      "Eval Loss: 235.29564 \n",
      "\n",
      "Epoch: 93 , TRAIN loss: 123.79265 , Temperature: 0.5\n",
      "Eval Loss: 235.581 \n",
      "\n",
      "Epoch: 94 , TRAIN loss: 123.93092 , Temperature: 0.5\n",
      "Eval Loss: 234.95891 \n",
      "\n",
      "Epoch: 95 , TRAIN loss: 123.49895 , Temperature: 0.5\n",
      "Eval Loss: 235.65594 \n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 96 , TRAIN loss: 123.94199 , Temperature: 0.5\n",
      "Eval Loss: 235.47629 \n",
      "\n",
      "Epoch: 97 , TRAIN loss: 118.7987 , Temperature: 0.5\n",
      "Eval Loss: 235.41869 \n",
      "\n",
      "Epoch: 98 , TRAIN loss: 115.97012 , Temperature: 0.5\n",
      "Eval Loss: 235.53271 \n",
      "\n",
      "Epoch: 99 , TRAIN loss: 119.211975 , Temperature: 0.5\n",
      "Eval Loss: 235.37158 \n",
      "\n",
      "Epoch: 100 , TRAIN loss: 119.041626 , Temperature: 0.5\n",
      "Eval Loss: 235.3255 \n",
      "\n",
      "CPU times: user 20min 5s, sys: 14.8 s, total: 20min 20s\n",
      "Wall time: 20min 1s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "\n",
    "model = DiscreteVAE(PARAMS)\n",
    "learning_rate = tf.Variable(PARAMS[\"learning_rate\"], trainable=False, name=\"LR\")\n",
    "\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
    "\n",
    "# data\n",
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "x_train = x_train.astype('float32') / 255.\n",
    "x_test = x_test.astype('float32') / 255.\n",
    "x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))\n",
    "x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))\n",
    "\n",
    "TRAIN_BUF = 60000\n",
    "BATCH_SIZE = 100\n",
    "TEST_BUF = 10000\n",
    "\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(TRAIN_BUF).batch(BATCH_SIZE)\n",
    "test_dataset = tf.data.Dataset.from_tensor_slices(x_test).shuffle(TEST_BUF).batch(BATCH_SIZE)\n",
    "\n",
    "# temperature\n",
    "tau = PARAMS[\"init_temperature\"]\n",
    "anneal_rate = PARAMS[\"anneal_rate\"]\n",
    "min_temperature = PARAMS[\"min_temperature\"]\n",
    "\n",
    "results = []\n",
    "\n",
    "# Train\n",
    "for epoch in range(1, PARAMS[\"nb_epoch\"] + 1):\n",
    "    \n",
    "    # this is only needed for the standard Gumbel softmax trick\n",
    "    tau = np.maximum(tau * np.exp(-anneal_rate*epoch), min_temperature)\n",
    "\n",
    "    for train_x in train_dataset:\n",
    "        gradients, loss = compute_gradients(model, train_x, tau, hard=PARAMS[\"hard\"])\n",
    "        apply_gradients(optimizer, gradients, model.trainable_variables)\n",
    "\n",
    "    print(\"Epoch:\", epoch, \", TRAIN loss:\", loss.numpy(), \", Temperature:\", tau)\n",
    "\n",
    "    if epoch % 1 == 0:\n",
    "        losses = []\n",
    "        for test_x in test_dataset:\n",
    "            losses.append(gumbel_loss(model, test_x, tau, hard=True, evaluate=True))\n",
    "        eval_loss = np.mean(losses)\n",
    "        results.append(eval_loss)\n",
    "        print(\"Eval Loss:\", eval_loss, \"\\n\")\n",
    "\n",
    "    if PARAMS['hard'] == True:\n",
    "        model.save_weights(\"model.h5\")\n",
    "    else:\n",
    "        model.save_weights(\"model_hard.h5\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
