{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pathlib\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = pathlib.Path(os.getcwd())\n",
    "base_path = str(base_path.parent)\n",
    "sys.path = [base_path] + sys.path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import random as python_random\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import backend as K\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras.utils import model_to_dot\n",
    "from tensorflow.python.keras.utils import tf_utils\n",
    "from IPython.display import SVG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set(context=\"notebook\", style=\"darkgrid\", palette=\"deep\", font=\"sans-serif\", font_scale=1.0, color_codes=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"./img/\", exist_ok=True)\n",
    "os.makedirs(\"./model/\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 6902\n",
    "\n",
    "# The below is necessary for starting Numpy generated random numbers\n",
    "# in a well-defined initial state.\n",
    "np.random.seed(seed)\n",
    "\n",
    "# The below is necessary for starting core Python generated random numbers\n",
    "# in a well-defined state.\n",
    "python_random.seed(seed)\n",
    "\n",
    "# The below set_seed() will make random number generation\n",
    "# in the TensorFlow backend have a well-defined initial state.\n",
    "# For further details, see:\n",
    "# https://www.tensorflow.org/api_docs/python/tf/random/set_seed\n",
    "tf.random.set_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Driver Version: b'419.17'\n",
      "Device 0 : b'GeForce GTX 1070 Ti'\n"
     ]
    }
   ],
   "source": [
    "from pynvml import *\n",
    "\n",
    "try:\n",
    "    nvmlInit()\n",
    "    print(\"Driver Version:\", nvmlSystemGetDriverVersion())\n",
    "    deviceCount = nvmlDeviceGetCount()\n",
    "    for i in range(deviceCount):\n",
    "        handle = nvmlDeviceGetHandleByIndex(i)\n",
    "        print(\"Device\", i, \":\", nvmlDeviceGetName(handle))\n",
    "    nvmlShutdown()\n",
    "except NVMLError as error:\n",
    "    print(error)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "python_version: 3.6.10.final.0 (64 bit)\n",
      "cpuinfo_version: [7, 0, 0]\n",
      "cpuinfo_version_string: 7.0.0\n",
      "arch: X86_64\n",
      "bits: 64\n",
      "count: 12\n",
      "arch_string_raw: AMD64\n",
      "vendor_id_raw: GenuineIntel\n",
      "brand_raw: Intel(R) Core(TM) i7-8700 CPU @ 3.20GHz\n",
      "hz_advertised_friendly: 3.2000 GHz\n",
      "hz_actual_friendly: 3.1920 GHz\n",
      "hz_advertised: [3200000000, 0]\n",
      "hz_actual: [3192000000, 0]\n",
      "l2_cache_size: 1572864\n",
      "stepping: 10\n",
      "model: 158\n",
      "family: 6\n",
      "l3_cache_size: 12582912\n",
      "flags: ['3dnow', '3dnowprefetch', 'abm', 'acpi', 'adx', 'aes', 'apic', 'avx', 'avx2', 'bmi1', 'bmi2', 'clflush', 'clflushopt', 'cmov', 'cx16', 'cx8', 'de', 'dtes64', 'dts', 'erms', 'est', 'f16c', 'fma', 'fpu', 'fxsr', 'hle', 'ht', 'hypervisor', 'ia64', 'invpcid', 'lahf_lm', 'mca', 'mce', 'mmx', 'movbe', 'mpx', 'msr', 'mtrr', 'osxsave', 'pae', 'pat', 'pbe', 'pcid', 'pclmulqdq', 'pdcm', 'pge', 'pni', 'popcnt', 'pse', 'pse36', 'rdrnd', 'rdseed', 'rtm', 'sep', 'serial', 'smap', 'smep', 'ss', 'sse', 'sse2', 'sse4_1', 'sse4_2', 'ssse3', 'tm', 'tm2', 'tsc', 'vme', 'x2apic', 'xsave', 'xtpr']\n",
      "l2_cache_line_size: 256\n",
      "l2_cache_associativity: 6\n"
     ]
    }
   ],
   "source": [
    "from cpuinfo import get_cpu_info\n",
    "\n",
    "for key, value in get_cpu_info().items():\n",
    "    print(\"{0}: {1}\".format(key, value))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "memory: 15.92GB\n"
     ]
    }
   ],
   "source": [
    "import psutil \n",
    "\n",
    "mem = psutil.virtual_memory() \n",
    "print(\"memory: {0:.2f}GB\".format(mem.total / 1024**3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "SKIP_TRAIN = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Args(object):\n",
    "    model_type = \"medium-tied\"\n",
    "    seq_length = 35\n",
    "    batch_size = 20\n",
    "    lr = 1.0\n",
    "    lr_decay = 1.2\n",
    "    lr_epoch = 6\n",
    "    max_grad_norm = 5.0\n",
    "    hidden_size = 650\n",
    "    weight_decay = 1.0e-7\n",
    "    drop_rate_i = 0.35  # dropout rate for lstm x\n",
    "    drop_rate_o = 0.35  # dropout rate for dense\n",
    "    drop_rate_x = 0.20  # dropout rate for embeddings\n",
    "    drop_rate_h = 0.20  # dropout rate for lstm h\n",
    "    implementation = 2  # tied\n",
    "    max_epochs = 39\n",
    "    save_model = \"model/{}.h5\".format(model_type)\n",
    "    \n",
    "# class Args(object):\n",
    "#     model_type = \"medium-untied\"\n",
    "#     seq_length = 35\n",
    "#     batch_size = 20\n",
    "#     lr = 1.0\n",
    "#     lr_decay = 1.2\n",
    "#     lr_epoch = 6\n",
    "#     max_grad_norm = 5.0\n",
    "#     hidden_size = 650\n",
    "#     weight_decay = 1.0e-7\n",
    "#     drop_rate_i = 0.35  # dropout rate for lstm x\n",
    "#     drop_rate_o = 0.35  # dropout rate for dense\n",
    "#     drop_rate_x = 0.20  # dropout rate for embeddings\n",
    "#     drop_rate_h = 0.20  # dropout rate for lstm h\n",
    "#     implementation = 1  # untied\n",
    "#     max_epochs = 39    \n",
    "#     save_model = \"model/{}.h5\".format(model_type)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EmbeddingDropout(keras.layers.Embedding):\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim,\n",
    "        output_dim,\n",
    "        rate,\n",
    "        embeddings_initializer=\"uniform\",\n",
    "        embeddings_regularizer=None,\n",
    "        activity_regularizer=None,\n",
    "        embeddings_constraint=None,\n",
    "        mask_zero=False,\n",
    "        input_length=None,\n",
    "        seed=None,\n",
    "        **kwargs\n",
    "    ):\n",
    "        super(EmbeddingDropout, self).__init__(\n",
    "            input_dim,\n",
    "            output_dim,\n",
    "            embeddings_initializer=embeddings_initializer,\n",
    "            embeddings_regularizer=embeddings_regularizer,\n",
    "            activity_regularizer=activity_regularizer,\n",
    "            embeddings_constraint=embeddings_constraint,\n",
    "            mask_zero=mask_zero,\n",
    "            input_length=input_length,\n",
    "            **kwargs\n",
    "        )\n",
    "        self.rate = min(1.0, max(0.0, rate))\n",
    "        self.seed = seed\n",
    "\n",
    "    def call(self, inputs, training=None):\n",
    "        if training is None:\n",
    "            training = K.learning_phase()\n",
    "\n",
    "        retain_prob = 1.0 - self.rate\n",
    "\n",
    "        dtype = K.dtype(inputs)\n",
    "        if dtype != \"int32\" and dtype != \"int64\":\n",
    "            inputs = tf.cast(inputs, \"int32\")\n",
    "\n",
    "        def original_embedding():\n",
    "            def in_train():\n",
    "                binary = K.random_binomial(\n",
    "                    (self.input_dim, 1),\n",
    "                    p=retain_prob,\n",
    "                    dtype=K.dtype(self.embeddings),\n",
    "                    seed=self.seed,\n",
    "                )\n",
    "                dropped_embeddings = binary * self.embeddings / retain_prob\n",
    "                out = tf.nn.embedding_lookup(dropped_embeddings, inputs)\n",
    "                return out\n",
    "\n",
    "            def in_test():\n",
    "                out = tf.nn.embedding_lookup(self.embeddings, inputs)\n",
    "                return out\n",
    "\n",
    "            return tf_utils.smart_cond(training, in_train, in_test)\n",
    "\n",
    "        return original_embedding()\n",
    "\n",
    "    def get_config(self):\n",
    "        config = {\"rate\": self.rate, \"seed\": self.seed}\n",
    "        base_config = super(EmbeddingDropout, self).get_config()\n",
    "        return dict(list(base_config.items()) + list(config.items()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cross_sentence_preproc(text):\n",
    "    return text.replace(\"\\n\", \" <eos> \")\n",
    "\n",
    "\n",
    "def tokenize_txt_cross_sentence(textfile):\n",
    "    # open the file as read only\n",
    "    file = open(textfile, \"r\", encoding=\"UTF-8\")\n",
    "    # read all text\n",
    "    text = file.read()\n",
    "    file.close()\n",
    "    print(\"\\nSample of original txt:\\n\\n\", text[:300])\n",
    "\n",
    "    # run text preprocessing\n",
    "    text_proc = cross_sentence_preproc(text)\n",
    "    print(\"\\nSample of processed txt:\\n\\n\", text_proc[:300])\n",
    "    print(\"\\nTotal tokens in text: %d\" % len(text_proc.split()))\n",
    "    print(\"Unique tokens in text: %d\" % len(set(text_proc.split())))\n",
    "\n",
    "    # fit tokenizer\n",
    "    tokenizer = keras.preprocessing.text.Tokenizer(filters=\"\", lower=False)\n",
    "    tokenizer.fit_on_texts([text_proc])\n",
    "    # saving tokenizer\n",
    "    with open(\n",
    "        \"data/cs-{0}.pickle\".format(os.path.splitext(os.path.basename(textfile))[0]), \"wb\",\n",
    "    ) as handle:\n",
    "        pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "\n",
    "\n",
    "def txt_to_tensor_cross_sent_LSTM(textfile, tokenizer, batch_size=20, seq_length=35):\n",
    "    PAD_VALUE = 0\n",
    "\n",
    "    # open the file as read only\n",
    "    file = open(textfile, \"r\", encoding=\"UTF-8\")\n",
    "\n",
    "    # read all text\n",
    "    text = file.read()\n",
    "    file.close()\n",
    "    print(\"\\nSample of original txt:\\n\\n\", text[:300])\n",
    "\n",
    "    # run text preprocessing\n",
    "    text_proc = cross_sentence_preproc(text)\n",
    "    print(\"\\nSample of processed txt:\\n\\n\", text_proc[:300])\n",
    "    print(\"\\nTotal tokens in text: %d\" % len(text_proc.split()))\n",
    "    print(\"Unique tokens in text: %d\" % len(set(text_proc.split())))\n",
    "\n",
    "    # load tokenizer\n",
    "    with open(tokenizer, \"rb\") as handle:\n",
    "        tokenizer = pickle.load(handle)\n",
    "\n",
    "    # vocabulary size\n",
    "    vocab_size = len(tokenizer.word_index) + 1\n",
    "    print(\"Vocab size: %d\" % vocab_size)\n",
    "\n",
    "    # coding text\n",
    "    text_coded = tokenizer.texts_to_sequences([text_proc])\n",
    "    text_coded_len = len(text_coded[0])\n",
    "    print(\"Coded text length:\", text_coded_len)\n",
    "\n",
    "    # pad according to current batch size and seq length (stateful training requirement!)\n",
    "    padding_length = (\n",
    "        batch_size * seq_length * ((text_coded_len // (batch_size * seq_length)) + 1)\n",
    "    )\n",
    "    input_array = keras.preprocessing.sequence.pad_sequences(\n",
    "        text_coded, padding=\"post\", maxlen=padding_length, value=PAD_VALUE\n",
    "    )[0, :]\n",
    "    print(\"Padded input array shape:\", input_array.shape)\n",
    "\n",
    "    # creat target array from input array\n",
    "    target_array = input_array.copy()\n",
    "    target_array[0:-1] = input_array[1:]\n",
    "    target_array[-1] = PAD_VALUE  # value for padding\n",
    "\n",
    "    # reshaping input and target array to fit stateful training\n",
    "    # reshaping according to batch_size\n",
    "    input_array = input_array.reshape((batch_size, -1))\n",
    "    target_array = target_array.reshape((batch_size, -1))\n",
    "    # creating list of batches (link: ...)\n",
    "    x_batches = np.split(input_array, input_array.shape[1] // seq_length, axis=1)\n",
    "    y_batches = np.split(target_array, target_array.shape[1] // seq_length, axis=1)\n",
    "    assert len(x_batches) == len(y_batches)\n",
    "\n",
    "    # concatenting list of batches (fit instead of fit generator)\n",
    "    X = np.concatenate(x_batches)\n",
    "    y = np.concatenate(y_batches)\n",
    "    # additional rank for y array (Keras requirement)\n",
    "    y = y.reshape(y.shape[0], y.shape[1], 1)\n",
    "\n",
    "    print(\"Input tensor shape:\", X.shape)\n",
    "    print(\"Target tensor shape:\", y.shape)\n",
    "\n",
    "    return X, y, vocab_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_bayes_model(\n",
    "    input_shape,\n",
    "    batch_size,\n",
    "    vocab_size,\n",
    "    hidden_size,\n",
    "    weight_decay,\n",
    "    drop_rate_i,\n",
    "    drop_rate_o,\n",
    "    drop_rate_x,\n",
    "    drop_rate_h,\n",
    "    implementation,\n",
    "):\n",
    "    inputs = keras.Input(shape=input_shape, batch_size=batch_size)\n",
    "    embedding = EmbeddingDropout(\n",
    "        vocab_size,\n",
    "        hidden_size,\n",
    "        drop_rate_x,\n",
    "        embeddings_regularizer=keras.regularizers.l2(weight_decay),\n",
    "    )\n",
    "    lstm1 = layers.LSTM(\n",
    "        hidden_size,\n",
    "        activation=\"tanh\",\n",
    "        recurrent_activation=\"sigmoid\",\n",
    "        dropout=drop_rate_i,\n",
    "        recurrent_dropout=drop_rate_h,\n",
    "        kernel_regularizer=keras.regularizers.l2(weight_decay),\n",
    "        recurrent_regularizer=keras.regularizers.l2(weight_decay),\n",
    "        return_sequences=True,\n",
    "        stateful=True,\n",
    "        implementation=implementation,\n",
    "    )\n",
    "    lstm2 = layers.LSTM(\n",
    "        hidden_size,\n",
    "        activation=\"tanh\",\n",
    "        recurrent_activation=\"sigmoid\",\n",
    "        dropout=drop_rate_i,\n",
    "        recurrent_dropout=drop_rate_h,\n",
    "        kernel_regularizer=keras.regularizers.l2(weight_decay),\n",
    "        recurrent_regularizer=keras.regularizers.l2(weight_decay),\n",
    "        return_sequences=True,\n",
    "        stateful=True,\n",
    "        implementation=implementation,\n",
    "    )\n",
    "    drop = layers.Dropout(rate=drop_rate_o)\n",
    "    dense = layers.Dense(\n",
    "        vocab_size,\n",
    "        kernel_regularizer=keras.regularizers.l2(weight_decay),\n",
    "        activation=\"softmax\",\n",
    "    )\n",
    "    dist = layers.TimeDistributed(dense)\n",
    "\n",
    "    x = embedding(inputs)\n",
    "    x = lstm1(x)\n",
    "    x = lstm2(x)\n",
    "    x = drop(x)\n",
    "    output = dist(x)\n",
    "\n",
    "    model = keras.Model(inputs=inputs, outputs=output)\n",
    "\n",
    "    return model\n",
    "\n",
    "\n",
    "class gSGD(keras.optimizers.SGD):\n",
    "    def __init__(\n",
    "        self,\n",
    "        learning_rate=0.01,\n",
    "        momentum=0.0,\n",
    "        nesterov=False,\n",
    "        clipglobalnorm=5.0,\n",
    "        name=\"SGD\",\n",
    "        **kwargs\n",
    "    ):\n",
    "        super(gSGD, self).__init__(\n",
    "            learning_rate=learning_rate,\n",
    "            momentum=momentum,\n",
    "            nesterov=nesterov,\n",
    "            name=name,\n",
    "            **kwargs\n",
    "        )\n",
    "        self.clipglobalnorm = clipglobalnorm\n",
    "\n",
    "    def apply_gradients(self, grads_and_vars, name=None):\n",
    "        \"\"\"Apply gradients to variables.\n",
    "        This is the second part of `minimize()`. It returns an `Operation` that\n",
    "        applies gradients.\n",
    "        Args:\n",
    "        grads_and_vars: List of (gradient, variable) pairs.\n",
    "        name: Optional name for the returned operation.  Default to the name\n",
    "            passed to the `Optimizer` constructor.\n",
    "        Returns:\n",
    "        An `Operation` that applies the specified gradients. The `iterations`\n",
    "        will be automatically increased by 1.\n",
    "        Raises:\n",
    "        TypeError: If `grads_and_vars` is malformed.\n",
    "        ValueError: If none of the variables have gradients.\n",
    "        \"\"\"\n",
    "        import functools\n",
    "        from tensorflow.python.keras.optimizer_v2.optimizer_v2 import _filter_grads\n",
    "        from tensorflow.python.distribute import (\n",
    "            distribution_strategy_context as distribute_ctx,\n",
    "        )\n",
    "\n",
    "        grads_and_vars = _filter_grads(grads_and_vars)\n",
    "        var_list = [v for (_, v) in grads_and_vars]\n",
    "        grads = [g for (g, _) in grads_and_vars]\n",
    "        grads, _ = tf.clip_by_global_norm(grads, self.clipglobalnorm)\n",
    "        grads_and_vars = list(zip(grads, var_list))\n",
    "\n",
    "        with K.name_scope(self._name):\n",
    "            # Create iteration if necessary.\n",
    "            with tf.init_scope():\n",
    "                _ = self.iterations\n",
    "                self._create_hypers()\n",
    "                self._create_slots(var_list)\n",
    "\n",
    "            if not grads_and_vars:\n",
    "                # Distribution strategy does not support reducing an empty list of\n",
    "                # gradients\n",
    "                return tf.no_op()\n",
    "            apply_state = self._prepare(var_list)\n",
    "            return distribute_ctx.get_replica_context().merge_call(\n",
    "                functools.partial(self._distributed_apply, apply_state=apply_state),\n",
    "                args=(grads_and_vars,),\n",
    "                kwargs={\"name\": name},\n",
    "            )\n",
    "\n",
    "\n",
    "def sparse_categorical_crossentropy_pad0(y_true, y_pred):\n",
    "    loss = keras.losses.sparse_categorical_crossentropy(y_true, y_pred)\n",
    "\n",
    "    y_true = K.cast(y_true, tf.int64)\n",
    "    not_pad_mask = K.greater(y_true, 0)\n",
    "    not_pad_mask = K.cast(not_pad_mask, K.dtype(loss))\n",
    "    not_pad_mask = K.reshape(not_pad_mask, K.shape(loss))\n",
    "    loss = not_pad_mask * loss\n",
    "\n",
    "    return K.sum(K.mean(loss, axis=0))\n",
    "\n",
    "\n",
    "class SparseCategoricalCrossentropyPad0(keras.metrics.Metric):\n",
    "    def __init__(self, name=\"SparseCategoricalCrossentropyPad0\", **kwargs):\n",
    "        super(SparseCategoricalCrossentropyPad0, self).__init__(name=name, **kwargs)\n",
    "        self.loss = self.add_weight(name=\"loss\", initializer=\"zeros\")\n",
    "        self.num_not_pad = self.add_weight(name=\"num_not_pad\", initializer=\"zeros\")\n",
    "\n",
    "    def update_state(self, y_true, y_pred, sample_weight=None):\n",
    "        loss = keras.losses.sparse_categorical_crossentropy(y_true, y_pred)\n",
    "\n",
    "        y_true = K.cast(y_true, tf.int64)\n",
    "        not_pad_mask = K.greater(y_true, 0)\n",
    "        not_pad_mask = K.cast(not_pad_mask, K.dtype(loss))\n",
    "        not_pad_mask = K.reshape(not_pad_mask, K.shape(loss))\n",
    "        loss = not_pad_mask * loss\n",
    "\n",
    "        self.loss.assign_add(K.sum(loss))\n",
    "        self.num_not_pad.assign_add(K.sum(not_pad_mask))\n",
    "\n",
    "    def result(self):\n",
    "        return self.loss / K.maximum(1.0, self.num_not_pad)\n",
    "\n",
    "    def reset_states(self):\n",
    "        # The state of the metric will be reset at the start of each epoch.\n",
    "        self.loss.assign(0.0)\n",
    "        self.num_not_pad.assign(0.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Args()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Sample of original txt:\n",
      "\n",
      "  aer banknote berlitz calloway centrust cluett fromstein gitano guterman hydro-quebec ipo kia memotec mlx nahb punts rake regatta rubens sim snack-food ssangyong swapo wachter \n",
      " pierre <unk> N years old will join the board as a nonexecutive director nov. N \n",
      " mr. <unk> is chairman of <unk> n.v. the d\n",
      "\n",
      "Sample of processed txt:\n",
      "\n",
      "  aer banknote berlitz calloway centrust cluett fromstein gitano guterman hydro-quebec ipo kia memotec mlx nahb punts rake regatta rubens sim snack-food ssangyong swapo wachter  <eos>  pierre <unk> N years old will join the board as a nonexecutive director nov. N  <eos>  mr. <unk> is chairman of <unk\n",
      "\n",
      "Total tokens in text: 929589\n",
      "Unique tokens in text: 10000\n",
      "\n",
      "Sample of original txt:\n",
      "\n",
      "  aer banknote berlitz calloway centrust cluett fromstein gitano guterman hydro-quebec ipo kia memotec mlx nahb punts rake regatta rubens sim snack-food ssangyong swapo wachter \n",
      " pierre <unk> N years old will join the board as a nonexecutive director nov. N \n",
      " mr. <unk> is chairman of <unk> n.v. the d\n",
      "\n",
      "Sample of processed txt:\n",
      "\n",
      "  aer banknote berlitz calloway centrust cluett fromstein gitano guterman hydro-quebec ipo kia memotec mlx nahb punts rake regatta rubens sim snack-food ssangyong swapo wachter  <eos>  pierre <unk> N years old will join the board as a nonexecutive director nov. N  <eos>  mr. <unk> is chairman of <unk\n",
      "\n",
      "Total tokens in text: 929589\n",
      "Unique tokens in text: 10000\n",
      "Vocab size: 10001\n",
      "Coded text length: 929589\n",
      "Padded input array shape: (929600,)\n",
      "Input tensor shape: (26560, 35)\n",
      "Target tensor shape: (26560, 35, 1)\n",
      "\n",
      "Sample of original txt:\n",
      "\n",
      "  consumers may want to move their telephones a little closer to the tv set \n",
      " <unk> <unk> watching abc 's monday night football can now vote during <unk> for the greatest play in N years from among four or five <unk> <unk> \n",
      " two weeks ago viewers of several nbc <unk> consumer segments started calling\n",
      "\n",
      "Sample of processed txt:\n",
      "\n",
      "  consumers may want to move their telephones a little closer to the tv set  <eos>  <unk> <unk> watching abc 's monday night football can now vote during <unk> for the greatest play in N years from among four or five <unk> <unk>  <eos>  two weeks ago viewers of several nbc <unk> consumer segments sta\n",
      "\n",
      "Total tokens in text: 73760\n",
      "Unique tokens in text: 6022\n",
      "Vocab size: 10001\n",
      "Coded text length: 73760\n",
      "Padded input array shape: (74200,)\n",
      "Input tensor shape: (2120, 35)\n",
      "Target tensor shape: (2120, 35, 1)\n",
      "\n",
      "Sample of original txt:\n",
      "\n",
      "  no it was n't black monday \n",
      " but while the new york stock exchange did n't fall apart friday as the dow jones industrial average plunged N points most of it in the final hour it barely managed to stay this side of chaos \n",
      " some circuit breakers installed after the october N crash failed their first \n",
      "\n",
      "Sample of processed txt:\n",
      "\n",
      "  no it was n't black monday  <eos>  but while the new york stock exchange did n't fall apart friday as the dow jones industrial average plunged N points most of it in the final hour it barely managed to stay this side of chaos  <eos>  some circuit breakers installed after the october N crash failed \n",
      "\n",
      "Total tokens in text: 82430\n",
      "Unique tokens in text: 6049\n",
      "Vocab size: 10001\n",
      "Coded text length: 82430\n",
      "Padded input array shape: (82600,)\n",
      "Input tensor shape: (2360, 35)\n",
      "Target tensor shape: (2360, 35, 1)\n"
     ]
    }
   ],
   "source": [
    "train_txt = \"data/ptb.train.txt\"\n",
    "val_txt = \"data/ptb.valid.txt\"\n",
    "test_txt = \"data/ptb.test.txt\"\n",
    "tokenize_txt_cross_sentence(train_txt)\n",
    "tokenizer = \"data/cs-ptb.train.pickle\"\n",
    "\n",
    "x_train, y_train, vocab_size = txt_to_tensor_cross_sent_LSTM(\n",
    "    train_txt, tokenizer, args.batch_size, args.seq_length\n",
    ")\n",
    "x_val, y_val, _ = txt_to_tensor_cross_sent_LSTM(\n",
    "    val_txt, tokenizer, args.batch_size, args.seq_length\n",
    ")\n",
    "x_test, y_test, _ = txt_to_tensor_cross_sent_LSTM(\n",
    "    test_txt, tokenizer, args.batch_size, args.seq_length\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = create_bayes_model(\n",
    "    (args.seq_length,),\n",
    "    args.batch_size,\n",
    "    vocab_size,\n",
    "    args.hidden_size,\n",
    "    args.weight_decay,\n",
    "    args.drop_rate_i,\n",
    "    args.drop_rate_o,\n",
    "    args.drop_rate_x,\n",
    "    args.drop_rate_h,\n",
    "    args.implementation,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/039\n",
      "learning_rate: 1.0\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 216.8371 - SparseCategoricalCrossentropyPad0: 6.1954 - sparse_categorical_accuracy: 0.1126\n",
      "val_loss: 5.447317123413086\n",
      "val_perplexity: 232.13453674316406\n",
      "Epoch: 002/039\n",
      "learning_rate: 1.0\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 191.2524 - SparseCategoricalCrossentropyPad0: 5.4643 - sparse_categorical_accuracy: 0.1657\n",
      "val_loss: 5.1063032150268555\n",
      "val_perplexity: 165.0590362548828\n",
      "Epoch: 003/039\n",
      "learning_rate: 1.0\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 182.5297 - SparseCategoricalCrossentropyPad0: 5.2150 - sparse_categorical_accuracy: 0.1844\n",
      "val_loss: 4.922767162322998\n",
      "val_perplexity: 137.3822479248047\n",
      "Epoch: 004/039\n",
      "learning_rate: 1.0\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 176.7400 - SparseCategoricalCrossentropyPad0: 5.0495 - sparse_categorical_accuracy: 0.1964\n",
      "val_loss: 4.803992748260498\n",
      "val_perplexity: 121.99655151367188\n",
      "Epoch: 005/039\n",
      "learning_rate: 1.0\n",
      "Train on 26560 samples\n",
      "26560/26560 - 210s - loss: 172.3135 - SparseCategoricalCrossentropyPad0: 4.9229 - sparse_categorical_accuracy: 0.2052\n",
      "val_loss: 4.722590446472168\n",
      "val_perplexity: 112.45919799804688\n",
      "Epoch: 006/039\n",
      "learning_rate: 1.0\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 169.0647 - SparseCategoricalCrossentropyPad0: 4.8300 - sparse_categorical_accuracy: 0.2112\n",
      "val_loss: 4.671114444732666\n",
      "val_perplexity: 106.81671905517578\n",
      "Epoch: 007/039\n",
      "learning_rate: 0.8333333134651184\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 165.1531 - SparseCategoricalCrossentropyPad0: 4.7181 - sparse_categorical_accuracy: 0.2203\n",
      "val_loss: 4.600442409515381\n",
      "val_perplexity: 99.52833557128906\n",
      "Epoch: 008/039\n",
      "learning_rate: 0.6944444179534912\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 161.9577 - SparseCategoricalCrossentropyPad0: 4.6268 - sparse_categorical_accuracy: 0.2274\n",
      "val_loss: 4.553650379180908\n",
      "val_perplexity: 94.97848510742188\n",
      "Epoch: 009/039\n",
      "learning_rate: 0.5787037014961243\n",
      "Train on 26560 samples\n",
      "26560/26560 - 210s - loss: 159.2770 - SparseCategoricalCrossentropyPad0: 4.5502 - sparse_categorical_accuracy: 0.2329\n",
      "val_loss: 4.526878356933594\n",
      "val_perplexity: 92.46945190429688\n",
      "Epoch: 010/039\n",
      "learning_rate: 0.4822530746459961\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 157.0754 - SparseCategoricalCrossentropyPad0: 4.4872 - sparse_categorical_accuracy: 0.2379\n",
      "val_loss: 4.505649566650391\n",
      "val_perplexity: 90.52713012695312\n",
      "Epoch: 011/039\n",
      "learning_rate: 0.4018775522708893\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 155.1796 - SparseCategoricalCrossentropyPad0: 4.4331 - sparse_categorical_accuracy: 0.2419\n",
      "val_loss: 4.48442268371582\n",
      "val_perplexity: 88.62577056884766\n",
      "Epoch: 012/039\n",
      "learning_rate: 0.3348979651927948\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 153.7168 - SparseCategoricalCrossentropyPad0: 4.3912 - sparse_categorical_accuracy: 0.2454\n",
      "val_loss: 4.47651481628418\n",
      "val_perplexity: 87.92769622802734\n",
      "Epoch: 013/039\n",
      "learning_rate: 0.27908164262771606\n",
      "Train on 26560 samples\n",
      "26560/26560 - 210s - loss: 152.2055 - SparseCategoricalCrossentropyPad0: 4.3480 - sparse_categorical_accuracy: 0.2488\n",
      "val_loss: 4.466259479522705\n",
      "val_perplexity: 87.03057098388672\n",
      "Epoch: 014/039\n",
      "learning_rate: 0.23256804049015045\n",
      "Train on 26560 samples\n",
      "26560/26560 - 210s - loss: 151.1743 - SparseCategoricalCrossentropyPad0: 4.3186 - sparse_categorical_accuracy: 0.2506\n",
      "val_loss: 4.456142425537109\n",
      "val_perplexity: 86.1545181274414\n",
      "Epoch: 015/039\n",
      "learning_rate: 0.1938067078590393\n",
      "Train on 26560 samples\n",
      "26560/26560 - 210s - loss: 150.1184 - SparseCategoricalCrossentropyPad0: 4.2884 - sparse_categorical_accuracy: 0.2537\n",
      "val_loss: 4.445494174957275\n",
      "val_perplexity: 85.24198913574219\n",
      "Epoch: 016/039\n",
      "learning_rate: 0.1615055948495865\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 149.3523 - SparseCategoricalCrossentropyPad0: 4.2665 - sparse_categorical_accuracy: 0.2555\n",
      "val_loss: 4.441019058227539\n",
      "val_perplexity: 84.86137390136719\n",
      "Epoch: 017/039\n",
      "learning_rate: 0.13458800315856934\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 148.7161 - SparseCategoricalCrossentropyPad0: 4.2483 - sparse_categorical_accuracy: 0.2567\n",
      "val_loss: 4.437992572784424\n",
      "val_perplexity: 84.60493469238281\n",
      "Epoch: 018/039\n",
      "learning_rate: 0.11215666681528091\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 148.1839 - SparseCategoricalCrossentropyPad0: 4.2331 - sparse_categorical_accuracy: 0.2578\n",
      "val_loss: 4.432490825653076\n",
      "val_perplexity: 84.14073944091797\n",
      "Epoch: 019/039\n",
      "learning_rate: 0.09346389025449753\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 147.4874 - SparseCategoricalCrossentropyPad0: 4.2132 - sparse_categorical_accuracy: 0.2600\n",
      "val_loss: 4.428966522216797\n",
      "val_perplexity: 83.84471893310547\n",
      "Epoch: 020/039\n",
      "learning_rate: 0.07788657397031784\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 147.1315 - SparseCategoricalCrossentropyPad0: 4.2030 - sparse_categorical_accuracy: 0.2600\n",
      "val_loss: 4.428220272064209\n",
      "val_perplexity: 83.78217315673828\n",
      "Epoch: 021/039\n",
      "learning_rate: 0.06490547955036163\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 146.8169 - SparseCategoricalCrossentropyPad0: 4.1941 - sparse_categorical_accuracy: 0.2606\n",
      "val_loss: 4.427041053771973\n",
      "val_perplexity: 83.68343353271484\n",
      "Epoch: 022/039\n",
      "learning_rate: 0.05408789962530136\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 146.6104 - SparseCategoricalCrossentropyPad0: 4.1882 - sparse_categorical_accuracy: 0.2611\n",
      "val_loss: 4.4236297607421875\n",
      "val_perplexity: 83.39845275878906\n",
      "Epoch: 023/039\n",
      "learning_rate: 0.0450732484459877\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 146.3768 - SparseCategoricalCrossentropyPad0: 4.1815 - sparse_categorical_accuracy: 0.2621\n",
      "val_loss: 4.422729969024658\n",
      "val_perplexity: 83.32344818115234\n",
      "Epoch: 024/039\n",
      "learning_rate: 0.03756104037165642\n",
      "Train on 26560 samples\n",
      "26560/26560 - 210s - loss: 146.0927 - SparseCategoricalCrossentropyPad0: 4.1734 - sparse_categorical_accuracy: 0.2627\n",
      "val_loss: 4.421693801879883\n",
      "val_perplexity: 83.23715209960938\n",
      "Epoch: 025/039\n",
      "learning_rate: 0.0313008651137352\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 145.9399 - SparseCategoricalCrossentropyPad0: 4.1690 - sparse_categorical_accuracy: 0.2629\n",
      "val_loss: 4.4203057289123535\n",
      "val_perplexity: 83.12169647216797\n",
      "Epoch: 026/039\n",
      "learning_rate: 0.026084054261446\n",
      "Train on 26560 samples\n",
      "26560/26560 - 210s - loss: 145.8594 - SparseCategoricalCrossentropyPad0: 4.1667 - sparse_categorical_accuracy: 0.2631\n",
      "val_loss: 4.418887138366699\n",
      "val_perplexity: 83.00386047363281\n",
      "Epoch: 027/039\n",
      "learning_rate: 0.021736711263656616\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 145.6604 - SparseCategoricalCrossentropyPad0: 4.1610 - sparse_categorical_accuracy: 0.2636\n",
      "val_loss: 4.4189839363098145\n",
      "val_perplexity: 83.01189422607422\n",
      "Epoch: 028/039\n",
      "learning_rate: 0.01811392605304718\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 145.5279 - SparseCategoricalCrossentropyPad0: 4.1572 - sparse_categorical_accuracy: 0.2643\n",
      "val_loss: 4.417514801025391\n",
      "val_perplexity: 82.89002990722656\n",
      "Epoch: 029/039\n",
      "learning_rate: 0.015094938687980175\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 145.6436 - SparseCategoricalCrossentropyPad0: 4.1605 - sparse_categorical_accuracy: 0.2627\n",
      "val_loss: 4.416733264923096\n",
      "val_perplexity: 82.82527160644531\n",
      "Epoch: 030/039\n",
      "learning_rate: 0.012579115107655525\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 145.3596 - SparseCategoricalCrossentropyPad0: 4.1524 - sparse_categorical_accuracy: 0.2642\n",
      "val_loss: 4.415800094604492\n",
      "val_perplexity: 82.7480239868164\n",
      "Epoch: 031/039\n",
      "learning_rate: 0.01048259623348713\n",
      "Train on 26560 samples\n",
      "26560/26560 - 210s - loss: 145.3574 - SparseCategoricalCrossentropyPad0: 4.1523 - sparse_categorical_accuracy: 0.2646\n",
      "val_loss: 4.416037082672119\n",
      "val_perplexity: 82.76763153076172\n",
      "Epoch: 032/039\n",
      "learning_rate: 0.008735496550798416\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 145.1933 - SparseCategoricalCrossentropyPad0: 4.1477 - sparse_categorical_accuracy: 0.2654\n",
      "val_loss: 4.415794849395752\n",
      "val_perplexity: 82.74758911132812\n",
      "Epoch: 033/039\n",
      "learning_rate: 0.00727958045899868\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 145.3494 - SparseCategoricalCrossentropyPad0: 4.1521 - sparse_categorical_accuracy: 0.2645\n",
      "val_loss: 4.41514253616333\n",
      "val_perplexity: 82.6936264038086\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 034/039\n",
      "learning_rate: 0.006066317204385996\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 145.1414 - SparseCategoricalCrossentropyPad0: 4.1462 - sparse_categorical_accuracy: 0.2651\n",
      "val_loss: 4.415287971496582\n",
      "val_perplexity: 82.70565795898438\n",
      "Epoch: 035/039\n",
      "learning_rate: 0.0050552645698189735\n",
      "Train on 26560 samples\n",
      "26560/26560 - 210s - loss: 145.2493 - SparseCategoricalCrossentropyPad0: 4.1493 - sparse_categorical_accuracy: 0.2647\n",
      "val_loss: 4.415131568908691\n",
      "val_perplexity: 82.69271850585938\n",
      "Epoch: 036/039\n",
      "learning_rate: 0.0042127203196287155\n",
      "Train on 26560 samples\n",
      "26560/26560 - 210s - loss: 145.1941 - SparseCategoricalCrossentropyPad0: 4.1477 - sparse_categorical_accuracy: 0.2650\n",
      "val_loss: 4.414928436279297\n",
      "val_perplexity: 82.6759262084961\n",
      "Epoch: 037/039\n",
      "learning_rate: 0.0035106001887470484\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 145.2218 - SparseCategoricalCrossentropyPad0: 4.1485 - sparse_categorical_accuracy: 0.2646\n",
      "val_loss: 4.414916038513184\n",
      "val_perplexity: 82.67489624023438\n",
      "Epoch: 038/039\n",
      "learning_rate: 0.002925500273704529\n",
      "Train on 26560 samples\n",
      "26560/26560 - 211s - loss: 145.0890 - SparseCategoricalCrossentropyPad0: 4.1447 - sparse_categorical_accuracy: 0.2653\n",
      "val_loss: 4.4146857261657715\n",
      "val_perplexity: 82.6558609008789\n",
      "Epoch: 039/039\n",
      "learning_rate: 0.0024379168171435595\n",
      "Train on 26560 samples\n",
      "26560/26560 - 210s - loss: 145.2307 - SparseCategoricalCrossentropyPad0: 4.1487 - sparse_categorical_accuracy: 0.2642\n",
      "val_loss: 4.414400100708008\n",
      "val_perplexity: 82.63225555419922\n"
     ]
    }
   ],
   "source": [
    "if not SKIP_TRAIN:\n",
    "    optimizer = gSGD(learning_rate=args.lr, clipglobalnorm=args.max_grad_norm)\n",
    "    model.compile(\n",
    "        loss=sparse_categorical_crossentropy_pad0,\n",
    "        optimizer=optimizer,\n",
    "        metrics=[SparseCategoricalCrossentropyPad0(), \"sparse_categorical_accuracy\"],\n",
    "    )\n",
    "\n",
    "    val_losses = []\n",
    "    for epoch in range(args.max_epochs):\n",
    "        print(\"Epoch: {0:0=3}/{1:0=3}\".format(epoch + 1, args.max_epochs))\n",
    "        if epoch >= args.lr_epoch:\n",
    "            learning_rate = K.get_value(optimizer.learning_rate) / args.lr_decay\n",
    "            K.set_value(optimizer.learning_rate, learning_rate)\n",
    "        print(\"learning_rate: {0}\".format(K.get_value(optimizer.learning_rate)))\n",
    "        model.reset_states()\n",
    "        model.fit(\n",
    "            x_train,\n",
    "            y_train,\n",
    "            batch_size=args.batch_size,\n",
    "            epochs=1,\n",
    "            shuffle=False,\n",
    "            verbose=2,\n",
    "        )\n",
    "        model.reset_states()\n",
    "        results = model.evaluate(x_val, y_val, batch_size=args.batch_size, verbose=0)\n",
    "        loss = results[1]\n",
    "        perplexity = np.exp(loss)\n",
    "        val_losses.append(loss)\n",
    "        print(\"val_loss: {0}\".format(loss))\n",
    "        print(\"val_perplexity: {0}\".format(perplexity))\n",
    "\n",
    "        if epoch == np.argmin(val_losses):\n",
    "            model.save(args.save_model, include_optimizer=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.\n"
     ]
    }
   ],
   "source": [
    "custom_objects = {\"EmbeddingDropout\": EmbeddingDropout}\n",
    "model = keras.models.load_model(args.save_model, custom_objects=custom_objects)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<svg height=\"470pt\" viewBox=\"0.00 0.00 453.00 470.00\" width=\"453pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 466)\">\n",
       "<title>G</title>\n",
       "<polygon fill=\"white\" points=\"-4,4 -4,-466 449,-466 449,4 -4,4\" stroke=\"none\"/>\n",
       "<!-- 2691269049368 -->\n",
       "<g class=\"node\" id=\"node1\"><title>2691269049368</title>\n",
       "<polygon fill=\"none\" points=\"97,-415.5 97,-461.5 348,-461.5 348,-415.5 97,-415.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"160\" y=\"-434.8\">input_1: InputLayer</text>\n",
       "<polyline fill=\"none\" points=\"223,-415.5 223,-461.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"251\" y=\"-446.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"223,-438.5 279,-438.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"251\" y=\"-423.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"279,-415.5 279,-461.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"313.5\" y=\"-446.3\">[(20, 35)]</text>\n",
       "<polyline fill=\"none\" points=\"279,-438.5 348,-438.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"313.5\" y=\"-423.3\">[(20, 35)]</text>\n",
       "</g>\n",
       "<!-- 2691269049312 -->\n",
       "<g class=\"node\" id=\"node2\"><title>2691269049312</title>\n",
       "<polygon fill=\"none\" points=\"26.5,-332.5 26.5,-378.5 418.5,-378.5 418.5,-332.5 26.5,-332.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"150.5\" y=\"-351.8\">embedding_dropout: EmbeddingDropout</text>\n",
       "<polyline fill=\"none\" points=\"274.5,-332.5 274.5,-378.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"302.5\" y=\"-363.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"274.5,-355.5 330.5,-355.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"302.5\" y=\"-340.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"330.5,-332.5 330.5,-378.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"374.5\" y=\"-363.3\">(20, 35)</text>\n",
       "<polyline fill=\"none\" points=\"330.5,-355.5 418.5,-355.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"374.5\" y=\"-340.3\">(20, 35, 650)</text>\n",
       "</g>\n",
       "<!-- 2691269049368&#45;&gt;2691269049312 -->\n",
       "<g class=\"edge\" id=\"edge1\"><title>2691269049368-&gt;2691269049312</title>\n",
       "<path d=\"M222.5,-415.366C222.5,-407.152 222.5,-397.658 222.5,-388.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"226,-388.607 222.5,-378.607 219,-388.607 226,-388.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 2691269048864 -->\n",
       "<g class=\"node\" id=\"node3\"><title>2691269048864</title>\n",
       "<polygon fill=\"none\" points=\"108.5,-249.5 108.5,-295.5 336.5,-295.5 336.5,-249.5 108.5,-249.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"150.5\" y=\"-268.8\">lstm: LSTM</text>\n",
       "<polyline fill=\"none\" points=\"192.5,-249.5 192.5,-295.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"220.5\" y=\"-280.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"192.5,-272.5 248.5,-272.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"220.5\" y=\"-257.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"248.5,-249.5 248.5,-295.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"292.5\" y=\"-280.3\">(20, 35, 650)</text>\n",
       "<polyline fill=\"none\" points=\"248.5,-272.5 336.5,-272.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"292.5\" y=\"-257.3\">(20, 35, 650)</text>\n",
       "</g>\n",
       "<!-- 2691269049312&#45;&gt;2691269048864 -->\n",
       "<g class=\"edge\" id=\"edge2\"><title>2691269049312-&gt;2691269048864</title>\n",
       "<path d=\"M222.5,-332.366C222.5,-324.152 222.5,-314.658 222.5,-305.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"226,-305.607 222.5,-295.607 219,-305.607 226,-305.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 2691285999800 -->\n",
       "<g class=\"node\" id=\"node4\"><title>2691285999800</title>\n",
       "<polygon fill=\"none\" points=\"101.5,-166.5 101.5,-212.5 343.5,-212.5 343.5,-166.5 101.5,-166.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"150.5\" y=\"-185.8\">lstm_1: LSTM</text>\n",
       "<polyline fill=\"none\" points=\"199.5,-166.5 199.5,-212.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"227.5\" y=\"-197.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"199.5,-189.5 255.5,-189.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"227.5\" y=\"-174.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"255.5,-166.5 255.5,-212.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"299.5\" y=\"-197.3\">(20, 35, 650)</text>\n",
       "<polyline fill=\"none\" points=\"255.5,-189.5 343.5,-189.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"299.5\" y=\"-174.3\">(20, 35, 650)</text>\n",
       "</g>\n",
       "<!-- 2691269048864&#45;&gt;2691285999800 -->\n",
       "<g class=\"edge\" id=\"edge3\"><title>2691269048864-&gt;2691285999800</title>\n",
       "<path d=\"M222.5,-249.366C222.5,-241.152 222.5,-231.658 222.5,-222.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"226,-222.607 222.5,-212.607 219,-222.607 226,-222.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 2691286001312 -->\n",
       "<g class=\"node\" id=\"node5\"><title>2691286001312</title>\n",
       "<polygon fill=\"none\" points=\"92.5,-83.5 92.5,-129.5 352.5,-129.5 352.5,-83.5 92.5,-83.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"150.5\" y=\"-102.8\">dropout: Dropout</text>\n",
       "<polyline fill=\"none\" points=\"208.5,-83.5 208.5,-129.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"236.5\" y=\"-114.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"208.5,-106.5 264.5,-106.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"236.5\" y=\"-91.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"264.5,-83.5 264.5,-129.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"308.5\" y=\"-114.3\">(20, 35, 650)</text>\n",
       "<polyline fill=\"none\" points=\"264.5,-106.5 352.5,-106.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"308.5\" y=\"-91.3\">(20, 35, 650)</text>\n",
       "</g>\n",
       "<!-- 2691285999800&#45;&gt;2691286001312 -->\n",
       "<g class=\"edge\" id=\"edge4\"><title>2691285999800-&gt;2691286001312</title>\n",
       "<path d=\"M222.5,-166.366C222.5,-158.152 222.5,-148.658 222.5,-139.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"226,-139.607 222.5,-129.607 219,-139.607 226,-139.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 2691269048752 -->\n",
       "<g class=\"node\" id=\"node6\"><title>2691269048752</title>\n",
       "<polygon fill=\"none\" points=\"0,-0.5 0,-46.5 445,-46.5 445,-0.5 0,-0.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"144\" y=\"-19.8\">time_distributed(dense): TimeDistributed(Dense)</text>\n",
       "<polyline fill=\"none\" points=\"288,-0.5 288,-46.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"316\" y=\"-31.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"288,-23.5 344,-23.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"316\" y=\"-8.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"344,-0.5 344,-46.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"394.5\" y=\"-31.3\">(20, 35, 650)</text>\n",
       "<polyline fill=\"none\" points=\"344,-23.5 445,-23.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"394.5\" y=\"-8.3\">(20, 35, 10001)</text>\n",
       "</g>\n",
       "<!-- 2691286001312&#45;&gt;2691269048752 -->\n",
       "<g class=\"edge\" id=\"edge5\"><title>2691286001312-&gt;2691269048752</title>\n",
       "<path d=\"M222.5,-83.3664C222.5,-75.1516 222.5,-65.6579 222.5,-56.7252\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"226,-56.6068 222.5,-46.6068 219,-56.6069 226,-56.6068\" stroke=\"black\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>"
      ],
      "text/plain": [
       "<IPython.core.display.SVG object>"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_to_dot(model, show_shapes=True).write_pdf(\"img/ptb_architecture-nn.pdf\")\n",
    "SVG(model_to_dot(model, show_shapes=True, dpi=72).create(prog='dot', format='svg'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<svg height=\"636pt\" viewBox=\"0.00 0.00 720.00 636.00\" width=\"720pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 632)\">\n",
       "<title>G</title>\n",
       "<polygon fill=\"white\" points=\"-4,4 -4,-632 716,-632 716,4 -4,4\" stroke=\"none\"/>\n",
       "<!-- 2691269049200 -->\n",
       "<g class=\"node\" id=\"node1\"><title>2691269049200</title>\n",
       "<polygon fill=\"none\" points=\"233,-581.5 233,-627.5 484,-627.5 484,-581.5 233,-581.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"296\" y=\"-600.8\">input_1: InputLayer</text>\n",
       "<polyline fill=\"none\" points=\"359,-581.5 359,-627.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"387\" y=\"-612.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"359,-604.5 415,-604.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"387\" y=\"-589.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"415,-581.5 415,-627.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"449.5\" y=\"-612.3\">[(20, 35)]</text>\n",
       "<polyline fill=\"none\" points=\"415,-604.5 484,-604.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"449.5\" y=\"-589.3\">[(20, 35)]</text>\n",
       "</g>\n",
       "<!-- 2692561563376 -->\n",
       "<g class=\"node\" id=\"node2\"><title>2692561563376</title>\n",
       "<polygon fill=\"none\" points=\"177,-498.5 177,-544.5 400,-544.5 400,-498.5 177,-498.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"230.5\" y=\"-517.8\">zero_1: Lambda</text>\n",
       "<polyline fill=\"none\" points=\"284,-498.5 284,-544.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"312\" y=\"-529.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"284,-521.5 340,-521.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"312\" y=\"-506.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"340,-498.5 340,-544.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"370\" y=\"-529.3\">(20, 35)</text>\n",
       "<polyline fill=\"none\" points=\"340,-521.5 400,-521.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"370\" y=\"-506.3\">(20, 35)</text>\n",
       "</g>\n",
       "<!-- 2691269049200&#45;&gt;2692561563376 -->\n",
       "<g class=\"edge\" id=\"edge1\"><title>2691269049200-&gt;2692561563376</title>\n",
       "<path d=\"M339.372,-581.366C331.656,-572.437 322.633,-561.997 314.347,-552.409\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"316.791,-549.884 307.605,-544.607 311.495,-554.462 316.791,-549.884\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 2691268907696 -->\n",
       "<g class=\"node\" id=\"node3\"><title>2691268907696</title>\n",
       "<polygon fill=\"none\" points=\"94,-415.5 94,-461.5 623,-461.5 623,-415.5 94,-415.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"242.5\" y=\"-434.8\">embedding_dropout: VarianceEmbeddingDropout</text>\n",
       "<polyline fill=\"none\" points=\"391,-415.5 391,-461.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"419\" y=\"-446.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"391,-438.5 447,-438.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"419\" y=\"-423.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"447,-415.5 447,-461.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"535\" y=\"-446.3\">[(20, 35), (20, 35)]</text>\n",
       "<polyline fill=\"none\" points=\"447,-438.5 623,-438.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"535\" y=\"-423.3\">[(20, 35, 650), (20, 35, 650)]</text>\n",
       "</g>\n",
       "<!-- 2691269049200&#45;&gt;2691268907696 -->\n",
       "<g class=\"edge\" id=\"edge2\"><title>2691269049200-&gt;2691268907696</title>\n",
       "<path d=\"M384.539,-581.389C394.336,-571.4 404.347,-558.753 409.5,-545 416.829,-525.439 416.829,-517.561 409.5,-498 405.555,-487.471 398.762,-477.589 391.397,-469.053\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"393.889,-466.592 384.539,-461.611 388.742,-471.336 393.889,-466.592\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 2692561563376&#45;&gt;2691268907696 -->\n",
       "<g class=\"edge\" id=\"edge3\"><title>2692561563376-&gt;2691268907696</title>\n",
       "<path d=\"M307.628,-498.366C315.344,-489.437 324.367,-478.997 332.653,-469.409\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"335.505,-471.462 339.395,-461.607 330.209,-466.884 335.505,-471.462\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 2691295482656 -->\n",
       "<g class=\"node\" id=\"node4\"><title>2691295482656</title>\n",
       "<polygon fill=\"none\" points=\"176,-332.5 176,-378.5 541,-378.5 541,-332.5 176,-332.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"242.5\" y=\"-351.8\">lstm: VarianceLSTM</text>\n",
       "<polyline fill=\"none\" points=\"309,-332.5 309,-378.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"337\" y=\"-363.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"309,-355.5 365,-355.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"337\" y=\"-340.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"365,-332.5 365,-378.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"453\" y=\"-363.3\">[(20, 35, 650), (20, 35, 650)]</text>\n",
       "<polyline fill=\"none\" points=\"365,-355.5 541,-355.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"453\" y=\"-340.3\">[(20, 35, 650), (20, 35, 650)]</text>\n",
       "</g>\n",
       "<!-- 2691268907696&#45;&gt;2691295482656 -->\n",
       "<g class=\"edge\" id=\"edge4\"><title>2691268907696-&gt;2691295482656</title>\n",
       "<path d=\"M358.5,-415.366C358.5,-407.152 358.5,-397.658 358.5,-388.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"362,-388.607 358.5,-378.607 355,-388.607 362,-388.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 2691295484728 -->\n",
       "<g class=\"node\" id=\"node5\"><title>2691295484728</title>\n",
       "<polygon fill=\"none\" points=\"169,-249.5 169,-295.5 548,-295.5 548,-249.5 169,-249.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"242.5\" y=\"-268.8\">lstm_1: VarianceLSTM</text>\n",
       "<polyline fill=\"none\" points=\"316,-249.5 316,-295.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"344\" y=\"-280.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"316,-272.5 372,-272.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"344\" y=\"-257.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"372,-249.5 372,-295.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"460\" y=\"-280.3\">[(20, 35, 650), (20, 35, 650)]</text>\n",
       "<polyline fill=\"none\" points=\"372,-272.5 548,-272.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"460\" y=\"-257.3\">[(20, 35, 650), (20, 35, 650)]</text>\n",
       "</g>\n",
       "<!-- 2691295482656&#45;&gt;2691295484728 -->\n",
       "<g class=\"edge\" id=\"edge5\"><title>2691295482656-&gt;2691295484728</title>\n",
       "<path d=\"M358.5,-332.366C358.5,-324.152 358.5,-314.658 358.5,-305.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"362,-305.607 358.5,-295.607 355,-305.607 362,-305.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 2691295592744 -->\n",
       "<g class=\"node\" id=\"node6\"><title>2691295592744</title>\n",
       "<polygon fill=\"none\" points=\"160.5,-166.5 160.5,-212.5 556.5,-212.5 556.5,-166.5 160.5,-166.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"242.5\" y=\"-185.8\">dropout: VarianceDropout</text>\n",
       "<polyline fill=\"none\" points=\"324.5,-166.5 324.5,-212.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"352.5\" y=\"-197.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"324.5,-189.5 380.5,-189.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"352.5\" y=\"-174.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"380.5,-166.5 380.5,-212.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"468.5\" y=\"-197.3\">[(20, 35, 650), (20, 35, 650)]</text>\n",
       "<polyline fill=\"none\" points=\"380.5,-189.5 556.5,-189.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"468.5\" y=\"-174.3\">[(20, 35, 650), (20, 35, 650)]</text>\n",
       "</g>\n",
       "<!-- 2691295484728&#45;&gt;2691295592744 -->\n",
       "<g class=\"edge\" id=\"edge6\"><title>2691295484728-&gt;2691295592744</title>\n",
       "<path d=\"M358.5,-249.366C358.5,-241.152 358.5,-231.658 358.5,-222.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"362,-222.607 358.5,-212.607 355,-222.607 362,-222.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 2691295594032 -->\n",
       "<g class=\"node\" id=\"node7\"><title>2691295594032</title>\n",
       "<polygon fill=\"none\" points=\"103,-83.5 103,-129.5 614,-129.5 614,-83.5 103,-83.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"229\" y=\"-102.8\">time_distributed: VarianceTimeDistributed</text>\n",
       "<polyline fill=\"none\" points=\"355,-83.5 355,-129.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"383\" y=\"-114.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"355,-106.5 411,-106.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"383\" y=\"-91.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"411,-83.5 411,-129.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"512.5\" y=\"-114.3\">[(20, 35, 650), (20, 35, 650)]</text>\n",
       "<polyline fill=\"none\" points=\"411,-106.5 614,-106.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"512.5\" y=\"-91.3\">[(20, 35, 10001), (20, 35, 10001)]</text>\n",
       "</g>\n",
       "<!-- 2691295592744&#45;&gt;2691295594032 -->\n",
       "<g class=\"edge\" id=\"edge7\"><title>2691295592744-&gt;2691295594032</title>\n",
       "<path d=\"M358.5,-166.366C358.5,-158.152 358.5,-148.658 358.5,-139.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"362,-139.607 358.5,-129.607 355,-139.607 362,-139.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 2691295594816 -->\n",
       "<g class=\"node\" id=\"node8\"><title>2691295594816</title>\n",
       "<polygon fill=\"none\" points=\"0,-0.5 0,-46.5 353,-46.5 353,-0.5 0,-0.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"98\" y=\"-19.8\">mean_time_distributed: Lambda</text>\n",
       "<polyline fill=\"none\" points=\"196,-0.5 196,-46.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"224\" y=\"-31.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"196,-23.5 252,-23.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"224\" y=\"-8.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"252,-0.5 252,-46.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"302.5\" y=\"-31.3\">(20, 35, 10001)</text>\n",
       "<polyline fill=\"none\" points=\"252,-23.5 353,-23.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"302.5\" y=\"-8.3\">(20, 35, 10001)</text>\n",
       "</g>\n",
       "<!-- 2691295594032&#45;&gt;2691295594816 -->\n",
       "<g class=\"edge\" id=\"edge8\"><title>2691295594032-&gt;2691295594816</title>\n",
       "<path d=\"M308.768,-83.3664C286.097,-73.2765 259.09,-61.2572 235.409,-50.7177\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"236.731,-47.4753 226.172,-46.6068 233.885,-53.8706 236.731,-47.4753\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 2691295595208 -->\n",
       "<g class=\"node\" id=\"node9\"><title>2691295595208</title>\n",
       "<polygon fill=\"none\" points=\"371,-0.5 371,-46.5 712,-46.5 712,-0.5 371,-0.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"463\" y=\"-19.8\">var_time_distributed: Lambda</text>\n",
       "<polyline fill=\"none\" points=\"555,-0.5 555,-46.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"583\" y=\"-31.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"555,-23.5 611,-23.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"583\" y=\"-8.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"611,-0.5 611,-46.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"661.5\" y=\"-31.3\">(20, 35, 10001)</text>\n",
       "<polyline fill=\"none\" points=\"611,-23.5 712,-23.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"661.5\" y=\"-8.3\">(20, 35, 10001)</text>\n",
       "</g>\n",
       "<!-- 2691295594032&#45;&gt;2691295595208 -->\n",
       "<g class=\"edge\" id=\"edge9\"><title>2691295594032-&gt;2691295595208</title>\n",
       "<path d=\"M408.506,-83.3664C431.301,-73.2765 458.456,-61.2572 482.267,-50.7177\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"483.827,-53.8549 491.555,-46.6068 480.994,-47.4539 483.827,-53.8549\" stroke=\"black\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>"
      ],
      "text/plain": [
       "<IPython.core.display.SVG object>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from vpbnn.models import nn2vpbnn\n",
    "vmodel = nn2vpbnn(model, variance_mode=2)\n",
    "model_to_dot(vmodel, show_shapes=True).write_pdf(\"img/ptb_architecture-vpbnn.pdf\")\n",
    "SVG(model_to_dot(vmodel, show_shapes=True, dpi=72).create(prog='dot', format='svg'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2120/2120 [==============================] - 5s 2ms/sample - loss: 153.6127 - SparseCategoricalCrossentropyPad0: 4.4144 - sparse_categorical_accuracy: 0.2754 3s - loss: 155.3295 - SparseCategoricalCrossentropyPad0: 4.4372 - sp\n",
      "2360/2360 [==============================] - 5s 2ms/sample - loss: 152.9348 - SparseCategoricalCrossentropyPad0: 4.3779 - sparse_categorical_accuracy: 0.2742\n",
      "Normal-mode\n",
      "val_loss: 4.414400100708008, val_perplexity: 82.63225555419922\n",
      "test_loss: 4.37785530090332, test_perplexity: 79.66698455810547\n"
     ]
    }
   ],
   "source": [
    "model.compile(\n",
    "    loss=sparse_categorical_crossentropy_pad0,\n",
    "    metrics=[SparseCategoricalCrossentropyPad0(), \"sparse_categorical_accuracy\"],\n",
    ")\n",
    "\n",
    "model.reset_states()\n",
    "results = model.evaluate(x_val, y_val, batch_size=args.batch_size, verbose=1)\n",
    "val_loss = results[1]\n",
    "val_perplexity = np.exp(val_loss)\n",
    "\n",
    "model.reset_states()\n",
    "results = model.evaluate(x_test, y_test, batch_size=args.batch_size, verbose=1)\n",
    "test_loss = results[1]\n",
    "test_perplexity = np.exp(test_loss)\n",
    "\n",
    "print(\"Normal-mode\")\n",
    "print(\"val_loss: {0}, val_perplexity: {1}\".format(val_loss, val_perplexity))\n",
    "print(\"test_loss: {0}, test_perplexity: {1}\".format(test_loss, test_perplexity))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2120/2120 [==============================] - 8s 4ms/sample - loss: 153.6127 - SparseCategoricalCrossentropyPad0: 4.4144 - sparse_categorical_accuracy: 0.2754\n",
      "2360/2360 [==============================] - 8s 3ms/sample - loss: 152.9348 - SparseCategoricalCrossentropyPad0: 4.3779 - sparse_categorical_accuracy: 0.2742 4s - loss: 154.5474 - SparseC\n",
      "Linear-mode\n",
      "val_loss: 4.414400100708008, val_perplexity: 82.63225555419922\n",
      "test_loss: 4.37785530090332, test_perplexity: 79.66698455810547\n"
     ]
    }
   ],
   "source": [
    "vmodel = nn2vpbnn(model, variance_mode=1)\n",
    "vmodel.layers[-3].variance_mode = 0  # Dense.variance_mode=0\n",
    "vmodel = keras.Model(inputs=vmodel.input, outputs=vmodel.output[0])\n",
    "\n",
    "vmodel.compile(\n",
    "    loss=sparse_categorical_crossentropy_pad0,\n",
    "    metrics=[SparseCategoricalCrossentropyPad0(), \"sparse_categorical_accuracy\"],\n",
    ")\n",
    "\n",
    "vmodel.reset_states()\n",
    "results = vmodel.evaluate(x_val, y_val, batch_size=args.batch_size, verbose=1)\n",
    "val_loss = results[1]\n",
    "val_perplexity = np.exp(val_loss)\n",
    "\n",
    "vmodel.reset_states()\n",
    "results = vmodel.evaluate(x_test, y_test, batch_size=args.batch_size, verbose=1)\n",
    "test_loss = results[1]\n",
    "test_perplexity = np.exp(test_loss)\n",
    "\n",
    "print(\"Linear-mode\")\n",
    "print(\"val_loss: {0}, val_perplexity: {1}\".format(val_loss, val_perplexity))\n",
    "print(\"test_loss: {0}, test_perplexity: {1}\".format(test_loss, test_perplexity))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2120/2120 [==============================] - 9s 4ms/sample - loss: 152.9578 - SparseCategoricalCrossentropyPad0: 4.3956 - sparse_categorical_accuracy: 0.2747\n",
      "2360/2360 [==============================] - 9s 4ms/sample - loss: 152.2874 - SparseCategoricalCrossentropyPad0: 4.3593 - sparse_categorical_accuracy: 0.2735\n",
      "Independent-mode\n",
      "val_loss: 4.395575523376465, val_perplexity: 81.0912857055664\n",
      "test_loss: 4.359323024749756, test_perplexity: 78.20417785644531\n"
     ]
    }
   ],
   "source": [
    "vmodel = nn2vpbnn(model, variance_mode=2)\n",
    "vmodel.layers[-3].variance_mode = 0  # Dense.variance_mode=0\n",
    "vmodel = keras.Model(inputs=vmodel.input, outputs=vmodel.output[0])\n",
    "\n",
    "vmodel.compile(\n",
    "    loss=sparse_categorical_crossentropy_pad0,\n",
    "    metrics=[SparseCategoricalCrossentropyPad0(), \"sparse_categorical_accuracy\"],\n",
    ")\n",
    "\n",
    "vmodel.reset_states()\n",
    "results = vmodel.evaluate(x_val, y_val, batch_size=args.batch_size, verbose=1)\n",
    "val_loss = results[1]\n",
    "val_perplexity = np.exp(val_loss)\n",
    "\n",
    "vmodel.reset_states()\n",
    "results = vmodel.evaluate(x_test, y_test, batch_size=args.batch_size, verbose=1)\n",
    "test_loss = results[1]\n",
    "test_perplexity = np.exp(test_loss)\n",
    "\n",
    "print(\"Independent-mode\")\n",
    "print(\"val_loss: {0}, val_perplexity: {1}\".format(val_loss, val_perplexity))\n",
    "print(\"test_loss: {0}, test_perplexity: {1}\".format(test_loss, test_perplexity))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2120/2120 [==============================] - 10s 4ms/sample - loss: 236.0720 - SparseCategoricalCrossentropyPad0: 6.7845 - sparse_categorical_accuracy: 0.0470\n",
      "2360/2360 [==============================] - 10s 4ms/sample - loss: 234.2057 - SparseCategoricalCrossentropyPad0: 6.7047 - sparse_categorical_accuracy: 0.0580\n",
      "Upper-mode\n",
      "val_loss: 6.784465312957764, val_perplexity: 884.0072631835938\n",
      "test_loss: 6.704699516296387, test_perplexity: 816.2327270507812\n"
     ]
    }
   ],
   "source": [
    "vmodel = nn2vpbnn(model, variance_mode=3)\n",
    "vmodel.layers[-3].variance_mode = 0  # Dense.variance_mode=0\n",
    "vmodel = keras.Model(inputs=vmodel.input, outputs=vmodel.output[0])\n",
    "\n",
    "vmodel.compile(\n",
    "    loss=sparse_categorical_crossentropy_pad0,\n",
    "    metrics=[SparseCategoricalCrossentropyPad0(), \"sparse_categorical_accuracy\"],\n",
    ")\n",
    "\n",
    "vmodel.reset_states()\n",
    "results = vmodel.evaluate(x_val, y_val, batch_size=args.batch_size, verbose=1)\n",
    "val_loss = results[1]\n",
    "val_perplexity = np.exp(val_loss)\n",
    "\n",
    "vmodel.reset_states()\n",
    "results = vmodel.evaluate(x_test, y_test, batch_size=args.batch_size, verbose=1)\n",
    "test_loss = results[1]\n",
    "test_perplexity = np.exp(test_loss)\n",
    "\n",
    "print(\"Upper-mode\")\n",
    "print(\"val_loss: {0}, val_perplexity: {1}\".format(val_loss, val_perplexity))\n",
    "print(\"test_loss: {0}, test_perplexity: {1}\".format(test_loss, test_perplexity))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Reference for implementation\n",
    "- https://github.com/btarjan/stateful-LSTM-LM\n",
    "- https://github.com/tensorflow/models/tree/v1.13.0/tutorials/rnn/ptb\n",
    "- https://www.tensorflow.org/tutorials/text/image_captioning\n",
    "- https://www.tensorflow.org/tutorials/text/nmt_with_attention\n",
    "- https://github.com/yaringal/BayesianRNN\n",
    "- https://github.com/wojzaremba/lstm\n",
    "- https://github.com/ahmetumutdurmus/zaremba\n",
    "- https://github.com/jarfo/kchar\n",
    "- https://qiita.com/keisuke-nakata/items/1e43c6698df800ecad73"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
