{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d1686d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "56c107e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "504a342e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_module(query, key, value, embed_dim, num_head, i):\n",
    "    \n",
    "    # Multi headed self-attention\n",
    "    attention_output = layers.MultiHeadAttention(\n",
    "        num_heads=num_head,\n",
    "        key_dim=embed_dim // num_head,\n",
    "        name=\"encoder_{}/multiheadattention\".format(i)\n",
    "    )(query, key, value, use_causal_mask=True)\n",
    "    \n",
    "    # Add & Normalize\n",
    "    attention_output = layers.Add()([query, attention_output])  # Skip Connection\n",
    "    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)\n",
    "    \n",
    "    # Feedforward network\n",
    "    ff_net = keras.models.Sequential([\n",
    "        layers.Dense(2 * embed_dim, activation='relu', name=\"encoder_{}/ffn_dense_1\".format(i)),\n",
    "        layers.Dense(embed_dim, name=\"encoder_{}/ffn_dense_2\".format(i)),\n",
    "    ])\n",
    "\n",
    "    # Apply Feedforward network\n",
    "    ffn_output = ff_net(attention_output)\n",
    "\n",
    "    # Add & Normalize\n",
    "    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection\n",
    "    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)\n",
    "    \n",
    "    return ffn_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fef2622a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sinusoidal_embeddings(sequence_length, embedding_dim):\n",
    "    position_enc = np.array([\n",
    "        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]\n",
    "        if pos != 0 else np.zeros(embedding_dim)\n",
    "        for pos in range(sequence_length)\n",
    "    ])\n",
    "    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i\n",
    "    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1\n",
    "    return tf.cast(position_enc, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3e5bfd69",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 20 # vocab_size\n",
    "\n",
    "vocabs = ['word_' + str(i) for i in range(N)]\n",
    "\n",
    "vocab_map = {}\n",
    "for i in range(len(vocabs)):\n",
    "    vocab_map[vocabs[i]] = i\n",
    "    \n",
    "pairs = []\n",
    "\n",
    "for i in vocabs:\n",
    "    for j in vocabs:\n",
    "        for k in vocabs:\n",
    "            if i != j and i != k and j != k:\n",
    "                pairs.append((i,j,k))\n",
    "            \n",
    "#indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])\n",
    "\n",
    "# pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]\n",
    "# pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]\n",
    "\n",
    "pairs_train = [x for x in pairs if int(x[0].split('_')[-1]) <= 9]\n",
    "pairs_test = [x for x in pairs if int(x[0].split('_')[-1]) >= 10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d43093fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences_train = []\n",
    "sentences_number_train = []\n",
    "sentences_test = []\n",
    "sentences_number_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    sentences_train.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    sentences_test.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = []\n",
    "y_masked_labels_train = []\n",
    "x_masked_test = []\n",
    "y_masked_labels_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    x_masked_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_train.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    x_masked_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_test.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = np.array(x_masked_train)\n",
    "y_masked_labels_train = np.array(y_masked_labels_train)\n",
    "x_masked_test = np.array(x_masked_test)\n",
    "y_masked_labels_test = np.array(y_masked_labels_test)\n",
    "\n",
    "perm = np.random.permutation(len(x_masked_train))\n",
    "x_masked_train = x_masked_train[perm]\n",
    "y_masked_labels_train = y_masked_labels_train[perm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "13b40f89",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 6s - loss: 3.2925 - val_loss: 3.2452 - 6s/epoch - 6s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.2646 - val_loss: 3.2214 - 191ms/epoch - 191ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.2415 - val_loss: 3.2011 - 182ms/epoch - 182ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.2217 - val_loss: 3.1833 - 226ms/epoch - 226ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.2041 - val_loss: 3.1673 - 175ms/epoch - 175ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.1881 - val_loss: 3.1528 - 179ms/epoch - 179ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.1735 - val_loss: 3.1398 - 174ms/epoch - 174ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.1602 - val_loss: 3.1273 - 210ms/epoch - 210ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.1477 - val_loss: 3.1155 - 213ms/epoch - 213ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 3.1357 - val_loss: 3.1043 - 204ms/epoch - 204ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 3.1244 - val_loss: 3.0938 - 224ms/epoch - 224ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 3.1137 - val_loss: 3.0841 - 210ms/epoch - 210ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 3.1036 - val_loss: 3.0750 - 184ms/epoch - 184ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 3.0940 - val_loss: 3.0666 - 188ms/epoch - 188ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 3.0848 - val_loss: 3.0588 - 212ms/epoch - 212ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 3.0762 - val_loss: 3.0515 - 162ms/epoch - 162ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 3.0679 - val_loss: 3.0445 - 156ms/epoch - 156ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 3.0597 - val_loss: 3.0377 - 164ms/epoch - 164ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 3.0515 - val_loss: 3.0314 - 162ms/epoch - 162ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 3.0438 - val_loss: 3.0257 - 211ms/epoch - 211ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 3.0366 - val_loss: 3.0203 - 195ms/epoch - 195ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 3.0300 - val_loss: 3.0151 - 210ms/epoch - 210ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 3.0236 - val_loss: 3.0101 - 185ms/epoch - 185ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 3.0174 - val_loss: 3.0053 - 175ms/epoch - 175ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 3.0115 - val_loss: 3.0009 - 178ms/epoch - 178ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 3.0062 - val_loss: 2.9967 - 169ms/epoch - 169ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 3.0012 - val_loss: 2.9928 - 140ms/epoch - 140ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.9964 - val_loss: 2.9890 - 142ms/epoch - 142ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.9918 - val_loss: 2.9853 - 140ms/epoch - 140ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.9874 - val_loss: 2.9815 - 169ms/epoch - 169ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.9833 - val_loss: 2.9777 - 184ms/epoch - 184ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.9793 - val_loss: 2.9739 - 181ms/epoch - 181ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.9754 - val_loss: 2.9703 - 179ms/epoch - 179ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.9716 - val_loss: 2.9669 - 187ms/epoch - 187ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.9681 - val_loss: 2.9639 - 204ms/epoch - 204ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.9648 - val_loss: 2.9611 - 208ms/epoch - 208ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.9616 - val_loss: 2.9586 - 187ms/epoch - 187ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.9585 - val_loss: 2.9562 - 194ms/epoch - 194ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.9555 - val_loss: 2.9539 - 172ms/epoch - 172ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.9526 - val_loss: 2.9518 - 157ms/epoch - 157ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.9499 - val_loss: 2.9498 - 150ms/epoch - 150ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.9472 - val_loss: 2.9477 - 165ms/epoch - 165ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.9445 - val_loss: 2.9457 - 184ms/epoch - 184ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.9419 - val_loss: 2.9436 - 200ms/epoch - 200ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.9393 - val_loss: 2.9415 - 143ms/epoch - 143ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.9367 - val_loss: 2.9392 - 196ms/epoch - 196ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.9341 - val_loss: 2.9368 - 162ms/epoch - 162ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.9313 - val_loss: 2.9343 - 145ms/epoch - 145ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.9285 - val_loss: 2.9318 - 153ms/epoch - 153ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.9256 - val_loss: 2.9294 - 139ms/epoch - 139ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.9229 - val_loss: 2.9269 - 130ms/epoch - 130ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.9202 - val_loss: 2.9245 - 134ms/epoch - 134ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.9176 - val_loss: 2.9220 - 151ms/epoch - 151ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.9149 - val_loss: 2.9193 - 124ms/epoch - 124ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.9120 - val_loss: 2.9165 - 145ms/epoch - 145ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.9090 - val_loss: 2.9135 - 141ms/epoch - 141ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.9058 - val_loss: 2.9104 - 149ms/epoch - 149ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.9026 - val_loss: 2.9073 - 154ms/epoch - 154ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.8992 - val_loss: 2.9039 - 128ms/epoch - 128ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.8957 - val_loss: 2.9003 - 113ms/epoch - 113ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.8918 - val_loss: 2.8963 - 109ms/epoch - 109ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.8875 - val_loss: 2.8918 - 111ms/epoch - 111ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.8827 - val_loss: 2.8869 - 115ms/epoch - 115ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.8775 - val_loss: 2.8816 - 140ms/epoch - 140ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.8719 - val_loss: 2.8755 - 150ms/epoch - 150ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.8656 - val_loss: 2.8685 - 119ms/epoch - 119ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.8585 - val_loss: 2.8606 - 163ms/epoch - 163ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.8504 - val_loss: 2.8519 - 167ms/epoch - 167ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.8414 - val_loss: 2.8421 - 171ms/epoch - 171ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.8312 - val_loss: 2.8310 - 134ms/epoch - 134ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.8196 - val_loss: 2.8184 - 157ms/epoch - 157ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.8066 - val_loss: 2.8045 - 152ms/epoch - 152ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.7924 - val_loss: 2.7894 - 175ms/epoch - 175ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.7771 - val_loss: 2.7736 - 184ms/epoch - 184ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.7612 - val_loss: 2.7577 - 178ms/epoch - 178ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.7455 - val_loss: 2.7423 - 179ms/epoch - 179ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.7303 - val_loss: 2.7269 - 193ms/epoch - 193ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.7153 - val_loss: 2.7117 - 171ms/epoch - 171ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.7005 - val_loss: 2.6970 - 170ms/epoch - 170ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.6860 - val_loss: 2.6829 - 183ms/epoch - 183ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.6721 - val_loss: 2.6703 - 193ms/epoch - 193ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.6594 - val_loss: 2.6593 - 181ms/epoch - 181ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.6480 - val_loss: 2.6494 - 192ms/epoch - 192ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.6378 - val_loss: 2.6396 - 197ms/epoch - 197ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.6279 - val_loss: 2.6300 - 197ms/epoch - 197ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.6179 - val_loss: 2.6208 - 198ms/epoch - 198ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.6081 - val_loss: 2.6112 - 209ms/epoch - 209ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.5986 - val_loss: 2.6029 - 214ms/epoch - 214ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.5898 - val_loss: 2.5949 - 195ms/epoch - 195ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.5820 - val_loss: 2.5890 - 204ms/epoch - 204ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.5753 - val_loss: 2.5804 - 198ms/epoch - 198ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.5670 - val_loss: 2.5715 - 217ms/epoch - 217ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.5581 - val_loss: 2.5643 - 210ms/epoch - 210ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.5510 - val_loss: 2.5570 - 198ms/epoch - 198ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.5438 - val_loss: 2.5484 - 185ms/epoch - 185ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.5352 - val_loss: 2.5401 - 224ms/epoch - 224ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 2.5269 - val_loss: 2.5326 - 201ms/epoch - 201ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 2.5196 - val_loss: 2.5254 - 197ms/epoch - 197ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 2.5123 - val_loss: 2.5174 - 185ms/epoch - 185ms/step\n",
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 2.5043 - val_loss: 2.5095 - 190ms/epoch - 190ms/step\n",
      "Epoch 101/2000\n",
      "1/1 - 0s - loss: 2.4962 - val_loss: 2.5029 - 201ms/epoch - 201ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 2.4896 - val_loss: 2.4964 - 181ms/epoch - 181ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 2.4831 - val_loss: 2.4893 - 184ms/epoch - 184ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 2.4760 - val_loss: 2.4815 - 178ms/epoch - 178ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 2.4684 - val_loss: 2.4746 - 180ms/epoch - 180ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 2.4615 - val_loss: 2.4701 - 182ms/epoch - 182ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 2.4568 - val_loss: 2.4699 - 180ms/epoch - 180ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 2.4570 - val_loss: 2.4605 - 181ms/epoch - 181ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 2.4476 - val_loss: 2.4488 - 183ms/epoch - 183ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 2.4360 - val_loss: 2.4491 - 167ms/epoch - 167ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 2.4365 - val_loss: 2.4372 - 167ms/epoch - 167ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 2.4244 - val_loss: 2.4353 - 168ms/epoch - 168ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 2.4226 - val_loss: 2.4254 - 166ms/epoch - 166ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 2.4129 - val_loss: 2.4218 - 167ms/epoch - 167ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 2.4096 - val_loss: 2.4151 - 163ms/epoch - 163ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 2.4031 - val_loss: 2.4104 - 160ms/epoch - 160ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 2.3986 - val_loss: 2.4035 - 159ms/epoch - 159ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 2.3919 - val_loss: 2.3993 - 160ms/epoch - 160ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 2.3878 - val_loss: 2.3939 - 161ms/epoch - 161ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 2.3825 - val_loss: 2.3893 - 153ms/epoch - 153ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 2.3781 - val_loss: 2.3834 - 152ms/epoch - 152ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 2.3724 - val_loss: 2.3790 - 190ms/epoch - 190ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 2.3682 - val_loss: 2.3742 - 161ms/epoch - 161ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 2.3636 - val_loss: 2.3697 - 161ms/epoch - 161ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 2.3592 - val_loss: 2.3653 - 187ms/epoch - 187ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 2.3548 - val_loss: 2.3606 - 164ms/epoch - 164ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 2.3502 - val_loss: 2.3567 - 195ms/epoch - 195ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 2.3465 - val_loss: 2.3520 - 140ms/epoch - 140ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 2.3418 - val_loss: 2.3487 - 129ms/epoch - 129ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 2.3385 - val_loss: 2.3442 - 107ms/epoch - 107ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 2.3341 - val_loss: 2.3415 - 106ms/epoch - 106ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 2.3317 - val_loss: 2.3372 - 121ms/epoch - 121ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 2.3273 - val_loss: 2.3339 - 106ms/epoch - 106ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 2.3242 - val_loss: 2.3288 - 111ms/epoch - 111ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 2.3190 - val_loss: 2.3248 - 112ms/epoch - 112ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 2.3153 - val_loss: 2.3203 - 106ms/epoch - 106ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 2.3108 - val_loss: 2.3178 - 113ms/epoch - 113ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 2.3082 - val_loss: 2.3144 - 110ms/epoch - 110ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 2.3051 - val_loss: 2.3117 - 114ms/epoch - 114ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 2.3023 - val_loss: 2.3074 - 123ms/epoch - 123ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 2.2981 - val_loss: 2.3033 - 143ms/epoch - 143ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 2.2940 - val_loss: 2.2995 - 145ms/epoch - 145ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 2.2904 - val_loss: 2.2961 - 134ms/epoch - 134ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 2.2870 - val_loss: 2.2936 - 131ms/epoch - 131ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 2.2844 - val_loss: 2.2905 - 130ms/epoch - 130ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 2.2816 - val_loss: 2.2882 - 135ms/epoch - 135ms/step\n",
      "Epoch 147/2000\n",
      "1/1 - 0s - loss: 2.2792 - val_loss: 2.2844 - 131ms/epoch - 131ms/step\n",
      "Epoch 148/2000\n",
      "1/1 - 0s - loss: 2.2755 - val_loss: 2.2808 - 131ms/epoch - 131ms/step\n",
      "Epoch 149/2000\n",
      "1/1 - 0s - loss: 2.2718 - val_loss: 2.2769 - 133ms/epoch - 133ms/step\n",
      "Epoch 150/2000\n",
      "1/1 - 0s - loss: 2.2680 - val_loss: 2.2736 - 140ms/epoch - 140ms/step\n",
      "Epoch 151/2000\n",
      "1/1 - 0s - loss: 2.2648 - val_loss: 2.2711 - 182ms/epoch - 182ms/step\n",
      "Epoch 152/2000\n",
      "1/1 - 0s - loss: 2.2622 - val_loss: 2.2682 - 204ms/epoch - 204ms/step\n",
      "Epoch 153/2000\n",
      "1/1 - 0s - loss: 2.2594 - val_loss: 2.2655 - 145ms/epoch - 145ms/step\n",
      "Epoch 154/2000\n",
      "1/1 - 0s - loss: 2.2567 - val_loss: 2.2621 - 159ms/epoch - 159ms/step\n",
      "Epoch 155/2000\n",
      "1/1 - 0s - loss: 2.2534 - val_loss: 2.2589 - 167ms/epoch - 167ms/step\n",
      "Epoch 156/2000\n",
      "1/1 - 0s - loss: 2.2501 - val_loss: 2.2560 - 174ms/epoch - 174ms/step\n",
      "Epoch 157/2000\n",
      "1/1 - 0s - loss: 2.2472 - val_loss: 2.2532 - 193ms/epoch - 193ms/step\n",
      "Epoch 158/2000\n",
      "1/1 - 0s - loss: 2.2445 - val_loss: 2.2508 - 190ms/epoch - 190ms/step\n",
      "Epoch 159/2000\n",
      "1/1 - 0s - loss: 2.2420 - val_loss: 2.2481 - 166ms/epoch - 166ms/step\n",
      "Epoch 160/2000\n",
      "1/1 - 0s - loss: 2.2394 - val_loss: 2.2453 - 189ms/epoch - 189ms/step\n",
      "Epoch 161/2000\n",
      "1/1 - 0s - loss: 2.2366 - val_loss: 2.2426 - 159ms/epoch - 159ms/step\n",
      "Epoch 162/2000\n",
      "1/1 - 0s - loss: 2.2339 - val_loss: 2.2397 - 154ms/epoch - 154ms/step\n",
      "Epoch 163/2000\n",
      "1/1 - 0s - loss: 2.2311 - val_loss: 2.2369 - 156ms/epoch - 156ms/step\n",
      "Epoch 164/2000\n",
      "1/1 - 0s - loss: 2.2283 - val_loss: 2.2343 - 152ms/epoch - 152ms/step\n",
      "Epoch 165/2000\n",
      "1/1 - 0s - loss: 2.2257 - val_loss: 2.2316 - 154ms/epoch - 154ms/step\n",
      "Epoch 166/2000\n",
      "1/1 - 0s - loss: 2.2231 - val_loss: 2.2291 - 152ms/epoch - 152ms/step\n",
      "Epoch 167/2000\n",
      "1/1 - 0s - loss: 2.2206 - val_loss: 2.2268 - 156ms/epoch - 156ms/step\n",
      "Epoch 168/2000\n",
      "1/1 - 0s - loss: 2.2183 - val_loss: 2.2245 - 180ms/epoch - 180ms/step\n",
      "Epoch 169/2000\n",
      "1/1 - 0s - loss: 2.2160 - val_loss: 2.2224 - 159ms/epoch - 159ms/step\n",
      "Epoch 170/2000\n",
      "1/1 - 0s - loss: 2.2138 - val_loss: 2.2205 - 155ms/epoch - 155ms/step\n",
      "Epoch 171/2000\n",
      "1/1 - 0s - loss: 2.2119 - val_loss: 2.2190 - 158ms/epoch - 158ms/step\n",
      "Epoch 172/2000\n",
      "1/1 - 0s - loss: 2.2105 - val_loss: 2.2177 - 164ms/epoch - 164ms/step\n",
      "Epoch 173/2000\n",
      "1/1 - 0s - loss: 2.2092 - val_loss: 2.2164 - 164ms/epoch - 164ms/step\n",
      "Epoch 174/2000\n",
      "1/1 - 0s - loss: 2.2079 - val_loss: 2.2134 - 165ms/epoch - 165ms/step\n",
      "Epoch 175/2000\n",
      "1/1 - 0s - loss: 2.2049 - val_loss: 2.2089 - 154ms/epoch - 154ms/step\n",
      "Epoch 176/2000\n",
      "1/1 - 0s - loss: 2.2005 - val_loss: 2.2058 - 160ms/epoch - 160ms/step\n",
      "Epoch 177/2000\n",
      "1/1 - 0s - loss: 2.1974 - val_loss: 2.2047 - 179ms/epoch - 179ms/step\n",
      "Epoch 178/2000\n",
      "1/1 - 0s - loss: 2.1963 - val_loss: 2.2033 - 177ms/epoch - 177ms/step\n",
      "Epoch 179/2000\n",
      "1/1 - 0s - loss: 2.1948 - val_loss: 2.2002 - 174ms/epoch - 174ms/step\n",
      "Epoch 180/2000\n",
      "1/1 - 0s - loss: 2.1918 - val_loss: 2.1972 - 182ms/epoch - 182ms/step\n",
      "Epoch 181/2000\n",
      "1/1 - 0s - loss: 2.1888 - val_loss: 2.1957 - 182ms/epoch - 182ms/step\n",
      "Epoch 182/2000\n",
      "1/1 - 0s - loss: 2.1873 - val_loss: 2.1943 - 179ms/epoch - 179ms/step\n",
      "Epoch 183/2000\n",
      "1/1 - 0s - loss: 2.1858 - val_loss: 2.1916 - 180ms/epoch - 180ms/step\n",
      "Epoch 184/2000\n",
      "1/1 - 0s - loss: 2.1832 - val_loss: 2.1891 - 183ms/epoch - 183ms/step\n",
      "Epoch 185/2000\n",
      "1/1 - 0s - loss: 2.1807 - val_loss: 2.1876 - 178ms/epoch - 178ms/step\n",
      "Epoch 186/2000\n",
      "1/1 - 0s - loss: 2.1792 - val_loss: 2.1860 - 169ms/epoch - 169ms/step\n",
      "Epoch 187/2000\n",
      "1/1 - 0s - loss: 2.1776 - val_loss: 2.1835 - 169ms/epoch - 169ms/step\n",
      "Epoch 188/2000\n",
      "1/1 - 0s - loss: 2.1751 - val_loss: 2.1814 - 160ms/epoch - 160ms/step\n",
      "Epoch 189/2000\n",
      "1/1 - 0s - loss: 2.1731 - val_loss: 2.1800 - 128ms/epoch - 128ms/step\n",
      "Epoch 190/2000\n",
      "1/1 - 0s - loss: 2.1717 - val_loss: 2.1781 - 126ms/epoch - 126ms/step\n",
      "Epoch 191/2000\n",
      "1/1 - 0s - loss: 2.1698 - val_loss: 2.1759 - 119ms/epoch - 119ms/step\n",
      "Epoch 192/2000\n",
      "1/1 - 0s - loss: 2.1676 - val_loss: 2.1741 - 118ms/epoch - 118ms/step\n",
      "Epoch 193/2000\n",
      "1/1 - 0s - loss: 2.1659 - val_loss: 2.1726 - 132ms/epoch - 132ms/step\n",
      "Epoch 194/2000\n",
      "1/1 - 0s - loss: 2.1643 - val_loss: 2.1707 - 129ms/epoch - 129ms/step\n",
      "Epoch 195/2000\n",
      "1/1 - 0s - loss: 2.1624 - val_loss: 2.1687 - 131ms/epoch - 131ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 196/2000\n",
      "1/1 - 0s - loss: 2.1605 - val_loss: 2.1672 - 163ms/epoch - 163ms/step\n",
      "Epoch 197/2000\n",
      "1/1 - 0s - loss: 2.1589 - val_loss: 2.1656 - 136ms/epoch - 136ms/step\n",
      "Epoch 198/2000\n",
      "1/1 - 0s - loss: 2.1573 - val_loss: 2.1637 - 154ms/epoch - 154ms/step\n",
      "Epoch 199/2000\n",
      "1/1 - 0s - loss: 2.1555 - val_loss: 2.1620 - 143ms/epoch - 143ms/step\n",
      "Epoch 200/2000\n",
      "1/1 - 0s - loss: 2.1538 - val_loss: 2.1605 - 123ms/epoch - 123ms/step\n",
      "Epoch 201/2000\n",
      "1/1 - 0s - loss: 2.1523 - val_loss: 2.1589 - 118ms/epoch - 118ms/step\n",
      "Epoch 202/2000\n",
      "1/1 - 0s - loss: 2.1507 - val_loss: 2.1572 - 125ms/epoch - 125ms/step\n",
      "Epoch 203/2000\n",
      "1/1 - 0s - loss: 2.1490 - val_loss: 2.1555 - 113ms/epoch - 113ms/step\n",
      "Epoch 204/2000\n",
      "1/1 - 0s - loss: 2.1474 - val_loss: 2.1541 - 117ms/epoch - 117ms/step\n",
      "Epoch 205/2000\n",
      "1/1 - 0s - loss: 2.1459 - val_loss: 2.1526 - 123ms/epoch - 123ms/step\n",
      "Epoch 206/2000\n",
      "1/1 - 0s - loss: 2.1444 - val_loss: 2.1510 - 116ms/epoch - 116ms/step\n",
      "Epoch 207/2000\n",
      "1/1 - 0s - loss: 2.1428 - val_loss: 2.1494 - 117ms/epoch - 117ms/step\n",
      "Epoch 208/2000\n",
      "1/1 - 0s - loss: 2.1413 - val_loss: 2.1480 - 122ms/epoch - 122ms/step\n",
      "Epoch 209/2000\n",
      "1/1 - 0s - loss: 2.1399 - val_loss: 2.1465 - 147ms/epoch - 147ms/step\n",
      "Epoch 210/2000\n",
      "1/1 - 0s - loss: 2.1384 - val_loss: 2.1451 - 135ms/epoch - 135ms/step\n",
      "Epoch 211/2000\n",
      "1/1 - 0s - loss: 2.1370 - val_loss: 2.1436 - 132ms/epoch - 132ms/step\n",
      "Epoch 212/2000\n",
      "1/1 - 0s - loss: 2.1355 - val_loss: 2.1422 - 130ms/epoch - 130ms/step\n",
      "Epoch 213/2000\n",
      "1/1 - 0s - loss: 2.1341 - val_loss: 2.1408 - 113ms/epoch - 113ms/step\n",
      "Epoch 214/2000\n",
      "1/1 - 0s - loss: 2.1327 - val_loss: 2.1394 - 108ms/epoch - 108ms/step\n",
      "Epoch 215/2000\n",
      "1/1 - 0s - loss: 2.1314 - val_loss: 2.1381 - 138ms/epoch - 138ms/step\n",
      "Epoch 216/2000\n",
      "1/1 - 0s - loss: 2.1300 - val_loss: 2.1367 - 140ms/epoch - 140ms/step\n",
      "Epoch 217/2000\n",
      "1/1 - 0s - loss: 2.1286 - val_loss: 2.1353 - 113ms/epoch - 113ms/step\n",
      "Epoch 218/2000\n",
      "1/1 - 0s - loss: 2.1273 - val_loss: 2.1340 - 108ms/epoch - 108ms/step\n",
      "Epoch 219/2000\n",
      "1/1 - 0s - loss: 2.1260 - val_loss: 2.1327 - 106ms/epoch - 106ms/step\n",
      "Epoch 220/2000\n",
      "1/1 - 0s - loss: 2.1247 - val_loss: 2.1315 - 117ms/epoch - 117ms/step\n",
      "Epoch 221/2000\n",
      "1/1 - 0s - loss: 2.1234 - val_loss: 2.1302 - 152ms/epoch - 152ms/step\n",
      "Epoch 222/2000\n",
      "1/1 - 0s - loss: 2.1222 - val_loss: 2.1289 - 120ms/epoch - 120ms/step\n",
      "Epoch 223/2000\n",
      "1/1 - 0s - loss: 2.1209 - val_loss: 2.1277 - 130ms/epoch - 130ms/step\n",
      "Epoch 224/2000\n",
      "1/1 - 0s - loss: 2.1196 - val_loss: 2.1265 - 118ms/epoch - 118ms/step\n",
      "Epoch 225/2000\n",
      "1/1 - 0s - loss: 2.1184 - val_loss: 2.1253 - 156ms/epoch - 156ms/step\n",
      "Epoch 226/2000\n",
      "1/1 - 0s - loss: 2.1172 - val_loss: 2.1241 - 121ms/epoch - 121ms/step\n",
      "Epoch 227/2000\n",
      "1/1 - 0s - loss: 2.1160 - val_loss: 2.1229 - 133ms/epoch - 133ms/step\n",
      "Epoch 228/2000\n",
      "1/1 - 0s - loss: 2.1148 - val_loss: 2.1217 - 135ms/epoch - 135ms/step\n",
      "Epoch 229/2000\n",
      "1/1 - 0s - loss: 2.1137 - val_loss: 2.1206 - 163ms/epoch - 163ms/step\n",
      "Epoch 230/2000\n",
      "1/1 - 0s - loss: 2.1125 - val_loss: 2.1194 - 138ms/epoch - 138ms/step\n",
      "Epoch 231/2000\n",
      "1/1 - 0s - loss: 2.1114 - val_loss: 2.1183 - 146ms/epoch - 146ms/step\n",
      "Epoch 232/2000\n",
      "1/1 - 0s - loss: 2.1103 - val_loss: 2.1172 - 136ms/epoch - 136ms/step\n",
      "Epoch 233/2000\n",
      "1/1 - 0s - loss: 2.1091 - val_loss: 2.1161 - 146ms/epoch - 146ms/step\n",
      "Epoch 234/2000\n",
      "1/1 - 0s - loss: 2.1080 - val_loss: 2.1150 - 150ms/epoch - 150ms/step\n",
      "Epoch 235/2000\n",
      "1/1 - 0s - loss: 2.1070 - val_loss: 2.1140 - 151ms/epoch - 151ms/step\n",
      "Epoch 236/2000\n",
      "1/1 - 0s - loss: 2.1059 - val_loss: 2.1129 - 173ms/epoch - 173ms/step\n",
      "Epoch 237/2000\n",
      "1/1 - 0s - loss: 2.1049 - val_loss: 2.1119 - 148ms/epoch - 148ms/step\n",
      "Epoch 238/2000\n",
      "1/1 - 0s - loss: 2.1038 - val_loss: 2.1110 - 153ms/epoch - 153ms/step\n",
      "Epoch 239/2000\n",
      "1/1 - 0s - loss: 2.1029 - val_loss: 2.1101 - 154ms/epoch - 154ms/step\n",
      "Epoch 240/2000\n",
      "1/1 - 0s - loss: 2.1020 - val_loss: 2.1093 - 148ms/epoch - 148ms/step\n",
      "Epoch 241/2000\n",
      "1/1 - 0s - loss: 2.1013 - val_loss: 2.1090 - 169ms/epoch - 169ms/step\n",
      "Epoch 242/2000\n",
      "1/1 - 0s - loss: 2.1008 - val_loss: 2.1087 - 173ms/epoch - 173ms/step\n",
      "Epoch 243/2000\n",
      "1/1 - 0s - loss: 2.1007 - val_loss: 2.1094 - 136ms/epoch - 136ms/step\n",
      "Epoch 244/2000\n",
      "1/1 - 0s - loss: 2.1013 - val_loss: 2.1086 - 143ms/epoch - 143ms/step\n",
      "Epoch 245/2000\n",
      "1/1 - 0s - loss: 2.1006 - val_loss: 2.1074 - 134ms/epoch - 134ms/step\n",
      "Epoch 246/2000\n",
      "1/1 - 0s - loss: 2.0993 - val_loss: 2.1040 - 147ms/epoch - 147ms/step\n",
      "Epoch 247/2000\n",
      "1/1 - 0s - loss: 2.0958 - val_loss: 2.1026 - 163ms/epoch - 163ms/step\n",
      "Epoch 248/2000\n",
      "1/1 - 0s - loss: 2.0944 - val_loss: 2.1032 - 169ms/epoch - 169ms/step\n",
      "Epoch 249/2000\n",
      "1/1 - 0s - loss: 2.0951 - val_loss: 2.1017 - 176ms/epoch - 176ms/step\n",
      "Epoch 250/2000\n",
      "1/1 - 0s - loss: 2.0936 - val_loss: 2.0997 - 151ms/epoch - 151ms/step\n",
      "Epoch 251/2000\n",
      "1/1 - 0s - loss: 2.0915 - val_loss: 2.0993 - 158ms/epoch - 158ms/step\n",
      "Epoch 252/2000\n",
      "1/1 - 0s - loss: 2.0911 - val_loss: 2.0986 - 156ms/epoch - 156ms/step\n",
      "Epoch 253/2000\n",
      "1/1 - 0s - loss: 2.0904 - val_loss: 2.0971 - 144ms/epoch - 144ms/step\n",
      "Epoch 254/2000\n",
      "1/1 - 0s - loss: 2.0888 - val_loss: 2.0964 - 171ms/epoch - 171ms/step\n",
      "Epoch 255/2000\n",
      "1/1 - 0s - loss: 2.0881 - val_loss: 2.0956 - 138ms/epoch - 138ms/step\n",
      "Epoch 256/2000\n",
      "1/1 - 0s - loss: 2.0873 - val_loss: 2.0944 - 138ms/epoch - 138ms/step\n",
      "Epoch 257/2000\n",
      "1/1 - 0s - loss: 2.0860 - val_loss: 2.0937 - 130ms/epoch - 130ms/step\n",
      "Epoch 258/2000\n",
      "1/1 - 0s - loss: 2.0854 - val_loss: 2.0929 - 133ms/epoch - 133ms/step\n",
      "Epoch 259/2000\n",
      "1/1 - 0s - loss: 2.0846 - val_loss: 2.0918 - 141ms/epoch - 141ms/step\n",
      "Epoch 260/2000\n",
      "1/1 - 0s - loss: 2.0835 - val_loss: 2.0912 - 138ms/epoch - 138ms/step\n",
      "Epoch 261/2000\n",
      "1/1 - 0s - loss: 2.0828 - val_loss: 2.0903 - 134ms/epoch - 134ms/step\n",
      "Epoch 262/2000\n",
      "1/1 - 0s - loss: 2.0820 - val_loss: 2.0894 - 130ms/epoch - 130ms/step\n",
      "Epoch 263/2000\n",
      "1/1 - 0s - loss: 2.0810 - val_loss: 2.0888 - 131ms/epoch - 131ms/step\n",
      "Epoch 264/2000\n",
      "1/1 - 0s - loss: 2.0804 - val_loss: 2.0879 - 134ms/epoch - 134ms/step\n",
      "Epoch 265/2000\n",
      "1/1 - 0s - loss: 2.0795 - val_loss: 2.0870 - 130ms/epoch - 130ms/step\n",
      "Epoch 266/2000\n",
      "1/1 - 0s - loss: 2.0786 - val_loss: 2.0864 - 139ms/epoch - 139ms/step\n",
      "Epoch 267/2000\n",
      "1/1 - 0s - loss: 2.0780 - val_loss: 2.0856 - 131ms/epoch - 131ms/step\n",
      "Epoch 268/2000\n",
      "1/1 - 0s - loss: 2.0772 - val_loss: 2.0848 - 149ms/epoch - 149ms/step\n",
      "Epoch 269/2000\n",
      "1/1 - 0s - loss: 2.0764 - val_loss: 2.0842 - 130ms/epoch - 130ms/step\n",
      "Epoch 270/2000\n",
      "1/1 - 0s - loss: 2.0757 - val_loss: 2.0834 - 127ms/epoch - 127ms/step\n",
      "Epoch 271/2000\n",
      "1/1 - 0s - loss: 2.0750 - val_loss: 2.0827 - 130ms/epoch - 130ms/step\n",
      "Epoch 272/2000\n",
      "1/1 - 0s - loss: 2.0742 - val_loss: 2.0820 - 128ms/epoch - 128ms/step\n",
      "Epoch 273/2000\n",
      "1/1 - 0s - loss: 2.0736 - val_loss: 2.0813 - 117ms/epoch - 117ms/step\n",
      "Epoch 274/2000\n",
      "1/1 - 0s - loss: 2.0728 - val_loss: 2.0806 - 126ms/epoch - 126ms/step\n",
      "Epoch 275/2000\n",
      "1/1 - 0s - loss: 2.0721 - val_loss: 2.0799 - 147ms/epoch - 147ms/step\n",
      "Epoch 276/2000\n",
      "1/1 - 0s - loss: 2.0714 - val_loss: 2.0792 - 147ms/epoch - 147ms/step\n",
      "Epoch 277/2000\n",
      "1/1 - 0s - loss: 2.0707 - val_loss: 2.0786 - 147ms/epoch - 147ms/step\n",
      "Epoch 278/2000\n",
      "1/1 - 0s - loss: 2.0700 - val_loss: 2.0779 - 149ms/epoch - 149ms/step\n",
      "Epoch 279/2000\n",
      "1/1 - 0s - loss: 2.0694 - val_loss: 2.0773 - 146ms/epoch - 146ms/step\n",
      "Epoch 280/2000\n",
      "1/1 - 0s - loss: 2.0687 - val_loss: 2.0766 - 145ms/epoch - 145ms/step\n",
      "Epoch 281/2000\n",
      "1/1 - 0s - loss: 2.0681 - val_loss: 2.0760 - 152ms/epoch - 152ms/step\n",
      "Epoch 282/2000\n",
      "1/1 - 0s - loss: 2.0674 - val_loss: 2.0754 - 129ms/epoch - 129ms/step\n",
      "Epoch 283/2000\n",
      "1/1 - 0s - loss: 2.0668 - val_loss: 2.0748 - 120ms/epoch - 120ms/step\n",
      "Epoch 284/2000\n",
      "1/1 - 0s - loss: 2.0661 - val_loss: 2.0742 - 117ms/epoch - 117ms/step\n",
      "Epoch 285/2000\n",
      "1/1 - 0s - loss: 2.0655 - val_loss: 2.0736 - 125ms/epoch - 125ms/step\n",
      "Epoch 286/2000\n",
      "1/1 - 0s - loss: 2.0649 - val_loss: 2.0730 - 124ms/epoch - 124ms/step\n",
      "Epoch 287/2000\n",
      "1/1 - 0s - loss: 2.0643 - val_loss: 2.0724 - 116ms/epoch - 116ms/step\n",
      "Epoch 288/2000\n",
      "1/1 - 0s - loss: 2.0637 - val_loss: 2.0718 - 116ms/epoch - 116ms/step\n",
      "Epoch 289/2000\n",
      "1/1 - 0s - loss: 2.0631 - val_loss: 2.0712 - 156ms/epoch - 156ms/step\n",
      "Epoch 290/2000\n",
      "1/1 - 0s - loss: 2.0625 - val_loss: 2.0707 - 133ms/epoch - 133ms/step\n",
      "Epoch 291/2000\n",
      "1/1 - 0s - loss: 2.0619 - val_loss: 2.0701 - 116ms/epoch - 116ms/step\n",
      "Epoch 292/2000\n",
      "1/1 - 0s - loss: 2.0613 - val_loss: 2.0696 - 123ms/epoch - 123ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 293/2000\n",
      "1/1 - 0s - loss: 2.0608 - val_loss: 2.0690 - 128ms/epoch - 128ms/step\n",
      "Epoch 294/2000\n",
      "1/1 - 0s - loss: 2.0602 - val_loss: 2.0685 - 118ms/epoch - 118ms/step\n",
      "Epoch 295/2000\n",
      "1/1 - 0s - loss: 2.0596 - val_loss: 2.0680 - 114ms/epoch - 114ms/step\n",
      "Epoch 296/2000\n",
      "1/1 - 0s - loss: 2.0591 - val_loss: 2.0674 - 130ms/epoch - 130ms/step\n",
      "Epoch 297/2000\n",
      "1/1 - 0s - loss: 2.0585 - val_loss: 2.0669 - 139ms/epoch - 139ms/step\n",
      "Epoch 298/2000\n",
      "1/1 - 0s - loss: 2.0580 - val_loss: 2.0664 - 113ms/epoch - 113ms/step\n",
      "Epoch 299/2000\n",
      "1/1 - 0s - loss: 2.0575 - val_loss: 2.0659 - 130ms/epoch - 130ms/step\n",
      "Epoch 300/2000\n",
      "1/1 - 0s - loss: 2.0569 - val_loss: 2.0654 - 103ms/epoch - 103ms/step\n",
      "Epoch 301/2000\n",
      "1/1 - 0s - loss: 2.0564 - val_loss: 2.0649 - 109ms/epoch - 109ms/step\n",
      "Epoch 302/2000\n",
      "1/1 - 0s - loss: 2.0559 - val_loss: 2.0644 - 111ms/epoch - 111ms/step\n",
      "Epoch 303/2000\n",
      "1/1 - 0s - loss: 2.0554 - val_loss: 2.0639 - 144ms/epoch - 144ms/step\n",
      "Epoch 304/2000\n",
      "1/1 - 0s - loss: 2.0549 - val_loss: 2.0634 - 113ms/epoch - 113ms/step\n",
      "Epoch 305/2000\n",
      "1/1 - 0s - loss: 2.0544 - val_loss: 2.0629 - 113ms/epoch - 113ms/step\n",
      "Epoch 306/2000\n",
      "1/1 - 0s - loss: 2.0539 - val_loss: 2.0625 - 108ms/epoch - 108ms/step\n",
      "Epoch 307/2000\n",
      "1/1 - 0s - loss: 2.0534 - val_loss: 2.0620 - 130ms/epoch - 130ms/step\n",
      "Epoch 308/2000\n",
      "1/1 - 0s - loss: 2.0529 - val_loss: 2.0615 - 133ms/epoch - 133ms/step\n",
      "Epoch 309/2000\n",
      "1/1 - 0s - loss: 2.0524 - val_loss: 2.0611 - 134ms/epoch - 134ms/step\n",
      "Epoch 310/2000\n",
      "1/1 - 0s - loss: 2.0519 - val_loss: 2.0606 - 128ms/epoch - 128ms/step\n",
      "Epoch 311/2000\n",
      "1/1 - 0s - loss: 2.0514 - val_loss: 2.0602 - 130ms/epoch - 130ms/step\n",
      "Epoch 312/2000\n",
      "1/1 - 0s - loss: 2.0510 - val_loss: 2.0597 - 127ms/epoch - 127ms/step\n",
      "Epoch 313/2000\n",
      "1/1 - 0s - loss: 2.0505 - val_loss: 2.0593 - 134ms/epoch - 134ms/step\n",
      "Epoch 314/2000\n",
      "1/1 - 0s - loss: 2.0500 - val_loss: 2.0588 - 130ms/epoch - 130ms/step\n",
      "Epoch 315/2000\n",
      "1/1 - 0s - loss: 2.0496 - val_loss: 2.0584 - 139ms/epoch - 139ms/step\n",
      "Epoch 316/2000\n",
      "1/1 - 0s - loss: 2.0491 - val_loss: 2.0580 - 128ms/epoch - 128ms/step\n",
      "Epoch 317/2000\n",
      "1/1 - 0s - loss: 2.0487 - val_loss: 2.0575 - 127ms/epoch - 127ms/step\n",
      "Epoch 318/2000\n",
      "1/1 - 0s - loss: 2.0483 - val_loss: 2.0571 - 133ms/epoch - 133ms/step\n",
      "Epoch 319/2000\n",
      "1/1 - 0s - loss: 2.0478 - val_loss: 2.0567 - 128ms/epoch - 128ms/step\n",
      "Epoch 320/2000\n",
      "1/1 - 0s - loss: 2.0474 - val_loss: 2.0563 - 145ms/epoch - 145ms/step\n",
      "Epoch 321/2000\n",
      "1/1 - 0s - loss: 2.0470 - val_loss: 2.0559 - 135ms/epoch - 135ms/step\n",
      "Epoch 322/2000\n",
      "1/1 - 0s - loss: 2.0465 - val_loss: 2.0555 - 139ms/epoch - 139ms/step\n",
      "Epoch 323/2000\n",
      "1/1 - 0s - loss: 2.0461 - val_loss: 2.0551 - 137ms/epoch - 137ms/step\n",
      "Epoch 324/2000\n",
      "1/1 - 0s - loss: 2.0457 - val_loss: 2.0547 - 149ms/epoch - 149ms/step\n",
      "Epoch 325/2000\n",
      "1/1 - 0s - loss: 2.0453 - val_loss: 2.0543 - 137ms/epoch - 137ms/step\n",
      "Epoch 326/2000\n",
      "1/1 - 0s - loss: 2.0449 - val_loss: 2.0539 - 132ms/epoch - 132ms/step\n",
      "Epoch 327/2000\n",
      "1/1 - 0s - loss: 2.0445 - val_loss: 2.0536 - 139ms/epoch - 139ms/step\n",
      "Epoch 328/2000\n",
      "1/1 - 0s - loss: 2.0441 - val_loss: 2.0532 - 136ms/epoch - 136ms/step\n",
      "Epoch 329/2000\n",
      "1/1 - 0s - loss: 2.0437 - val_loss: 2.0528 - 131ms/epoch - 131ms/step\n",
      "Epoch 330/2000\n",
      "1/1 - 0s - loss: 2.0433 - val_loss: 2.0525 - 128ms/epoch - 128ms/step\n",
      "Epoch 331/2000\n",
      "1/1 - 0s - loss: 2.0429 - val_loss: 2.0521 - 139ms/epoch - 139ms/step\n",
      "Epoch 332/2000\n",
      "1/1 - 0s - loss: 2.0425 - val_loss: 2.0517 - 148ms/epoch - 148ms/step\n",
      "Epoch 333/2000\n",
      "1/1 - 0s - loss: 2.0421 - val_loss: 2.0514 - 145ms/epoch - 145ms/step\n",
      "Epoch 334/2000\n",
      "1/1 - 0s - loss: 2.0417 - val_loss: 2.0510 - 149ms/epoch - 149ms/step\n",
      "Epoch 335/2000\n",
      "1/1 - 0s - loss: 2.0414 - val_loss: 2.0507 - 181ms/epoch - 181ms/step\n",
      "Epoch 336/2000\n",
      "1/1 - 0s - loss: 2.0410 - val_loss: 2.0503 - 147ms/epoch - 147ms/step\n",
      "Epoch 337/2000\n",
      "1/1 - 0s - loss: 2.0406 - val_loss: 2.0500 - 139ms/epoch - 139ms/step\n",
      "Epoch 338/2000\n",
      "1/1 - 0s - loss: 2.0403 - val_loss: 2.0496 - 186ms/epoch - 186ms/step\n",
      "Epoch 339/2000\n",
      "1/1 - 0s - loss: 2.0399 - val_loss: 2.0493 - 149ms/epoch - 149ms/step\n",
      "Epoch 340/2000\n",
      "1/1 - 0s - loss: 2.0395 - val_loss: 2.0490 - 146ms/epoch - 146ms/step\n",
      "Epoch 341/2000\n",
      "1/1 - 0s - loss: 2.0392 - val_loss: 2.0487 - 153ms/epoch - 153ms/step\n",
      "Epoch 342/2000\n",
      "1/1 - 0s - loss: 2.0388 - val_loss: 2.0483 - 172ms/epoch - 172ms/step\n",
      "Epoch 343/2000\n",
      "1/1 - 0s - loss: 2.0385 - val_loss: 2.0480 - 149ms/epoch - 149ms/step\n",
      "Epoch 344/2000\n",
      "1/1 - 0s - loss: 2.0381 - val_loss: 2.0477 - 146ms/epoch - 146ms/step\n",
      "Epoch 345/2000\n",
      "1/1 - 0s - loss: 2.0378 - val_loss: 2.0474 - 152ms/epoch - 152ms/step\n",
      "Epoch 346/2000\n",
      "1/1 - 0s - loss: 2.0375 - val_loss: 2.0471 - 159ms/epoch - 159ms/step\n",
      "Epoch 347/2000\n",
      "1/1 - 0s - loss: 2.0372 - val_loss: 2.0468 - 185ms/epoch - 185ms/step\n",
      "Epoch 348/2000\n",
      "1/1 - 0s - loss: 2.0369 - val_loss: 2.0466 - 168ms/epoch - 168ms/step\n",
      "Epoch 349/2000\n",
      "1/1 - 0s - loss: 2.0366 - val_loss: 2.0465 - 144ms/epoch - 144ms/step\n",
      "Epoch 350/2000\n",
      "1/1 - 0s - loss: 2.0365 - val_loss: 2.0466 - 150ms/epoch - 150ms/step\n",
      "Epoch 351/2000\n",
      "1/1 - 0s - loss: 2.0366 - val_loss: 2.0470 - 143ms/epoch - 143ms/step\n",
      "Epoch 352/2000\n",
      "1/1 - 0s - loss: 2.0370 - val_loss: 2.0485 - 124ms/epoch - 124ms/step\n",
      "Epoch 353/2000\n",
      "1/1 - 0s - loss: 2.0384 - val_loss: 2.0498 - 151ms/epoch - 151ms/step\n",
      "Epoch 354/2000\n",
      "1/1 - 0s - loss: 2.0397 - val_loss: 2.0502 - 173ms/epoch - 173ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 10\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "position_embeddings = PositionEmbedding(sequence_length=len(x_masked_train[0]))(word_embeddings)\n",
    "encoder_output = word_embeddings + position_embeddings\n",
    "\n",
    "for i in range(5):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "aa33aaff",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fe4d2103",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0, 0.018784668)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0005b730",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 6s - loss: 3.5969 - val_loss: 3.2553 - 6s/epoch - 6s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.2400 - val_loss: 3.0435 - 442ms/epoch - 442ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.0411 - val_loss: 2.9438 - 432ms/epoch - 432ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 2.9469 - val_loss: 2.9176 - 431ms/epoch - 431ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 1s - loss: 2.9201 - val_loss: 2.8815 - 513ms/epoch - 513ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 1s - loss: 2.8821 - val_loss: 2.8532 - 506ms/epoch - 506ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 2.8509 - val_loss: 2.8274 - 471ms/epoch - 471ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 2.8212 - val_loss: 2.8148 - 495ms/epoch - 495ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 2.8052 - val_loss: 2.8164 - 466ms/epoch - 466ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 2.8044 - val_loss: 2.8180 - 451ms/epoch - 451ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 2.8050 - val_loss: 2.8127 - 457ms/epoch - 457ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 2.7998 - val_loss: 2.8032 - 483ms/epoch - 483ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 2.7911 - val_loss: 2.7920 - 482ms/epoch - 482ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 2.7809 - val_loss: 2.7819 - 456ms/epoch - 456ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 2.7717 - val_loss: 2.7757 - 420ms/epoch - 420ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 2.7667 - val_loss: 2.7741 - 451ms/epoch - 451ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 2.7659 - val_loss: 2.7735 - 414ms/epoch - 414ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 2.7656 - val_loss: 2.7716 - 416ms/epoch - 416ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 2.7631 - val_loss: 2.7684 - 425ms/epoch - 425ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 2.7586 - val_loss: 2.7652 - 434ms/epoch - 434ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 2.7538 - val_loss: 2.7626 - 447ms/epoch - 447ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 2.7497 - val_loss: 2.7596 - 456ms/epoch - 456ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 2.7457 - val_loss: 2.7553 - 436ms/epoch - 436ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 2.7406 - val_loss: 2.7493 - 475ms/epoch - 475ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 2.7340 - val_loss: 2.7418 - 446ms/epoch - 446ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 2.7260 - val_loss: 2.7325 - 432ms/epoch - 432ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 2.7162 - val_loss: 2.7205 - 442ms/epoch - 442ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.7036 - val_loss: 2.7033 - 447ms/epoch - 447ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.6855 - val_loss: 2.6783 - 451ms/epoch - 451ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.6590 - val_loss: 2.6427 - 464ms/epoch - 464ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.6220 - val_loss: 2.5909 - 466ms/epoch - 466ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.5687 - val_loss: 2.5164 - 462ms/epoch - 462ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 1s - loss: 2.4923 - val_loss: 2.4199 - 550ms/epoch - 550ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.3937 - val_loss: 2.3193 - 474ms/epoch - 474ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 1s - loss: 2.2933 - val_loss: 2.2293 - 522ms/epoch - 522ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 1s - loss: 2.2038 - val_loss: 2.1627 - 558ms/epoch - 558ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.1368 - val_loss: 2.1303 - 468ms/epoch - 468ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.1012 - val_loss: 2.1098 - 461ms/epoch - 461ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.0801 - val_loss: 2.0896 - 437ms/epoch - 437ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.0600 - val_loss: 2.0733 - 422ms/epoch - 422ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.0441 - val_loss: 2.0629 - 435ms/epoch - 435ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.0345 - val_loss: 2.0536 - 435ms/epoch - 435ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.0267 - val_loss: 2.0435 - 425ms/epoch - 425ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.0187 - val_loss: 2.0340 - 409ms/epoch - 409ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.0115 - val_loss: 2.0272 - 442ms/epoch - 442ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.0069 - val_loss: 2.0227 - 486ms/epoch - 486ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 1s - loss: 2.0045 - val_loss: 2.0191 - 511ms/epoch - 511ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.0026 - val_loss: 2.0161 - 473ms/epoch - 473ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.0007 - val_loss: 2.0139 - 451ms/epoch - 451ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 1.9991 - val_loss: 2.0126 - 468ms/epoch - 468ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 1.9975 - val_loss: 2.0117 - 431ms/epoch - 431ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 1.9960 - val_loss: 2.0113 - 406ms/epoch - 406ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 1.9944 - val_loss: 2.0111 - 417ms/epoch - 417ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 1.9931 - val_loss: 2.0109 - 425ms/epoch - 425ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 1.9919 - val_loss: 2.0105 - 432ms/epoch - 432ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 1.9908 - val_loss: 2.0100 - 482ms/epoch - 482ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 1s - loss: 1.9899 - val_loss: 2.0095 - 505ms/epoch - 505ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 1s - loss: 1.9892 - val_loss: 2.0092 - 553ms/epoch - 553ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 1s - loss: 1.9886 - val_loss: 2.0091 - 607ms/epoch - 607ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 1.9880 - val_loss: 2.0091 - 477ms/epoch - 477ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 1.9873 - val_loss: 2.0092 - 469ms/epoch - 469ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 1s - loss: 1.9865 - val_loss: 2.0094 - 503ms/epoch - 503ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 1.9857 - val_loss: 2.0096 - 439ms/epoch - 439ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 1.9849 - val_loss: 2.0096 - 426ms/epoch - 426ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 1.9840 - val_loss: 2.0095 - 459ms/epoch - 459ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 100\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "position_embeddings = PositionEmbedding(sequence_length=len(x_masked_train[0]))(word_embeddings)\n",
    "encoder_output = word_embeddings + position_embeddings\n",
    "\n",
    "for i in range(5):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "beb07cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "dbdd4a69",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0, 0.007304905)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "54f31bac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[10,  0,  1, 10],\n",
       "       [10,  0,  2, 10],\n",
       "       [10,  0,  3, 10],\n",
       "       ...,\n",
       "       [19, 18, 15, 19],\n",
       "       [19, 18, 16, 19],\n",
       "       [19, 18, 17, 19]])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_masked_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "458ae205",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
