{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d1686d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "56c107e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "504a342e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_module(query, key, value, embed_dim, num_head, i):\n",
    "    \n",
    "    # Multi headed self-attention\n",
    "    attention_output = layers.MultiHeadAttention(\n",
    "        num_heads=num_head,\n",
    "        key_dim=embed_dim // num_head,\n",
    "        name=\"encoder_{}/multiheadattention\".format(i)\n",
    "    )(query, key, value, use_causal_mask=True)\n",
    "    \n",
    "    # Add & Normalize\n",
    "    attention_output = layers.Add()([query, attention_output])  # Skip Connection\n",
    "    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)\n",
    "    \n",
    "    # Feedforward network\n",
    "    ff_net = keras.models.Sequential([\n",
    "        layers.Dense(2 * embed_dim, activation='relu', name=\"encoder_{}/ffn_dense_1\".format(i)),\n",
    "        layers.Dense(embed_dim, name=\"encoder_{}/ffn_dense_2\".format(i)),\n",
    "    ])\n",
    "\n",
    "    # Apply Feedforward network\n",
    "    ffn_output = ff_net(attention_output)\n",
    "\n",
    "    # Add & Normalize\n",
    "    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection\n",
    "    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)\n",
    "    \n",
    "    return ffn_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fef2622a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sinusoidal_embeddings(sequence_length, embedding_dim):\n",
    "    position_enc = np.array([\n",
    "        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]\n",
    "        if pos != 0 else np.zeros(embedding_dim)\n",
    "        for pos in range(sequence_length)\n",
    "    ])\n",
    "    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i\n",
    "    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1\n",
    "    return tf.cast(position_enc, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3e5bfd69",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 20 # vocab_size\n",
    "\n",
    "vocabs = ['word_' + str(i) for i in range(N)]\n",
    "\n",
    "vocab_map = {}\n",
    "for i in range(len(vocabs)):\n",
    "    vocab_map[vocabs[i]] = i\n",
    "    \n",
    "pairs = []\n",
    "\n",
    "for i in vocabs:\n",
    "    for j in vocabs:\n",
    "        for k in vocabs:\n",
    "            if i != j and i != k and j != k:\n",
    "                pairs.append((i,j,k))\n",
    "            \n",
    "#indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])\n",
    "\n",
    "# pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]\n",
    "# pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]\n",
    "\n",
    "pairs_train = [x for x in pairs if int(x[0].split('_')[-1]) <= 9]\n",
    "pairs_test = [x for x in pairs if int(x[0].split('_')[-1]) >= 10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d43093fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences_train = []\n",
    "sentences_number_train = []\n",
    "sentences_test = []\n",
    "sentences_number_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    sentences_train.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    sentences_test.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = []\n",
    "y_masked_labels_train = []\n",
    "x_masked_test = []\n",
    "y_masked_labels_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    x_masked_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_train.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    x_masked_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_test.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = np.array(x_masked_train)\n",
    "y_masked_labels_train = np.array(y_masked_labels_train)\n",
    "x_masked_test = np.array(x_masked_test)\n",
    "y_masked_labels_test = np.array(y_masked_labels_test)\n",
    "\n",
    "perm = np.random.permutation(len(x_masked_train))\n",
    "x_masked_train = x_masked_train[perm]\n",
    "y_masked_labels_train = y_masked_labels_train[perm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "13b40f89",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-22 07:06:20.584513: W tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory\n",
      "2024-05-22 07:06:20.584555: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)\n",
      "2024-05-22 07:06:20.584576: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (gl3035.arc-ts.umich.edu): /proc/driver/nvidia/version does not exist\n",
      "2024-05-22 07:06:20.584787: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 1s - loss: 3.3615 - val_loss: 3.3031 - 1s/epoch - 1s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.3175 - val_loss: 3.2610 - 54ms/epoch - 54ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.2739 - val_loss: 3.2230 - 49ms/epoch - 49ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.2332 - val_loss: 3.1907 - 53ms/epoch - 53ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.1976 - val_loss: 3.1645 - 62ms/epoch - 62ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.1674 - val_loss: 3.1434 - 57ms/epoch - 57ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.1424 - val_loss: 3.1264 - 53ms/epoch - 53ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.1220 - val_loss: 3.1128 - 62ms/epoch - 62ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.1053 - val_loss: 3.1016 - 53ms/epoch - 53ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 3.0915 - val_loss: 3.0921 - 79ms/epoch - 79ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 3.0798 - val_loss: 3.0836 - 60ms/epoch - 60ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 3.0698 - val_loss: 3.0760 - 64ms/epoch - 64ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 3.0608 - val_loss: 3.0689 - 59ms/epoch - 59ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 3.0527 - val_loss: 3.0621 - 56ms/epoch - 56ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 3.0451 - val_loss: 3.0556 - 47ms/epoch - 47ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 3.0380 - val_loss: 3.0493 - 50ms/epoch - 50ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 3.0313 - val_loss: 3.0432 - 55ms/epoch - 55ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 3.0248 - val_loss: 3.0372 - 50ms/epoch - 50ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 3.0186 - val_loss: 3.0315 - 57ms/epoch - 57ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 3.0127 - val_loss: 3.0260 - 52ms/epoch - 52ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 3.0071 - val_loss: 3.0208 - 50ms/epoch - 50ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 3.0019 - val_loss: 3.0160 - 65ms/epoch - 65ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 2.9970 - val_loss: 3.0115 - 65ms/epoch - 65ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 2.9924 - val_loss: 3.0073 - 49ms/epoch - 49ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 2.9882 - val_loss: 3.0035 - 64ms/epoch - 64ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 2.9843 - val_loss: 3.0000 - 81ms/epoch - 81ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 2.9807 - val_loss: 2.9967 - 73ms/epoch - 73ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.9773 - val_loss: 2.9937 - 64ms/epoch - 64ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.9742 - val_loss: 2.9909 - 97ms/epoch - 97ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.9712 - val_loss: 2.9882 - 46ms/epoch - 46ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.9685 - val_loss: 2.9857 - 72ms/epoch - 72ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.9658 - val_loss: 2.9833 - 72ms/epoch - 72ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.9633 - val_loss: 2.9809 - 65ms/epoch - 65ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.9608 - val_loss: 2.9786 - 78ms/epoch - 78ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.9583 - val_loss: 2.9763 - 134ms/epoch - 134ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.9559 - val_loss: 2.9741 - 76ms/epoch - 76ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.9536 - val_loss: 2.9719 - 126ms/epoch - 126ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.9513 - val_loss: 2.9698 - 79ms/epoch - 79ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.9491 - val_loss: 2.9677 - 50ms/epoch - 50ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.9469 - val_loss: 2.9658 - 53ms/epoch - 53ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.9449 - val_loss: 2.9638 - 74ms/epoch - 74ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.9429 - val_loss: 2.9620 - 63ms/epoch - 63ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.9410 - val_loss: 2.9603 - 54ms/epoch - 54ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.9391 - val_loss: 2.9586 - 56ms/epoch - 56ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.9374 - val_loss: 2.9570 - 46ms/epoch - 46ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.9357 - val_loss: 2.9554 - 53ms/epoch - 53ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.9340 - val_loss: 2.9540 - 77ms/epoch - 77ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.9324 - val_loss: 2.9526 - 51ms/epoch - 51ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.9309 - val_loss: 2.9512 - 51ms/epoch - 51ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.9294 - val_loss: 2.9499 - 93ms/epoch - 93ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.9279 - val_loss: 2.9486 - 68ms/epoch - 68ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.9264 - val_loss: 2.9473 - 62ms/epoch - 62ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.9250 - val_loss: 2.9461 - 76ms/epoch - 76ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.9236 - val_loss: 2.9450 - 70ms/epoch - 70ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.9223 - val_loss: 2.9438 - 53ms/epoch - 53ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.9209 - val_loss: 2.9427 - 71ms/epoch - 71ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.9196 - val_loss: 2.9415 - 99ms/epoch - 99ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.9183 - val_loss: 2.9404 - 75ms/epoch - 75ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.9170 - val_loss: 2.9393 - 75ms/epoch - 75ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.9158 - val_loss: 2.9382 - 57ms/epoch - 57ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.9145 - val_loss: 2.9372 - 64ms/epoch - 64ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.9133 - val_loss: 2.9361 - 43ms/epoch - 43ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.9121 - val_loss: 2.9351 - 76ms/epoch - 76ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.9109 - val_loss: 2.9341 - 61ms/epoch - 61ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.9097 - val_loss: 2.9331 - 64ms/epoch - 64ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.9085 - val_loss: 2.9320 - 75ms/epoch - 75ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.9073 - val_loss: 2.9310 - 83ms/epoch - 83ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.9061 - val_loss: 2.9300 - 67ms/epoch - 67ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.9049 - val_loss: 2.9290 - 64ms/epoch - 64ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.9038 - val_loss: 2.9280 - 49ms/epoch - 49ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.9026 - val_loss: 2.9270 - 48ms/epoch - 48ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.9014 - val_loss: 2.9260 - 64ms/epoch - 64ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.9002 - val_loss: 2.9250 - 45ms/epoch - 45ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.8989 - val_loss: 2.9240 - 47ms/epoch - 47ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.8977 - val_loss: 2.9230 - 59ms/epoch - 59ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.8965 - val_loss: 2.9220 - 62ms/epoch - 62ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.8953 - val_loss: 2.9210 - 73ms/epoch - 73ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.8941 - val_loss: 2.9201 - 48ms/epoch - 48ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.8928 - val_loss: 2.9191 - 54ms/epoch - 54ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.8916 - val_loss: 2.9181 - 49ms/epoch - 49ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.8904 - val_loss: 2.9171 - 46ms/epoch - 46ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.8891 - val_loss: 2.9161 - 64ms/epoch - 64ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.8878 - val_loss: 2.9151 - 82ms/epoch - 82ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.8865 - val_loss: 2.9141 - 91ms/epoch - 91ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.8852 - val_loss: 2.9131 - 78ms/epoch - 78ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.8839 - val_loss: 2.9120 - 98ms/epoch - 98ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.8826 - val_loss: 2.9110 - 49ms/epoch - 49ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.8813 - val_loss: 2.9100 - 50ms/epoch - 50ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.8800 - val_loss: 2.9089 - 64ms/epoch - 64ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.8786 - val_loss: 2.9079 - 79ms/epoch - 79ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.8773 - val_loss: 2.9068 - 64ms/epoch - 64ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.8759 - val_loss: 2.9058 - 64ms/epoch - 64ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.8745 - val_loss: 2.9047 - 62ms/epoch - 62ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.8731 - val_loss: 2.9036 - 63ms/epoch - 63ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.8717 - val_loss: 2.9025 - 47ms/epoch - 47ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.8703 - val_loss: 2.9014 - 73ms/epoch - 73ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 2.8689 - val_loss: 2.9003 - 64ms/epoch - 64ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 2.8675 - val_loss: 2.8991 - 74ms/epoch - 74ms/step\n",
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 2.8661 - val_loss: 2.8980 - 71ms/epoch - 71ms/step\n",
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 2.8647 - val_loss: 2.8970 - 46ms/epoch - 46ms/step\n",
      "Epoch 101/2000\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 - 0s - loss: 2.8634 - val_loss: 2.8959 - 80ms/epoch - 80ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 2.8620 - val_loss: 2.8949 - 80ms/epoch - 80ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 2.8606 - val_loss: 2.8938 - 59ms/epoch - 59ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 2.8592 - val_loss: 2.8928 - 49ms/epoch - 49ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 2.8579 - val_loss: 2.8918 - 51ms/epoch - 51ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 2.8566 - val_loss: 2.8908 - 62ms/epoch - 62ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 2.8552 - val_loss: 2.8898 - 62ms/epoch - 62ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 2.8539 - val_loss: 2.8888 - 52ms/epoch - 52ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 2.8526 - val_loss: 2.8877 - 81ms/epoch - 81ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 2.8513 - val_loss: 2.8866 - 68ms/epoch - 68ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 2.8500 - val_loss: 2.8856 - 50ms/epoch - 50ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 2.8488 - val_loss: 2.8845 - 59ms/epoch - 59ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 2.8475 - val_loss: 2.8835 - 53ms/epoch - 53ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 2.8462 - val_loss: 2.8824 - 55ms/epoch - 55ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 2.8449 - val_loss: 2.8814 - 50ms/epoch - 50ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 2.8436 - val_loss: 2.8804 - 65ms/epoch - 65ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 2.8423 - val_loss: 2.8793 - 51ms/epoch - 51ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 2.8411 - val_loss: 2.8782 - 48ms/epoch - 48ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 2.8398 - val_loss: 2.8771 - 55ms/epoch - 55ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 2.8384 - val_loss: 2.8760 - 53ms/epoch - 53ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 2.8371 - val_loss: 2.8749 - 49ms/epoch - 49ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 2.8358 - val_loss: 2.8738 - 56ms/epoch - 56ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 2.8345 - val_loss: 2.8727 - 49ms/epoch - 49ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 2.8331 - val_loss: 2.8716 - 51ms/epoch - 51ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 2.8318 - val_loss: 2.8704 - 56ms/epoch - 56ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 2.8304 - val_loss: 2.8693 - 53ms/epoch - 53ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 2.8291 - val_loss: 2.8680 - 50ms/epoch - 50ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 2.8278 - val_loss: 2.8668 - 48ms/epoch - 48ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 2.8265 - val_loss: 2.8655 - 47ms/epoch - 47ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 2.8252 - val_loss: 2.8642 - 60ms/epoch - 60ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 2.8239 - val_loss: 2.8629 - 60ms/epoch - 60ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 2.8226 - val_loss: 2.8616 - 48ms/epoch - 48ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 2.8213 - val_loss: 2.8601 - 57ms/epoch - 57ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 2.8200 - val_loss: 2.8587 - 56ms/epoch - 56ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 2.8187 - val_loss: 2.8573 - 47ms/epoch - 47ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 2.8174 - val_loss: 2.8559 - 55ms/epoch - 55ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 2.8160 - val_loss: 2.8546 - 76ms/epoch - 76ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 2.8147 - val_loss: 2.8532 - 74ms/epoch - 74ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 2.8134 - val_loss: 2.8519 - 85ms/epoch - 85ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 2.8121 - val_loss: 2.8506 - 94ms/epoch - 94ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 2.8107 - val_loss: 2.8492 - 100ms/epoch - 100ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 2.8094 - val_loss: 2.8478 - 71ms/epoch - 71ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 2.8080 - val_loss: 2.8464 - 72ms/epoch - 72ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 2.8066 - val_loss: 2.8451 - 64ms/epoch - 64ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 2.8052 - val_loss: 2.8438 - 76ms/epoch - 76ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 2.8038 - val_loss: 2.8424 - 79ms/epoch - 79ms/step\n",
      "Epoch 147/2000\n",
      "1/1 - 0s - loss: 2.8024 - val_loss: 2.8411 - 72ms/epoch - 72ms/step\n",
      "Epoch 148/2000\n",
      "1/1 - 0s - loss: 2.8011 - val_loss: 2.8398 - 80ms/epoch - 80ms/step\n",
      "Epoch 149/2000\n",
      "1/1 - 0s - loss: 2.7997 - val_loss: 2.8384 - 52ms/epoch - 52ms/step\n",
      "Epoch 150/2000\n",
      "1/1 - 0s - loss: 2.7983 - val_loss: 2.8371 - 103ms/epoch - 103ms/step\n",
      "Epoch 151/2000\n",
      "1/1 - 0s - loss: 2.7969 - val_loss: 2.8356 - 69ms/epoch - 69ms/step\n",
      "Epoch 152/2000\n",
      "1/1 - 0s - loss: 2.7956 - val_loss: 2.8342 - 63ms/epoch - 63ms/step\n",
      "Epoch 153/2000\n",
      "1/1 - 0s - loss: 2.7943 - val_loss: 2.8327 - 56ms/epoch - 56ms/step\n",
      "Epoch 154/2000\n",
      "1/1 - 0s - loss: 2.7930 - val_loss: 2.8313 - 76ms/epoch - 76ms/step\n",
      "Epoch 155/2000\n",
      "1/1 - 0s - loss: 2.7917 - val_loss: 2.8300 - 54ms/epoch - 54ms/step\n",
      "Epoch 156/2000\n",
      "1/1 - 0s - loss: 2.7904 - val_loss: 2.8286 - 53ms/epoch - 53ms/step\n",
      "Epoch 157/2000\n",
      "1/1 - 0s - loss: 2.7891 - val_loss: 2.8273 - 54ms/epoch - 54ms/step\n",
      "Epoch 158/2000\n",
      "1/1 - 0s - loss: 2.7878 - val_loss: 2.8260 - 46ms/epoch - 46ms/step\n",
      "Epoch 159/2000\n",
      "1/1 - 0s - loss: 2.7865 - val_loss: 2.8246 - 74ms/epoch - 74ms/step\n",
      "Epoch 160/2000\n",
      "1/1 - 0s - loss: 2.7853 - val_loss: 2.8233 - 50ms/epoch - 50ms/step\n",
      "Epoch 161/2000\n",
      "1/1 - 0s - loss: 2.7840 - val_loss: 2.8220 - 80ms/epoch - 80ms/step\n",
      "Epoch 162/2000\n",
      "1/1 - 0s - loss: 2.7827 - val_loss: 2.8208 - 51ms/epoch - 51ms/step\n",
      "Epoch 163/2000\n",
      "1/1 - 0s - loss: 2.7814 - val_loss: 2.8195 - 62ms/epoch - 62ms/step\n",
      "Epoch 164/2000\n",
      "1/1 - 0s - loss: 2.7802 - val_loss: 2.8183 - 62ms/epoch - 62ms/step\n",
      "Epoch 165/2000\n",
      "1/1 - 0s - loss: 2.7789 - val_loss: 2.8170 - 56ms/epoch - 56ms/step\n",
      "Epoch 166/2000\n",
      "1/1 - 0s - loss: 2.7777 - val_loss: 2.8157 - 58ms/epoch - 58ms/step\n",
      "Epoch 167/2000\n",
      "1/1 - 0s - loss: 2.7764 - val_loss: 2.8144 - 51ms/epoch - 51ms/step\n",
      "Epoch 168/2000\n",
      "1/1 - 0s - loss: 2.7752 - val_loss: 2.8131 - 64ms/epoch - 64ms/step\n",
      "Epoch 169/2000\n",
      "1/1 - 0s - loss: 2.7740 - val_loss: 2.8118 - 67ms/epoch - 67ms/step\n",
      "Epoch 170/2000\n",
      "1/1 - 0s - loss: 2.7727 - val_loss: 2.8106 - 52ms/epoch - 52ms/step\n",
      "Epoch 171/2000\n",
      "1/1 - 0s - loss: 2.7715 - val_loss: 2.8094 - 57ms/epoch - 57ms/step\n",
      "Epoch 172/2000\n",
      "1/1 - 0s - loss: 2.7703 - val_loss: 2.8083 - 88ms/epoch - 88ms/step\n",
      "Epoch 173/2000\n",
      "1/1 - 0s - loss: 2.7690 - val_loss: 2.8072 - 56ms/epoch - 56ms/step\n",
      "Epoch 174/2000\n",
      "1/1 - 0s - loss: 2.7678 - val_loss: 2.8061 - 54ms/epoch - 54ms/step\n",
      "Epoch 175/2000\n",
      "1/1 - 0s - loss: 2.7666 - val_loss: 2.8048 - 59ms/epoch - 59ms/step\n",
      "Epoch 176/2000\n",
      "1/1 - 0s - loss: 2.7654 - val_loss: 2.8036 - 47ms/epoch - 47ms/step\n",
      "Epoch 177/2000\n",
      "1/1 - 0s - loss: 2.7641 - val_loss: 2.8023 - 50ms/epoch - 50ms/step\n",
      "Epoch 178/2000\n",
      "1/1 - 0s - loss: 2.7629 - val_loss: 2.8012 - 50ms/epoch - 50ms/step\n",
      "Epoch 179/2000\n",
      "1/1 - 0s - loss: 2.7618 - val_loss: 2.8001 - 43ms/epoch - 43ms/step\n",
      "Epoch 180/2000\n",
      "1/1 - 0s - loss: 2.7606 - val_loss: 2.7991 - 48ms/epoch - 48ms/step\n",
      "Epoch 181/2000\n",
      "1/1 - 0s - loss: 2.7595 - val_loss: 2.7980 - 51ms/epoch - 51ms/step\n",
      "Epoch 182/2000\n",
      "1/1 - 0s - loss: 2.7583 - val_loss: 2.7970 - 53ms/epoch - 53ms/step\n",
      "Epoch 183/2000\n",
      "1/1 - 0s - loss: 2.7572 - val_loss: 2.7961 - 54ms/epoch - 54ms/step\n",
      "Epoch 184/2000\n",
      "1/1 - 0s - loss: 2.7561 - val_loss: 2.7950 - 81ms/epoch - 81ms/step\n",
      "Epoch 185/2000\n",
      "1/1 - 0s - loss: 2.7550 - val_loss: 2.7940 - 59ms/epoch - 59ms/step\n",
      "Epoch 186/2000\n",
      "1/1 - 0s - loss: 2.7539 - val_loss: 2.7931 - 45ms/epoch - 45ms/step\n",
      "Epoch 187/2000\n",
      "1/1 - 0s - loss: 2.7529 - val_loss: 2.7921 - 47ms/epoch - 47ms/step\n",
      "Epoch 188/2000\n",
      "1/1 - 0s - loss: 2.7518 - val_loss: 2.7912 - 48ms/epoch - 48ms/step\n",
      "Epoch 189/2000\n",
      "1/1 - 0s - loss: 2.7507 - val_loss: 2.7903 - 48ms/epoch - 48ms/step\n",
      "Epoch 190/2000\n",
      "1/1 - 0s - loss: 2.7496 - val_loss: 2.7894 - 60ms/epoch - 60ms/step\n",
      "Epoch 191/2000\n",
      "1/1 - 0s - loss: 2.7485 - val_loss: 2.7884 - 57ms/epoch - 57ms/step\n",
      "Epoch 192/2000\n",
      "1/1 - 0s - loss: 2.7474 - val_loss: 2.7875 - 66ms/epoch - 66ms/step\n",
      "Epoch 193/2000\n",
      "1/1 - 0s - loss: 2.7463 - val_loss: 2.7865 - 57ms/epoch - 57ms/step\n",
      "Epoch 194/2000\n",
      "1/1 - 0s - loss: 2.7452 - val_loss: 2.7855 - 49ms/epoch - 49ms/step\n",
      "Epoch 195/2000\n",
      "1/1 - 0s - loss: 2.7441 - val_loss: 2.7845 - 57ms/epoch - 57ms/step\n",
      "Epoch 196/2000\n",
      "1/1 - 0s - loss: 2.7430 - val_loss: 2.7835 - 59ms/epoch - 59ms/step\n",
      "Epoch 197/2000\n",
      "1/1 - 0s - loss: 2.7419 - val_loss: 2.7825 - 52ms/epoch - 52ms/step\n",
      "Epoch 198/2000\n",
      "1/1 - 0s - loss: 2.7408 - val_loss: 2.7816 - 67ms/epoch - 67ms/step\n",
      "Epoch 199/2000\n",
      "1/1 - 0s - loss: 2.7397 - val_loss: 2.7806 - 64ms/epoch - 64ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200/2000\n",
      "1/1 - 0s - loss: 2.7386 - val_loss: 2.7796 - 72ms/epoch - 72ms/step\n",
      "Epoch 201/2000\n",
      "1/1 - 0s - loss: 2.7375 - val_loss: 2.7786 - 88ms/epoch - 88ms/step\n",
      "Epoch 202/2000\n",
      "1/1 - 0s - loss: 2.7364 - val_loss: 2.7775 - 71ms/epoch - 71ms/step\n",
      "Epoch 203/2000\n",
      "1/1 - 0s - loss: 2.7354 - val_loss: 2.7765 - 79ms/epoch - 79ms/step\n",
      "Epoch 204/2000\n",
      "1/1 - 0s - loss: 2.7343 - val_loss: 2.7754 - 55ms/epoch - 55ms/step\n",
      "Epoch 205/2000\n",
      "1/1 - 0s - loss: 2.7332 - val_loss: 2.7744 - 57ms/epoch - 57ms/step\n",
      "Epoch 206/2000\n",
      "1/1 - 0s - loss: 2.7321 - val_loss: 2.7734 - 51ms/epoch - 51ms/step\n",
      "Epoch 207/2000\n",
      "1/1 - 0s - loss: 2.7311 - val_loss: 2.7725 - 97ms/epoch - 97ms/step\n",
      "Epoch 208/2000\n",
      "1/1 - 0s - loss: 2.7300 - val_loss: 2.7717 - 63ms/epoch - 63ms/step\n",
      "Epoch 209/2000\n",
      "1/1 - 0s - loss: 2.7289 - val_loss: 2.7709 - 63ms/epoch - 63ms/step\n",
      "Epoch 210/2000\n",
      "1/1 - 0s - loss: 2.7278 - val_loss: 2.7699 - 100ms/epoch - 100ms/step\n",
      "Epoch 211/2000\n",
      "1/1 - 0s - loss: 2.7268 - val_loss: 2.7689 - 63ms/epoch - 63ms/step\n",
      "Epoch 212/2000\n",
      "1/1 - 0s - loss: 2.7257 - val_loss: 2.7679 - 56ms/epoch - 56ms/step\n",
      "Epoch 213/2000\n",
      "1/1 - 0s - loss: 2.7246 - val_loss: 2.7670 - 51ms/epoch - 51ms/step\n",
      "Epoch 214/2000\n",
      "1/1 - 0s - loss: 2.7236 - val_loss: 2.7660 - 79ms/epoch - 79ms/step\n",
      "Epoch 215/2000\n",
      "1/1 - 0s - loss: 2.7225 - val_loss: 2.7651 - 75ms/epoch - 75ms/step\n",
      "Epoch 216/2000\n",
      "1/1 - 0s - loss: 2.7214 - val_loss: 2.7642 - 93ms/epoch - 93ms/step\n",
      "Epoch 217/2000\n",
      "1/1 - 0s - loss: 2.7204 - val_loss: 2.7633 - 60ms/epoch - 60ms/step\n",
      "Epoch 218/2000\n",
      "1/1 - 0s - loss: 2.7193 - val_loss: 2.7624 - 85ms/epoch - 85ms/step\n",
      "Epoch 219/2000\n",
      "1/1 - 0s - loss: 2.7183 - val_loss: 2.7614 - 47ms/epoch - 47ms/step\n",
      "Epoch 220/2000\n",
      "1/1 - 0s - loss: 2.7173 - val_loss: 2.7605 - 58ms/epoch - 58ms/step\n",
      "Epoch 221/2000\n",
      "1/1 - 0s - loss: 2.7162 - val_loss: 2.7597 - 84ms/epoch - 84ms/step\n",
      "Epoch 222/2000\n",
      "1/1 - 0s - loss: 2.7152 - val_loss: 2.7589 - 63ms/epoch - 63ms/step\n",
      "Epoch 223/2000\n",
      "1/1 - 0s - loss: 2.7142 - val_loss: 2.7581 - 67ms/epoch - 67ms/step\n",
      "Epoch 224/2000\n",
      "1/1 - 0s - loss: 2.7132 - val_loss: 2.7572 - 75ms/epoch - 75ms/step\n",
      "Epoch 225/2000\n",
      "1/1 - 0s - loss: 2.7122 - val_loss: 2.7564 - 66ms/epoch - 66ms/step\n",
      "Epoch 226/2000\n",
      "1/1 - 0s - loss: 2.7112 - val_loss: 2.7557 - 86ms/epoch - 86ms/step\n",
      "Epoch 227/2000\n",
      "1/1 - 0s - loss: 2.7102 - val_loss: 2.7548 - 83ms/epoch - 83ms/step\n",
      "Epoch 228/2000\n",
      "1/1 - 0s - loss: 2.7092 - val_loss: 2.7539 - 77ms/epoch - 77ms/step\n",
      "Epoch 229/2000\n",
      "1/1 - 0s - loss: 2.7082 - val_loss: 2.7531 - 52ms/epoch - 52ms/step\n",
      "Epoch 230/2000\n",
      "1/1 - 0s - loss: 2.7072 - val_loss: 2.7523 - 47ms/epoch - 47ms/step\n",
      "Epoch 231/2000\n",
      "1/1 - 0s - loss: 2.7063 - val_loss: 2.7516 - 49ms/epoch - 49ms/step\n",
      "Epoch 232/2000\n",
      "1/1 - 0s - loss: 2.7054 - val_loss: 2.7509 - 50ms/epoch - 50ms/step\n",
      "Epoch 233/2000\n",
      "1/1 - 0s - loss: 2.7044 - val_loss: 2.7503 - 58ms/epoch - 58ms/step\n",
      "Epoch 234/2000\n",
      "1/1 - 0s - loss: 2.7035 - val_loss: 2.7495 - 76ms/epoch - 76ms/step\n",
      "Epoch 235/2000\n",
      "1/1 - 0s - loss: 2.7026 - val_loss: 2.7487 - 54ms/epoch - 54ms/step\n",
      "Epoch 236/2000\n",
      "1/1 - 0s - loss: 2.7016 - val_loss: 2.7478 - 64ms/epoch - 64ms/step\n",
      "Epoch 237/2000\n",
      "1/1 - 0s - loss: 2.7008 - val_loss: 2.7472 - 49ms/epoch - 49ms/step\n",
      "Epoch 238/2000\n",
      "1/1 - 0s - loss: 2.6998 - val_loss: 2.7464 - 56ms/epoch - 56ms/step\n",
      "Epoch 239/2000\n",
      "1/1 - 0s - loss: 2.6989 - val_loss: 2.7457 - 47ms/epoch - 47ms/step\n",
      "Epoch 240/2000\n",
      "1/1 - 0s - loss: 2.6980 - val_loss: 2.7450 - 58ms/epoch - 58ms/step\n",
      "Epoch 241/2000\n",
      "1/1 - 0s - loss: 2.6971 - val_loss: 2.7442 - 48ms/epoch - 48ms/step\n",
      "Epoch 242/2000\n",
      "1/1 - 0s - loss: 2.6963 - val_loss: 2.7434 - 46ms/epoch - 46ms/step\n",
      "Epoch 243/2000\n",
      "1/1 - 0s - loss: 2.6954 - val_loss: 2.7426 - 66ms/epoch - 66ms/step\n",
      "Epoch 244/2000\n",
      "1/1 - 0s - loss: 2.6945 - val_loss: 2.7419 - 64ms/epoch - 64ms/step\n",
      "Epoch 245/2000\n",
      "1/1 - 0s - loss: 2.6936 - val_loss: 2.7413 - 101ms/epoch - 101ms/step\n",
      "Epoch 246/2000\n",
      "1/1 - 0s - loss: 2.6927 - val_loss: 2.7405 - 55ms/epoch - 55ms/step\n",
      "Epoch 247/2000\n",
      "1/1 - 0s - loss: 2.6918 - val_loss: 2.7398 - 51ms/epoch - 51ms/step\n",
      "Epoch 248/2000\n",
      "1/1 - 0s - loss: 2.6909 - val_loss: 2.7391 - 45ms/epoch - 45ms/step\n",
      "Epoch 249/2000\n",
      "1/1 - 0s - loss: 2.6901 - val_loss: 2.7385 - 47ms/epoch - 47ms/step\n",
      "Epoch 250/2000\n",
      "1/1 - 0s - loss: 2.6892 - val_loss: 2.7379 - 44ms/epoch - 44ms/step\n",
      "Epoch 251/2000\n",
      "1/1 - 0s - loss: 2.6883 - val_loss: 2.7372 - 52ms/epoch - 52ms/step\n",
      "Epoch 252/2000\n",
      "1/1 - 0s - loss: 2.6875 - val_loss: 2.7365 - 61ms/epoch - 61ms/step\n",
      "Epoch 253/2000\n",
      "1/1 - 0s - loss: 2.6866 - val_loss: 2.7357 - 54ms/epoch - 54ms/step\n",
      "Epoch 254/2000\n",
      "1/1 - 0s - loss: 2.6858 - val_loss: 2.7350 - 58ms/epoch - 58ms/step\n",
      "Epoch 255/2000\n",
      "1/1 - 0s - loss: 2.6850 - val_loss: 2.7343 - 64ms/epoch - 64ms/step\n",
      "Epoch 256/2000\n",
      "1/1 - 0s - loss: 2.6841 - val_loss: 2.7338 - 71ms/epoch - 71ms/step\n",
      "Epoch 257/2000\n",
      "1/1 - 0s - loss: 2.6833 - val_loss: 2.7332 - 65ms/epoch - 65ms/step\n",
      "Epoch 258/2000\n",
      "1/1 - 0s - loss: 2.6825 - val_loss: 2.7325 - 68ms/epoch - 68ms/step\n",
      "Epoch 259/2000\n",
      "1/1 - 0s - loss: 2.6817 - val_loss: 2.7319 - 53ms/epoch - 53ms/step\n",
      "Epoch 260/2000\n",
      "1/1 - 0s - loss: 2.6809 - val_loss: 2.7312 - 57ms/epoch - 57ms/step\n",
      "Epoch 261/2000\n",
      "1/1 - 0s - loss: 2.6801 - val_loss: 2.7305 - 79ms/epoch - 79ms/step\n",
      "Epoch 262/2000\n",
      "1/1 - 0s - loss: 2.6793 - val_loss: 2.7299 - 58ms/epoch - 58ms/step\n",
      "Epoch 263/2000\n",
      "1/1 - 0s - loss: 2.6785 - val_loss: 2.7292 - 51ms/epoch - 51ms/step\n",
      "Epoch 264/2000\n",
      "1/1 - 0s - loss: 2.6777 - val_loss: 2.7287 - 49ms/epoch - 49ms/step\n",
      "Epoch 265/2000\n",
      "1/1 - 0s - loss: 2.6769 - val_loss: 2.7280 - 59ms/epoch - 59ms/step\n",
      "Epoch 266/2000\n",
      "1/1 - 0s - loss: 2.6762 - val_loss: 2.7273 - 56ms/epoch - 56ms/step\n",
      "Epoch 267/2000\n",
      "1/1 - 0s - loss: 2.6754 - val_loss: 2.7265 - 55ms/epoch - 55ms/step\n",
      "Epoch 268/2000\n",
      "1/1 - 0s - loss: 2.6746 - val_loss: 2.7259 - 53ms/epoch - 53ms/step\n",
      "Epoch 269/2000\n",
      "1/1 - 0s - loss: 2.6738 - val_loss: 2.7253 - 70ms/epoch - 70ms/step\n",
      "Epoch 270/2000\n",
      "1/1 - 0s - loss: 2.6731 - val_loss: 2.7247 - 68ms/epoch - 68ms/step\n",
      "Epoch 271/2000\n",
      "1/1 - 0s - loss: 2.6723 - val_loss: 2.7241 - 59ms/epoch - 59ms/step\n",
      "Epoch 272/2000\n",
      "1/1 - 0s - loss: 2.6716 - val_loss: 2.7234 - 65ms/epoch - 65ms/step\n",
      "Epoch 273/2000\n",
      "1/1 - 0s - loss: 2.6708 - val_loss: 2.7228 - 48ms/epoch - 48ms/step\n",
      "Epoch 274/2000\n",
      "1/1 - 0s - loss: 2.6701 - val_loss: 2.7222 - 49ms/epoch - 49ms/step\n",
      "Epoch 275/2000\n",
      "1/1 - 0s - loss: 2.6693 - val_loss: 2.7216 - 47ms/epoch - 47ms/step\n",
      "Epoch 276/2000\n",
      "1/1 - 0s - loss: 2.6686 - val_loss: 2.7211 - 50ms/epoch - 50ms/step\n",
      "Epoch 277/2000\n",
      "1/1 - 0s - loss: 2.6679 - val_loss: 2.7205 - 51ms/epoch - 51ms/step\n",
      "Epoch 278/2000\n",
      "1/1 - 0s - loss: 2.6671 - val_loss: 2.7198 - 49ms/epoch - 49ms/step\n",
      "Epoch 279/2000\n",
      "1/1 - 0s - loss: 2.6664 - val_loss: 2.7191 - 50ms/epoch - 50ms/step\n",
      "Epoch 280/2000\n",
      "1/1 - 0s - loss: 2.6656 - val_loss: 2.7186 - 56ms/epoch - 56ms/step\n",
      "Epoch 281/2000\n",
      "1/1 - 0s - loss: 2.6649 - val_loss: 2.7181 - 60ms/epoch - 60ms/step\n",
      "Epoch 282/2000\n",
      "1/1 - 0s - loss: 2.6642 - val_loss: 2.7175 - 73ms/epoch - 73ms/step\n",
      "Epoch 283/2000\n",
      "1/1 - 0s - loss: 2.6635 - val_loss: 2.7168 - 68ms/epoch - 68ms/step\n",
      "Epoch 284/2000\n",
      "1/1 - 0s - loss: 2.6627 - val_loss: 2.7161 - 57ms/epoch - 57ms/step\n",
      "Epoch 285/2000\n",
      "1/1 - 0s - loss: 2.6620 - val_loss: 2.7155 - 62ms/epoch - 62ms/step\n",
      "Epoch 286/2000\n",
      "1/1 - 0s - loss: 2.6613 - val_loss: 2.7150 - 54ms/epoch - 54ms/step\n",
      "Epoch 287/2000\n",
      "1/1 - 0s - loss: 2.6606 - val_loss: 2.7145 - 64ms/epoch - 64ms/step\n",
      "Epoch 288/2000\n",
      "1/1 - 0s - loss: 2.6599 - val_loss: 2.7140 - 62ms/epoch - 62ms/step\n",
      "Epoch 289/2000\n",
      "1/1 - 0s - loss: 2.6591 - val_loss: 2.7134 - 85ms/epoch - 85ms/step\n",
      "Epoch 290/2000\n",
      "1/1 - 0s - loss: 2.6584 - val_loss: 2.7127 - 62ms/epoch - 62ms/step\n",
      "Epoch 291/2000\n",
      "1/1 - 0s - loss: 2.6577 - val_loss: 2.7119 - 81ms/epoch - 81ms/step\n",
      "Epoch 292/2000\n",
      "1/1 - 0s - loss: 2.6570 - val_loss: 2.7112 - 62ms/epoch - 62ms/step\n",
      "Epoch 293/2000\n",
      "1/1 - 0s - loss: 2.6563 - val_loss: 2.7108 - 48ms/epoch - 48ms/step\n",
      "Epoch 294/2000\n",
      "1/1 - 0s - loss: 2.6556 - val_loss: 2.7103 - 51ms/epoch - 51ms/step\n",
      "Epoch 295/2000\n",
      "1/1 - 0s - loss: 2.6549 - val_loss: 2.7097 - 66ms/epoch - 66ms/step\n",
      "Epoch 296/2000\n",
      "1/1 - 0s - loss: 2.6541 - val_loss: 2.7090 - 114ms/epoch - 114ms/step\n",
      "Epoch 297/2000\n",
      "1/1 - 0s - loss: 2.6534 - val_loss: 2.7084 - 80ms/epoch - 80ms/step\n",
      "Epoch 298/2000\n",
      "1/1 - 0s - loss: 2.6527 - val_loss: 2.7078 - 76ms/epoch - 76ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 299/2000\n",
      "1/1 - 0s - loss: 2.6520 - val_loss: 2.7074 - 65ms/epoch - 65ms/step\n",
      "Epoch 300/2000\n",
      "1/1 - 0s - loss: 2.6513 - val_loss: 2.7068 - 63ms/epoch - 63ms/step\n",
      "Epoch 301/2000\n",
      "1/1 - 0s - loss: 2.6505 - val_loss: 2.7060 - 65ms/epoch - 65ms/step\n",
      "Epoch 302/2000\n",
      "1/1 - 0s - loss: 2.6498 - val_loss: 2.7054 - 54ms/epoch - 54ms/step\n",
      "Epoch 303/2000\n",
      "1/1 - 0s - loss: 2.6491 - val_loss: 2.7050 - 55ms/epoch - 55ms/step\n",
      "Epoch 304/2000\n",
      "1/1 - 0s - loss: 2.6484 - val_loss: 2.7045 - 60ms/epoch - 60ms/step\n",
      "Epoch 305/2000\n",
      "1/1 - 0s - loss: 2.6476 - val_loss: 2.7039 - 67ms/epoch - 67ms/step\n",
      "Epoch 306/2000\n",
      "1/1 - 0s - loss: 2.6469 - val_loss: 2.7032 - 66ms/epoch - 66ms/step\n",
      "Epoch 307/2000\n",
      "1/1 - 0s - loss: 2.6462 - val_loss: 2.7025 - 82ms/epoch - 82ms/step\n",
      "Epoch 308/2000\n",
      "1/1 - 0s - loss: 2.6454 - val_loss: 2.7020 - 55ms/epoch - 55ms/step\n",
      "Epoch 309/2000\n",
      "1/1 - 0s - loss: 2.6447 - val_loss: 2.7015 - 53ms/epoch - 53ms/step\n",
      "Epoch 310/2000\n",
      "1/1 - 0s - loss: 2.6439 - val_loss: 2.7010 - 49ms/epoch - 49ms/step\n",
      "Epoch 311/2000\n",
      "1/1 - 0s - loss: 2.6432 - val_loss: 2.7005 - 46ms/epoch - 46ms/step\n",
      "Epoch 312/2000\n",
      "1/1 - 0s - loss: 2.6424 - val_loss: 2.6999 - 46ms/epoch - 46ms/step\n",
      "Epoch 313/2000\n",
      "1/1 - 0s - loss: 2.6417 - val_loss: 2.6991 - 74ms/epoch - 74ms/step\n",
      "Epoch 314/2000\n",
      "1/1 - 0s - loss: 2.6410 - val_loss: 2.6985 - 72ms/epoch - 72ms/step\n",
      "Epoch 315/2000\n",
      "1/1 - 0s - loss: 2.6402 - val_loss: 2.6979 - 57ms/epoch - 57ms/step\n",
      "Epoch 316/2000\n",
      "1/1 - 0s - loss: 2.6394 - val_loss: 2.6975 - 58ms/epoch - 58ms/step\n",
      "Epoch 317/2000\n",
      "1/1 - 0s - loss: 2.6387 - val_loss: 2.6970 - 95ms/epoch - 95ms/step\n",
      "Epoch 318/2000\n",
      "1/1 - 0s - loss: 2.6379 - val_loss: 2.6962 - 92ms/epoch - 92ms/step\n",
      "Epoch 319/2000\n",
      "1/1 - 0s - loss: 2.6371 - val_loss: 2.6955 - 97ms/epoch - 97ms/step\n",
      "Epoch 320/2000\n",
      "1/1 - 0s - loss: 2.6364 - val_loss: 2.6950 - 74ms/epoch - 74ms/step\n",
      "Epoch 321/2000\n",
      "1/1 - 0s - loss: 2.6356 - val_loss: 2.6945 - 53ms/epoch - 53ms/step\n",
      "Epoch 322/2000\n",
      "1/1 - 0s - loss: 2.6348 - val_loss: 2.6938 - 75ms/epoch - 75ms/step\n",
      "Epoch 323/2000\n",
      "1/1 - 0s - loss: 2.6341 - val_loss: 2.6932 - 53ms/epoch - 53ms/step\n",
      "Epoch 324/2000\n",
      "1/1 - 0s - loss: 2.6333 - val_loss: 2.6926 - 47ms/epoch - 47ms/step\n",
      "Epoch 325/2000\n",
      "1/1 - 0s - loss: 2.6325 - val_loss: 2.6918 - 48ms/epoch - 48ms/step\n",
      "Epoch 326/2000\n",
      "1/1 - 0s - loss: 2.6318 - val_loss: 2.6911 - 45ms/epoch - 45ms/step\n",
      "Epoch 327/2000\n",
      "1/1 - 0s - loss: 2.6311 - val_loss: 2.6904 - 47ms/epoch - 47ms/step\n",
      "Epoch 328/2000\n",
      "1/1 - 0s - loss: 2.6303 - val_loss: 2.6899 - 45ms/epoch - 45ms/step\n",
      "Epoch 329/2000\n",
      "1/1 - 0s - loss: 2.6296 - val_loss: 2.6892 - 64ms/epoch - 64ms/step\n",
      "Epoch 330/2000\n",
      "1/1 - 0s - loss: 2.6288 - val_loss: 2.6887 - 60ms/epoch - 60ms/step\n",
      "Epoch 331/2000\n",
      "1/1 - 0s - loss: 2.6281 - val_loss: 2.6878 - 63ms/epoch - 63ms/step\n",
      "Epoch 332/2000\n",
      "1/1 - 0s - loss: 2.6272 - val_loss: 2.6869 - 43ms/epoch - 43ms/step\n",
      "Epoch 333/2000\n",
      "1/1 - 0s - loss: 2.6265 - val_loss: 2.6862 - 48ms/epoch - 48ms/step\n",
      "Epoch 334/2000\n",
      "1/1 - 0s - loss: 2.6257 - val_loss: 2.6852 - 43ms/epoch - 43ms/step\n",
      "Epoch 335/2000\n",
      "1/1 - 0s - loss: 2.6250 - val_loss: 2.6846 - 66ms/epoch - 66ms/step\n",
      "Epoch 336/2000\n",
      "1/1 - 0s - loss: 2.6241 - val_loss: 2.6840 - 69ms/epoch - 69ms/step\n",
      "Epoch 337/2000\n",
      "1/1 - 0s - loss: 2.6233 - val_loss: 2.6833 - 101ms/epoch - 101ms/step\n",
      "Epoch 338/2000\n",
      "1/1 - 0s - loss: 2.6225 - val_loss: 2.6829 - 126ms/epoch - 126ms/step\n",
      "Epoch 339/2000\n",
      "1/1 - 0s - loss: 2.6217 - val_loss: 2.6816 - 81ms/epoch - 81ms/step\n",
      "Epoch 340/2000\n",
      "1/1 - 0s - loss: 2.6209 - val_loss: 2.6808 - 70ms/epoch - 70ms/step\n",
      "Epoch 341/2000\n",
      "1/1 - 0s - loss: 2.6200 - val_loss: 2.6800 - 66ms/epoch - 66ms/step\n",
      "Epoch 342/2000\n",
      "1/1 - 0s - loss: 2.6192 - val_loss: 2.6795 - 89ms/epoch - 89ms/step\n",
      "Epoch 343/2000\n",
      "1/1 - 0s - loss: 2.6184 - val_loss: 2.6790 - 86ms/epoch - 86ms/step\n",
      "Epoch 344/2000\n",
      "1/1 - 0s - loss: 2.6176 - val_loss: 2.6783 - 121ms/epoch - 121ms/step\n",
      "Epoch 345/2000\n",
      "1/1 - 0s - loss: 2.6168 - val_loss: 2.6771 - 65ms/epoch - 65ms/step\n",
      "Epoch 346/2000\n",
      "1/1 - 0s - loss: 2.6159 - val_loss: 2.6760 - 87ms/epoch - 87ms/step\n",
      "Epoch 347/2000\n",
      "1/1 - 0s - loss: 2.6151 - val_loss: 2.6752 - 62ms/epoch - 62ms/step\n",
      "Epoch 348/2000\n",
      "1/1 - 0s - loss: 2.6142 - val_loss: 2.6748 - 71ms/epoch - 71ms/step\n",
      "Epoch 349/2000\n",
      "1/1 - 0s - loss: 2.6134 - val_loss: 2.6744 - 99ms/epoch - 99ms/step\n",
      "Epoch 350/2000\n",
      "1/1 - 0s - loss: 2.6126 - val_loss: 2.6737 - 62ms/epoch - 62ms/step\n",
      "Epoch 351/2000\n",
      "1/1 - 0s - loss: 2.6118 - val_loss: 2.6726 - 63ms/epoch - 63ms/step\n",
      "Epoch 352/2000\n",
      "1/1 - 0s - loss: 2.6110 - val_loss: 2.6718 - 51ms/epoch - 51ms/step\n",
      "Epoch 353/2000\n",
      "1/1 - 0s - loss: 2.6102 - val_loss: 2.6709 - 44ms/epoch - 44ms/step\n",
      "Epoch 354/2000\n",
      "1/1 - 0s - loss: 2.6094 - val_loss: 2.6712 - 43ms/epoch - 43ms/step\n",
      "Epoch 355/2000\n",
      "1/1 - 0s - loss: 2.6087 - val_loss: 2.6703 - 66ms/epoch - 66ms/step\n",
      "Epoch 356/2000\n",
      "1/1 - 0s - loss: 2.6082 - val_loss: 2.6704 - 63ms/epoch - 63ms/step\n",
      "Epoch 357/2000\n",
      "1/1 - 0s - loss: 2.6074 - val_loss: 2.6684 - 52ms/epoch - 52ms/step\n",
      "Epoch 358/2000\n",
      "1/1 - 0s - loss: 2.6064 - val_loss: 2.6681 - 54ms/epoch - 54ms/step\n",
      "Epoch 359/2000\n",
      "1/1 - 0s - loss: 2.6054 - val_loss: 2.6668 - 58ms/epoch - 58ms/step\n",
      "Epoch 360/2000\n",
      "1/1 - 0s - loss: 2.6044 - val_loss: 2.6661 - 64ms/epoch - 64ms/step\n",
      "Epoch 361/2000\n",
      "1/1 - 0s - loss: 2.6037 - val_loss: 2.6663 - 54ms/epoch - 54ms/step\n",
      "Epoch 362/2000\n",
      "1/1 - 0s - loss: 2.6031 - val_loss: 2.6651 - 44ms/epoch - 44ms/step\n",
      "Epoch 363/2000\n",
      "1/1 - 0s - loss: 2.6025 - val_loss: 2.6650 - 49ms/epoch - 49ms/step\n",
      "Epoch 364/2000\n",
      "1/1 - 0s - loss: 2.6016 - val_loss: 2.6632 - 59ms/epoch - 59ms/step\n",
      "Epoch 365/2000\n",
      "1/1 - 0s - loss: 2.6005 - val_loss: 2.6626 - 59ms/epoch - 59ms/step\n",
      "Epoch 366/2000\n",
      "1/1 - 0s - loss: 2.5997 - val_loss: 2.6624 - 46ms/epoch - 46ms/step\n",
      "Epoch 367/2000\n",
      "1/1 - 0s - loss: 2.5991 - val_loss: 2.6614 - 94ms/epoch - 94ms/step\n",
      "Epoch 368/2000\n",
      "1/1 - 0s - loss: 2.5985 - val_loss: 2.6616 - 96ms/epoch - 96ms/step\n",
      "Epoch 369/2000\n",
      "1/1 - 0s - loss: 2.5977 - val_loss: 2.6595 - 57ms/epoch - 57ms/step\n",
      "Epoch 370/2000\n",
      "1/1 - 0s - loss: 2.5968 - val_loss: 2.6587 - 51ms/epoch - 51ms/step\n",
      "Epoch 371/2000\n",
      "1/1 - 0s - loss: 2.5958 - val_loss: 2.6580 - 56ms/epoch - 56ms/step\n",
      "Epoch 372/2000\n",
      "1/1 - 0s - loss: 2.5949 - val_loss: 2.6571 - 59ms/epoch - 59ms/step\n",
      "Epoch 373/2000\n",
      "1/1 - 0s - loss: 2.5943 - val_loss: 2.6570 - 65ms/epoch - 65ms/step\n",
      "Epoch 374/2000\n",
      "1/1 - 0s - loss: 2.5937 - val_loss: 2.6552 - 54ms/epoch - 54ms/step\n",
      "Epoch 375/2000\n",
      "1/1 - 0s - loss: 2.5930 - val_loss: 2.6554 - 50ms/epoch - 50ms/step\n",
      "Epoch 376/2000\n",
      "1/1 - 0s - loss: 2.5922 - val_loss: 2.6536 - 52ms/epoch - 52ms/step\n",
      "Epoch 377/2000\n",
      "1/1 - 0s - loss: 2.5912 - val_loss: 2.6530 - 49ms/epoch - 49ms/step\n",
      "Epoch 378/2000\n",
      "1/1 - 0s - loss: 2.5902 - val_loss: 2.6516 - 63ms/epoch - 63ms/step\n",
      "Epoch 379/2000\n",
      "1/1 - 0s - loss: 2.5894 - val_loss: 2.6500 - 58ms/epoch - 58ms/step\n",
      "Epoch 380/2000\n",
      "1/1 - 0s - loss: 2.5888 - val_loss: 2.6508 - 46ms/epoch - 46ms/step\n",
      "Epoch 381/2000\n",
      "1/1 - 0s - loss: 2.5882 - val_loss: 2.6487 - 59ms/epoch - 59ms/step\n",
      "Epoch 382/2000\n",
      "1/1 - 0s - loss: 2.5873 - val_loss: 2.6484 - 79ms/epoch - 79ms/step\n",
      "Epoch 383/2000\n",
      "1/1 - 0s - loss: 2.5863 - val_loss: 2.6470 - 51ms/epoch - 51ms/step\n",
      "Epoch 384/2000\n",
      "1/1 - 0s - loss: 2.5853 - val_loss: 2.6458 - 77ms/epoch - 77ms/step\n",
      "Epoch 385/2000\n",
      "1/1 - 0s - loss: 2.5847 - val_loss: 2.6464 - 62ms/epoch - 62ms/step\n",
      "Epoch 386/2000\n",
      "1/1 - 0s - loss: 2.5840 - val_loss: 2.6450 - 57ms/epoch - 57ms/step\n",
      "Epoch 387/2000\n",
      "1/1 - 0s - loss: 2.5833 - val_loss: 2.6450 - 69ms/epoch - 69ms/step\n",
      "Epoch 388/2000\n",
      "1/1 - 0s - loss: 2.5824 - val_loss: 2.6432 - 57ms/epoch - 57ms/step\n",
      "Epoch 389/2000\n",
      "1/1 - 0s - loss: 2.5814 - val_loss: 2.6427 - 71ms/epoch - 71ms/step\n",
      "Epoch 390/2000\n",
      "1/1 - 0s - loss: 2.5806 - val_loss: 2.6428 - 48ms/epoch - 48ms/step\n",
      "Epoch 391/2000\n",
      "1/1 - 0s - loss: 2.5800 - val_loss: 2.6410 - 58ms/epoch - 58ms/step\n",
      "Epoch 392/2000\n",
      "1/1 - 0s - loss: 2.5793 - val_loss: 2.6410 - 61ms/epoch - 61ms/step\n",
      "Epoch 393/2000\n",
      "1/1 - 0s - loss: 2.5785 - val_loss: 2.6389 - 71ms/epoch - 71ms/step\n",
      "Epoch 394/2000\n",
      "1/1 - 0s - loss: 2.5776 - val_loss: 2.6392 - 44ms/epoch - 44ms/step\n",
      "Epoch 395/2000\n",
      "1/1 - 0s - loss: 2.5767 - val_loss: 2.6379 - 62ms/epoch - 62ms/step\n",
      "Epoch 396/2000\n",
      "1/1 - 0s - loss: 2.5758 - val_loss: 2.6368 - 51ms/epoch - 51ms/step\n",
      "Epoch 397/2000\n",
      "1/1 - 0s - loss: 2.5750 - val_loss: 2.6364 - 56ms/epoch - 56ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 398/2000\n",
      "1/1 - 0s - loss: 2.5743 - val_loss: 2.6349 - 61ms/epoch - 61ms/step\n",
      "Epoch 399/2000\n",
      "1/1 - 0s - loss: 2.5737 - val_loss: 2.6363 - 65ms/epoch - 65ms/step\n",
      "Epoch 400/2000\n",
      "1/1 - 0s - loss: 2.5732 - val_loss: 2.6339 - 43ms/epoch - 43ms/step\n",
      "Epoch 401/2000\n",
      "1/1 - 0s - loss: 2.5730 - val_loss: 2.6353 - 42ms/epoch - 42ms/step\n",
      "Epoch 402/2000\n",
      "1/1 - 0s - loss: 2.5719 - val_loss: 2.6321 - 49ms/epoch - 49ms/step\n",
      "Epoch 403/2000\n",
      "1/1 - 0s - loss: 2.5706 - val_loss: 2.6318 - 58ms/epoch - 58ms/step\n",
      "Epoch 404/2000\n",
      "1/1 - 0s - loss: 2.5697 - val_loss: 2.6322 - 63ms/epoch - 63ms/step\n",
      "Epoch 405/2000\n",
      "1/1 - 0s - loss: 2.5691 - val_loss: 2.6307 - 57ms/epoch - 57ms/step\n",
      "Epoch 406/2000\n",
      "1/1 - 0s - loss: 2.5690 - val_loss: 2.6320 - 47ms/epoch - 47ms/step\n",
      "Epoch 407/2000\n",
      "1/1 - 0s - loss: 2.5679 - val_loss: 2.6288 - 62ms/epoch - 62ms/step\n",
      "Epoch 408/2000\n",
      "1/1 - 0s - loss: 2.5668 - val_loss: 2.6283 - 48ms/epoch - 48ms/step\n",
      "Epoch 409/2000\n",
      "1/1 - 0s - loss: 2.5661 - val_loss: 2.6304 - 59ms/epoch - 59ms/step\n",
      "Epoch 410/2000\n",
      "1/1 - 0s - loss: 2.5657 - val_loss: 2.6280 - 58ms/epoch - 58ms/step\n",
      "Epoch 411/2000\n",
      "1/1 - 0s - loss: 2.5654 - val_loss: 2.6283 - 45ms/epoch - 45ms/step\n",
      "Epoch 412/2000\n",
      "1/1 - 0s - loss: 2.5637 - val_loss: 2.6269 - 45ms/epoch - 45ms/step\n",
      "Epoch 413/2000\n",
      "1/1 - 0s - loss: 2.5629 - val_loss: 2.6250 - 44ms/epoch - 44ms/step\n",
      "Epoch 414/2000\n",
      "1/1 - 0s - loss: 2.5629 - val_loss: 2.6273 - 43ms/epoch - 43ms/step\n",
      "Epoch 415/2000\n",
      "1/1 - 0s - loss: 2.5622 - val_loss: 2.6245 - 44ms/epoch - 44ms/step\n",
      "Epoch 416/2000\n",
      "1/1 - 0s - loss: 2.5608 - val_loss: 2.6233 - 65ms/epoch - 65ms/step\n",
      "Epoch 417/2000\n",
      "1/1 - 0s - loss: 2.5599 - val_loss: 2.6248 - 79ms/epoch - 79ms/step\n",
      "Epoch 418/2000\n",
      "1/1 - 0s - loss: 2.5601 - val_loss: 2.6227 - 50ms/epoch - 50ms/step\n",
      "Epoch 419/2000\n",
      "1/1 - 0s - loss: 2.5600 - val_loss: 2.6235 - 45ms/epoch - 45ms/step\n",
      "Epoch 420/2000\n",
      "1/1 - 0s - loss: 2.5582 - val_loss: 2.6227 - 50ms/epoch - 50ms/step\n",
      "Epoch 421/2000\n",
      "1/1 - 0s - loss: 2.5576 - val_loss: 2.6192 - 48ms/epoch - 48ms/step\n",
      "Epoch 422/2000\n",
      "1/1 - 0s - loss: 2.5571 - val_loss: 2.6198 - 46ms/epoch - 46ms/step\n",
      "Epoch 423/2000\n",
      "1/1 - 0s - loss: 2.5560 - val_loss: 2.6188 - 46ms/epoch - 46ms/step\n",
      "Epoch 424/2000\n",
      "1/1 - 0s - loss: 2.5549 - val_loss: 2.6183 - 58ms/epoch - 58ms/step\n",
      "Epoch 425/2000\n",
      "1/1 - 0s - loss: 2.5546 - val_loss: 2.6184 - 58ms/epoch - 58ms/step\n",
      "Epoch 426/2000\n",
      "1/1 - 0s - loss: 2.5534 - val_loss: 2.6166 - 52ms/epoch - 52ms/step\n",
      "Epoch 427/2000\n",
      "1/1 - 0s - loss: 2.5528 - val_loss: 2.6160 - 46ms/epoch - 46ms/step\n",
      "Epoch 428/2000\n",
      "1/1 - 0s - loss: 2.5520 - val_loss: 2.6161 - 60ms/epoch - 60ms/step\n",
      "Epoch 429/2000\n",
      "1/1 - 0s - loss: 2.5512 - val_loss: 2.6149 - 63ms/epoch - 63ms/step\n",
      "Epoch 430/2000\n",
      "1/1 - 0s - loss: 2.5505 - val_loss: 2.6143 - 61ms/epoch - 61ms/step\n",
      "Epoch 431/2000\n",
      "1/1 - 0s - loss: 2.5495 - val_loss: 2.6131 - 55ms/epoch - 55ms/step\n",
      "Epoch 432/2000\n",
      "1/1 - 0s - loss: 2.5490 - val_loss: 2.6122 - 61ms/epoch - 61ms/step\n",
      "Epoch 433/2000\n",
      "1/1 - 0s - loss: 2.5480 - val_loss: 2.6121 - 47ms/epoch - 47ms/step\n",
      "Epoch 434/2000\n",
      "1/1 - 0s - loss: 2.5473 - val_loss: 2.6109 - 55ms/epoch - 55ms/step\n",
      "Epoch 435/2000\n",
      "1/1 - 0s - loss: 2.5465 - val_loss: 2.6099 - 48ms/epoch - 48ms/step\n",
      "Epoch 436/2000\n",
      "1/1 - 0s - loss: 2.5457 - val_loss: 2.6087 - 64ms/epoch - 64ms/step\n",
      "Epoch 437/2000\n",
      "1/1 - 0s - loss: 2.5450 - val_loss: 2.6093 - 46ms/epoch - 46ms/step\n",
      "Epoch 438/2000\n",
      "1/1 - 0s - loss: 2.5441 - val_loss: 2.6082 - 48ms/epoch - 48ms/step\n",
      "Epoch 439/2000\n",
      "1/1 - 0s - loss: 2.5435 - val_loss: 2.6071 - 52ms/epoch - 52ms/step\n",
      "Epoch 440/2000\n",
      "1/1 - 0s - loss: 2.5426 - val_loss: 2.6061 - 45ms/epoch - 45ms/step\n",
      "Epoch 441/2000\n",
      "1/1 - 0s - loss: 2.5419 - val_loss: 2.6066 - 57ms/epoch - 57ms/step\n",
      "Epoch 442/2000\n",
      "1/1 - 0s - loss: 2.5411 - val_loss: 2.6047 - 66ms/epoch - 66ms/step\n",
      "Epoch 443/2000\n",
      "1/1 - 0s - loss: 2.5405 - val_loss: 2.6053 - 43ms/epoch - 43ms/step\n",
      "Epoch 444/2000\n",
      "1/1 - 0s - loss: 2.5398 - val_loss: 2.6031 - 47ms/epoch - 47ms/step\n",
      "Epoch 445/2000\n",
      "1/1 - 0s - loss: 2.5393 - val_loss: 2.6045 - 44ms/epoch - 44ms/step\n",
      "Epoch 446/2000\n",
      "1/1 - 0s - loss: 2.5384 - val_loss: 2.6019 - 48ms/epoch - 48ms/step\n",
      "Epoch 447/2000\n",
      "1/1 - 0s - loss: 2.5376 - val_loss: 2.6017 - 65ms/epoch - 65ms/step\n",
      "Epoch 448/2000\n",
      "1/1 - 0s - loss: 2.5369 - val_loss: 2.6013 - 61ms/epoch - 61ms/step\n",
      "Epoch 449/2000\n",
      "1/1 - 0s - loss: 2.5359 - val_loss: 2.5996 - 49ms/epoch - 49ms/step\n",
      "Epoch 450/2000\n",
      "1/1 - 0s - loss: 2.5351 - val_loss: 2.5998 - 46ms/epoch - 46ms/step\n",
      "Epoch 451/2000\n",
      "1/1 - 0s - loss: 2.5343 - val_loss: 2.5983 - 46ms/epoch - 46ms/step\n",
      "Epoch 452/2000\n",
      "1/1 - 0s - loss: 2.5340 - val_loss: 2.5998 - 43ms/epoch - 43ms/step\n",
      "Epoch 453/2000\n",
      "1/1 - 0s - loss: 2.5333 - val_loss: 2.5975 - 55ms/epoch - 55ms/step\n",
      "Epoch 454/2000\n",
      "1/1 - 0s - loss: 2.5324 - val_loss: 2.5967 - 44ms/epoch - 44ms/step\n",
      "Epoch 455/2000\n",
      "1/1 - 0s - loss: 2.5313 - val_loss: 2.5952 - 58ms/epoch - 58ms/step\n",
      "Epoch 456/2000\n",
      "1/1 - 0s - loss: 2.5304 - val_loss: 2.5954 - 56ms/epoch - 56ms/step\n",
      "Epoch 457/2000\n",
      "1/1 - 0s - loss: 2.5298 - val_loss: 2.5936 - 51ms/epoch - 51ms/step\n",
      "Epoch 458/2000\n",
      "1/1 - 0s - loss: 2.5291 - val_loss: 2.5936 - 62ms/epoch - 62ms/step\n",
      "Epoch 459/2000\n",
      "1/1 - 0s - loss: 2.5283 - val_loss: 2.5926 - 54ms/epoch - 54ms/step\n",
      "Epoch 460/2000\n",
      "1/1 - 0s - loss: 2.5274 - val_loss: 2.5927 - 57ms/epoch - 57ms/step\n",
      "Epoch 461/2000\n",
      "1/1 - 0s - loss: 2.5268 - val_loss: 2.5918 - 45ms/epoch - 45ms/step\n",
      "Epoch 462/2000\n",
      "1/1 - 0s - loss: 2.5263 - val_loss: 2.5901 - 42ms/epoch - 42ms/step\n",
      "Epoch 463/2000\n",
      "1/1 - 0s - loss: 2.5257 - val_loss: 2.5923 - 42ms/epoch - 42ms/step\n",
      "Epoch 464/2000\n",
      "1/1 - 0s - loss: 2.5258 - val_loss: 2.5903 - 40ms/epoch - 40ms/step\n",
      "Epoch 465/2000\n",
      "1/1 - 0s - loss: 2.5272 - val_loss: 2.5924 - 43ms/epoch - 43ms/step\n",
      "Epoch 466/2000\n",
      "1/1 - 0s - loss: 2.5248 - val_loss: 2.5894 - 57ms/epoch - 57ms/step\n",
      "Epoch 467/2000\n",
      "1/1 - 0s - loss: 2.5229 - val_loss: 2.5878 - 75ms/epoch - 75ms/step\n",
      "Epoch 468/2000\n",
      "1/1 - 0s - loss: 2.5246 - val_loss: 2.5921 - 50ms/epoch - 50ms/step\n",
      "Epoch 469/2000\n",
      "1/1 - 0s - loss: 2.5244 - val_loss: 2.5896 - 48ms/epoch - 48ms/step\n",
      "Epoch 470/2000\n",
      "1/1 - 0s - loss: 2.5240 - val_loss: 2.5902 - 44ms/epoch - 44ms/step\n",
      "Epoch 471/2000\n",
      "1/1 - 0s - loss: 2.5223 - val_loss: 2.5925 - 52ms/epoch - 52ms/step\n",
      "Epoch 472/2000\n",
      "1/1 - 0s - loss: 2.5231 - val_loss: 2.5858 - 44ms/epoch - 44ms/step\n",
      "Epoch 473/2000\n",
      "1/1 - 0s - loss: 2.5208 - val_loss: 2.5843 - 57ms/epoch - 57ms/step\n",
      "Epoch 474/2000\n",
      "1/1 - 0s - loss: 2.5196 - val_loss: 2.5864 - 59ms/epoch - 59ms/step\n",
      "Epoch 475/2000\n",
      "1/1 - 0s - loss: 2.5196 - val_loss: 2.5847 - 42ms/epoch - 42ms/step\n",
      "Epoch 476/2000\n",
      "1/1 - 0s - loss: 2.5187 - val_loss: 2.5853 - 45ms/epoch - 45ms/step\n",
      "Epoch 477/2000\n",
      "1/1 - 0s - loss: 2.5172 - val_loss: 2.5854 - 42ms/epoch - 42ms/step\n",
      "Epoch 478/2000\n",
      "1/1 - 0s - loss: 2.5171 - val_loss: 2.5811 - 52ms/epoch - 52ms/step\n",
      "Epoch 479/2000\n",
      "1/1 - 0s - loss: 2.5153 - val_loss: 2.5813 - 42ms/epoch - 42ms/step\n",
      "Epoch 480/2000\n",
      "1/1 - 0s - loss: 2.5152 - val_loss: 2.5817 - 57ms/epoch - 57ms/step\n",
      "Epoch 481/2000\n",
      "1/1 - 0s - loss: 2.5142 - val_loss: 2.5796 - 45ms/epoch - 45ms/step\n",
      "Epoch 482/2000\n",
      "1/1 - 0s - loss: 2.5136 - val_loss: 2.5798 - 55ms/epoch - 55ms/step\n",
      "Epoch 483/2000\n",
      "1/1 - 0s - loss: 2.5125 - val_loss: 2.5811 - 42ms/epoch - 42ms/step\n",
      "Epoch 484/2000\n",
      "1/1 - 0s - loss: 2.5121 - val_loss: 2.5766 - 46ms/epoch - 46ms/step\n",
      "Epoch 485/2000\n",
      "1/1 - 0s - loss: 2.5109 - val_loss: 2.5755 - 51ms/epoch - 51ms/step\n",
      "Epoch 486/2000\n",
      "1/1 - 0s - loss: 2.5105 - val_loss: 2.5767 - 42ms/epoch - 42ms/step\n",
      "Epoch 487/2000\n",
      "1/1 - 0s - loss: 2.5095 - val_loss: 2.5755 - 43ms/epoch - 43ms/step\n",
      "Epoch 488/2000\n",
      "1/1 - 0s - loss: 2.5088 - val_loss: 2.5746 - 45ms/epoch - 45ms/step\n",
      "Epoch 489/2000\n",
      "1/1 - 0s - loss: 2.5080 - val_loss: 2.5742 - 58ms/epoch - 58ms/step\n",
      "Epoch 490/2000\n",
      "1/1 - 0s - loss: 2.5073 - val_loss: 2.5717 - 56ms/epoch - 56ms/step\n",
      "Epoch 491/2000\n",
      "1/1 - 0s - loss: 2.5066 - val_loss: 2.5718 - 46ms/epoch - 46ms/step\n",
      "Epoch 492/2000\n",
      "1/1 - 0s - loss: 2.5057 - val_loss: 2.5721 - 48ms/epoch - 48ms/step\n",
      "Epoch 493/2000\n",
      "1/1 - 0s - loss: 2.5049 - val_loss: 2.5709 - 50ms/epoch - 50ms/step\n",
      "Epoch 494/2000\n",
      "1/1 - 0s - loss: 2.5044 - val_loss: 2.5708 - 57ms/epoch - 57ms/step\n",
      "Epoch 495/2000\n",
      "1/1 - 0s - loss: 2.5034 - val_loss: 2.5698 - 60ms/epoch - 60ms/step\n",
      "Epoch 496/2000\n",
      "1/1 - 0s - loss: 2.5027 - val_loss: 2.5684 - 45ms/epoch - 45ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 497/2000\n",
      "1/1 - 0s - loss: 2.5021 - val_loss: 2.5689 - 48ms/epoch - 48ms/step\n",
      "Epoch 498/2000\n",
      "1/1 - 0s - loss: 2.5012 - val_loss: 2.5679 - 46ms/epoch - 46ms/step\n",
      "Epoch 499/2000\n",
      "1/1 - 0s - loss: 2.5005 - val_loss: 2.5676 - 64ms/epoch - 64ms/step\n",
      "Epoch 500/2000\n",
      "1/1 - 0s - loss: 2.4998 - val_loss: 2.5661 - 42ms/epoch - 42ms/step\n",
      "Epoch 501/2000\n",
      "1/1 - 0s - loss: 2.4993 - val_loss: 2.5665 - 41ms/epoch - 41ms/step\n",
      "Epoch 502/2000\n",
      "1/1 - 0s - loss: 2.4989 - val_loss: 2.5660 - 46ms/epoch - 46ms/step\n",
      "Epoch 503/2000\n",
      "1/1 - 0s - loss: 2.4987 - val_loss: 2.5670 - 46ms/epoch - 46ms/step\n",
      "Epoch 504/2000\n",
      "1/1 - 0s - loss: 2.4989 - val_loss: 2.5653 - 50ms/epoch - 50ms/step\n",
      "Epoch 505/2000\n",
      "1/1 - 0s - loss: 2.4983 - val_loss: 2.5650 - 43ms/epoch - 43ms/step\n",
      "Epoch 506/2000\n",
      "1/1 - 0s - loss: 2.4964 - val_loss: 2.5636 - 47ms/epoch - 47ms/step\n",
      "Epoch 507/2000\n",
      "1/1 - 0s - loss: 2.4951 - val_loss: 2.5638 - 41ms/epoch - 41ms/step\n",
      "Epoch 508/2000\n",
      "1/1 - 0s - loss: 2.4954 - val_loss: 2.5653 - 48ms/epoch - 48ms/step\n",
      "Epoch 509/2000\n",
      "1/1 - 0s - loss: 2.4977 - val_loss: 2.5636 - 50ms/epoch - 50ms/step\n",
      "Epoch 510/2000\n",
      "1/1 - 0s - loss: 2.4966 - val_loss: 2.5622 - 66ms/epoch - 66ms/step\n",
      "Epoch 511/2000\n",
      "1/1 - 0s - loss: 2.4934 - val_loss: 2.5609 - 43ms/epoch - 43ms/step\n",
      "Epoch 512/2000\n",
      "1/1 - 0s - loss: 2.4925 - val_loss: 2.5620 - 43ms/epoch - 43ms/step\n",
      "Epoch 513/2000\n",
      "1/1 - 0s - loss: 2.4939 - val_loss: 2.5614 - 44ms/epoch - 44ms/step\n",
      "Epoch 514/2000\n",
      "1/1 - 0s - loss: 2.4924 - val_loss: 2.5601 - 46ms/epoch - 46ms/step\n",
      "Epoch 515/2000\n",
      "1/1 - 0s - loss: 2.4903 - val_loss: 2.5597 - 55ms/epoch - 55ms/step\n",
      "Epoch 516/2000\n",
      "1/1 - 0s - loss: 2.4907 - val_loss: 2.5603 - 56ms/epoch - 56ms/step\n",
      "Epoch 517/2000\n",
      "1/1 - 0s - loss: 2.4912 - val_loss: 2.5581 - 52ms/epoch - 52ms/step\n",
      "Epoch 518/2000\n",
      "1/1 - 0s - loss: 2.4891 - val_loss: 2.5569 - 46ms/epoch - 46ms/step\n",
      "Epoch 519/2000\n",
      "1/1 - 0s - loss: 2.4881 - val_loss: 2.5587 - 43ms/epoch - 43ms/step\n",
      "Epoch 520/2000\n",
      "1/1 - 0s - loss: 2.4882 - val_loss: 2.5573 - 43ms/epoch - 43ms/step\n",
      "Epoch 521/2000\n",
      "1/1 - 0s - loss: 2.4875 - val_loss: 2.5550 - 100ms/epoch - 100ms/step\n",
      "Epoch 522/2000\n",
      "1/1 - 0s - loss: 2.4865 - val_loss: 2.5585 - 57ms/epoch - 57ms/step\n",
      "Epoch 523/2000\n",
      "1/1 - 0s - loss: 2.4866 - val_loss: 2.5553 - 53ms/epoch - 53ms/step\n",
      "Epoch 524/2000\n",
      "1/1 - 0s - loss: 2.4862 - val_loss: 2.5555 - 45ms/epoch - 45ms/step\n",
      "Epoch 525/2000\n",
      "1/1 - 0s - loss: 2.4851 - val_loss: 2.5558 - 47ms/epoch - 47ms/step\n",
      "Epoch 526/2000\n",
      "1/1 - 0s - loss: 2.4846 - val_loss: 2.5531 - 46ms/epoch - 46ms/step\n",
      "Epoch 527/2000\n",
      "1/1 - 0s - loss: 2.4848 - val_loss: 2.5552 - 44ms/epoch - 44ms/step\n",
      "Epoch 528/2000\n",
      "1/1 - 0s - loss: 2.4835 - val_loss: 2.5531 - 42ms/epoch - 42ms/step\n",
      "Epoch 529/2000\n",
      "1/1 - 0s - loss: 2.4827 - val_loss: 2.5516 - 44ms/epoch - 44ms/step\n",
      "Epoch 530/2000\n",
      "1/1 - 0s - loss: 2.4823 - val_loss: 2.5532 - 40ms/epoch - 40ms/step\n",
      "Epoch 531/2000\n",
      "1/1 - 0s - loss: 2.4816 - val_loss: 2.5508 - 41ms/epoch - 41ms/step\n",
      "Epoch 532/2000\n",
      "1/1 - 0s - loss: 2.4810 - val_loss: 2.5511 - 40ms/epoch - 40ms/step\n",
      "Epoch 533/2000\n",
      "1/1 - 0s - loss: 2.4807 - val_loss: 2.5520 - 39ms/epoch - 39ms/step\n",
      "Epoch 534/2000\n",
      "1/1 - 0s - loss: 2.4805 - val_loss: 2.5525 - 41ms/epoch - 41ms/step\n",
      "Epoch 535/2000\n",
      "1/1 - 0s - loss: 2.4802 - val_loss: 2.5486 - 36ms/epoch - 36ms/step\n",
      "Epoch 536/2000\n",
      "1/1 - 0s - loss: 2.4802 - val_loss: 2.5508 - 34ms/epoch - 34ms/step\n",
      "Epoch 537/2000\n",
      "1/1 - 0s - loss: 2.4791 - val_loss: 2.5490 - 35ms/epoch - 35ms/step\n",
      "Epoch 538/2000\n",
      "1/1 - 0s - loss: 2.4776 - val_loss: 2.5493 - 35ms/epoch - 35ms/step\n",
      "Epoch 539/2000\n",
      "1/1 - 0s - loss: 2.4772 - val_loss: 2.5501 - 33ms/epoch - 33ms/step\n",
      "Epoch 540/2000\n",
      "1/1 - 0s - loss: 2.4772 - val_loss: 2.5464 - 35ms/epoch - 35ms/step\n",
      "Epoch 541/2000\n",
      "1/1 - 0s - loss: 2.4770 - val_loss: 2.5470 - 34ms/epoch - 34ms/step\n",
      "Epoch 542/2000\n",
      "1/1 - 0s - loss: 2.4757 - val_loss: 2.5480 - 33ms/epoch - 33ms/step\n",
      "Epoch 543/2000\n",
      "1/1 - 0s - loss: 2.4753 - val_loss: 2.5450 - 35ms/epoch - 35ms/step\n",
      "Epoch 544/2000\n",
      "1/1 - 0s - loss: 2.4747 - val_loss: 2.5463 - 34ms/epoch - 34ms/step\n",
      "Epoch 545/2000\n",
      "1/1 - 0s - loss: 2.4740 - val_loss: 2.5462 - 32ms/epoch - 32ms/step\n",
      "Epoch 546/2000\n",
      "1/1 - 0s - loss: 2.4736 - val_loss: 2.5453 - 33ms/epoch - 33ms/step\n",
      "Epoch 547/2000\n",
      "1/1 - 0s - loss: 2.4738 - val_loss: 2.5472 - 33ms/epoch - 33ms/step\n",
      "Epoch 548/2000\n",
      "1/1 - 0s - loss: 2.4747 - val_loss: 2.5464 - 44ms/epoch - 44ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 10\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "encoder_output = word_embeddings\n",
    "\n",
    "for i in range(1):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "aa33aaff",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fe4d2103",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0, 0.013756235)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0005b730",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 1s - loss: 3.6980 - val_loss: 3.4224 - 1s/epoch - 1s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.3798 - val_loss: 3.2408 - 136ms/epoch - 136ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.1896 - val_loss: 3.1299 - 109ms/epoch - 109ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.0727 - val_loss: 3.0596 - 145ms/epoch - 145ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 2.9989 - val_loss: 3.0090 - 129ms/epoch - 129ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 2.9465 - val_loss: 2.9705 - 104ms/epoch - 104ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 2.9068 - val_loss: 2.9417 - 123ms/epoch - 123ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 2.8766 - val_loss: 2.9195 - 109ms/epoch - 109ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 2.8526 - val_loss: 2.9008 - 109ms/epoch - 109ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 2.8317 - val_loss: 2.8837 - 137ms/epoch - 137ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 2.8122 - val_loss: 2.8683 - 120ms/epoch - 120ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 2.7943 - val_loss: 2.8546 - 127ms/epoch - 127ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 2.7785 - val_loss: 2.8420 - 109ms/epoch - 109ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 2.7640 - val_loss: 2.8291 - 106ms/epoch - 106ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 2.7496 - val_loss: 2.8158 - 108ms/epoch - 108ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 2.7347 - val_loss: 2.8025 - 114ms/epoch - 114ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 2.7197 - val_loss: 2.7900 - 134ms/epoch - 134ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 2.7051 - val_loss: 2.7787 - 127ms/epoch - 127ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 2.6915 - val_loss: 2.7691 - 164ms/epoch - 164ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 2.6796 - val_loss: 2.7611 - 130ms/epoch - 130ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 2.6692 - val_loss: 2.7543 - 149ms/epoch - 149ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 2.6601 - val_loss: 2.7485 - 127ms/epoch - 127ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 2.6521 - val_loss: 2.7438 - 115ms/epoch - 115ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 2.6452 - val_loss: 2.7402 - 114ms/epoch - 114ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 2.6392 - val_loss: 2.7376 - 145ms/epoch - 145ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 2.6340 - val_loss: 2.7358 - 143ms/epoch - 143ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 2.6294 - val_loss: 2.7343 - 115ms/epoch - 115ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.6251 - val_loss: 2.7328 - 130ms/epoch - 130ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.6210 - val_loss: 2.7312 - 116ms/epoch - 116ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.6171 - val_loss: 2.7292 - 166ms/epoch - 166ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.6133 - val_loss: 2.7272 - 169ms/epoch - 169ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.6096 - val_loss: 2.7253 - 113ms/epoch - 113ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.6061 - val_loss: 2.7237 - 128ms/epoch - 128ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.6027 - val_loss: 2.7225 - 118ms/epoch - 118ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.5995 - val_loss: 2.7215 - 163ms/epoch - 163ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.5964 - val_loss: 2.7209 - 168ms/epoch - 168ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.5936 - val_loss: 2.7204 - 122ms/epoch - 122ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.5910 - val_loss: 2.7199 - 123ms/epoch - 123ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.5886 - val_loss: 2.7193 - 119ms/epoch - 119ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.5863 - val_loss: 2.7186 - 133ms/epoch - 133ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.5842 - val_loss: 2.7177 - 140ms/epoch - 140ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.5820 - val_loss: 2.7167 - 143ms/epoch - 143ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.5798 - val_loss: 2.7158 - 110ms/epoch - 110ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.5776 - val_loss: 2.7149 - 110ms/epoch - 110ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.5755 - val_loss: 2.7142 - 135ms/epoch - 135ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.5734 - val_loss: 2.7137 - 146ms/epoch - 146ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.5714 - val_loss: 2.7132 - 134ms/epoch - 134ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.5694 - val_loss: 2.7128 - 151ms/epoch - 151ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.5673 - val_loss: 2.7125 - 128ms/epoch - 128ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.5653 - val_loss: 2.7123 - 152ms/epoch - 152ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.5633 - val_loss: 2.7122 - 131ms/epoch - 131ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.5613 - val_loss: 2.7121 - 154ms/epoch - 154ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.5593 - val_loss: 2.7119 - 137ms/epoch - 137ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.5573 - val_loss: 2.7116 - 188ms/epoch - 188ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.5552 - val_loss: 2.7113 - 143ms/epoch - 143ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.5532 - val_loss: 2.7111 - 142ms/epoch - 142ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.5511 - val_loss: 2.7109 - 123ms/epoch - 123ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.5490 - val_loss: 2.7107 - 117ms/epoch - 117ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.5469 - val_loss: 2.7105 - 132ms/epoch - 132ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.5447 - val_loss: 2.7102 - 177ms/epoch - 177ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.5425 - val_loss: 2.7099 - 126ms/epoch - 126ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.5403 - val_loss: 2.7096 - 115ms/epoch - 115ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.5380 - val_loss: 2.7093 - 114ms/epoch - 114ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.5358 - val_loss: 2.7090 - 152ms/epoch - 152ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.5334 - val_loss: 2.7089 - 121ms/epoch - 121ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.5311 - val_loss: 2.7086 - 140ms/epoch - 140ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.5287 - val_loss: 2.7085 - 160ms/epoch - 160ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.5263 - val_loss: 2.7083 - 126ms/epoch - 126ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.5238 - val_loss: 2.7081 - 143ms/epoch - 143ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.5213 - val_loss: 2.7080 - 167ms/epoch - 167ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.5187 - val_loss: 2.7078 - 130ms/epoch - 130ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.5161 - val_loss: 2.7078 - 118ms/epoch - 118ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.5134 - val_loss: 2.7077 - 125ms/epoch - 125ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.5107 - val_loss: 2.7075 - 153ms/epoch - 153ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.5079 - val_loss: 2.7074 - 133ms/epoch - 133ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.5050 - val_loss: 2.7072 - 129ms/epoch - 129ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.5021 - val_loss: 2.7070 - 144ms/epoch - 144ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.4991 - val_loss: 2.7068 - 167ms/epoch - 167ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.4961 - val_loss: 2.7065 - 118ms/epoch - 118ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.4930 - val_loss: 2.7061 - 123ms/epoch - 123ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.4897 - val_loss: 2.7056 - 143ms/epoch - 143ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.4864 - val_loss: 2.7050 - 175ms/epoch - 175ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.4830 - val_loss: 2.7044 - 116ms/epoch - 116ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.4795 - val_loss: 2.7036 - 113ms/epoch - 113ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.4759 - val_loss: 2.7029 - 115ms/epoch - 115ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.4722 - val_loss: 2.7021 - 110ms/epoch - 110ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.4685 - val_loss: 2.7014 - 124ms/epoch - 124ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.4646 - val_loss: 2.7007 - 114ms/epoch - 114ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.4607 - val_loss: 2.6999 - 152ms/epoch - 152ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.4567 - val_loss: 2.6992 - 113ms/epoch - 113ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.4525 - val_loss: 2.6984 - 124ms/epoch - 124ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.4482 - val_loss: 2.6975 - 162ms/epoch - 162ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.4439 - val_loss: 2.6966 - 125ms/epoch - 125ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.4394 - val_loss: 2.6955 - 111ms/epoch - 111ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.4349 - val_loss: 2.6943 - 111ms/epoch - 111ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.4302 - val_loss: 2.6930 - 118ms/epoch - 118ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 2.4254 - val_loss: 2.6915 - 112ms/epoch - 112ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 2.4205 - val_loss: 2.6899 - 172ms/epoch - 172ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 2.4154 - val_loss: 2.6882 - 146ms/epoch - 146ms/step\n",
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 2.4102 - val_loss: 2.6864 - 123ms/epoch - 123ms/step\n",
      "Epoch 101/2000\n",
      "1/1 - 0s - loss: 2.4047 - val_loss: 2.6844 - 169ms/epoch - 169ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 2.3992 - val_loss: 2.6823 - 106ms/epoch - 106ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 2.3935 - val_loss: 2.6801 - 109ms/epoch - 109ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 2.3876 - val_loss: 2.6778 - 127ms/epoch - 127ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 2.3816 - val_loss: 2.6752 - 118ms/epoch - 118ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 2.3754 - val_loss: 2.6724 - 115ms/epoch - 115ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 2.3691 - val_loss: 2.6695 - 140ms/epoch - 140ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 2.3626 - val_loss: 2.6664 - 121ms/epoch - 121ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 2.3560 - val_loss: 2.6632 - 111ms/epoch - 111ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 2.3492 - val_loss: 2.6598 - 117ms/epoch - 117ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 2.3423 - val_loss: 2.6562 - 131ms/epoch - 131ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 2.3352 - val_loss: 2.6525 - 140ms/epoch - 140ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 2.3278 - val_loss: 2.6487 - 151ms/epoch - 151ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 2.3203 - val_loss: 2.6449 - 177ms/epoch - 177ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 2.3126 - val_loss: 2.6411 - 132ms/epoch - 132ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 2.3047 - val_loss: 2.6369 - 109ms/epoch - 109ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 2.2965 - val_loss: 2.6328 - 128ms/epoch - 128ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 2.2883 - val_loss: 2.6285 - 118ms/epoch - 118ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 2.2798 - val_loss: 2.6240 - 198ms/epoch - 198ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 2.2711 - val_loss: 2.6193 - 223ms/epoch - 223ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 2.2623 - val_loss: 2.6144 - 143ms/epoch - 143ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 2.2532 - val_loss: 2.6093 - 113ms/epoch - 113ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 2.2440 - val_loss: 2.6042 - 136ms/epoch - 136ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 2.2348 - val_loss: 2.5987 - 155ms/epoch - 155ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 2.2253 - val_loss: 2.5931 - 117ms/epoch - 117ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 2.2158 - val_loss: 2.5873 - 148ms/epoch - 148ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 2.2061 - val_loss: 2.5816 - 140ms/epoch - 140ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 2.1963 - val_loss: 2.5761 - 155ms/epoch - 155ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 2.1865 - val_loss: 2.5708 - 228ms/epoch - 228ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 2.1767 - val_loss: 2.5659 - 211ms/epoch - 211ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 2.1670 - val_loss: 2.5613 - 174ms/epoch - 174ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 2.1573 - val_loss: 2.5569 - 174ms/epoch - 174ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 2.1477 - val_loss: 2.5530 - 187ms/epoch - 187ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 2.1382 - val_loss: 2.5492 - 113ms/epoch - 113ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 2.1287 - val_loss: 2.5455 - 164ms/epoch - 164ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 2.1192 - val_loss: 2.5423 - 117ms/epoch - 117ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 2.1099 - val_loss: 2.5396 - 156ms/epoch - 156ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 2.1007 - val_loss: 2.5374 - 154ms/epoch - 154ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 2.0916 - val_loss: 2.5353 - 137ms/epoch - 137ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 2.0827 - val_loss: 2.5343 - 151ms/epoch - 151ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 2.0739 - val_loss: 2.5338 - 148ms/epoch - 148ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 2.0654 - val_loss: 2.5335 - 119ms/epoch - 119ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 2.0571 - val_loss: 2.5345 - 111ms/epoch - 111ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 2.0496 - val_loss: 2.5375 - 122ms/epoch - 122ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 2.0443 - val_loss: 2.5435 - 146ms/epoch - 146ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 2.0420 - val_loss: 2.5429 - 133ms/epoch - 133ms/step\n",
      "Epoch 147/2000\n",
      "1/1 - 0s - loss: 2.0341 - val_loss: 2.5373 - 168ms/epoch - 168ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 100\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "encoder_output = word_embeddings\n",
    "\n",
    "for i in range(1):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "beb07cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "dbdd4a69",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0, 0.0035523782)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "54f31bac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[10,  0,  1, 10],\n",
       "       [10,  0,  2, 10],\n",
       "       [10,  0,  3, 10],\n",
       "       ...,\n",
       "       [19, 18, 15, 19],\n",
       "       [19, 18, 16, 19],\n",
       "       [19, 18, 17, 19]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_masked_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "458ae205",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
