{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9e7e7d7c",
   "metadata": {},
   "source": [
    "Experiment 6:\n",
    "\n",
    "Training sentences: a b c a\n",
    "\n",
    "Embedding dimension = 10 or 100\n",
    "\n",
    "Number of words = 20\n",
    "\n",
    "Use learnable or sinusoidal positional encodings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d1686d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "56c107e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "504a342e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_module(query, key, value, embed_dim, num_head, i):\n",
    "    \n",
    "    # Multi headed self-attention\n",
    "    attention_output = layers.MultiHeadAttention(\n",
    "        num_heads=num_head,\n",
    "        key_dim=embed_dim // num_head,\n",
    "        name=\"encoder_{}/multiheadattention\".format(i)\n",
    "    )(query, key, value, use_causal_mask=True)\n",
    "    \n",
    "    # Add & Normalize\n",
    "    attention_output = layers.Add()([query, attention_output])  # Skip Connection\n",
    "    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)\n",
    "    \n",
    "    # Feedforward network\n",
    "    ff_net = keras.models.Sequential([\n",
    "        layers.Dense(2 * embed_dim, activation='relu', name=\"encoder_{}/ffn_dense_1\".format(i)),\n",
    "        layers.Dense(embed_dim, name=\"encoder_{}/ffn_dense_2\".format(i)),\n",
    "    ])\n",
    "\n",
    "    # Apply Feedforward network\n",
    "    ffn_output = ff_net(attention_output)\n",
    "\n",
    "    # Add & Normalize\n",
    "    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection\n",
    "    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)\n",
    "    \n",
    "    return ffn_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fef2622a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sinusoidal_embeddings(sequence_length, embedding_dim):\n",
    "    position_enc = np.array([\n",
    "        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]\n",
    "        if pos != 0 else np.zeros(embedding_dim)\n",
    "        for pos in range(sequence_length)\n",
    "    ])\n",
    "    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i\n",
    "    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1\n",
    "    return tf.cast(position_enc, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3e5bfd69",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 20 # vocab_size\n",
    "\n",
    "vocabs = ['word_' + str(i) for i in range(N)]\n",
    "\n",
    "vocab_map = {}\n",
    "for i in range(len(vocabs)):\n",
    "    vocab_map[vocabs[i]] = i\n",
    "    \n",
    "pairs = []\n",
    "\n",
    "for i in vocabs:\n",
    "    for j in vocabs:\n",
    "        for k in vocabs:\n",
    "            if i != j and i != k and j != k:\n",
    "                pairs.append((i,j,k))\n",
    "            \n",
    "#indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])\n",
    "\n",
    "# pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]\n",
    "# pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]\n",
    "\n",
    "pairs_train = [x for x in pairs if int(x[0].split('_')[-1]) <= 9]\n",
    "pairs_test = [x for x in pairs if int(x[0].split('_')[-1]) >= 10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d43093fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences_train = []\n",
    "sentences_number_train = []\n",
    "sentences_test = []\n",
    "sentences_number_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    sentences_train.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    sentences_test.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = []\n",
    "y_masked_labels_train = []\n",
    "x_masked_test = []\n",
    "y_masked_labels_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    x_masked_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_train.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    x_masked_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_test.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = np.array(x_masked_train)\n",
    "y_masked_labels_train = np.array(y_masked_labels_train)\n",
    "x_masked_test = np.array(x_masked_test)\n",
    "y_masked_labels_test = np.array(y_masked_labels_test)\n",
    "\n",
    "perm = np.random.permutation(len(x_masked_train))\n",
    "x_masked_train = x_masked_train[perm]\n",
    "y_masked_labels_train = y_masked_labels_train[perm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "13b40f89",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 6s - loss: 3.3491 - val_loss: 3.3085 - 6s/epoch - 6s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.3033 - val_loss: 3.2764 - 155ms/epoch - 155ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.2703 - val_loss: 3.2499 - 181ms/epoch - 181ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.2426 - val_loss: 3.2273 - 181ms/epoch - 181ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.2184 - val_loss: 3.2073 - 147ms/epoch - 147ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.1969 - val_loss: 3.1901 - 179ms/epoch - 179ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.1780 - val_loss: 3.1755 - 168ms/epoch - 168ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.1617 - val_loss: 3.1625 - 232ms/epoch - 232ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.1474 - val_loss: 3.1515 - 190ms/epoch - 190ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 3.1352 - val_loss: 3.1428 - 167ms/epoch - 167ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 3.1255 - val_loss: 3.1360 - 173ms/epoch - 173ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 3.1179 - val_loss: 3.1296 - 172ms/epoch - 172ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 3.1110 - val_loss: 3.1232 - 152ms/epoch - 152ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 3.1044 - val_loss: 3.1168 - 153ms/epoch - 153ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 3.0978 - val_loss: 3.1104 - 170ms/epoch - 170ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 3.0912 - val_loss: 3.1037 - 162ms/epoch - 162ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 3.0843 - val_loss: 3.0967 - 174ms/epoch - 174ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 3.0773 - val_loss: 3.0894 - 153ms/epoch - 153ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 3.0700 - val_loss: 3.0817 - 177ms/epoch - 177ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 3.0625 - val_loss: 3.0735 - 162ms/epoch - 162ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 3.0546 - val_loss: 3.0646 - 159ms/epoch - 159ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 3.0462 - val_loss: 3.0551 - 162ms/epoch - 162ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 3.0374 - val_loss: 3.0456 - 151ms/epoch - 151ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 3.0288 - val_loss: 3.0363 - 153ms/epoch - 153ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 3.0205 - val_loss: 3.0265 - 144ms/epoch - 144ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 3.0118 - val_loss: 3.0161 - 145ms/epoch - 145ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 3.0023 - val_loss: 3.0053 - 144ms/epoch - 144ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.9923 - val_loss: 2.9942 - 154ms/epoch - 154ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.9816 - val_loss: 2.9836 - 214ms/epoch - 214ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.9711 - val_loss: 2.9731 - 167ms/epoch - 167ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.9604 - val_loss: 2.9639 - 168ms/epoch - 168ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.9510 - val_loss: 2.9553 - 166ms/epoch - 166ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.9424 - val_loss: 2.9448 - 188ms/epoch - 188ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.9325 - val_loss: 2.9322 - 164ms/epoch - 164ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.9208 - val_loss: 2.9190 - 171ms/epoch - 171ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.9084 - val_loss: 2.9067 - 164ms/epoch - 164ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.8965 - val_loss: 2.8954 - 164ms/epoch - 164ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.8853 - val_loss: 2.8844 - 164ms/epoch - 164ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.8741 - val_loss: 2.8737 - 177ms/epoch - 177ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.8634 - val_loss: 2.8630 - 198ms/epoch - 198ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.8527 - val_loss: 2.8521 - 165ms/epoch - 165ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.8420 - val_loss: 2.8411 - 171ms/epoch - 171ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.8311 - val_loss: 2.8298 - 178ms/epoch - 178ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.8203 - val_loss: 2.8204 - 161ms/epoch - 161ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.8116 - val_loss: 2.8121 - 163ms/epoch - 163ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.8038 - val_loss: 2.8023 - 164ms/epoch - 164ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.7943 - val_loss: 2.7924 - 165ms/epoch - 165ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.7844 - val_loss: 2.7846 - 168ms/epoch - 168ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.7763 - val_loss: 2.7776 - 170ms/epoch - 170ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.7692 - val_loss: 2.7693 - 163ms/epoch - 163ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.7606 - val_loss: 2.7603 - 159ms/epoch - 159ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.7514 - val_loss: 2.7521 - 170ms/epoch - 170ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.7432 - val_loss: 2.7448 - 164ms/epoch - 164ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.7355 - val_loss: 2.7382 - 158ms/epoch - 158ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.7283 - val_loss: 2.7315 - 165ms/epoch - 165ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.7209 - val_loss: 2.7239 - 163ms/epoch - 163ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.7127 - val_loss: 2.7171 - 160ms/epoch - 160ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.7053 - val_loss: 2.7102 - 153ms/epoch - 153ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.6979 - val_loss: 2.7031 - 173ms/epoch - 173ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.6906 - val_loss: 2.6950 - 168ms/epoch - 168ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.6826 - val_loss: 2.6878 - 163ms/epoch - 163ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.6756 - val_loss: 2.6808 - 157ms/epoch - 157ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.6685 - val_loss: 2.6743 - 163ms/epoch - 163ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.6619 - val_loss: 2.6675 - 142ms/epoch - 142ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.6551 - val_loss: 2.6607 - 149ms/epoch - 149ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.6480 - val_loss: 2.6536 - 142ms/epoch - 142ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.6405 - val_loss: 2.6466 - 147ms/epoch - 147ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.6330 - val_loss: 2.6401 - 141ms/epoch - 141ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.6260 - val_loss: 2.6332 - 149ms/epoch - 149ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.6185 - val_loss: 2.6260 - 145ms/epoch - 145ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.6110 - val_loss: 2.6188 - 144ms/epoch - 144ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.6037 - val_loss: 2.6104 - 147ms/epoch - 147ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.5954 - val_loss: 2.6022 - 141ms/epoch - 141ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.5874 - val_loss: 2.5942 - 144ms/epoch - 144ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.5800 - val_loss: 2.5878 - 143ms/epoch - 143ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.5734 - val_loss: 2.5858 - 137ms/epoch - 137ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.5718 - val_loss: 2.5743 - 132ms/epoch - 132ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.5598 - val_loss: 2.5638 - 129ms/epoch - 129ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.5489 - val_loss: 2.5618 - 135ms/epoch - 135ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.5463 - val_loss: 2.5576 - 134ms/epoch - 134ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.5411 - val_loss: 2.5463 - 142ms/epoch - 142ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.5301 - val_loss: 2.5407 - 131ms/epoch - 131ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.5240 - val_loss: 2.5343 - 132ms/epoch - 132ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.5174 - val_loss: 2.5266 - 138ms/epoch - 138ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.5093 - val_loss: 2.5213 - 142ms/epoch - 142ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.5040 - val_loss: 2.5155 - 137ms/epoch - 137ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.4982 - val_loss: 2.5073 - 135ms/epoch - 135ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.4903 - val_loss: 2.4989 - 134ms/epoch - 134ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.4820 - val_loss: 2.4931 - 130ms/epoch - 130ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.4763 - val_loss: 2.4866 - 129ms/epoch - 129ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.4704 - val_loss: 2.4802 - 118ms/epoch - 118ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.4641 - val_loss: 2.4737 - 121ms/epoch - 121ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.4577 - val_loss: 2.4684 - 123ms/epoch - 123ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.4523 - val_loss: 2.4619 - 121ms/epoch - 121ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.4464 - val_loss: 2.4554 - 129ms/epoch - 129ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.4408 - val_loss: 2.4508 - 120ms/epoch - 120ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 2.4364 - val_loss: 2.4486 - 121ms/epoch - 121ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 2.4344 - val_loss: 2.4446 - 152ms/epoch - 152ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 2.4304 - val_loss: 2.4346 - 118ms/epoch - 118ms/step\n",
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 2.4213 - val_loss: 2.4264 - 124ms/epoch - 124ms/step\n",
      "Epoch 101/2000\n",
      "1/1 - 0s - loss: 2.4132 - val_loss: 2.4237 - 123ms/epoch - 123ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 2.4101 - val_loss: 2.4248 - 110ms/epoch - 110ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 2.4108 - val_loss: 2.4218 - 121ms/epoch - 121ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 2.4077 - val_loss: 2.4103 - 118ms/epoch - 118ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 2.3975 - val_loss: 2.4025 - 120ms/epoch - 120ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 2.3898 - val_loss: 2.4043 - 121ms/epoch - 121ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 2.3906 - val_loss: 2.4009 - 124ms/epoch - 124ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 2.3881 - val_loss: 2.3907 - 126ms/epoch - 126ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 2.3778 - val_loss: 2.3875 - 130ms/epoch - 130ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 2.3746 - val_loss: 2.3870 - 132ms/epoch - 132ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 2.3748 - val_loss: 2.3802 - 124ms/epoch - 124ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 2.3672 - val_loss: 2.3741 - 145ms/epoch - 145ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 2.3613 - val_loss: 2.3736 - 152ms/epoch - 152ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 2.3611 - val_loss: 2.3699 - 153ms/epoch - 153ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 2.3570 - val_loss: 2.3630 - 157ms/epoch - 157ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 2.3507 - val_loss: 2.3600 - 151ms/epoch - 151ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 2.3478 - val_loss: 2.3593 - 157ms/epoch - 157ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 2.3466 - val_loss: 2.3548 - 165ms/epoch - 165ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 2.3427 - val_loss: 2.3494 - 165ms/epoch - 165ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 2.3372 - val_loss: 2.3470 - 157ms/epoch - 157ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 2.3347 - val_loss: 2.3456 - 166ms/epoch - 166ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 2.3336 - val_loss: 2.3430 - 172ms/epoch - 172ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 2.3306 - val_loss: 2.3380 - 172ms/epoch - 172ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 2.3260 - val_loss: 2.3340 - 179ms/epoch - 179ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 2.3220 - val_loss: 2.3323 - 173ms/epoch - 173ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 2.3202 - val_loss: 2.3311 - 167ms/epoch - 167ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 2.3194 - val_loss: 2.3290 - 177ms/epoch - 177ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 2.3166 - val_loss: 2.3241 - 175ms/epoch - 175ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 2.3122 - val_loss: 2.3201 - 163ms/epoch - 163ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 2.3082 - val_loss: 2.3184 - 165ms/epoch - 165ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 2.3064 - val_loss: 2.3177 - 170ms/epoch - 170ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 2.3060 - val_loss: 2.3159 - 168ms/epoch - 168ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 2.3036 - val_loss: 2.3114 - 165ms/epoch - 165ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 2.2996 - val_loss: 2.3076 - 167ms/epoch - 167ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 2.2957 - val_loss: 2.3059 - 166ms/epoch - 166ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 2.2937 - val_loss: 2.3055 - 165ms/epoch - 165ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 2.2936 - val_loss: 2.3034 - 165ms/epoch - 165ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 2.2912 - val_loss: 2.3000 - 168ms/epoch - 168ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 2.2881 - val_loss: 2.2962 - 164ms/epoch - 164ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 2.2841 - val_loss: 2.2941 - 167ms/epoch - 167ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 2.2820 - val_loss: 2.2935 - 160ms/epoch - 160ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 2.2815 - val_loss: 2.2916 - 165ms/epoch - 165ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 2.2792 - val_loss: 2.2889 - 165ms/epoch - 165ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 2.2768 - val_loss: 2.2854 - 162ms/epoch - 162ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 2.2732 - val_loss: 2.2833 - 159ms/epoch - 159ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 2.2710 - val_loss: 2.2824 - 159ms/epoch - 159ms/step\n",
      "Epoch 147/2000\n",
      "1/1 - 0s - loss: 2.2702 - val_loss: 2.2806 - 185ms/epoch - 185ms/step\n",
      "Epoch 148/2000\n",
      "1/1 - 0s - loss: 2.2681 - val_loss: 2.2781 - 180ms/epoch - 180ms/step\n",
      "Epoch 149/2000\n",
      "1/1 - 0s - loss: 2.2658 - val_loss: 2.2751 - 180ms/epoch - 180ms/step\n",
      "Epoch 150/2000\n",
      "1/1 - 0s - loss: 2.2628 - val_loss: 2.2730 - 176ms/epoch - 176ms/step\n",
      "Epoch 151/2000\n",
      "1/1 - 0s - loss: 2.2606 - val_loss: 2.2715 - 177ms/epoch - 177ms/step\n",
      "Epoch 152/2000\n",
      "1/1 - 0s - loss: 2.2593 - val_loss: 2.2698 - 171ms/epoch - 171ms/step\n",
      "Epoch 153/2000\n",
      "1/1 - 0s - loss: 2.2573 - val_loss: 2.2673 - 173ms/epoch - 173ms/step\n",
      "Epoch 154/2000\n",
      "1/1 - 0s - loss: 2.2548 - val_loss: 2.2648 - 169ms/epoch - 169ms/step\n",
      "Epoch 155/2000\n",
      "1/1 - 0s - loss: 2.2523 - val_loss: 2.2628 - 168ms/epoch - 168ms/step\n",
      "Epoch 156/2000\n",
      "1/1 - 0s - loss: 2.2505 - val_loss: 2.2611 - 169ms/epoch - 169ms/step\n",
      "Epoch 157/2000\n",
      "1/1 - 0s - loss: 2.2489 - val_loss: 2.2592 - 177ms/epoch - 177ms/step\n",
      "Epoch 158/2000\n",
      "1/1 - 0s - loss: 2.2467 - val_loss: 2.2570 - 174ms/epoch - 174ms/step\n",
      "Epoch 159/2000\n",
      "1/1 - 0s - loss: 2.2445 - val_loss: 2.2550 - 166ms/epoch - 166ms/step\n",
      "Epoch 160/2000\n",
      "1/1 - 0s - loss: 2.2425 - val_loss: 2.2530 - 164ms/epoch - 164ms/step\n",
      "Epoch 161/2000\n",
      "1/1 - 0s - loss: 2.2405 - val_loss: 2.2512 - 163ms/epoch - 163ms/step\n",
      "Epoch 162/2000\n",
      "1/1 - 0s - loss: 2.2389 - val_loss: 2.2493 - 163ms/epoch - 163ms/step\n",
      "Epoch 163/2000\n",
      "1/1 - 0s - loss: 2.2369 - val_loss: 2.2474 - 161ms/epoch - 161ms/step\n",
      "Epoch 164/2000\n",
      "1/1 - 0s - loss: 2.2348 - val_loss: 2.2456 - 161ms/epoch - 161ms/step\n",
      "Epoch 165/2000\n",
      "1/1 - 0s - loss: 2.2329 - val_loss: 2.2434 - 162ms/epoch - 162ms/step\n",
      "Epoch 166/2000\n",
      "1/1 - 0s - loss: 2.2309 - val_loss: 2.2415 - 155ms/epoch - 155ms/step\n",
      "Epoch 167/2000\n",
      "1/1 - 0s - loss: 2.2291 - val_loss: 2.2398 - 155ms/epoch - 155ms/step\n",
      "Epoch 168/2000\n",
      "1/1 - 0s - loss: 2.2273 - val_loss: 2.2379 - 149ms/epoch - 149ms/step\n",
      "Epoch 169/2000\n",
      "1/1 - 0s - loss: 2.2254 - val_loss: 2.2362 - 151ms/epoch - 151ms/step\n",
      "Epoch 170/2000\n",
      "1/1 - 0s - loss: 2.2236 - val_loss: 2.2343 - 165ms/epoch - 165ms/step\n",
      "Epoch 171/2000\n",
      "1/1 - 0s - loss: 2.2217 - val_loss: 2.2325 - 153ms/epoch - 153ms/step\n",
      "Epoch 172/2000\n",
      "1/1 - 0s - loss: 2.2199 - val_loss: 2.2310 - 176ms/epoch - 176ms/step\n",
      "Epoch 173/2000\n",
      "1/1 - 0s - loss: 2.2185 - val_loss: 2.2295 - 185ms/epoch - 185ms/step\n",
      "Epoch 174/2000\n",
      "1/1 - 0s - loss: 2.2169 - val_loss: 2.2284 - 136ms/epoch - 136ms/step\n",
      "Epoch 175/2000\n",
      "1/1 - 0s - loss: 2.2158 - val_loss: 2.2267 - 146ms/epoch - 146ms/step\n",
      "Epoch 176/2000\n",
      "1/1 - 0s - loss: 2.2140 - val_loss: 2.2271 - 124ms/epoch - 124ms/step\n",
      "Epoch 177/2000\n",
      "1/1 - 0s - loss: 2.2144 - val_loss: 2.2236 - 133ms/epoch - 133ms/step\n",
      "Epoch 178/2000\n",
      "1/1 - 0s - loss: 2.2110 - val_loss: 2.2199 - 136ms/epoch - 136ms/step\n",
      "Epoch 179/2000\n",
      "1/1 - 0s - loss: 2.2074 - val_loss: 2.2176 - 134ms/epoch - 134ms/step\n",
      "Epoch 180/2000\n",
      "1/1 - 0s - loss: 2.2050 - val_loss: 2.2159 - 147ms/epoch - 147ms/step\n",
      "Epoch 181/2000\n",
      "1/1 - 0s - loss: 2.2033 - val_loss: 2.2150 - 132ms/epoch - 132ms/step\n",
      "Epoch 182/2000\n",
      "1/1 - 0s - loss: 2.2022 - val_loss: 2.2133 - 132ms/epoch - 132ms/step\n",
      "Epoch 183/2000\n",
      "1/1 - 0s - loss: 2.2005 - val_loss: 2.2109 - 127ms/epoch - 127ms/step\n",
      "Epoch 184/2000\n",
      "1/1 - 0s - loss: 2.1982 - val_loss: 2.2088 - 130ms/epoch - 130ms/step\n",
      "Epoch 185/2000\n",
      "1/1 - 0s - loss: 2.1961 - val_loss: 2.2064 - 128ms/epoch - 128ms/step\n",
      "Epoch 186/2000\n",
      "1/1 - 0s - loss: 2.1936 - val_loss: 2.2046 - 135ms/epoch - 135ms/step\n",
      "Epoch 187/2000\n",
      "1/1 - 0s - loss: 2.1918 - val_loss: 2.2028 - 135ms/epoch - 135ms/step\n",
      "Epoch 188/2000\n",
      "1/1 - 0s - loss: 2.1900 - val_loss: 2.2010 - 132ms/epoch - 132ms/step\n",
      "Epoch 189/2000\n",
      "1/1 - 0s - loss: 2.1883 - val_loss: 2.1994 - 130ms/epoch - 130ms/step\n",
      "Epoch 190/2000\n",
      "1/1 - 0s - loss: 2.1867 - val_loss: 2.1978 - 131ms/epoch - 131ms/step\n",
      "Epoch 191/2000\n",
      "1/1 - 0s - loss: 2.1852 - val_loss: 2.1960 - 127ms/epoch - 127ms/step\n",
      "Epoch 192/2000\n",
      "1/1 - 0s - loss: 2.1834 - val_loss: 2.1937 - 125ms/epoch - 125ms/step\n",
      "Epoch 193/2000\n",
      "1/1 - 0s - loss: 2.1811 - val_loss: 2.1911 - 137ms/epoch - 137ms/step\n",
      "Epoch 194/2000\n",
      "1/1 - 0s - loss: 2.1786 - val_loss: 2.1889 - 124ms/epoch - 124ms/step\n",
      "Epoch 195/2000\n",
      "1/1 - 0s - loss: 2.1765 - val_loss: 2.1871 - 128ms/epoch - 128ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 196/2000\n",
      "1/1 - 0s - loss: 2.1748 - val_loss: 2.1857 - 128ms/epoch - 128ms/step\n",
      "Epoch 197/2000\n",
      "1/1 - 0s - loss: 2.1733 - val_loss: 2.1844 - 130ms/epoch - 130ms/step\n",
      "Epoch 198/2000\n",
      "1/1 - 0s - loss: 2.1720 - val_loss: 2.1827 - 124ms/epoch - 124ms/step\n",
      "Epoch 199/2000\n",
      "1/1 - 0s - loss: 2.1703 - val_loss: 2.1820 - 126ms/epoch - 126ms/step\n",
      "Epoch 200/2000\n",
      "1/1 - 0s - loss: 2.1696 - val_loss: 2.1807 - 123ms/epoch - 123ms/step\n",
      "Epoch 201/2000\n",
      "1/1 - 0s - loss: 2.1683 - val_loss: 2.1785 - 125ms/epoch - 125ms/step\n",
      "Epoch 202/2000\n",
      "1/1 - 0s - loss: 2.1661 - val_loss: 2.1748 - 140ms/epoch - 140ms/step\n",
      "Epoch 203/2000\n",
      "1/1 - 0s - loss: 2.1625 - val_loss: 2.1722 - 148ms/epoch - 148ms/step\n",
      "Epoch 204/2000\n",
      "1/1 - 0s - loss: 2.1601 - val_loss: 2.1704 - 144ms/epoch - 144ms/step\n",
      "Epoch 205/2000\n",
      "1/1 - 0s - loss: 2.1583 - val_loss: 2.1688 - 120ms/epoch - 120ms/step\n",
      "Epoch 206/2000\n",
      "1/1 - 0s - loss: 2.1566 - val_loss: 2.1673 - 124ms/epoch - 124ms/step\n",
      "Epoch 207/2000\n",
      "1/1 - 0s - loss: 2.1553 - val_loss: 2.1662 - 126ms/epoch - 126ms/step\n",
      "Epoch 208/2000\n",
      "1/1 - 0s - loss: 2.1542 - val_loss: 2.1649 - 123ms/epoch - 123ms/step\n",
      "Epoch 209/2000\n",
      "1/1 - 0s - loss: 2.1528 - val_loss: 2.1627 - 146ms/epoch - 146ms/step\n",
      "Epoch 210/2000\n",
      "1/1 - 0s - loss: 2.1506 - val_loss: 2.1607 - 152ms/epoch - 152ms/step\n",
      "Epoch 211/2000\n",
      "1/1 - 0s - loss: 2.1487 - val_loss: 2.1581 - 150ms/epoch - 150ms/step\n",
      "Epoch 212/2000\n",
      "1/1 - 0s - loss: 2.1462 - val_loss: 2.1558 - 130ms/epoch - 130ms/step\n",
      "Epoch 213/2000\n",
      "1/1 - 0s - loss: 2.1439 - val_loss: 2.1542 - 129ms/epoch - 129ms/step\n",
      "Epoch 214/2000\n",
      "1/1 - 0s - loss: 2.1422 - val_loss: 2.1525 - 126ms/epoch - 126ms/step\n",
      "Epoch 215/2000\n",
      "1/1 - 0s - loss: 2.1406 - val_loss: 2.1514 - 128ms/epoch - 128ms/step\n",
      "Epoch 216/2000\n",
      "1/1 - 0s - loss: 2.1394 - val_loss: 2.1501 - 134ms/epoch - 134ms/step\n",
      "Epoch 217/2000\n",
      "1/1 - 0s - loss: 2.1382 - val_loss: 2.1496 - 132ms/epoch - 132ms/step\n",
      "Epoch 218/2000\n",
      "1/1 - 0s - loss: 2.1376 - val_loss: 2.1478 - 135ms/epoch - 135ms/step\n",
      "Epoch 219/2000\n",
      "1/1 - 0s - loss: 2.1359 - val_loss: 2.1466 - 131ms/epoch - 131ms/step\n",
      "Epoch 220/2000\n",
      "1/1 - 0s - loss: 2.1347 - val_loss: 2.1434 - 134ms/epoch - 134ms/step\n",
      "Epoch 221/2000\n",
      "1/1 - 0s - loss: 2.1317 - val_loss: 2.1410 - 132ms/epoch - 132ms/step\n",
      "Epoch 222/2000\n",
      "1/1 - 0s - loss: 2.1293 - val_loss: 2.1391 - 139ms/epoch - 139ms/step\n",
      "Epoch 223/2000\n",
      "1/1 - 0s - loss: 2.1273 - val_loss: 2.1379 - 147ms/epoch - 147ms/step\n",
      "Epoch 224/2000\n",
      "1/1 - 0s - loss: 2.1263 - val_loss: 2.1372 - 158ms/epoch - 158ms/step\n",
      "Epoch 225/2000\n",
      "1/1 - 0s - loss: 2.1255 - val_loss: 2.1356 - 163ms/epoch - 163ms/step\n",
      "Epoch 226/2000\n",
      "1/1 - 0s - loss: 2.1240 - val_loss: 2.1342 - 162ms/epoch - 162ms/step\n",
      "Epoch 227/2000\n",
      "1/1 - 0s - loss: 2.1224 - val_loss: 2.1317 - 162ms/epoch - 162ms/step\n",
      "Epoch 228/2000\n",
      "1/1 - 0s - loss: 2.1200 - val_loss: 2.1298 - 164ms/epoch - 164ms/step\n",
      "Epoch 229/2000\n",
      "1/1 - 0s - loss: 2.1182 - val_loss: 2.1284 - 181ms/epoch - 181ms/step\n",
      "Epoch 230/2000\n",
      "1/1 - 0s - loss: 2.1168 - val_loss: 2.1274 - 176ms/epoch - 176ms/step\n",
      "Epoch 231/2000\n",
      "1/1 - 0s - loss: 2.1158 - val_loss: 2.1264 - 173ms/epoch - 173ms/step\n",
      "Epoch 232/2000\n",
      "1/1 - 0s - loss: 2.1147 - val_loss: 2.1248 - 181ms/epoch - 181ms/step\n",
      "Epoch 233/2000\n",
      "1/1 - 0s - loss: 2.1132 - val_loss: 2.1232 - 163ms/epoch - 163ms/step\n",
      "Epoch 234/2000\n",
      "1/1 - 0s - loss: 2.1116 - val_loss: 2.1215 - 162ms/epoch - 162ms/step\n",
      "Epoch 235/2000\n",
      "1/1 - 0s - loss: 2.1099 - val_loss: 2.1200 - 174ms/epoch - 174ms/step\n",
      "Epoch 236/2000\n",
      "1/1 - 0s - loss: 2.1084 - val_loss: 2.1188 - 174ms/epoch - 174ms/step\n",
      "Epoch 237/2000\n",
      "1/1 - 0s - loss: 2.1073 - val_loss: 2.1176 - 179ms/epoch - 179ms/step\n",
      "Epoch 238/2000\n",
      "1/1 - 0s - loss: 2.1062 - val_loss: 2.1167 - 179ms/epoch - 179ms/step\n",
      "Epoch 239/2000\n",
      "1/1 - 0s - loss: 2.1052 - val_loss: 2.1155 - 207ms/epoch - 207ms/step\n",
      "Epoch 240/2000\n",
      "1/1 - 0s - loss: 2.1041 - val_loss: 2.1143 - 177ms/epoch - 177ms/step\n",
      "Epoch 241/2000\n",
      "1/1 - 0s - loss: 2.1028 - val_loss: 2.1129 - 187ms/epoch - 187ms/step\n",
      "Epoch 242/2000\n",
      "1/1 - 0s - loss: 2.1015 - val_loss: 2.1116 - 180ms/epoch - 180ms/step\n",
      "Epoch 243/2000\n",
      "1/1 - 0s - loss: 2.1002 - val_loss: 2.1104 - 161ms/epoch - 161ms/step\n",
      "Epoch 244/2000\n",
      "1/1 - 0s - loss: 2.0990 - val_loss: 2.1093 - 177ms/epoch - 177ms/step\n",
      "Epoch 245/2000\n",
      "1/1 - 0s - loss: 2.0979 - val_loss: 2.1084 - 169ms/epoch - 169ms/step\n",
      "Epoch 246/2000\n",
      "1/1 - 0s - loss: 2.0969 - val_loss: 2.1074 - 170ms/epoch - 170ms/step\n",
      "Epoch 247/2000\n",
      "1/1 - 0s - loss: 2.0959 - val_loss: 2.1065 - 162ms/epoch - 162ms/step\n",
      "Epoch 248/2000\n",
      "1/1 - 0s - loss: 2.0950 - val_loss: 2.1055 - 163ms/epoch - 163ms/step\n",
      "Epoch 249/2000\n",
      "1/1 - 0s - loss: 2.0940 - val_loss: 2.1046 - 166ms/epoch - 166ms/step\n",
      "Epoch 250/2000\n",
      "1/1 - 0s - loss: 2.0931 - val_loss: 2.1035 - 166ms/epoch - 166ms/step\n",
      "Epoch 251/2000\n",
      "1/1 - 0s - loss: 2.0920 - val_loss: 2.1026 - 190ms/epoch - 190ms/step\n",
      "Epoch 252/2000\n",
      "1/1 - 0s - loss: 2.0910 - val_loss: 2.1014 - 163ms/epoch - 163ms/step\n",
      "Epoch 253/2000\n",
      "1/1 - 0s - loss: 2.0900 - val_loss: 2.1005 - 162ms/epoch - 162ms/step\n",
      "Epoch 254/2000\n",
      "1/1 - 0s - loss: 2.0890 - val_loss: 2.0994 - 204ms/epoch - 204ms/step\n",
      "Epoch 255/2000\n",
      "1/1 - 0s - loss: 2.0880 - val_loss: 2.0985 - 166ms/epoch - 166ms/step\n",
      "Epoch 256/2000\n",
      "1/1 - 0s - loss: 2.0870 - val_loss: 2.0975 - 165ms/epoch - 165ms/step\n",
      "Epoch 257/2000\n",
      "1/1 - 0s - loss: 2.0860 - val_loss: 2.0966 - 165ms/epoch - 165ms/step\n",
      "Epoch 258/2000\n",
      "1/1 - 0s - loss: 2.0851 - val_loss: 2.0956 - 166ms/epoch - 166ms/step\n",
      "Epoch 259/2000\n",
      "1/1 - 0s - loss: 2.0842 - val_loss: 2.0948 - 169ms/epoch - 169ms/step\n",
      "Epoch 260/2000\n",
      "1/1 - 0s - loss: 2.0833 - val_loss: 2.0940 - 161ms/epoch - 161ms/step\n",
      "Epoch 261/2000\n",
      "1/1 - 0s - loss: 2.0824 - val_loss: 2.0931 - 170ms/epoch - 170ms/step\n",
      "Epoch 262/2000\n",
      "1/1 - 0s - loss: 2.0816 - val_loss: 2.0922 - 166ms/epoch - 166ms/step\n",
      "Epoch 263/2000\n",
      "1/1 - 0s - loss: 2.0807 - val_loss: 2.0914 - 166ms/epoch - 166ms/step\n",
      "Epoch 264/2000\n",
      "1/1 - 0s - loss: 2.0799 - val_loss: 2.0907 - 165ms/epoch - 165ms/step\n",
      "Epoch 265/2000\n",
      "1/1 - 0s - loss: 2.0791 - val_loss: 2.0899 - 163ms/epoch - 163ms/step\n",
      "Epoch 266/2000\n",
      "1/1 - 0s - loss: 2.0784 - val_loss: 2.0893 - 158ms/epoch - 158ms/step\n",
      "Epoch 267/2000\n",
      "1/1 - 0s - loss: 2.0778 - val_loss: 2.0888 - 166ms/epoch - 166ms/step\n",
      "Epoch 268/2000\n",
      "1/1 - 0s - loss: 2.0773 - val_loss: 2.0891 - 147ms/epoch - 147ms/step\n",
      "Epoch 269/2000\n",
      "1/1 - 0s - loss: 2.0775 - val_loss: 2.0897 - 147ms/epoch - 147ms/step\n",
      "Epoch 270/2000\n",
      "1/1 - 0s - loss: 2.0781 - val_loss: 2.0938 - 148ms/epoch - 148ms/step\n",
      "Epoch 271/2000\n",
      "1/1 - 0s - loss: 2.0818 - val_loss: 2.0929 - 147ms/epoch - 147ms/step\n",
      "Epoch 272/2000\n",
      "1/1 - 0s - loss: 2.0813 - val_loss: 2.0944 - 182ms/epoch - 182ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 10\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "sinusoidal_embeddings = get_sinusoidal_embeddings(len(x_masked_train[0]), embed_dim)\n",
    "encoder_output = word_embeddings + sinusoidal_embeddings\n",
    "\n",
    "for i in range(5):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "aa33aaff",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fe4d2103",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0, 0.040708642)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0005b730",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 8s - loss: 3.6666 - val_loss: 3.2988 - 8s/epoch - 8s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 1s - loss: 3.2762 - val_loss: 3.1187 - 801ms/epoch - 801ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 1s - loss: 3.0928 - val_loss: 3.0142 - 768ms/epoch - 768ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 1s - loss: 2.9847 - val_loss: 2.9598 - 783ms/epoch - 783ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 1s - loss: 2.9298 - val_loss: 2.9000 - 809ms/epoch - 809ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 1s - loss: 2.8705 - val_loss: 2.8506 - 802ms/epoch - 802ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 1s - loss: 2.8221 - val_loss: 2.8194 - 688ms/epoch - 688ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 1s - loss: 2.7915 - val_loss: 2.7655 - 639ms/epoch - 639ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 1s - loss: 2.7369 - val_loss: 2.6987 - 605ms/epoch - 605ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 1s - loss: 2.6682 - val_loss: 2.6356 - 608ms/epoch - 608ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 1s - loss: 2.6027 - val_loss: 2.5779 - 542ms/epoch - 542ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 1s - loss: 2.5427 - val_loss: 2.5304 - 543ms/epoch - 543ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 1s - loss: 2.4935 - val_loss: 2.4869 - 541ms/epoch - 541ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 1s - loss: 2.4493 - val_loss: 2.4416 - 543ms/epoch - 543ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 1s - loss: 2.4034 - val_loss: 2.4014 - 569ms/epoch - 569ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 1s - loss: 2.3629 - val_loss: 2.3710 - 547ms/epoch - 547ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 1s - loss: 2.3311 - val_loss: 2.3339 - 501ms/epoch - 501ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 2.2904 - val_loss: 2.2961 - 495ms/epoch - 495ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 1s - loss: 2.2500 - val_loss: 2.2434 - 521ms/epoch - 521ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 1s - loss: 2.1975 - val_loss: 2.1978 - 536ms/epoch - 536ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 2.1547 - val_loss: 2.1610 - 498ms/epoch - 498ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 1s - loss: 2.1214 - val_loss: 2.1349 - 559ms/epoch - 559ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 1s - loss: 2.0974 - val_loss: 2.1133 - 624ms/epoch - 624ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 1s - loss: 2.0774 - val_loss: 2.0939 - 683ms/epoch - 683ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 1s - loss: 2.0595 - val_loss: 2.0779 - 723ms/epoch - 723ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 1s - loss: 2.0447 - val_loss: 2.0635 - 923ms/epoch - 923ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 1s - loss: 2.0314 - val_loss: 2.0517 - 792ms/epoch - 792ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 1s - loss: 2.0204 - val_loss: 2.0435 - 673ms/epoch - 673ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 1s - loss: 2.0127 - val_loss: 2.0364 - 663ms/epoch - 663ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 1s - loss: 2.0059 - val_loss: 2.0310 - 826ms/epoch - 826ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 1s - loss: 2.0009 - val_loss: 2.0267 - 646ms/epoch - 646ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 1s - loss: 1.9969 - val_loss: 2.0228 - 607ms/epoch - 607ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 1s - loss: 1.9935 - val_loss: 2.0205 - 610ms/epoch - 610ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 1s - loss: 1.9916 - val_loss: 2.0189 - 605ms/epoch - 605ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 1s - loss: 1.9903 - val_loss: 2.0176 - 613ms/epoch - 613ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 1s - loss: 1.9891 - val_loss: 2.0166 - 629ms/epoch - 629ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 1s - loss: 1.9878 - val_loss: 2.0157 - 654ms/epoch - 654ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 1s - loss: 1.9863 - val_loss: 2.0149 - 579ms/epoch - 579ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 1s - loss: 1.9847 - val_loss: 2.0145 - 613ms/epoch - 613ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 1s - loss: 1.9834 - val_loss: 2.0143 - 596ms/epoch - 596ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 1s - loss: 1.9824 - val_loss: 2.0141 - 564ms/epoch - 564ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 1s - loss: 1.9813 - val_loss: 2.0142 - 511ms/epoch - 511ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 1s - loss: 1.9802 - val_loss: 2.0146 - 522ms/epoch - 522ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 1s - loss: 1.9792 - val_loss: 2.0152 - 588ms/epoch - 588ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 1s - loss: 1.9782 - val_loss: 2.0161 - 593ms/epoch - 593ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 1s - loss: 1.9772 - val_loss: 2.0171 - 584ms/epoch - 584ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 100\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "sinusoidal_embeddings = get_sinusoidal_embeddings(len(x_masked_train[0]), embed_dim)\n",
    "encoder_output = word_embeddings + sinusoidal_embeddings\n",
    "\n",
    "for i in range(5):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "beb07cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "dbdd4a69",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0, 0.0068704057)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "54f31bac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[10,  0,  1, 10],\n",
       "       [10,  0,  2, 10],\n",
       "       [10,  0,  3, 10],\n",
       "       ...,\n",
       "       [19, 18, 15, 19],\n",
       "       [19, 18, 16, 19],\n",
       "       [19, 18, 17, 19]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_masked_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "458ae205",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
