{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-09-06 13:35:09.149772: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "# choose a particular GPU\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "gpus = tf.config.experimental.list_physical_devices('GPU')\n",
    "tf.config.experimental.set_visible_devices(gpus[0], 'GPU')\n",
    "tf.config.experimental.set_memory_growth(gpus[0], True)\n",
    "\n",
    "from tensorflow.keras.layers import Dense, Flatten\n",
    "from tensorflow.keras import Model\n",
    "from tensorflow.keras.utils import plot_model\n",
    "\n",
    "import tensorflow_probability as tfp\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import time\n",
    "\n",
    "seed = 1234\n",
    "tf.random.set_seed(seed)\n",
    "os.environ['TF_DETERMINISTIC_OPS'] = 'true'\n",
    "os.environ['PYTHONHASHSEED'] = f'{seed}'\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "stop = tf.stop_gradient\n",
    "log1mexp = tfp.math.log1mexp\n",
    "\n",
    "@tf.function\n",
    "def log_sigmoid(logits):\n",
    "    return tf.clip_by_value(tf.math.log_sigmoid(logits), clip_value_max=-1e-7, clip_value_min=-float('inf'))\n",
    "\n",
    "@tf.function\n",
    "def logaddexp(x1, x2):\n",
    "    delta = tf.where(x1 == x2, 0., x1 - x2)\n",
    "    return tf.math.maximum(x1, x2) + tf.math.softplus(-tf.math.abs(delta))\n",
    "\n",
    "@tf.function\n",
    "def log_pr_exactly_k(logp, logq, k):\n",
    "    \n",
    "    batch_size = logp.shape[0]\n",
    "    n = logp.shape[1]\n",
    "    \n",
    "    state = np.ones((batch_size, k+2)) * -float('inf')\n",
    "    state[:, 1] = 0\n",
    "    state = tf.convert_to_tensor(state, dtype=tf.float32)\n",
    "\n",
    "    a = tf.TensorArray(tf.float32, size=n+1)\n",
    "    a = a.write(0, state)\n",
    "    \n",
    "    for i in range(1, n+1):\n",
    "        \n",
    "        state = tf.concat([\n",
    "            tf.ones([batch_size, 1]) * -float('inf'), \n",
    "            logaddexp(\n",
    "                state[:, :-1] + logp[:, i-1:i], \n",
    "                state[:, 1:] + logq[:, i-1:i]\n",
    "            )\n",
    "        ], 1)\n",
    "        \n",
    "        a = a.write(i, state)\n",
    "    a = tf.transpose(a.stack(), perm=[1, 0, 2])\n",
    "    return a\n",
    "\n",
    "# @tf.function\n",
    "def marginals(theta, k):\n",
    "    log_p = log_sigmoid(theta) \n",
    "    log_p_complement = log1mexp(log_p) \n",
    "    with tf.GradientTape() as tape:\n",
    "        tape.watch(log_p)\n",
    "        a = log_pr_exactly_k(log_p, log_p_complement, 10)\n",
    "        log_pr = a[:, -1, k+1:k+2]\n",
    "    return tape.gradient(log_pr, log_p), a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def sample(a, probs):\n",
    "    \n",
    "    n = a.shape[-2] - 1\n",
    "    k = a.shape[-1] - 1\n",
    "    bsz = a.shape[0]\n",
    "\n",
    "    j = tf.fill((bsz,), k)\n",
    "    samples = tf.TensorArray(tf.int32, size=n, clear_after_read=False)\n",
    "    \n",
    "    for i in tf.range(n, 0, -1):\n",
    "        \n",
    "        # Unnormalized probabilities of Xi and -Xi\n",
    "        full = tf.fill((bsz,), i-1)\n",
    "        p_idx = tf.stack([full, j-1], axis=1)\n",
    "        z_idx = tf.stack([full + 1, j], axis=1)\n",
    "        \n",
    "        p = tf.gather_nd(batch_dims=1, indices=p_idx, params=a)\n",
    "        z = tf.gather_nd(batch_dims=1, indices=z_idx, params=a)\n",
    "        \n",
    "        p = (p + probs[:, i-1]) - z\n",
    "        q = log1mexp(p)\n",
    "\n",
    "        # Sample according to normalized dist.\n",
    "        X = tfp.distributions.Bernoulli(logits=(p-q)).sample()\n",
    "\n",
    "        # Pick next state based on value of sample\n",
    "        j = tf.where(X>0, j - 1, j)\n",
    "\n",
    "        # Concatenate to samples\n",
    "        samples = samples.write(i-1, X)\n",
    "        \n",
    "    samples = tf.transpose(samples.stack(), perm=[1, 0])\n",
    "    \n",
    "    # Our samples should always satisfy the constraint\n",
    "    tf.debugging.assert_equal(tf.math.reduce_sum(samples, axis=-1), k-1)\n",
    "    \n",
    "    return tf.cast(samples, tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def xexpx(x):\n",
    "    expx = tf.exp(x)\n",
    "    return tf.where(expx == 0, expx, x*expx)\n",
    "\n",
    "@tf.function\n",
    "def xexpy(x,y):\n",
    "    expy = tf.exp(y)\n",
    "    return tf.where(expy == 0, expy, x*expy)\n",
    "\n",
    "@tf.function\n",
    "def entropy(a, logprobs):\n",
    "    entropy = tf.zeros((a.shape[0], a.shape[-1]))\n",
    "    for i in range(10, a.shape[-2]):\n",
    "        \n",
    "        p_left = (a[:, i-1, :-1] + logprobs[:, i-1:i]) - a[:, i, 1:]\n",
    "        p_right = (a[:, i-1, 1:] + log1mexp(logprobs[:, i-1:i])) - a[:, i, 1:]\n",
    "        \n",
    "        entropy = tf.concat([tf.zeros((a.shape[0], 1)),\n",
    "                             xexpx(p_left) + xexpx(p_right) +\\\n",
    "                             xexpy(entropy[:, :-1], p_left) + xexpy(entropy[:, 1:], p_right)\n",
    "                            ], 1)\n",
    "    return tf.clip_by_value(-entropy[:, -1], clip_value_max=float('inf'), clip_value_min=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class IMLESubsetkLayer(tf.keras.layers.Layer):\n",
    "  \n",
    "    def __init__(self, _k=10, _tau=1.0, _lambda=1.0):\n",
    "        super(IMLESubsetkLayer, self).__init__()\n",
    "        \n",
    "        self.k = _k\n",
    "        self._tau = _tau\n",
    "        self._lambda = _lambda\n",
    "        self.samples = None\n",
    "        self.gumbel_dist = tfp.distributions.Gumbel(loc=0.0, scale=1.0)\n",
    "        \n",
    "    @tf.function\n",
    "    def sample_gumbel(self, shape, eps=1e-20):\n",
    "        return self.gumbel_dist.sample(shape)\n",
    "    \n",
    "    @tf.function\n",
    "    def sample_gumbel_k(self, shape):\n",
    "        \n",
    "        s = tf.map_fn(fn=lambda t: tf.random.gamma(shape, 1.0/self.k,  t/self.k), \n",
    "                  elems=tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]))\n",
    "        # now add the samples\n",
    "        s = tf.reduce_sum(s, 0)\n",
    "        # the log(m) term\n",
    "        s = s - tf.math.log(10.0)\n",
    "        # divide by k --> each s[c] has k samples whose sum is distributed as Gumbel(0, 1)\n",
    "        s = self._tau * (s / self.k)\n",
    "\n",
    "        return s\n",
    "    \n",
    "\n",
    "    @tf.custom_gradient\n",
    "    def imle_layer(self, logits, hard=False):\n",
    "        \n",
    "        # ZK: Should be exact sampling: we're going to pass it to the decoder on the forward pass\n",
    "        logp = log_sigmoid(logits)\n",
    "        logq = log1mexp(logp)\n",
    "        \n",
    "        a = log_pr_exactly_k(logp, logq, self.k)\n",
    "        samples_p = sample(a, logp)\n",
    "\n",
    "        def custom_grad(dy):\n",
    "            \n",
    "            with tf.autodiff.ForwardAccumulator(logits, dy) as accumulate:\n",
    "                y = marginals(logits, self.k)[0]\n",
    "            return accumulate.jvp(y), hard\n",
    "\n",
    "            return grad, hard\n",
    "\n",
    "        return samples_p, custom_grad\n",
    "\n",
    "    def call(self, logits, hard=False):\n",
    "        return self.imle_layer(logits, hard)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "PARAMS = {\n",
    "    \"batch_size\": 100,\n",
    "    \"data_dim\": 784,\n",
    "    \"M\": 20,\n",
    "    \"N\": 20,\n",
    "    \"nb_epoch\": 200, \n",
    "    \"epsilon_std\": 0.01,\n",
    "    \"anneal_rate\": 0.0003,\n",
    "    \"init_temperature\": 1.0,\n",
    "    \"min_temperature\": 0.5,\n",
    "    \"learning_rate\": 5e-4,\n",
    "    \"hard\": False,\n",
    "}\n",
    "\n",
    "class DiscreteVAE(tf.keras.Model):\n",
    "    \n",
    "    def __init__(self, params):\n",
    "        super(DiscreteVAE, self).__init__()\n",
    "        \n",
    "        self.params = params\n",
    "                \n",
    "        # encoder\n",
    "        self.enc_dense1 = tf.keras.layers.Dense(512, activation='relu')\n",
    "        self.enc_dense2 = tf.keras.layers.Dense(256, activation='relu')\n",
    "        self.enc_dense3 = tf.keras.layers.Dense(params[\"N\"]*params[\"M\"])\n",
    "        \n",
    "        # this is our new Gumbel layer\n",
    "        self.imleLayer = IMLESubsetkLayer(_k=10, _tau=1.0, _lambda=10.0)\n",
    "\n",
    "        # decoder\n",
    "        self.flatten = Flatten()\n",
    "        self.dec_dense1 = tf.keras.layers.Dense(256, activation='relu')\n",
    "        self.dec_dense2 = tf.keras.layers.Dense(512, activation='relu')\n",
    "        self.dec_dense3 = tf.keras.layers.Dense(params[\"data_dim\"])\n",
    "\n",
    "\n",
    "    def sample_gumbel(self, shape, eps=1e-20): \n",
    "        \"\"\"Sample from Gumbel(0, 1)\"\"\" \n",
    "        U = tf.random.uniform(shape, minval=0, maxval=1)\n",
    "        return -tf.math.log(-tf.math.log(U + eps) + eps)\n",
    "    \n",
    "    def gumbel_softmax_sample(self, logits, temperature): \n",
    "        \"\"\" Draw a sample from the Gumbel-Softmax distribution\"\"\"\n",
    "        # logits: [batch_size, n_class] unnormalized log-probs\n",
    "        y = logits + self.sample_gumbel(tf.shape(logits))\n",
    "        return tf.nn.softmax(y / temperature)  \n",
    "\n",
    "    def gumbel_softmax(self, logits, temperature, hard=True):\n",
    "        \"\"\"\n",
    "        logits: [batch_size, n_class] unnormalized log-probs\n",
    "        temperature: non-negative scalar\n",
    "        hard: if True, take argmax, but differentiate w.r.t. soft sample y\n",
    "        \"\"\"\n",
    "        y = self.gumbel_softmax_sample(logits, temperature)\n",
    "        if hard: \n",
    "            # \n",
    "            y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keepdims=True)),y.dtype)\n",
    "            y = tf.stop_gradient(y_hard - y) + y\n",
    "        return y\n",
    "    \n",
    "    def decoder(self, x):\n",
    "        # decoder\n",
    "        h = self.flatten(x)\n",
    "        h = self.dec_dense1(h)\n",
    "        h = self.dec_dense2(h)\n",
    "        h = self.dec_dense3(h)\n",
    "        return h\n",
    "\n",
    "    def call(self, x, tau, hard=False):\n",
    "        N = self.params[\"N\"]\n",
    "        M = self.params[\"M\"]\n",
    "\n",
    "        # encoder\n",
    "        x = self.enc_dense1(x)\n",
    "        x = self.enc_dense2(x)\n",
    "        x = self.enc_dense3(x)   # (batch, N*M)\n",
    "        logits_y = tf.reshape(x, [-1, M])   # (batch*N, M)\n",
    "\n",
    "        ###################################################################\n",
    "        ## here we toggle between methods #################################\n",
    "        # here we can switch between traditional and our method\n",
    "        # \"traditional\" Gumbel Softmax trick\n",
    "        #y = self.gumbel_softmax(logits=logits_y, temperature=tau, hard=False)\n",
    "        # IMLE approach -- note: we don't anneal so set temperature once at init\n",
    "        y = self.imleLayer(logits=logits_y, hard=True)\n",
    "        ###################################################################\n",
    "        \n",
    "        assert y.shape == (self.params[\"batch_size\"]*N, M)\n",
    "        y = tf.reshape(y, [-1, N, M])\n",
    "        self.sample_y = y\n",
    "\n",
    "        # decoder\n",
    "        logits_x = self.decoder(y)\n",
    "        return logits_y, logits_x\n",
    "\n",
    "def gumbel_loss(model, x, tau, hard=True):\n",
    "    M = 20\n",
    "    N = 20\n",
    "    data_dim = PARAMS['data_dim']\n",
    "    logits_y, logits_x = model(x, tau, hard)\n",
    "    \n",
    "    # cross-entropy\n",
    "    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=logits_x)\n",
    "    cross_ent = tf.math.reduce_sum(cross_ent, 1)\n",
    "    cross_ent = tf.math.reduce_mean(cross_ent, 0)\n",
    "    \n",
    "    # KL loss\n",
    "    logprobs_q = log_sigmoid(logits_y)\n",
    "    marginals_q, a_q = marginals(logits_y, 10)\n",
    "    a_q = tf.where(a_q == -float('inf'), -1000., a_q)\n",
    "    q_entropy = entropy(a_q, logprobs_q)\n",
    "    kl = tf.math.log(184756.) - tf.reshape(q_entropy, [-1,N])\n",
    "    kl = tf.math.reduce_sum(kl, 1)\n",
    "    kl = tf.math.reduce_mean(kl)\n",
    "\n",
    "    return cross_ent + kl\n",
    "\n",
    "\n",
    "def compute_gradients(model, x, tau, hard):\n",
    "    with tf.GradientTape() as tape:\n",
    "        loss = gumbel_loss(model, x, tau, hard)\n",
    "    return tape.gradient(loss, model.trainable_variables), loss\n",
    "\n",
    "\n",
    "def apply_gradients(optimizer, gradients, variables):\n",
    "    optimizer.apply_gradients(zip(gradients, variables))\n",
    "\n",
    "\n",
    "def get_learning_rate(step, init=PARAMS[\"learning_rate\"]):\n",
    "    return tf.convert_to_tensor(init * pow(0.95, (step / 1000.)), dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-09-06 13:35:11.900861: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2022-09-06 13:35:12.437097: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 7025 MB memory:  -> device: 0, name: NVIDIA TITAN RTX, pci bus id: 0000:5e:00.0, compute capability: 7.5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 , TRAIN loss: 222.86719 , Temperature: 0.9997000449955004\n",
      "Eval Loss: 219.2978 \n",
      "\n",
      "Epoch: 2 , TRAIN loss: 204.12341 , Temperature: 0.9991004048785274\n",
      "Eval Loss: 204.57222 \n",
      "\n",
      "Epoch: 3 , TRAIN loss: 198.81891 , Temperature: 0.9982016190284373\n",
      "Eval Loss: 197.22336 \n",
      "\n",
      "Epoch: 4 , TRAIN loss: 197.72275 , Temperature: 0.997004495503373\n",
      "Eval Loss: 193.98055 \n",
      "\n",
      "Epoch: 5 , TRAIN loss: 189.35 , Temperature: 0.9955101098295706\n",
      "Eval Loss: 191.29666 \n",
      "\n",
      "Epoch: 6 , TRAIN loss: 190.1796 , Temperature: 0.9937198033910547\n",
      "Eval Loss: 189.26395 \n",
      "\n",
      "Epoch: 7 , TRAIN loss: 191.48802 , Temperature: 0.9916351814230984\n",
      "Eval Loss: 186.94223 \n",
      "\n",
      "Epoch: 8 , TRAIN loss: 188.79102 , Temperature: 0.9892581106136482\n",
      "Eval Loss: 185.77408 \n",
      "\n",
      "Epoch: 9 , TRAIN loss: 184.5834 , Temperature: 0.9865907163177327\n",
      "Eval Loss: 184.91791 \n",
      "\n",
      "Epoch: 10 , TRAIN loss: 180.92511 , Temperature: 0.9836353793906725\n",
      "Eval Loss: 184.20523 \n",
      "\n",
      "Epoch: 11 , TRAIN loss: 185.68341 , Temperature: 0.9803947326466972\n",
      "Eval Loss: 183.14648 \n",
      "\n",
      "Epoch: 12 , TRAIN loss: 180.33151 , Temperature: 0.9768716569503434\n",
      "Eval Loss: 182.39667 \n",
      "\n",
      "Epoch: 13 , TRAIN loss: 180.16121 , Temperature: 0.9730692769487556\n",
      "Eval Loss: 181.93253 \n",
      "\n",
      "Epoch: 14 , TRAIN loss: 181.99124 , Temperature: 0.9689909564537397\n",
      "Eval Loss: 181.53888 \n",
      "\n",
      "Epoch: 15 , TRAIN loss: 181.05328 , Temperature: 0.964640293483123\n",
      "Eval Loss: 181.02055 \n",
      "\n",
      "Epoch: 16 , TRAIN loss: 181.8434 , Temperature: 0.9600211149716509\n",
      "Eval Loss: 180.84465 \n",
      "\n",
      "Epoch: 17 , TRAIN loss: 178.90036 , Temperature: 0.9551374711623026\n",
      "Eval Loss: 180.33397 \n",
      "\n",
      "Epoch: 18 , TRAIN loss: 177.15692 , Temperature: 0.9499936296895314\n",
      "Eval Loss: 180.08871 \n",
      "\n",
      "Epoch: 20 , TRAIN loss: 176.72379 , Temperature: 0.9389434736891332\n",
      "Eval Loss: 179.35037 \n",
      "\n",
      "Epoch: 21 , TRAIN loss: 178.48834 , Temperature: 0.9330467240696794\n",
      "Eval Loss: 178.98863 \n",
      "\n",
      "Epoch: 22 , TRAIN loss: 185.26244 , Temperature: 0.9269088928142736\n",
      "Eval Loss: 178.76521 \n",
      "\n",
      "Epoch: 23 , TRAIN loss: 184.23013 , Temperature: 0.9205352358578188\n",
      "Eval Loss: 178.37082 \n",
      "\n",
      "Epoch: 24 , TRAIN loss: 178.35878 , Temperature: 0.9139311852712282\n",
      "Eval Loss: 178.32088 \n",
      "\n",
      "Epoch: 25 , TRAIN loss: 175.60944 , Temperature: 0.9071023415558017\n",
      "Eval Loss: 178.1059 \n",
      "\n",
      "Eval Loss: 178.1059 \n",
      "\n",
      "Eval Loss: 178.1059 \n",
      "\n",
      "Epoch: 26 , TRAIN loss: 181.47766 , Temperature: 0.9000544657400421\n",
      "Epoch: 26 , TRAIN loss: 181.47766 , Temperature: 0.9000544657400421\n",
      "Epoch: 26 , TRAIN loss: 181.47766 , Temperature: 0.9000544657400421\n",
      "Eval Loss: 177.9834 \n",
      "\n",
      "Eval Loss: 177.9834 \n",
      "\n",
      "Eval Loss: 177.9834 \n",
      "\n",
      "Epoch: 27 , TRAIN loss: 180.16612 , Temperature: 0.8927934712944959\n",
      "Epoch: 27 , TRAIN loss: 180.16612 , Temperature: 0.8927934712944959\n",
      "Epoch: 27 , TRAIN loss: 180.16612 , Temperature: 0.8927934712944959\n",
      "Eval Loss: 177.54549 \n",
      "\n",
      "Eval Loss: 177.54549 \n",
      "\n",
      "Eval Loss: 177.54549 \n",
      "\n",
      "Epoch: 28 , TRAIN loss: 175.469 , Temperature: 0.8853254158804752\n",
      "Epoch: 28 , TRAIN loss: 175.469 , Temperature: 0.8853254158804752\n",
      "Epoch: 28 , TRAIN loss: 175.469 , Temperature: 0.8853254158804752\n",
      "Eval Loss: 177.57256 \n",
      "\n",
      "Eval Loss: 177.57256 \n",
      "\n",
      "Eval Loss: 177.57256 \n",
      "\n",
      "Epoch: 29 , TRAIN loss: 177.90222 , Temperature: 0.8776564929487385\n",
      "Epoch: 29 , TRAIN loss: 177.90222 , Temperature: 0.8776564929487385\n",
      "Epoch: 29 , TRAIN loss: 177.90222 , Temperature: 0.8776564929487385\n",
      "Eval Loss: 176.818 \n",
      "\n",
      "Eval Loss: 176.818 \n",
      "\n",
      "Eval Loss: 176.818 \n",
      "\n",
      "Epoch: 30 , TRAIN loss: 175.24109 , Temperature: 0.8697930232043984\n",
      "Epoch: 30 , TRAIN loss: 175.24109 , Temperature: 0.8697930232043984\n",
      "Epoch: 30 , TRAIN loss: 175.24109 , Temperature: 0.8697930232043984\n",
      "Eval Loss: 176.91985 \n",
      "\n",
      "Eval Loss: 176.91985 \n",
      "\n",
      "Eval Loss: 176.91985 \n",
      "\n",
      "Epoch: 31 , TRAIN loss: 179.36661 , Temperature: 0.8617414459544691\n",
      "Epoch: 31 , TRAIN loss: 179.36661 , Temperature: 0.8617414459544691\n",
      "Epoch: 31 , TRAIN loss: 179.36661 , Temperature: 0.8617414459544691\n",
      "Eval Loss: 176.81981 \n",
      "\n",
      "Eval Loss: 176.81981 \n",
      "\n",
      "Eval Loss: 176.81981 \n",
      "\n",
      "Epoch: 32 , TRAIN loss: 180.45859 , Temperature: 0.8535083103545701\n",
      "Epoch: 32 , TRAIN loss: 180.45859 , Temperature: 0.8535083103545701\n",
      "Epoch: 32 , TRAIN loss: 180.45859 , Temperature: 0.8535083103545701\n",
      "Eval Loss: 176.64006 \n",
      "\n",
      "Eval Loss: 176.64006 \n",
      "\n",
      "Eval Loss: 176.64006 \n",
      "\n",
      "Epoch: 33 , TRAIN loss: 179.03955 , Temperature: 0.8451002665713722\n",
      "Epoch: 33 , TRAIN loss: 179.03955 , Temperature: 0.8451002665713722\n",
      "Epoch: 33 , TRAIN loss: 179.03955 , Temperature: 0.8451002665713722\n",
      "Eval Loss: 176.76805 \n",
      "\n",
      "Eval Loss: 176.76805 \n",
      "\n",
      "Eval Loss: 176.76805 \n",
      "\n",
      "Epoch: 34 , TRAIN loss: 173.22406 , Temperature: 0.8365240568773926\n",
      "Epoch: 34 , TRAIN loss: 173.22406 , Temperature: 0.8365240568773926\n",
      "Epoch: 34 , TRAIN loss: 173.22406 , Temperature: 0.8365240568773926\n",
      "Eval Loss: 176.29471 \n",
      "\n",
      "Eval Loss: 176.29471 \n",
      "\n",
      "Eval Loss: 176.29471 \n",
      "\n",
      "Epoch: 35 , TRAIN loss: 176.1395 , Temperature: 0.8277865066947337\n",
      "Epoch: 35 , TRAIN loss: 176.1395 , Temperature: 0.8277865066947337\n",
      "Epoch: 35 , TRAIN loss: 176.1395 , Temperature: 0.8277865066947337\n",
      "Eval Loss: 176.22055 \n",
      "\n",
      "Eval Loss: 176.22055 \n",
      "\n",
      "Eval Loss: 176.22055 \n",
      "\n",
      "Epoch: 36 , TRAIN loss: 176.95535 , Temperature: 0.8188945156043044\n",
      "Epoch: 36 , TRAIN loss: 176.95535 , Temperature: 0.8188945156043044\n",
      "Epoch: 36 , TRAIN loss: 176.95535 , Temperature: 0.8188945156043044\n",
      "Eval Loss: 176.10684 \n",
      "\n",
      "Eval Loss: 176.10684 \n",
      "\n",
      "Eval Loss: 176.10684 \n",
      "\n",
      "Epoch: 37 , TRAIN loss: 177.27069 , Temperature: 0.8098550483369699\n",
      "Epoch: 37 , TRAIN loss: 177.27069 , Temperature: 0.8098550483369699\n",
      "Epoch: 37 , TRAIN loss: 177.27069 , Temperature: 0.8098550483369699\n",
      "Eval Loss: 176.40402 \n",
      "\n",
      "Eval Loss: 176.40402 \n",
      "\n",
      "Eval Loss: 176.40402 \n",
      "\n",
      "Epoch: 38 , TRAIN loss: 177.05414 , Temperature: 0.8006751257629464\n",
      "Epoch: 38 , TRAIN loss: 177.05414 , Temperature: 0.8006751257629464\n",
      "Epoch: 38 , TRAIN loss: 177.05414 , Temperature: 0.8006751257629464\n",
      "Eval Loss: 176.01904 \n",
      "\n",
      "Eval Loss: 176.01904 \n",
      "\n",
      "Eval Loss: 176.01904 \n",
      "\n",
      "Epoch: 39 , TRAIN loss: 175.79575 , Temperature: 0.791361815895584\n",
      "Epoch: 39 , TRAIN loss: 175.79575 , Temperature: 0.791361815895584\n",
      "Epoch: 39 , TRAIN loss: 175.79575 , Temperature: 0.791361815895584\n",
      "Eval Loss: 175.71777 \n",
      "\n",
      "Eval Loss: 175.71777 \n",
      "\n",
      "Eval Loss: 175.71777 \n",
      "\n",
      "Epoch: 40 , TRAIN loss: 176.91483 , Temperature: 0.7819222249254774\n",
      "Epoch: 40 , TRAIN loss: 176.91483 , Temperature: 0.7819222249254774\n",
      "Epoch: 40 , TRAIN loss: 176.91483 , Temperature: 0.7819222249254774\n",
      "Eval Loss: 175.4921 \n",
      "\n",
      "Eval Loss: 175.4921 \n",
      "\n",
      "Eval Loss: 175.4921 \n",
      "\n",
      "Epoch: 41 , TRAIN loss: 176.57095 , Temperature: 0.7723634883006051\n",
      "Epoch: 41 , TRAIN loss: 176.57095 , Temperature: 0.7723634883006051\n",
      "Epoch: 41 , TRAIN loss: 176.57095 , Temperature: 0.7723634883006051\n",
      "Eval Loss: 175.45439 \n",
      "\n",
      "Eval Loss: 175.45439 \n",
      "\n",
      "Eval Loss: 175.45439 \n",
      "\n",
      "Epoch: 42 , TRAIN loss: 176.366 , Temperature: 0.7626927618679156\n",
      "Epoch: 42 , TRAIN loss: 176.366 , Temperature: 0.7626927618679156\n",
      "Epoch: 42 , TRAIN loss: 176.366 , Temperature: 0.7626927618679156\n",
      "Eval Loss: 175.36203 \n",
      "\n",
      "Eval Loss: 175.36203 \n",
      "\n",
      "Eval Loss: 175.36203 \n",
      "\n",
      "Epoch: 43 , TRAIN loss: 176.39609 , Temperature: 0.7529172130914742\n",
      "Epoch: 43 , TRAIN loss: 176.39609 , Temperature: 0.7529172130914742\n",
      "Epoch: 43 , TRAIN loss: 176.39609 , Temperature: 0.7529172130914742\n",
      "Eval Loss: 175.241 \n",
      "\n",
      "Eval Loss: 175.241 \n",
      "\n",
      "Eval Loss: 175.241 \n",
      "\n",
      "Epoch: 44 , TRAIN loss: 179.29848 , Temperature: 0.74304401236194\n",
      "Epoch: 44 , TRAIN loss: 179.29848 , Temperature: 0.74304401236194\n",
      "Epoch: 44 , TRAIN loss: 179.29848 , Temperature: 0.74304401236194\n",
      "Eval Loss: 175.32228 \n",
      "\n",
      "Eval Loss: 175.32228 \n",
      "\n",
      "Eval Loss: 175.32228 \n",
      "\n",
      "Epoch: 45 , TRAIN loss: 172.85764 , Temperature: 0.7330803244117685\n",
      "Epoch: 45 , TRAIN loss: 172.85764 , Temperature: 0.7330803244117685\n",
      "Epoch: 45 , TRAIN loss: 172.85764 , Temperature: 0.7330803244117685\n",
      "Eval Loss: 174.77113 \n",
      "\n",
      "Eval Loss: 174.77113 \n",
      "\n",
      "Eval Loss: 174.77113 \n",
      "\n",
      "Epoch: 46 , TRAIN loss: 174.27164 , Temperature: 0.7230332998501351\n",
      "Epoch: 46 , TRAIN loss: 174.27164 , Temperature: 0.7230332998501351\n",
      "Epoch: 46 , TRAIN loss: 174.27164 , Temperature: 0.7230332998501351\n",
      "Eval Loss: 175.218 \n",
      "\n",
      "Eval Loss: 175.218 \n",
      "\n",
      "Eval Loss: 175.218 \n",
      "\n",
      "Epoch: 47 , TRAIN loss: 176.36441 , Temperature: 0.7129100668311394\n",
      "Epoch: 47 , TRAIN loss: 176.36441 , Temperature: 0.7129100668311394\n",
      "Epoch: 47 , TRAIN loss: 176.36441 , Temperature: 0.7129100668311394\n",
      "Eval Loss: 174.86433 \n",
      "\n",
      "Eval Loss: 174.86433 \n",
      "\n",
      "Eval Loss: 174.86433 \n",
      "\n",
      "Epoch: 48 , TRAIN loss: 169.47008 , Temperature: 0.7027177228683977\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 48 , TRAIN loss: 169.47008 , Temperature: 0.7027177228683977\n",
      "Epoch: 48 , TRAIN loss: 169.47008 , Temperature: 0.7027177228683977\n",
      "Eval Loss: 174.8649 \n",
      "\n",
      "Eval Loss: 174.8649 \n",
      "\n",
      "Eval Loss: 174.8649 \n",
      "\n",
      "Epoch: 49 , TRAIN loss: 171.39128 , Temperature: 0.6924633268086435\n",
      "Epoch: 49 , TRAIN loss: 171.39128 , Temperature: 0.6924633268086435\n",
      "Epoch: 49 , TRAIN loss: 171.39128 , Temperature: 0.6924633268086435\n",
      "Eval Loss: 174.74059 \n",
      "\n",
      "Eval Loss: 174.74059 \n",
      "\n",
      "Eval Loss: 174.74059 \n",
      "\n",
      "Epoch: 50 , TRAIN loss: 174.15659 , Temperature: 0.6821538909764523\n",
      "Epoch: 50 , TRAIN loss: 174.15659 , Temperature: 0.6821538909764523\n",
      "Epoch: 50 , TRAIN loss: 174.15659 , Temperature: 0.6821538909764523\n",
      "Eval Loss: 174.45892 \n",
      "\n",
      "Eval Loss: 174.45892 \n",
      "\n",
      "Eval Loss: 174.45892 \n",
      "\n",
      "Epoch: 51 , TRAIN loss: 174.13226 , Temperature: 0.6717963735016784\n",
      "Epoch: 51 , TRAIN loss: 174.13226 , Temperature: 0.6717963735016784\n",
      "Epoch: 51 , TRAIN loss: 174.13226 , Temperature: 0.6717963735016784\n",
      "Eval Loss: 174.35344 \n",
      "\n",
      "Eval Loss: 174.35344 \n",
      "\n",
      "Eval Loss: 174.35344 \n",
      "\n",
      "Epoch: 52 , TRAIN loss: 175.11208 , Temperature: 0.6613976708406429\n",
      "Epoch: 52 , TRAIN loss: 175.11208 , Temperature: 0.6613976708406429\n",
      "Epoch: 52 , TRAIN loss: 175.11208 , Temperature: 0.6613976708406429\n",
      "Eval Loss: 174.1819 \n",
      "\n",
      "Eval Loss: 174.1819 \n",
      "\n",
      "Eval Loss: 174.1819 \n",
      "\n",
      "Epoch: 53 , TRAIN loss: 170.14043 , Temperature: 0.6509646105015452\n",
      "Epoch: 53 , TRAIN loss: 170.14043 , Temperature: 0.6509646105015452\n",
      "Epoch: 53 , TRAIN loss: 170.14043 , Temperature: 0.6509646105015452\n",
      "Eval Loss: 174.61488 \n",
      "\n",
      "Eval Loss: 174.61488 \n",
      "\n",
      "Eval Loss: 174.61488 \n",
      "\n",
      "Epoch: 54 , TRAIN loss: 175.18777 , Temperature: 0.6405039439839885\n",
      "Epoch: 54 , TRAIN loss: 175.18777 , Temperature: 0.6405039439839885\n",
      "Epoch: 54 , TRAIN loss: 175.18777 , Temperature: 0.6405039439839885\n",
      "Eval Loss: 174.39209 \n",
      "\n",
      "Eval Loss: 174.39209 \n",
      "\n",
      "Eval Loss: 174.39209 \n",
      "\n",
      "Epoch: 55 , TRAIN loss: 171.03917 , Temperature: 0.6300223399419125\n",
      "Epoch: 55 , TRAIN loss: 171.03917 , Temperature: 0.6300223399419125\n",
      "Epoch: 55 , TRAIN loss: 171.03917 , Temperature: 0.6300223399419125\n",
      "Eval Loss: 174.3662 \n",
      "\n",
      "Eval Loss: 174.3662 \n",
      "\n",
      "Eval Loss: 174.3662 \n",
      "\n",
      "Epoch: 56 , TRAIN loss: 171.82672 , Temperature: 0.6195263775786136\n",
      "Epoch: 56 , TRAIN loss: 171.82672 , Temperature: 0.6195263775786136\n",
      "Epoch: 56 , TRAIN loss: 171.82672 , Temperature: 0.6195263775786136\n",
      "Eval Loss: 174.14452 \n",
      "\n",
      "Eval Loss: 174.14452 \n",
      "\n",
      "Eval Loss: 174.14452 \n",
      "\n",
      "Epoch: 57 , TRAIN loss: 173.78212 , Temperature: 0.609022540281914\n",
      "Epoch: 57 , TRAIN loss: 173.78212 , Temperature: 0.609022540281914\n",
      "Epoch: 57 , TRAIN loss: 173.78212 , Temperature: 0.609022540281914\n",
      "Eval Loss: 173.98553 \n",
      "\n",
      "Eval Loss: 173.98553 \n",
      "\n",
      "Eval Loss: 173.98553 \n",
      "\n",
      "Epoch: 58 , TRAIN loss: 173.27303 , Temperature: 0.5985172095069092\n",
      "Epoch: 58 , TRAIN loss: 173.27303 , Temperature: 0.5985172095069092\n",
      "Epoch: 58 , TRAIN loss: 173.27303 , Temperature: 0.5985172095069092\n",
      "Eval Loss: 173.95375 \n",
      "\n",
      "Eval Loss: 173.95375 \n",
      "\n",
      "Eval Loss: 173.95375 \n",
      "\n",
      "Epoch: 59 , TRAIN loss: 169.38342 , Temperature: 0.5880166589130854\n",
      "Epoch: 59 , TRAIN loss: 169.38342 , Temperature: 0.5880166589130854\n",
      "Epoch: 59 , TRAIN loss: 169.38342 , Temperature: 0.5880166589130854\n",
      "Eval Loss: 174.06482 \n",
      "\n",
      "Eval Loss: 174.06482 \n",
      "\n",
      "Eval Loss: 174.06482 \n",
      "\n",
      "Epoch: 60 , TRAIN loss: 174.96268 , Temperature: 0.5775270487619548\n",
      "Epoch: 60 , TRAIN loss: 174.96268 , Temperature: 0.5775270487619548\n",
      "Epoch: 60 , TRAIN loss: 174.96268 , Temperature: 0.5775270487619548\n",
      "Eval Loss: 173.57787 \n",
      "\n",
      "Eval Loss: 173.57787 \n",
      "\n",
      "Eval Loss: 173.57787 \n",
      "\n",
      "Epoch: 61 , TRAIN loss: 172.9725 , Temperature: 0.5670544205807092\n",
      "Epoch: 61 , TRAIN loss: 172.9725 , Temperature: 0.5670544205807092\n",
      "Epoch: 61 , TRAIN loss: 172.9725 , Temperature: 0.5670544205807092\n",
      "Eval Loss: 174.04526 \n",
      "\n",
      "Eval Loss: 174.04526 \n",
      "\n",
      "Eval Loss: 174.04526 \n",
      "\n",
      "Epoch: 62 , TRAIN loss: 171.10132 , Temperature: 0.556604692096744\n",
      "Epoch: 62 , TRAIN loss: 171.10132 , Temperature: 0.556604692096744\n",
      "Epoch: 62 , TRAIN loss: 171.10132 , Temperature: 0.556604692096744\n",
      "Eval Loss: 173.58028 \n",
      "\n",
      "Eval Loss: 173.58028 \n",
      "\n",
      "Eval Loss: 173.58028 \n",
      "\n",
      "Epoch: 63 , TRAIN loss: 172.89413 , Temperature: 0.5461836524472542\n",
      "Epoch: 63 , TRAIN loss: 172.89413 , Temperature: 0.5461836524472542\n",
      "Epoch: 63 , TRAIN loss: 172.89413 , Temperature: 0.5461836524472542\n",
      "Eval Loss: 173.60333 \n",
      "\n",
      "Eval Loss: 173.60333 \n",
      "\n",
      "Eval Loss: 173.60333 \n",
      "\n",
      "Epoch: 64 , TRAIN loss: 174.63405 , Temperature: 0.5357969576674562\n",
      "Epoch: 64 , TRAIN loss: 174.63405 , Temperature: 0.5357969576674562\n",
      "Epoch: 64 , TRAIN loss: 174.63405 , Temperature: 0.5357969576674562\n",
      "Eval Loss: 173.67941 \n",
      "\n",
      "Eval Loss: 173.67941 \n",
      "\n",
      "Eval Loss: 173.67941 \n",
      "\n",
      "Epoch: 65 , TRAIN loss: 173.0586 , Temperature: 0.5254501264603462\n",
      "Epoch: 65 , TRAIN loss: 173.0586 , Temperature: 0.5254501264603462\n",
      "Epoch: 65 , TRAIN loss: 173.0586 , Temperature: 0.5254501264603462\n",
      "Eval Loss: 173.6269 \n",
      "\n",
      "Eval Loss: 173.6269 \n",
      "\n",
      "Eval Loss: 173.6269 \n",
      "\n",
      "Epoch: 66 , TRAIN loss: 171.43977 , Temperature: 0.5151485362502642\n",
      "Epoch: 66 , TRAIN loss: 171.43977 , Temperature: 0.5151485362502642\n",
      "Epoch: 66 , TRAIN loss: 171.43977 , Temperature: 0.5151485362502642\n",
      "Eval Loss: 173.32936 \n",
      "\n",
      "Eval Loss: 173.32936 \n",
      "\n",
      "Eval Loss: 173.32936 \n",
      "\n",
      "Epoch: 67 , TRAIN loss: 178.52104 , Temperature: 0.5048974195219025\n",
      "Epoch: 67 , TRAIN loss: 178.52104 , Temperature: 0.5048974195219025\n",
      "Epoch: 67 , TRAIN loss: 178.52104 , Temperature: 0.5048974195219025\n",
      "Eval Loss: 173.46825 \n",
      "\n",
      "Eval Loss: 173.46825 \n",
      "\n",
      "Eval Loss: 173.46825 \n",
      "\n",
      "Epoch: 68 , TRAIN loss: 172.51602 , Temperature: 0.5\n",
      "Epoch: 68 , TRAIN loss: 172.51602 , Temperature: 0.5\n",
      "Epoch: 68 , TRAIN loss: 172.51602 , Temperature: 0.5\n",
      "Eval Loss: 173.41962 \n",
      "\n",
      "Eval Loss: 173.41962 \n",
      "\n",
      "Eval Loss: 173.41962 \n",
      "\n",
      "Epoch: 69 , TRAIN loss: 172.0632 , Temperature: 0.5\n",
      "Epoch: 69 , TRAIN loss: 172.0632 , Temperature: 0.5\n",
      "Epoch: 69 , TRAIN loss: 172.0632 , Temperature: 0.5\n",
      "Eval Loss: 173.48105 \n",
      "\n",
      "Eval Loss: 173.48105 \n",
      "\n",
      "Eval Loss: 173.48105 \n",
      "\n",
      "Epoch: 70 , TRAIN loss: 174.34991 , Temperature: 0.5\n",
      "Epoch: 70 , TRAIN loss: 174.34991 , Temperature: 0.5\n",
      "Epoch: 70 , TRAIN loss: 174.34991 , Temperature: 0.5\n",
      "Eval Loss: 173.27168 \n",
      "\n",
      "Eval Loss: 173.27168 \n",
      "\n",
      "Eval Loss: 173.27168 \n",
      "\n",
      "Epoch: 71 , TRAIN loss: 175.90549 , Temperature: 0.5\n",
      "Epoch: 71 , TRAIN loss: 175.90549 , Temperature: 0.5\n",
      "Epoch: 71 , TRAIN loss: 175.90549 , Temperature: 0.5\n",
      "Eval Loss: 173.16269 \n",
      "\n",
      "Eval Loss: 173.16269 \n",
      "\n",
      "Eval Loss: 173.16269 \n",
      "\n",
      "Epoch: 72 , TRAIN loss: 171.69522 , Temperature: 0.5\n",
      "Epoch: 72 , TRAIN loss: 171.69522 , Temperature: 0.5\n",
      "Epoch: 72 , TRAIN loss: 171.69522 , Temperature: 0.5\n",
      "Eval Loss: 173.14604 \n",
      "\n",
      "Eval Loss: 173.14604 \n",
      "\n",
      "Eval Loss: 173.14604 \n",
      "\n",
      "Epoch: 73 , TRAIN loss: 174.16461 , Temperature: 0.5\n",
      "Epoch: 73 , TRAIN loss: 174.16461 , Temperature: 0.5\n",
      "Epoch: 73 , TRAIN loss: 174.16461 , Temperature: 0.5\n",
      "Eval Loss: 173.11865 \n",
      "\n",
      "Eval Loss: 173.11865 \n",
      "\n",
      "Eval Loss: 173.11865 \n",
      "\n",
      "Epoch: 74 , TRAIN loss: 172.67331 , Temperature: 0.5\n",
      "Epoch: 74 , TRAIN loss: 172.67331 , Temperature: 0.5\n",
      "Epoch: 74 , TRAIN loss: 172.67331 , Temperature: 0.5\n",
      "Eval Loss: 173.05666 \n",
      "\n",
      "Eval Loss: 173.05666 \n",
      "\n",
      "Eval Loss: 173.05666 \n",
      "\n",
      "Epoch: 75 , TRAIN loss: 176.88733 , Temperature: 0.5\n",
      "Epoch: 75 , TRAIN loss: 176.88733 , Temperature: 0.5\n",
      "Epoch: 75 , TRAIN loss: 176.88733 , Temperature: 0.5\n",
      "Eval Loss: 173.06396 \n",
      "\n",
      "Eval Loss: 173.06396 \n",
      "\n",
      "Eval Loss: 173.06396 \n",
      "\n",
      "Epoch: 76 , TRAIN loss: 171.48334 , Temperature: 0.5\n",
      "Epoch: 76 , TRAIN loss: 171.48334 , Temperature: 0.5\n",
      "Epoch: 76 , TRAIN loss: 171.48334 , Temperature: 0.5\n",
      "Eval Loss: 172.81145 \n",
      "\n",
      "Eval Loss: 172.81145 \n",
      "\n",
      "Eval Loss: 172.81145 \n",
      "\n",
      "Epoch: 77 , TRAIN loss: 168.52563 , Temperature: 0.5\n",
      "Epoch: 77 , TRAIN loss: 168.52563 , Temperature: 0.5\n",
      "Epoch: 77 , TRAIN loss: 168.52563 , Temperature: 0.5\n",
      "Eval Loss: 172.88803 \n",
      "\n",
      "Eval Loss: 172.88803 \n",
      "\n",
      "Eval Loss: 172.88803 \n",
      "\n",
      "Epoch: 78 , TRAIN loss: 173.72249 , Temperature: 0.5\n",
      "Epoch: 78 , TRAIN loss: 173.72249 , Temperature: 0.5\n",
      "Epoch: 78 , TRAIN loss: 173.72249 , Temperature: 0.5\n",
      "Eval Loss: 172.80031 \n",
      "\n",
      "Eval Loss: 172.80031 \n",
      "\n",
      "Eval Loss: 172.80031 \n",
      "\n",
      "Epoch: 79 , TRAIN loss: 175.05774 , Temperature: 0.5\n",
      "Epoch: 79 , TRAIN loss: 175.05774 , Temperature: 0.5\n",
      "Epoch: 79 , TRAIN loss: 175.05774 , Temperature: 0.5\n",
      "Eval Loss: 172.64458 \n",
      "\n",
      "Eval Loss: 172.64458 \n",
      "\n",
      "Eval Loss: 172.64458 \n",
      "\n",
      "Epoch: 80 , TRAIN loss: 172.05139 , Temperature: 0.5\n",
      "Epoch: 80 , TRAIN loss: 172.05139 , Temperature: 0.5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 80 , TRAIN loss: 172.05139 , Temperature: 0.5\n",
      "Eval Loss: 173.15005 \n",
      "\n",
      "Eval Loss: 173.15005 \n",
      "\n",
      "Eval Loss: 173.15005 \n",
      "\n",
      "Epoch: 81 , TRAIN loss: 172.23923 , Temperature: 0.5\n",
      "Epoch: 81 , TRAIN loss: 172.23923 , Temperature: 0.5\n",
      "Epoch: 81 , TRAIN loss: 172.23923 , Temperature: 0.5\n",
      "Eval Loss: 172.63911 \n",
      "\n",
      "Eval Loss: 172.63911 \n",
      "\n",
      "Eval Loss: 172.63911 \n",
      "\n",
      "Epoch: 83 , TRAIN loss: 172.92651 , Temperature: 0.5\n",
      "Eval Loss: 172.52286 \n",
      "\n",
      "Epoch: 84 , TRAIN loss: 170.39667 , Temperature: 0.5\n",
      "Eval Loss: 172.58745 \n",
      "\n",
      "Epoch: 85 , TRAIN loss: 169.82507 , Temperature: 0.5\n",
      "Eval Loss: 172.43915 \n",
      "\n",
      "Epoch: 86 , TRAIN loss: 172.76514 , Temperature: 0.5\n",
      "Eval Loss: 172.3853 \n",
      "\n",
      "Epoch: 87 , TRAIN loss: 171.57738 , Temperature: 0.5\n",
      "Eval Loss: 172.30728 \n",
      "\n",
      "Epoch: 88 , TRAIN loss: 172.87787 , Temperature: 0.5\n",
      "Eval Loss: 172.42871 \n",
      "\n",
      "Epoch: 89 , TRAIN loss: 171.5992 , Temperature: 0.5\n",
      "Eval Loss: 172.26154 \n",
      "\n",
      "Epoch: 90 , TRAIN loss: 168.21303 , Temperature: 0.5\n",
      "Eval Loss: 172.19038 \n",
      "\n",
      "Epoch: 91 , TRAIN loss: 171.00125 , Temperature: 0.5\n",
      "Eval Loss: 172.0547 \n",
      "\n",
      "Epoch: 92 , TRAIN loss: 173.19826 , Temperature: 0.5\n",
      "Eval Loss: 172.22894 \n",
      "\n",
      "Epoch: 93 , TRAIN loss: 173.17082 , Temperature: 0.5\n",
      "Eval Loss: 172.3416 \n",
      "\n",
      "Epoch: 94 , TRAIN loss: 173.84677 , Temperature: 0.5\n",
      "Eval Loss: 171.98016 \n",
      "\n",
      "Epoch: 95 , TRAIN loss: 175.77722 , Temperature: 0.5\n",
      "Eval Loss: 171.86363 \n",
      "\n",
      "Epoch: 96 , TRAIN loss: 172.18658 , Temperature: 0.5\n",
      "Eval Loss: 171.57275 \n",
      "\n",
      "Epoch: 97 , TRAIN loss: 169.5189 , Temperature: 0.5\n",
      "Eval Loss: 171.79364 \n",
      "\n",
      "Epoch: 98 , TRAIN loss: 169.80823 , Temperature: 0.5\n",
      "Eval Loss: 171.75197 \n",
      "\n",
      "Epoch: 99 , TRAIN loss: 172.32935 , Temperature: 0.5\n",
      "Eval Loss: 171.8062 \n",
      "\n",
      "Epoch: 100 , TRAIN loss: 169.79044 , Temperature: 0.5\n",
      "Eval Loss: 171.65459 \n",
      "\n",
      "Epoch: 101 , TRAIN loss: 171.969 , Temperature: 0.5\n",
      "Eval Loss: 171.97168 \n",
      "\n",
      "Epoch: 102 , TRAIN loss: 170.96527 , Temperature: 0.5\n",
      "Eval Loss: 171.4877 \n",
      "\n",
      "Epoch: 103 , TRAIN loss: 168.14055 , Temperature: 0.5\n",
      "Eval Loss: 171.63568 \n",
      "\n",
      "Epoch: 104 , TRAIN loss: 170.14471 , Temperature: 0.5\n",
      "Eval Loss: 171.67625 \n",
      "\n",
      "Epoch: 105 , TRAIN loss: 172.2569 , Temperature: 0.5\n",
      "Eval Loss: 171.44037 \n",
      "\n",
      "Epoch: 106 , TRAIN loss: 169.6297 , Temperature: 0.5\n",
      "Eval Loss: 171.81468 \n",
      "\n",
      "Epoch: 107 , TRAIN loss: 170.8016 , Temperature: 0.5\n",
      "Eval Loss: 171.37704 \n",
      "\n",
      "Epoch: 108 , TRAIN loss: 171.62402 , Temperature: 0.5\n",
      "Eval Loss: 171.61223 \n",
      "\n",
      "Epoch: 109 , TRAIN loss: 168.9326 , Temperature: 0.5\n",
      "Eval Loss: 171.2698 \n",
      "\n",
      "Epoch: 110 , TRAIN loss: 168.3006 , Temperature: 0.5\n",
      "Eval Loss: 171.31754 \n",
      "\n",
      "Epoch: 111 , TRAIN loss: 175.29173 , Temperature: 0.5\n",
      "Eval Loss: 171.34674 \n",
      "\n",
      "Epoch: 112 , TRAIN loss: 169.33163 , Temperature: 0.5\n",
      "Eval Loss: 171.29346 \n",
      "\n",
      "Epoch: 113 , TRAIN loss: 170.48087 , Temperature: 0.5\n",
      "Eval Loss: 171.32281 \n",
      "\n",
      "Epoch: 114 , TRAIN loss: 168.71515 , Temperature: 0.5\n",
      "Eval Loss: 171.37836 \n",
      "\n",
      "Epoch: 115 , TRAIN loss: 171.1796 , Temperature: 0.5\n",
      "Eval Loss: 171.13441 \n",
      "\n",
      "Epoch: 116 , TRAIN loss: 173.3055 , Temperature: 0.5\n",
      "Eval Loss: 171.07916 \n",
      "\n",
      "Epoch: 117 , TRAIN loss: 169.52194 , Temperature: 0.5\n",
      "Eval Loss: 171.22803 \n",
      "\n",
      "Epoch: 118 , TRAIN loss: 171.73375 , Temperature: 0.5\n",
      "Eval Loss: 171.11284 \n",
      "\n",
      "Epoch: 119 , TRAIN loss: 170.52943 , Temperature: 0.5\n",
      "Eval Loss: 171.07472 \n",
      "\n",
      "Epoch: 120 , TRAIN loss: 175.2594 , Temperature: 0.5\n",
      "Eval Loss: 170.99855 \n",
      "\n",
      "Epoch: 121 , TRAIN loss: 168.57071 , Temperature: 0.5\n",
      "Eval Loss: 171.07616 \n",
      "\n",
      "Epoch: 122 , TRAIN loss: 170.13214 , Temperature: 0.5\n",
      "Eval Loss: 170.97832 \n",
      "\n",
      "Epoch: 123 , TRAIN loss: 170.14641 , Temperature: 0.5\n",
      "Eval Loss: 171.07822 \n",
      "\n",
      "Epoch: 124 , TRAIN loss: 170.36018 , Temperature: 0.5\n",
      "Eval Loss: 170.90332 \n",
      "\n",
      "Epoch: 125 , TRAIN loss: 171.58618 , Temperature: 0.5\n",
      "Eval Loss: 171.03633 \n",
      "\n",
      "Epoch: 126 , TRAIN loss: 170.00186 , Temperature: 0.5\n",
      "Eval Loss: 170.78802 \n",
      "\n",
      "Epoch: 127 , TRAIN loss: 169.2663 , Temperature: 0.5\n",
      "Eval Loss: 170.89006 \n",
      "\n",
      "Epoch: 128 , TRAIN loss: 169.8675 , Temperature: 0.5\n",
      "Eval Loss: 170.74904 \n",
      "\n",
      "Epoch: 129 , TRAIN loss: 164.47319 , Temperature: 0.5\n",
      "Eval Loss: 170.69568 \n",
      "\n",
      "Epoch: 130 , TRAIN loss: 166.60895 , Temperature: 0.5\n",
      "Eval Loss: 170.52203 \n",
      "\n",
      "Epoch: 131 , TRAIN loss: 167.27783 , Temperature: 0.5\n",
      "Eval Loss: 170.69928 \n",
      "\n",
      "Epoch: 132 , TRAIN loss: 168.0321 , Temperature: 0.5\n",
      "Eval Loss: 170.69049 \n",
      "\n",
      "Epoch: 133 , TRAIN loss: 171.77798 , Temperature: 0.5\n",
      "Eval Loss: 170.58992 \n",
      "\n",
      "Epoch: 134 , TRAIN loss: 169.65994 , Temperature: 0.5\n",
      "Eval Loss: 170.75653 \n",
      "\n",
      "Epoch: 135 , TRAIN loss: 170.10175 , Temperature: 0.5\n"
     ]
    }
   ],
   "source": [
    "# %%time\n",
    "\n",
    "np.set_printoptions(precision=4,linewidth=200)\n",
    "model = DiscreteVAE(PARAMS)\n",
    "learning_rate = tf.Variable(PARAMS[\"learning_rate\"], trainable=False, name=\"LR\")\n",
    "\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)\n",
    "\n",
    "# data\n",
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "x_train = x_train.astype('float32') / 255.\n",
    "x_test = x_test.astype('float32') / 255.\n",
    "x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))\n",
    "x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))\n",
    "\n",
    "TRAIN_BUF = 60000\n",
    "BATCH_SIZE = 100\n",
    "TEST_BUF = 10000\n",
    "\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(TRAIN_BUF).batch(BATCH_SIZE)\n",
    "test_dataset = tf.data.Dataset.from_tensor_slices(x_test).shuffle(TEST_BUF).batch(BATCH_SIZE)\n",
    "\n",
    "# temperature\n",
    "tau = PARAMS[\"init_temperature\"]\n",
    "anneal_rate = PARAMS[\"anneal_rate\"]\n",
    "min_temperature = PARAMS[\"min_temperature\"]\n",
    "\n",
    "results = []\n",
    "\n",
    "# Train\n",
    "for epoch in range(1, PARAMS[\"nb_epoch\"] + 1):\n",
    "    \n",
    "    # this is only needed for the standard Gumbel softmax trick\n",
    "    tau = np.maximum(tau * np.exp(-anneal_rate*epoch), min_temperature)\n",
    "\n",
    "    for train_x in train_dataset:\n",
    "        gradients, loss = compute_gradients(model, train_x, tau, hard=PARAMS[\"hard\"])\n",
    "        apply_gradients(optimizer, gradients, model.trainable_variables)\n",
    "\n",
    "    print(\"Epoch:\", epoch, \", TRAIN loss:\", loss.numpy(), \", Temperature:\", tau)\n",
    "\n",
    "    if epoch % 1 == 0:\n",
    "        losses = []\n",
    "        for test_x in test_dataset:\n",
    "            losses.append(gumbel_loss(model, test_x, tau, hard=True))\n",
    "        eval_loss = np.mean(losses)\n",
    "        results.append(eval_loss)\n",
    "        print(\"Eval Loss:\", eval_loss, \"\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
