{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "674d5c38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5a29d33b",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d96c3193",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_module(query, key, value, embed_dim, num_head, i):\n",
    "    \n",
    "    # Multi headed self-attention\n",
    "    attention_output = layers.MultiHeadAttention(\n",
    "        num_heads=num_head,\n",
    "        key_dim=embed_dim // num_head,\n",
    "        name=\"encoder_{}/multiheadattention\".format(i)\n",
    "    )(query, key, value, use_causal_mask=True)\n",
    "    \n",
    "    # Add & Normalize\n",
    "    attention_output = layers.Add()([query, attention_output])  # Skip Connection\n",
    "    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)\n",
    "    \n",
    "    # Feedforward network\n",
    "    ff_net = keras.models.Sequential([\n",
    "        layers.Dense(2 * embed_dim, activation='relu', name=\"encoder_{}/ffn_dense_1\".format(i)),\n",
    "        layers.Dense(embed_dim, name=\"encoder_{}/ffn_dense_2\".format(i)),\n",
    "    ])\n",
    "\n",
    "    # Apply Feedforward network\n",
    "    ffn_output = ff_net(attention_output)\n",
    "\n",
    "    # Add & Normalize\n",
    "    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection\n",
    "    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)\n",
    "    \n",
    "    return ffn_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fd7afdea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sinusoidal_embeddings(sequence_length, embedding_dim):\n",
    "    position_enc = np.array([\n",
    "        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]\n",
    "        if pos != 0 else np.zeros(embedding_dim)\n",
    "        for pos in range(sequence_length)\n",
    "    ])\n",
    "    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i\n",
    "    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1\n",
    "    return tf.cast(position_enc, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "95bab1fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 20 # vocab_size\n",
    "\n",
    "vocabs = ['word_' + str(i) for i in range(N)]\n",
    "\n",
    "vocab_map = {}\n",
    "for i in range(len(vocabs)):\n",
    "    vocab_map[vocabs[i]] = i\n",
    "    \n",
    "pairs = []\n",
    "\n",
    "for i in vocabs:\n",
    "    for j in vocabs:\n",
    "        for k in vocabs:\n",
    "            if i != j and i != k and j != k:\n",
    "                pairs.append((i,j,k))\n",
    "            \n",
    "indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])\n",
    "\n",
    "pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]\n",
    "pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4f71cf58",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences_train = []\n",
    "sentences_number_train = []\n",
    "sentences_test = []\n",
    "sentences_number_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    sentences_train.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    sentences_test.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = []\n",
    "y_masked_labels_train = []\n",
    "x_masked_test = []\n",
    "y_masked_labels_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    x_masked_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_train.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    x_masked_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_test.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = np.array(x_masked_train)\n",
    "y_masked_labels_train = np.array(y_masked_labels_train)\n",
    "x_masked_test = np.array(x_masked_test)\n",
    "y_masked_labels_test = np.array(y_masked_labels_test)\n",
    "\n",
    "perm = np.random.permutation(len(x_masked_train))\n",
    "x_masked_train = x_masked_train[perm]\n",
    "y_masked_labels_train = y_masked_labels_train[perm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3e737b36",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 8s - loss: 3.2032 - val_loss: 3.1657 - 8s/epoch - 8s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.1614 - val_loss: 3.1427 - 148ms/epoch - 148ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.1373 - val_loss: 3.1242 - 152ms/epoch - 152ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.1174 - val_loss: 3.1082 - 154ms/epoch - 154ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.1000 - val_loss: 3.0936 - 142ms/epoch - 142ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.0841 - val_loss: 3.0807 - 143ms/epoch - 143ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.0698 - val_loss: 3.0697 - 142ms/epoch - 142ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.0573 - val_loss: 3.0597 - 156ms/epoch - 156ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.0455 - val_loss: 3.0505 - 145ms/epoch - 145ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 3.0346 - val_loss: 3.0416 - 161ms/epoch - 161ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 3.0242 - val_loss: 3.0318 - 146ms/epoch - 146ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 3.0133 - val_loss: 3.0208 - 138ms/epoch - 138ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 3.0016 - val_loss: 3.0096 - 144ms/epoch - 144ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 2.9899 - val_loss: 2.9977 - 146ms/epoch - 146ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 2.9775 - val_loss: 2.9851 - 150ms/epoch - 150ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 2.9644 - val_loss: 2.9734 - 149ms/epoch - 149ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 2.9516 - val_loss: 2.9619 - 145ms/epoch - 145ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 2.9386 - val_loss: 2.9499 - 141ms/epoch - 141ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 2.9253 - val_loss: 2.9394 - 140ms/epoch - 140ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 2.9135 - val_loss: 2.9304 - 137ms/epoch - 137ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 2.9035 - val_loss: 2.9219 - 139ms/epoch - 139ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 2.8940 - val_loss: 2.9126 - 136ms/epoch - 136ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 2.8839 - val_loss: 2.9032 - 140ms/epoch - 140ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 2.8740 - val_loss: 2.8936 - 140ms/epoch - 140ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 2.8645 - val_loss: 2.8841 - 143ms/epoch - 143ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 2.8558 - val_loss: 2.8741 - 135ms/epoch - 135ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 2.8468 - val_loss: 2.8644 - 135ms/epoch - 135ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.8380 - val_loss: 2.8543 - 140ms/epoch - 140ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.8284 - val_loss: 2.8444 - 137ms/epoch - 137ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.8186 - val_loss: 2.8366 - 218ms/epoch - 218ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.8106 - val_loss: 2.8296 - 138ms/epoch - 138ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.8041 - val_loss: 2.8217 - 148ms/epoch - 148ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.7962 - val_loss: 2.8148 - 144ms/epoch - 144ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.7894 - val_loss: 2.8082 - 135ms/epoch - 135ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.7832 - val_loss: 2.8009 - 132ms/epoch - 132ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.7762 - val_loss: 2.7940 - 126ms/epoch - 126ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.7692 - val_loss: 2.7873 - 124ms/epoch - 124ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.7625 - val_loss: 2.7800 - 124ms/epoch - 124ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.7551 - val_loss: 2.7722 - 126ms/epoch - 126ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.7472 - val_loss: 2.7646 - 131ms/epoch - 131ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.7396 - val_loss: 2.7564 - 149ms/epoch - 149ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.7314 - val_loss: 2.7484 - 129ms/epoch - 129ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.7235 - val_loss: 2.7423 - 130ms/epoch - 130ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.7170 - val_loss: 2.7369 - 133ms/epoch - 133ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.7112 - val_loss: 2.7313 - 135ms/epoch - 135ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.7055 - val_loss: 2.7254 - 139ms/epoch - 139ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.6999 - val_loss: 2.7203 - 132ms/epoch - 132ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.6954 - val_loss: 2.7145 - 138ms/epoch - 138ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.6899 - val_loss: 2.7090 - 131ms/epoch - 131ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.6840 - val_loss: 2.7032 - 126ms/epoch - 126ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.6782 - val_loss: 2.6969 - 131ms/epoch - 131ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.6721 - val_loss: 2.6908 - 133ms/epoch - 133ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.6658 - val_loss: 2.6847 - 138ms/epoch - 138ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.6598 - val_loss: 2.6781 - 135ms/epoch - 135ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.6532 - val_loss: 2.6720 - 134ms/epoch - 134ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.6467 - val_loss: 2.6653 - 133ms/epoch - 133ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.6404 - val_loss: 2.6582 - 138ms/epoch - 138ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.6334 - val_loss: 2.6511 - 136ms/epoch - 136ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.6265 - val_loss: 2.6447 - 136ms/epoch - 136ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.6202 - val_loss: 2.6384 - 142ms/epoch - 142ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.6141 - val_loss: 2.6321 - 147ms/epoch - 147ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.6078 - val_loss: 2.6254 - 148ms/epoch - 148ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.6012 - val_loss: 2.6183 - 186ms/epoch - 186ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.5947 - val_loss: 2.6116 - 192ms/epoch - 192ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.5882 - val_loss: 2.6040 - 156ms/epoch - 156ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.5821 - val_loss: 2.5998 - 160ms/epoch - 160ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.5773 - val_loss: 2.5930 - 161ms/epoch - 161ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.5729 - val_loss: 2.5853 - 160ms/epoch - 160ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.5638 - val_loss: 2.5783 - 154ms/epoch - 154ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.5569 - val_loss: 2.5752 - 159ms/epoch - 159ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.5546 - val_loss: 2.5699 - 160ms/epoch - 160ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.5479 - val_loss: 2.5658 - 162ms/epoch - 162ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.5438 - val_loss: 2.5610 - 161ms/epoch - 161ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.5403 - val_loss: 2.5545 - 166ms/epoch - 166ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.5333 - val_loss: 2.5495 - 199ms/epoch - 199ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.5284 - val_loss: 2.5437 - 183ms/epoch - 183ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.5237 - val_loss: 2.5374 - 180ms/epoch - 180ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.5172 - val_loss: 2.5336 - 176ms/epoch - 176ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.5136 - val_loss: 2.5278 - 179ms/epoch - 179ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.5085 - val_loss: 2.5229 - 173ms/epoch - 173ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.5036 - val_loss: 2.5188 - 176ms/epoch - 176ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.4996 - val_loss: 2.5130 - 177ms/epoch - 177ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.4943 - val_loss: 2.5078 - 174ms/epoch - 174ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.4894 - val_loss: 2.5040 - 180ms/epoch - 180ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.4851 - val_loss: 2.4986 - 178ms/epoch - 178ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.4799 - val_loss: 2.4940 - 173ms/epoch - 173ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.4753 - val_loss: 2.4906 - 180ms/epoch - 180ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.4713 - val_loss: 2.4856 - 178ms/epoch - 178ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.4668 - val_loss: 2.4823 - 178ms/epoch - 178ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.4629 - val_loss: 2.4793 - 174ms/epoch - 174ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.4607 - val_loss: 2.4766 - 176ms/epoch - 176ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.4577 - val_loss: 2.4707 - 203ms/epoch - 203ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.4526 - val_loss: 2.4633 - 174ms/epoch - 174ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.4450 - val_loss: 2.4626 - 195ms/epoch - 195ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.4446 - val_loss: 2.4587 - 176ms/epoch - 176ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.4418 - val_loss: 2.4504 - 170ms/epoch - 170ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 2.4331 - val_loss: 2.4515 - 157ms/epoch - 157ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 2.4342 - val_loss: 2.4457 - 172ms/epoch - 172ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 2.4294 - val_loss: 2.4383 - 167ms/epoch - 167ms/step\n",
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 2.4218 - val_loss: 2.4402 - 162ms/epoch - 162ms/step\n",
      "Epoch 101/2000\n",
      "1/1 - 0s - loss: 2.4236 - val_loss: 2.4312 - 195ms/epoch - 195ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 2.4148 - val_loss: 2.4286 - 164ms/epoch - 164ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 2.4126 - val_loss: 2.4255 - 178ms/epoch - 178ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 2.4095 - val_loss: 2.4192 - 162ms/epoch - 162ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 2.4031 - val_loss: 2.4169 - 158ms/epoch - 158ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 2.4015 - val_loss: 2.4116 - 156ms/epoch - 156ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 2.3964 - val_loss: 2.4092 - 149ms/epoch - 149ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 2.3935 - val_loss: 2.4051 - 148ms/epoch - 148ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 2.3897 - val_loss: 2.4016 - 141ms/epoch - 141ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 2.3869 - val_loss: 2.3982 - 141ms/epoch - 141ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 2.3827 - val_loss: 2.3953 - 143ms/epoch - 143ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 2.3798 - val_loss: 2.3916 - 142ms/epoch - 142ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 2.3771 - val_loss: 2.3875 - 136ms/epoch - 136ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 2.3723 - val_loss: 2.3857 - 141ms/epoch - 141ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 2.3702 - val_loss: 2.3813 - 153ms/epoch - 153ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 2.3666 - val_loss: 2.3786 - 147ms/epoch - 147ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 2.3638 - val_loss: 2.3753 - 133ms/epoch - 133ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 2.3600 - val_loss: 2.3719 - 142ms/epoch - 142ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 2.3567 - val_loss: 2.3690 - 144ms/epoch - 144ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 2.3543 - val_loss: 2.3658 - 147ms/epoch - 147ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 2.3511 - val_loss: 2.3639 - 136ms/epoch - 136ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 2.3485 - val_loss: 2.3604 - 137ms/epoch - 137ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 2.3454 - val_loss: 2.3576 - 137ms/epoch - 137ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 2.3425 - val_loss: 2.3551 - 136ms/epoch - 136ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 2.3397 - val_loss: 2.3526 - 138ms/epoch - 138ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 2.3370 - val_loss: 2.3497 - 139ms/epoch - 139ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 2.3347 - val_loss: 2.3486 - 139ms/epoch - 139ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 2.3330 - val_loss: 2.3482 - 133ms/epoch - 133ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 2.3329 - val_loss: 2.3522 - 124ms/epoch - 124ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 2.3361 - val_loss: 2.3458 - 129ms/epoch - 129ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 2.3305 - val_loss: 2.3377 - 137ms/epoch - 137ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 2.3218 - val_loss: 2.3354 - 134ms/epoch - 134ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 2.3195 - val_loss: 2.3365 - 126ms/epoch - 126ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 2.3214 - val_loss: 2.3370 - 126ms/epoch - 126ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 2.3204 - val_loss: 2.3279 - 140ms/epoch - 140ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 2.3119 - val_loss: 2.3267 - 136ms/epoch - 136ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 2.3110 - val_loss: 2.3301 - 137ms/epoch - 137ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 2.3137 - val_loss: 2.3217 - 138ms/epoch - 138ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 2.3061 - val_loss: 2.3193 - 139ms/epoch - 139ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 2.3032 - val_loss: 2.3205 - 129ms/epoch - 129ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 2.3039 - val_loss: 2.3145 - 171ms/epoch - 171ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 2.2988 - val_loss: 2.3120 - 134ms/epoch - 134ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 2.2960 - val_loss: 2.3135 - 132ms/epoch - 132ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 2.2963 - val_loss: 2.3078 - 139ms/epoch - 139ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 2.2914 - val_loss: 2.3055 - 143ms/epoch - 143ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 2.2897 - val_loss: 2.3056 - 135ms/epoch - 135ms/step\n",
      "Epoch 147/2000\n",
      "1/1 - 0s - loss: 2.2890 - val_loss: 2.3015 - 136ms/epoch - 136ms/step\n",
      "Epoch 148/2000\n",
      "1/1 - 0s - loss: 2.2850 - val_loss: 2.2997 - 170ms/epoch - 170ms/step\n",
      "Epoch 149/2000\n",
      "1/1 - 0s - loss: 2.2832 - val_loss: 2.2989 - 165ms/epoch - 165ms/step\n",
      "Epoch 150/2000\n",
      "1/1 - 0s - loss: 2.2821 - val_loss: 2.2953 - 152ms/epoch - 152ms/step\n",
      "Epoch 151/2000\n",
      "1/1 - 0s - loss: 2.2790 - val_loss: 2.2932 - 130ms/epoch - 130ms/step\n",
      "Epoch 152/2000\n",
      "1/1 - 0s - loss: 2.2765 - val_loss: 2.2930 - 136ms/epoch - 136ms/step\n",
      "Epoch 153/2000\n",
      "1/1 - 0s - loss: 2.2757 - val_loss: 2.2897 - 168ms/epoch - 168ms/step\n",
      "Epoch 154/2000\n",
      "1/1 - 0s - loss: 2.2734 - val_loss: 2.2871 - 132ms/epoch - 132ms/step\n",
      "Epoch 155/2000\n",
      "1/1 - 0s - loss: 2.2706 - val_loss: 2.2853 - 147ms/epoch - 147ms/step\n",
      "Epoch 156/2000\n",
      "1/1 - 0s - loss: 2.2687 - val_loss: 2.2837 - 173ms/epoch - 173ms/step\n",
      "Epoch 157/2000\n",
      "1/1 - 0s - loss: 2.2672 - val_loss: 2.2829 - 158ms/epoch - 158ms/step\n",
      "Epoch 158/2000\n",
      "1/1 - 0s - loss: 2.2658 - val_loss: 2.2797 - 128ms/epoch - 128ms/step\n",
      "Epoch 159/2000\n",
      "1/1 - 0s - loss: 2.2630 - val_loss: 2.2778 - 135ms/epoch - 135ms/step\n",
      "Epoch 160/2000\n",
      "1/1 - 0s - loss: 2.2609 - val_loss: 2.2762 - 146ms/epoch - 146ms/step\n",
      "Epoch 161/2000\n",
      "1/1 - 0s - loss: 2.2592 - val_loss: 2.2742 - 170ms/epoch - 170ms/step\n",
      "Epoch 162/2000\n",
      "1/1 - 0s - loss: 2.2577 - val_loss: 2.2733 - 140ms/epoch - 140ms/step\n",
      "Epoch 163/2000\n",
      "1/1 - 0s - loss: 2.2562 - val_loss: 2.2707 - 137ms/epoch - 137ms/step\n",
      "Epoch 164/2000\n",
      "1/1 - 0s - loss: 2.2541 - val_loss: 2.2694 - 140ms/epoch - 140ms/step\n",
      "Epoch 165/2000\n",
      "1/1 - 0s - loss: 2.2522 - val_loss: 2.2667 - 146ms/epoch - 146ms/step\n",
      "Epoch 166/2000\n",
      "1/1 - 0s - loss: 2.2501 - val_loss: 2.2651 - 163ms/epoch - 163ms/step\n",
      "Epoch 167/2000\n",
      "1/1 - 0s - loss: 2.2481 - val_loss: 2.2627 - 170ms/epoch - 170ms/step\n",
      "Epoch 168/2000\n",
      "1/1 - 0s - loss: 2.2461 - val_loss: 2.2610 - 161ms/epoch - 161ms/step\n",
      "Epoch 169/2000\n",
      "1/1 - 0s - loss: 2.2443 - val_loss: 2.2592 - 178ms/epoch - 178ms/step\n",
      "Epoch 170/2000\n",
      "1/1 - 0s - loss: 2.2425 - val_loss: 2.2574 - 166ms/epoch - 166ms/step\n",
      "Epoch 171/2000\n",
      "1/1 - 0s - loss: 2.2407 - val_loss: 2.2556 - 212ms/epoch - 212ms/step\n",
      "Epoch 172/2000\n",
      "1/1 - 0s - loss: 2.2389 - val_loss: 2.2538 - 204ms/epoch - 204ms/step\n",
      "Epoch 173/2000\n",
      "1/1 - 0s - loss: 2.2372 - val_loss: 2.2528 - 200ms/epoch - 200ms/step\n",
      "Epoch 174/2000\n",
      "1/1 - 0s - loss: 2.2357 - val_loss: 2.2509 - 184ms/epoch - 184ms/step\n",
      "Epoch 175/2000\n",
      "1/1 - 0s - loss: 2.2346 - val_loss: 2.2521 - 179ms/epoch - 179ms/step\n",
      "Epoch 176/2000\n",
      "1/1 - 0s - loss: 2.2348 - val_loss: 2.2518 - 170ms/epoch - 170ms/step\n",
      "Epoch 177/2000\n",
      "1/1 - 0s - loss: 2.2358 - val_loss: 2.2588 - 220ms/epoch - 220ms/step\n",
      "Epoch 178/2000\n",
      "1/1 - 0s - loss: 2.2409 - val_loss: 2.2497 - 212ms/epoch - 212ms/step\n",
      "Epoch 179/2000\n",
      "1/1 - 0s - loss: 2.2335 - val_loss: 2.2433 - 219ms/epoch - 219ms/step\n",
      "Epoch 180/2000\n",
      "1/1 - 0s - loss: 2.2262 - val_loss: 2.2446 - 212ms/epoch - 212ms/step\n",
      "Epoch 181/2000\n",
      "1/1 - 0s - loss: 2.2269 - val_loss: 2.2447 - 185ms/epoch - 185ms/step\n",
      "Epoch 182/2000\n",
      "1/1 - 0s - loss: 2.2285 - val_loss: 2.2451 - 164ms/epoch - 164ms/step\n",
      "Epoch 183/2000\n",
      "1/1 - 0s - loss: 2.2275 - val_loss: 2.2365 - 175ms/epoch - 175ms/step\n",
      "Epoch 184/2000\n",
      "1/1 - 0s - loss: 2.2197 - val_loss: 2.2400 - 163ms/epoch - 163ms/step\n",
      "Epoch 185/2000\n",
      "1/1 - 0s - loss: 2.2234 - val_loss: 2.2431 - 221ms/epoch - 221ms/step\n",
      "Epoch 186/2000\n",
      "1/1 - 0s - loss: 2.2245 - val_loss: 2.2327 - 287ms/epoch - 287ms/step\n",
      "Epoch 187/2000\n",
      "1/1 - 0s - loss: 2.2156 - val_loss: 2.2350 - 191ms/epoch - 191ms/step\n",
      "Epoch 188/2000\n",
      "1/1 - 0s - loss: 2.2183 - val_loss: 2.2366 - 196ms/epoch - 196ms/step\n",
      "Epoch 189/2000\n",
      "1/1 - 0s - loss: 2.2181 - val_loss: 2.2271 - 170ms/epoch - 170ms/step\n",
      "Epoch 190/2000\n",
      "1/1 - 0s - loss: 2.2096 - val_loss: 2.2328 - 160ms/epoch - 160ms/step\n",
      "Epoch 191/2000\n",
      "1/1 - 0s - loss: 2.2161 - val_loss: 2.2258 - 174ms/epoch - 174ms/step\n",
      "Epoch 192/2000\n",
      "1/1 - 0s - loss: 2.2081 - val_loss: 2.2255 - 272ms/epoch - 272ms/step\n",
      "Epoch 193/2000\n",
      "1/1 - 0s - loss: 2.2075 - val_loss: 2.2248 - 210ms/epoch - 210ms/step\n",
      "Epoch 194/2000\n",
      "1/1 - 0s - loss: 2.2080 - val_loss: 2.2188 - 209ms/epoch - 209ms/step\n",
      "Epoch 195/2000\n",
      "1/1 - 0s - loss: 2.2013 - val_loss: 2.2212 - 214ms/epoch - 214ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 196/2000\n",
      "1/1 - 0s - loss: 2.2029 - val_loss: 2.2157 - 208ms/epoch - 208ms/step\n",
      "Epoch 197/2000\n",
      "1/1 - 0s - loss: 2.1983 - val_loss: 2.2156 - 238ms/epoch - 238ms/step\n",
      "Epoch 198/2000\n",
      "1/1 - 0s - loss: 2.1981 - val_loss: 2.2147 - 198ms/epoch - 198ms/step\n",
      "Epoch 199/2000\n",
      "1/1 - 0s - loss: 2.1963 - val_loss: 2.2110 - 201ms/epoch - 201ms/step\n",
      "Epoch 200/2000\n",
      "1/1 - 0s - loss: 2.1936 - val_loss: 2.2101 - 209ms/epoch - 209ms/step\n",
      "Epoch 201/2000\n",
      "1/1 - 0s - loss: 2.1932 - val_loss: 2.2080 - 208ms/epoch - 208ms/step\n",
      "Epoch 202/2000\n",
      "1/1 - 0s - loss: 2.1902 - val_loss: 2.2071 - 246ms/epoch - 246ms/step\n",
      "Epoch 203/2000\n",
      "1/1 - 0s - loss: 2.1891 - val_loss: 2.2048 - 206ms/epoch - 206ms/step\n",
      "Epoch 204/2000\n",
      "1/1 - 0s - loss: 2.1878 - val_loss: 2.2025 - 217ms/epoch - 217ms/step\n",
      "Epoch 205/2000\n",
      "1/1 - 0s - loss: 2.1853 - val_loss: 2.2022 - 196ms/epoch - 196ms/step\n",
      "Epoch 206/2000\n",
      "1/1 - 0s - loss: 2.1844 - val_loss: 2.1997 - 162ms/epoch - 162ms/step\n",
      "Epoch 207/2000\n",
      "1/1 - 0s - loss: 2.1825 - val_loss: 2.1978 - 154ms/epoch - 154ms/step\n",
      "Epoch 208/2000\n",
      "1/1 - 0s - loss: 2.1805 - val_loss: 2.1974 - 159ms/epoch - 159ms/step\n",
      "Epoch 209/2000\n",
      "1/1 - 0s - loss: 2.1795 - val_loss: 2.1946 - 152ms/epoch - 152ms/step\n",
      "Epoch 210/2000\n",
      "1/1 - 0s - loss: 2.1776 - val_loss: 2.1929 - 156ms/epoch - 156ms/step\n",
      "Epoch 211/2000\n",
      "1/1 - 0s - loss: 2.1759 - val_loss: 2.1925 - 156ms/epoch - 156ms/step\n",
      "Epoch 212/2000\n",
      "1/1 - 0s - loss: 2.1748 - val_loss: 2.1901 - 192ms/epoch - 192ms/step\n",
      "Epoch 213/2000\n",
      "1/1 - 0s - loss: 2.1730 - val_loss: 2.1884 - 154ms/epoch - 154ms/step\n",
      "Epoch 214/2000\n",
      "1/1 - 0s - loss: 2.1714 - val_loss: 2.1879 - 160ms/epoch - 160ms/step\n",
      "Epoch 215/2000\n",
      "1/1 - 0s - loss: 2.1705 - val_loss: 2.1854 - 164ms/epoch - 164ms/step\n",
      "Epoch 216/2000\n",
      "1/1 - 0s - loss: 2.1684 - val_loss: 2.1842 - 177ms/epoch - 177ms/step\n",
      "Epoch 217/2000\n",
      "1/1 - 0s - loss: 2.1673 - val_loss: 2.1835 - 151ms/epoch - 151ms/step\n",
      "Epoch 218/2000\n",
      "1/1 - 0s - loss: 2.1660 - val_loss: 2.1816 - 178ms/epoch - 178ms/step\n",
      "Epoch 219/2000\n",
      "1/1 - 0s - loss: 2.1647 - val_loss: 2.1798 - 143ms/epoch - 143ms/step\n",
      "Epoch 220/2000\n",
      "1/1 - 0s - loss: 2.1628 - val_loss: 2.1787 - 153ms/epoch - 153ms/step\n",
      "Epoch 221/2000\n",
      "1/1 - 0s - loss: 2.1616 - val_loss: 2.1773 - 146ms/epoch - 146ms/step\n",
      "Epoch 222/2000\n",
      "1/1 - 0s - loss: 2.1604 - val_loss: 2.1760 - 159ms/epoch - 159ms/step\n",
      "Epoch 223/2000\n",
      "1/1 - 0s - loss: 2.1589 - val_loss: 2.1744 - 154ms/epoch - 154ms/step\n",
      "Epoch 224/2000\n",
      "1/1 - 0s - loss: 2.1575 - val_loss: 2.1731 - 231ms/epoch - 231ms/step\n",
      "Epoch 225/2000\n",
      "1/1 - 0s - loss: 2.1560 - val_loss: 2.1723 - 142ms/epoch - 142ms/step\n",
      "Epoch 226/2000\n",
      "1/1 - 0s - loss: 2.1549 - val_loss: 2.1712 - 138ms/epoch - 138ms/step\n",
      "Epoch 227/2000\n",
      "1/1 - 0s - loss: 2.1541 - val_loss: 2.1706 - 174ms/epoch - 174ms/step\n",
      "Epoch 228/2000\n",
      "1/1 - 0s - loss: 2.1533 - val_loss: 2.1684 - 139ms/epoch - 139ms/step\n",
      "Epoch 229/2000\n",
      "1/1 - 0s - loss: 2.1515 - val_loss: 2.1670 - 137ms/epoch - 137ms/step\n",
      "Epoch 230/2000\n",
      "1/1 - 0s - loss: 2.1498 - val_loss: 2.1659 - 136ms/epoch - 136ms/step\n",
      "Epoch 231/2000\n",
      "1/1 - 0s - loss: 2.1485 - val_loss: 2.1649 - 137ms/epoch - 137ms/step\n",
      "Epoch 232/2000\n",
      "1/1 - 0s - loss: 2.1476 - val_loss: 2.1640 - 138ms/epoch - 138ms/step\n",
      "Epoch 233/2000\n",
      "1/1 - 0s - loss: 2.1465 - val_loss: 2.1623 - 131ms/epoch - 131ms/step\n",
      "Epoch 234/2000\n",
      "1/1 - 0s - loss: 2.1450 - val_loss: 2.1611 - 127ms/epoch - 127ms/step\n",
      "Epoch 235/2000\n",
      "1/1 - 0s - loss: 2.1436 - val_loss: 2.1595 - 133ms/epoch - 133ms/step\n",
      "Epoch 236/2000\n",
      "1/1 - 0s - loss: 2.1424 - val_loss: 2.1582 - 125ms/epoch - 125ms/step\n",
      "Epoch 237/2000\n",
      "1/1 - 0s - loss: 2.1411 - val_loss: 2.1573 - 119ms/epoch - 119ms/step\n",
      "Epoch 238/2000\n",
      "1/1 - 0s - loss: 2.1400 - val_loss: 2.1560 - 123ms/epoch - 123ms/step\n",
      "Epoch 239/2000\n",
      "1/1 - 0s - loss: 2.1389 - val_loss: 2.1551 - 121ms/epoch - 121ms/step\n",
      "Epoch 240/2000\n",
      "1/1 - 0s - loss: 2.1379 - val_loss: 2.1541 - 118ms/epoch - 118ms/step\n",
      "Epoch 241/2000\n",
      "1/1 - 0s - loss: 2.1369 - val_loss: 2.1533 - 115ms/epoch - 115ms/step\n",
      "Epoch 242/2000\n",
      "1/1 - 0s - loss: 2.1359 - val_loss: 2.1522 - 111ms/epoch - 111ms/step\n",
      "Epoch 243/2000\n",
      "1/1 - 0s - loss: 2.1351 - val_loss: 2.1518 - 111ms/epoch - 111ms/step\n",
      "Epoch 244/2000\n",
      "1/1 - 0s - loss: 2.1345 - val_loss: 2.1502 - 110ms/epoch - 110ms/step\n",
      "Epoch 245/2000\n",
      "1/1 - 0s - loss: 2.1332 - val_loss: 2.1487 - 107ms/epoch - 107ms/step\n",
      "Epoch 246/2000\n",
      "1/1 - 0s - loss: 2.1316 - val_loss: 2.1471 - 107ms/epoch - 107ms/step\n",
      "Epoch 247/2000\n",
      "1/1 - 0s - loss: 2.1301 - val_loss: 2.1459 - 102ms/epoch - 102ms/step\n",
      "Epoch 248/2000\n",
      "1/1 - 0s - loss: 2.1289 - val_loss: 2.1452 - 108ms/epoch - 108ms/step\n",
      "Epoch 249/2000\n",
      "1/1 - 0s - loss: 2.1282 - val_loss: 2.1446 - 106ms/epoch - 106ms/step\n",
      "Epoch 250/2000\n",
      "1/1 - 0s - loss: 2.1277 - val_loss: 2.1437 - 119ms/epoch - 119ms/step\n",
      "Epoch 251/2000\n",
      "1/1 - 0s - loss: 2.1266 - val_loss: 2.1423 - 152ms/epoch - 152ms/step\n",
      "Epoch 252/2000\n",
      "1/1 - 0s - loss: 2.1254 - val_loss: 2.1410 - 157ms/epoch - 157ms/step\n",
      "Epoch 253/2000\n",
      "1/1 - 0s - loss: 2.1239 - val_loss: 2.1399 - 155ms/epoch - 155ms/step\n",
      "Epoch 254/2000\n",
      "1/1 - 0s - loss: 2.1228 - val_loss: 2.1388 - 113ms/epoch - 113ms/step\n",
      "Epoch 255/2000\n",
      "1/1 - 0s - loss: 2.1218 - val_loss: 2.1382 - 115ms/epoch - 115ms/step\n",
      "Epoch 256/2000\n",
      "1/1 - 0s - loss: 2.1211 - val_loss: 2.1374 - 126ms/epoch - 126ms/step\n",
      "Epoch 257/2000\n",
      "1/1 - 0s - loss: 2.1206 - val_loss: 2.1382 - 117ms/epoch - 117ms/step\n",
      "Epoch 258/2000\n",
      "1/1 - 0s - loss: 2.1209 - val_loss: 2.1375 - 114ms/epoch - 114ms/step\n",
      "Epoch 259/2000\n",
      "1/1 - 0s - loss: 2.1211 - val_loss: 2.1382 - 113ms/epoch - 113ms/step\n",
      "Epoch 260/2000\n",
      "1/1 - 0s - loss: 2.1203 - val_loss: 2.1345 - 121ms/epoch - 121ms/step\n",
      "Epoch 261/2000\n",
      "1/1 - 0s - loss: 2.1177 - val_loss: 2.1328 - 126ms/epoch - 126ms/step\n",
      "Epoch 262/2000\n",
      "1/1 - 0s - loss: 2.1158 - val_loss: 2.1334 - 121ms/epoch - 121ms/step\n",
      "Epoch 263/2000\n",
      "1/1 - 0s - loss: 2.1157 - val_loss: 2.1324 - 131ms/epoch - 131ms/step\n",
      "Epoch 264/2000\n",
      "1/1 - 0s - loss: 2.1158 - val_loss: 2.1313 - 145ms/epoch - 145ms/step\n",
      "Epoch 265/2000\n",
      "1/1 - 0s - loss: 2.1143 - val_loss: 2.1283 - 142ms/epoch - 142ms/step\n",
      "Epoch 266/2000\n",
      "1/1 - 0s - loss: 2.1122 - val_loss: 2.1275 - 143ms/epoch - 143ms/step\n",
      "Epoch 267/2000\n",
      "1/1 - 0s - loss: 2.1113 - val_loss: 2.1281 - 133ms/epoch - 133ms/step\n",
      "Epoch 268/2000\n",
      "1/1 - 0s - loss: 2.1112 - val_loss: 2.1272 - 139ms/epoch - 139ms/step\n",
      "Epoch 269/2000\n",
      "1/1 - 0s - loss: 2.1108 - val_loss: 2.1263 - 170ms/epoch - 170ms/step\n",
      "Epoch 270/2000\n",
      "1/1 - 0s - loss: 2.1093 - val_loss: 2.1241 - 166ms/epoch - 166ms/step\n",
      "Epoch 271/2000\n",
      "1/1 - 0s - loss: 2.1075 - val_loss: 2.1236 - 164ms/epoch - 164ms/step\n",
      "Epoch 272/2000\n",
      "1/1 - 0s - loss: 2.1069 - val_loss: 2.1240 - 157ms/epoch - 157ms/step\n",
      "Epoch 273/2000\n",
      "1/1 - 0s - loss: 2.1065 - val_loss: 2.1227 - 167ms/epoch - 167ms/step\n",
      "Epoch 274/2000\n",
      "1/1 - 0s - loss: 2.1057 - val_loss: 2.1215 - 171ms/epoch - 171ms/step\n",
      "Epoch 275/2000\n",
      "1/1 - 0s - loss: 2.1043 - val_loss: 2.1202 - 198ms/epoch - 198ms/step\n",
      "Epoch 276/2000\n",
      "1/1 - 0s - loss: 2.1032 - val_loss: 2.1197 - 258ms/epoch - 258ms/step\n",
      "Epoch 277/2000\n",
      "1/1 - 0s - loss: 2.1028 - val_loss: 2.1196 - 169ms/epoch - 169ms/step\n",
      "Epoch 278/2000\n",
      "1/1 - 0s - loss: 2.1024 - val_loss: 2.1181 - 166ms/epoch - 166ms/step\n",
      "Epoch 279/2000\n",
      "1/1 - 0s - loss: 2.1014 - val_loss: 2.1173 - 167ms/epoch - 167ms/step\n",
      "Epoch 280/2000\n",
      "1/1 - 0s - loss: 2.1002 - val_loss: 2.1162 - 178ms/epoch - 178ms/step\n",
      "Epoch 281/2000\n",
      "1/1 - 0s - loss: 2.0992 - val_loss: 2.1155 - 230ms/epoch - 230ms/step\n",
      "Epoch 282/2000\n",
      "1/1 - 0s - loss: 2.0987 - val_loss: 2.1153 - 201ms/epoch - 201ms/step\n",
      "Epoch 283/2000\n",
      "1/1 - 0s - loss: 2.0980 - val_loss: 2.1139 - 168ms/epoch - 168ms/step\n",
      "Epoch 284/2000\n",
      "1/1 - 0s - loss: 2.0971 - val_loss: 2.1130 - 244ms/epoch - 244ms/step\n",
      "Epoch 285/2000\n",
      "1/1 - 0s - loss: 2.0961 - val_loss: 2.1120 - 192ms/epoch - 192ms/step\n",
      "Epoch 286/2000\n",
      "1/1 - 0s - loss: 2.0953 - val_loss: 2.1115 - 154ms/epoch - 154ms/step\n",
      "Epoch 287/2000\n",
      "1/1 - 0s - loss: 2.0946 - val_loss: 2.1110 - 157ms/epoch - 157ms/step\n",
      "Epoch 288/2000\n",
      "1/1 - 0s - loss: 2.0939 - val_loss: 2.1099 - 158ms/epoch - 158ms/step\n",
      "Epoch 289/2000\n",
      "1/1 - 0s - loss: 2.0931 - val_loss: 2.1092 - 161ms/epoch - 161ms/step\n",
      "Epoch 290/2000\n",
      "1/1 - 0s - loss: 2.0923 - val_loss: 2.1083 - 188ms/epoch - 188ms/step\n",
      "Epoch 291/2000\n",
      "1/1 - 0s - loss: 2.0916 - val_loss: 2.1076 - 161ms/epoch - 161ms/step\n",
      "Epoch 292/2000\n",
      "1/1 - 0s - loss: 2.0909 - val_loss: 2.1073 - 157ms/epoch - 157ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 293/2000\n",
      "1/1 - 0s - loss: 2.0902 - val_loss: 2.1063 - 154ms/epoch - 154ms/step\n",
      "Epoch 294/2000\n",
      "1/1 - 0s - loss: 2.0895 - val_loss: 2.1057 - 157ms/epoch - 157ms/step\n",
      "Epoch 295/2000\n",
      "1/1 - 0s - loss: 2.0887 - val_loss: 2.1047 - 154ms/epoch - 154ms/step\n",
      "Epoch 296/2000\n",
      "1/1 - 0s - loss: 2.0880 - val_loss: 2.1041 - 157ms/epoch - 157ms/step\n",
      "Epoch 297/2000\n",
      "1/1 - 0s - loss: 2.0873 - val_loss: 2.1034 - 154ms/epoch - 154ms/step\n",
      "Epoch 298/2000\n",
      "1/1 - 0s - loss: 2.0866 - val_loss: 2.1027 - 159ms/epoch - 159ms/step\n",
      "Epoch 299/2000\n",
      "1/1 - 0s - loss: 2.0859 - val_loss: 2.1021 - 160ms/epoch - 160ms/step\n",
      "Epoch 300/2000\n",
      "1/1 - 0s - loss: 2.0852 - val_loss: 2.1014 - 165ms/epoch - 165ms/step\n",
      "Epoch 301/2000\n",
      "1/1 - 0s - loss: 2.0846 - val_loss: 2.1008 - 166ms/epoch - 166ms/step\n",
      "Epoch 302/2000\n",
      "1/1 - 0s - loss: 2.0839 - val_loss: 2.1001 - 167ms/epoch - 167ms/step\n",
      "Epoch 303/2000\n",
      "1/1 - 0s - loss: 2.0832 - val_loss: 2.0993 - 175ms/epoch - 175ms/step\n",
      "Epoch 304/2000\n",
      "1/1 - 0s - loss: 2.0826 - val_loss: 2.0988 - 166ms/epoch - 166ms/step\n",
      "Epoch 305/2000\n",
      "1/1 - 0s - loss: 2.0819 - val_loss: 2.0981 - 165ms/epoch - 165ms/step\n",
      "Epoch 306/2000\n",
      "1/1 - 0s - loss: 2.0813 - val_loss: 2.0977 - 164ms/epoch - 164ms/step\n",
      "Epoch 307/2000\n",
      "1/1 - 0s - loss: 2.0807 - val_loss: 2.0968 - 165ms/epoch - 165ms/step\n",
      "Epoch 308/2000\n",
      "1/1 - 0s - loss: 2.0801 - val_loss: 2.0966 - 163ms/epoch - 163ms/step\n",
      "Epoch 309/2000\n",
      "1/1 - 0s - loss: 2.0796 - val_loss: 2.0955 - 162ms/epoch - 162ms/step\n",
      "Epoch 310/2000\n",
      "1/1 - 0s - loss: 2.0790 - val_loss: 2.0957 - 153ms/epoch - 153ms/step\n",
      "Epoch 311/2000\n",
      "1/1 - 0s - loss: 2.0785 - val_loss: 2.0945 - 162ms/epoch - 162ms/step\n",
      "Epoch 312/2000\n",
      "1/1 - 0s - loss: 2.0779 - val_loss: 2.0946 - 157ms/epoch - 157ms/step\n",
      "Epoch 313/2000\n",
      "1/1 - 0s - loss: 2.0774 - val_loss: 2.0932 - 165ms/epoch - 165ms/step\n",
      "Epoch 314/2000\n",
      "1/1 - 0s - loss: 2.0771 - val_loss: 2.0943 - 155ms/epoch - 155ms/step\n",
      "Epoch 315/2000\n",
      "1/1 - 0s - loss: 2.0769 - val_loss: 2.0932 - 168ms/epoch - 168ms/step\n",
      "Epoch 316/2000\n",
      "1/1 - 0s - loss: 2.0769 - val_loss: 2.0939 - 155ms/epoch - 155ms/step\n",
      "Epoch 317/2000\n",
      "1/1 - 0s - loss: 2.0764 - val_loss: 2.0918 - 163ms/epoch - 163ms/step\n",
      "Epoch 318/2000\n",
      "1/1 - 0s - loss: 2.0756 - val_loss: 2.0913 - 157ms/epoch - 157ms/step\n",
      "Epoch 319/2000\n",
      "1/1 - 0s - loss: 2.0740 - val_loss: 2.0899 - 154ms/epoch - 154ms/step\n",
      "Epoch 320/2000\n",
      "1/1 - 0s - loss: 2.0729 - val_loss: 2.0897 - 156ms/epoch - 156ms/step\n",
      "Epoch 321/2000\n",
      "1/1 - 0s - loss: 2.0727 - val_loss: 2.0904 - 155ms/epoch - 155ms/step\n",
      "Epoch 322/2000\n",
      "1/1 - 0s - loss: 2.0731 - val_loss: 2.0904 - 176ms/epoch - 176ms/step\n",
      "Epoch 323/2000\n",
      "1/1 - 0s - loss: 2.0746 - val_loss: 2.0941 - 142ms/epoch - 142ms/step\n",
      "Epoch 324/2000\n",
      "1/1 - 0s - loss: 2.0756 - val_loss: 2.0902 - 139ms/epoch - 139ms/step\n",
      "Epoch 325/2000\n",
      "1/1 - 0s - loss: 2.0737 - val_loss: 2.0893 - 138ms/epoch - 138ms/step\n",
      "Epoch 326/2000\n",
      "1/1 - 0s - loss: 2.0712 - val_loss: 2.0875 - 137ms/epoch - 137ms/step\n",
      "Epoch 327/2000\n",
      "1/1 - 0s - loss: 2.0702 - val_loss: 2.0881 - 130ms/epoch - 130ms/step\n",
      "Epoch 328/2000\n",
      "1/1 - 0s - loss: 2.0722 - val_loss: 2.0929 - 127ms/epoch - 127ms/step\n",
      "Epoch 329/2000\n",
      "1/1 - 0s - loss: 2.0743 - val_loss: 2.0871 - 137ms/epoch - 137ms/step\n",
      "Epoch 330/2000\n",
      "1/1 - 0s - loss: 2.0707 - val_loss: 2.0854 - 136ms/epoch - 136ms/step\n",
      "Epoch 331/2000\n",
      "1/1 - 0s - loss: 2.0677 - val_loss: 2.0873 - 130ms/epoch - 130ms/step\n",
      "Epoch 332/2000\n",
      "1/1 - 0s - loss: 2.0687 - val_loss: 2.0845 - 132ms/epoch - 132ms/step\n",
      "Epoch 333/2000\n",
      "1/1 - 0s - loss: 2.0677 - val_loss: 2.0836 - 133ms/epoch - 133ms/step\n",
      "Epoch 334/2000\n",
      "1/1 - 0s - loss: 2.0661 - val_loss: 2.0837 - 128ms/epoch - 128ms/step\n",
      "Epoch 335/2000\n",
      "1/1 - 0s - loss: 2.0657 - val_loss: 2.0820 - 133ms/epoch - 133ms/step\n",
      "Epoch 336/2000\n",
      "1/1 - 0s - loss: 2.0651 - val_loss: 2.0819 - 135ms/epoch - 135ms/step\n",
      "Epoch 337/2000\n",
      "1/1 - 0s - loss: 2.0642 - val_loss: 2.0819 - 135ms/epoch - 135ms/step\n",
      "Epoch 338/2000\n",
      "1/1 - 0s - loss: 2.0639 - val_loss: 2.0809 - 133ms/epoch - 133ms/step\n",
      "Epoch 339/2000\n",
      "1/1 - 0s - loss: 2.0633 - val_loss: 2.0807 - 129ms/epoch - 129ms/step\n",
      "Epoch 340/2000\n",
      "1/1 - 0s - loss: 2.0628 - val_loss: 2.0805 - 168ms/epoch - 168ms/step\n",
      "Epoch 341/2000\n",
      "1/1 - 0s - loss: 2.0624 - val_loss: 2.0796 - 136ms/epoch - 136ms/step\n",
      "Epoch 342/2000\n",
      "1/1 - 0s - loss: 2.0620 - val_loss: 2.0792 - 166ms/epoch - 166ms/step\n",
      "Epoch 343/2000\n",
      "1/1 - 0s - loss: 2.0616 - val_loss: 2.0789 - 162ms/epoch - 162ms/step\n",
      "Epoch 344/2000\n",
      "1/1 - 0s - loss: 2.0608 - val_loss: 2.0778 - 136ms/epoch - 136ms/step\n",
      "Epoch 345/2000\n",
      "1/1 - 0s - loss: 2.0602 - val_loss: 2.0776 - 140ms/epoch - 140ms/step\n",
      "Epoch 346/2000\n",
      "1/1 - 0s - loss: 2.0598 - val_loss: 2.0778 - 132ms/epoch - 132ms/step\n",
      "Epoch 347/2000\n",
      "1/1 - 0s - loss: 2.0597 - val_loss: 2.0768 - 142ms/epoch - 142ms/step\n",
      "Epoch 348/2000\n",
      "1/1 - 0s - loss: 2.0593 - val_loss: 2.0763 - 143ms/epoch - 143ms/step\n",
      "Epoch 349/2000\n",
      "1/1 - 0s - loss: 2.0586 - val_loss: 2.0760 - 138ms/epoch - 138ms/step\n",
      "Epoch 350/2000\n",
      "1/1 - 0s - loss: 2.0579 - val_loss: 2.0754 - 122ms/epoch - 122ms/step\n",
      "Epoch 351/2000\n",
      "1/1 - 0s - loss: 2.0576 - val_loss: 2.0755 - 141ms/epoch - 141ms/step\n",
      "Epoch 352/2000\n",
      "1/1 - 0s - loss: 2.0577 - val_loss: 2.0755 - 162ms/epoch - 162ms/step\n",
      "Epoch 353/2000\n",
      "1/1 - 0s - loss: 2.0577 - val_loss: 2.0744 - 128ms/epoch - 128ms/step\n",
      "Epoch 354/2000\n",
      "1/1 - 0s - loss: 2.0569 - val_loss: 2.0738 - 119ms/epoch - 119ms/step\n",
      "Epoch 355/2000\n",
      "1/1 - 0s - loss: 2.0557 - val_loss: 2.0740 - 115ms/epoch - 115ms/step\n",
      "Epoch 356/2000\n",
      "1/1 - 0s - loss: 2.0557 - val_loss: 2.0739 - 105ms/epoch - 105ms/step\n",
      "Epoch 357/2000\n",
      "1/1 - 0s - loss: 2.0559 - val_loss: 2.0732 - 110ms/epoch - 110ms/step\n",
      "Epoch 358/2000\n",
      "1/1 - 0s - loss: 2.0556 - val_loss: 2.0722 - 111ms/epoch - 111ms/step\n",
      "Epoch 359/2000\n",
      "1/1 - 0s - loss: 2.0545 - val_loss: 2.0718 - 113ms/epoch - 113ms/step\n",
      "Epoch 360/2000\n",
      "1/1 - 0s - loss: 2.0539 - val_loss: 2.0721 - 102ms/epoch - 102ms/step\n",
      "Epoch 361/2000\n",
      "1/1 - 0s - loss: 2.0539 - val_loss: 2.0726 - 104ms/epoch - 104ms/step\n",
      "Epoch 362/2000\n",
      "1/1 - 0s - loss: 2.0542 - val_loss: 2.0713 - 114ms/epoch - 114ms/step\n",
      "Epoch 363/2000\n",
      "1/1 - 0s - loss: 2.0536 - val_loss: 2.0701 - 107ms/epoch - 107ms/step\n",
      "Epoch 364/2000\n",
      "1/1 - 0s - loss: 2.0523 - val_loss: 2.0701 - 109ms/epoch - 109ms/step\n",
      "Epoch 365/2000\n",
      "1/1 - 0s - loss: 2.0520 - val_loss: 2.0700 - 106ms/epoch - 106ms/step\n",
      "Epoch 366/2000\n",
      "1/1 - 0s - loss: 2.0519 - val_loss: 2.0698 - 145ms/epoch - 145ms/step\n",
      "Epoch 367/2000\n",
      "1/1 - 0s - loss: 2.0514 - val_loss: 2.0687 - 146ms/epoch - 146ms/step\n",
      "Epoch 368/2000\n",
      "1/1 - 0s - loss: 2.0508 - val_loss: 2.0686 - 109ms/epoch - 109ms/step\n",
      "Epoch 369/2000\n",
      "1/1 - 0s - loss: 2.0502 - val_loss: 2.0679 - 108ms/epoch - 108ms/step\n",
      "Epoch 370/2000\n",
      "1/1 - 0s - loss: 2.0498 - val_loss: 2.0676 - 107ms/epoch - 107ms/step\n",
      "Epoch 371/2000\n",
      "1/1 - 0s - loss: 2.0494 - val_loss: 2.0697 - 99ms/epoch - 99ms/step\n",
      "Epoch 372/2000\n",
      "1/1 - 0s - loss: 2.0502 - val_loss: 2.0680 - 101ms/epoch - 101ms/step\n",
      "Epoch 373/2000\n",
      "1/1 - 0s - loss: 2.0506 - val_loss: 2.0687 - 102ms/epoch - 102ms/step\n",
      "Epoch 374/2000\n",
      "1/1 - 0s - loss: 2.0497 - val_loss: 2.0666 - 116ms/epoch - 116ms/step\n",
      "Epoch 375/2000\n",
      "1/1 - 0s - loss: 2.0483 - val_loss: 2.0663 - 115ms/epoch - 115ms/step\n",
      "Epoch 376/2000\n",
      "1/1 - 0s - loss: 2.0490 - val_loss: 2.0691 - 116ms/epoch - 116ms/step\n",
      "Epoch 377/2000\n",
      "1/1 - 0s - loss: 2.0495 - val_loss: 2.0648 - 124ms/epoch - 124ms/step\n",
      "Epoch 378/2000\n",
      "1/1 - 0s - loss: 2.0472 - val_loss: 2.0654 - 118ms/epoch - 118ms/step\n",
      "Epoch 379/2000\n",
      "1/1 - 0s - loss: 2.0479 - val_loss: 2.0679 - 135ms/epoch - 135ms/step\n",
      "Epoch 380/2000\n",
      "1/1 - 0s - loss: 2.0483 - val_loss: 2.0646 - 153ms/epoch - 153ms/step\n",
      "Epoch 381/2000\n",
      "1/1 - 0s - loss: 2.0465 - val_loss: 2.0650 - 157ms/epoch - 157ms/step\n",
      "Epoch 382/2000\n",
      "1/1 - 0s - loss: 2.0469 - val_loss: 2.0661 - 172ms/epoch - 172ms/step\n",
      "Epoch 383/2000\n",
      "1/1 - 0s - loss: 2.0469 - val_loss: 2.0632 - 151ms/epoch - 151ms/step\n",
      "Epoch 384/2000\n",
      "1/1 - 0s - loss: 2.0459 - val_loss: 2.0628 - 152ms/epoch - 152ms/step\n",
      "Epoch 385/2000\n",
      "1/1 - 0s - loss: 2.0449 - val_loss: 2.0639 - 156ms/epoch - 156ms/step\n",
      "Epoch 386/2000\n",
      "1/1 - 0s - loss: 2.0453 - val_loss: 2.0630 - 162ms/epoch - 162ms/step\n",
      "Epoch 387/2000\n",
      "1/1 - 0s - loss: 2.0450 - val_loss: 2.0627 - 156ms/epoch - 156ms/step\n",
      "Epoch 388/2000\n",
      "1/1 - 0s - loss: 2.0439 - val_loss: 2.0629 - 181ms/epoch - 181ms/step\n",
      "Epoch 389/2000\n",
      "1/1 - 0s - loss: 2.0437 - val_loss: 2.0615 - 159ms/epoch - 159ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 390/2000\n",
      "1/1 - 0s - loss: 2.0435 - val_loss: 2.0611 - 160ms/epoch - 160ms/step\n",
      "Epoch 391/2000\n",
      "1/1 - 0s - loss: 2.0425 - val_loss: 2.0605 - 155ms/epoch - 155ms/step\n",
      "Epoch 392/2000\n",
      "1/1 - 0s - loss: 2.0423 - val_loss: 2.0601 - 156ms/epoch - 156ms/step\n",
      "Epoch 393/2000\n",
      "1/1 - 0s - loss: 2.0419 - val_loss: 2.0608 - 145ms/epoch - 145ms/step\n",
      "Epoch 394/2000\n",
      "1/1 - 0s - loss: 2.0415 - val_loss: 2.0606 - 144ms/epoch - 144ms/step\n",
      "Epoch 395/2000\n",
      "1/1 - 0s - loss: 2.0412 - val_loss: 2.0601 - 144ms/epoch - 144ms/step\n",
      "Epoch 396/2000\n",
      "1/1 - 0s - loss: 2.0409 - val_loss: 2.0599 - 157ms/epoch - 157ms/step\n",
      "Epoch 397/2000\n",
      "1/1 - 0s - loss: 2.0405 - val_loss: 2.0587 - 155ms/epoch - 155ms/step\n",
      "Epoch 398/2000\n",
      "1/1 - 0s - loss: 2.0402 - val_loss: 2.0586 - 157ms/epoch - 157ms/step\n",
      "Epoch 399/2000\n",
      "1/1 - 0s - loss: 2.0397 - val_loss: 2.0592 - 141ms/epoch - 141ms/step\n",
      "Epoch 400/2000\n",
      "1/1 - 0s - loss: 2.0398 - val_loss: 2.0586 - 145ms/epoch - 145ms/step\n",
      "Epoch 401/2000\n",
      "1/1 - 0s - loss: 2.0404 - val_loss: 2.0598 - 151ms/epoch - 151ms/step\n",
      "Epoch 402/2000\n",
      "1/1 - 0s - loss: 2.0399 - val_loss: 2.0576 - 150ms/epoch - 150ms/step\n",
      "Epoch 403/2000\n",
      "1/1 - 0s - loss: 2.0385 - val_loss: 2.0576 - 149ms/epoch - 149ms/step\n",
      "Epoch 404/2000\n",
      "1/1 - 0s - loss: 2.0392 - val_loss: 2.0611 - 156ms/epoch - 156ms/step\n",
      "Epoch 405/2000\n",
      "1/1 - 0s - loss: 2.0403 - val_loss: 2.0579 - 151ms/epoch - 151ms/step\n",
      "Epoch 406/2000\n",
      "1/1 - 0s - loss: 2.0398 - val_loss: 2.0578 - 153ms/epoch - 153ms/step\n",
      "Epoch 407/2000\n",
      "1/1 - 0s - loss: 2.0382 - val_loss: 2.0579 - 170ms/epoch - 170ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 10\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "sinusoidal_embeddings = get_sinusoidal_embeddings(len(x_masked_train[0]), embed_dim)\n",
    "encoder_output = word_embeddings + sinusoidal_embeddings\n",
    "\n",
    "for i in range(1):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "dfa1251c",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "a0d91607",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.0, 0.8643297)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "bb91f983",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 2s - loss: 3.5863 - val_loss: 3.3188 - 2s/epoch - 2s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.3004 - val_loss: 3.2026 - 208ms/epoch - 208ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.1773 - val_loss: 3.1451 - 199ms/epoch - 199ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.1153 - val_loss: 3.1187 - 202ms/epoch - 202ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.0840 - val_loss: 3.1055 - 221ms/epoch - 221ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.0647 - val_loss: 3.1001 - 227ms/epoch - 227ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.0529 - val_loss: 3.0925 - 227ms/epoch - 227ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.0395 - val_loss: 3.0780 - 227ms/epoch - 227ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.0203 - val_loss: 3.0621 - 211ms/epoch - 211ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 3.0009 - val_loss: 3.0505 - 231ms/epoch - 231ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 2.9868 - val_loss: 3.0449 - 223ms/epoch - 223ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 2.9793 - val_loss: 3.0441 - 211ms/epoch - 211ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 2.9768 - val_loss: 3.0452 - 229ms/epoch - 229ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 2.9766 - val_loss: 3.0462 - 203ms/epoch - 203ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 2.9764 - val_loss: 3.0462 - 204ms/epoch - 204ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 2.9750 - val_loss: 3.0451 - 209ms/epoch - 209ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 2.9727 - val_loss: 3.0438 - 213ms/epoch - 213ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 2.9699 - val_loss: 3.0426 - 212ms/epoch - 212ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 2.9671 - val_loss: 3.0416 - 217ms/epoch - 217ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 2.9643 - val_loss: 3.0407 - 228ms/epoch - 228ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 2.9617 - val_loss: 3.0400 - 214ms/epoch - 214ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 2.9594 - val_loss: 3.0395 - 213ms/epoch - 213ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 2.9574 - val_loss: 3.0395 - 213ms/epoch - 213ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 2.9561 - val_loss: 3.0397 - 209ms/epoch - 209ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 2.9553 - val_loss: 3.0401 - 215ms/epoch - 215ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 2.9549 - val_loss: 3.0402 - 209ms/epoch - 209ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 2.9546 - val_loss: 3.0400 - 205ms/epoch - 205ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.9541 - val_loss: 3.0393 - 206ms/epoch - 206ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.9532 - val_loss: 3.0381 - 233ms/epoch - 233ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.9521 - val_loss: 3.0368 - 206ms/epoch - 206ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.9509 - val_loss: 3.0357 - 201ms/epoch - 201ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.9498 - val_loss: 3.0348 - 202ms/epoch - 202ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.9489 - val_loss: 3.0341 - 201ms/epoch - 201ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.9482 - val_loss: 3.0337 - 199ms/epoch - 199ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.9475 - val_loss: 3.0334 - 215ms/epoch - 215ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.9468 - val_loss: 3.0332 - 194ms/epoch - 194ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.9461 - val_loss: 3.0331 - 205ms/epoch - 205ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.9453 - val_loss: 3.0329 - 191ms/epoch - 191ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.9445 - val_loss: 3.0326 - 193ms/epoch - 193ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.9436 - val_loss: 3.0321 - 191ms/epoch - 191ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.9426 - val_loss: 3.0316 - 172ms/epoch - 172ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.9416 - val_loss: 3.0310 - 184ms/epoch - 184ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.9407 - val_loss: 3.0303 - 196ms/epoch - 196ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.9397 - val_loss: 3.0296 - 171ms/epoch - 171ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.9387 - val_loss: 3.0286 - 182ms/epoch - 182ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.9375 - val_loss: 3.0276 - 182ms/epoch - 182ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.9364 - val_loss: 3.0265 - 185ms/epoch - 185ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.9351 - val_loss: 3.0252 - 205ms/epoch - 205ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.9338 - val_loss: 3.0238 - 202ms/epoch - 202ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.9323 - val_loss: 3.0222 - 202ms/epoch - 202ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.9307 - val_loss: 3.0206 - 179ms/epoch - 179ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.9291 - val_loss: 3.0189 - 313ms/epoch - 313ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.9274 - val_loss: 3.0172 - 189ms/epoch - 189ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.9256 - val_loss: 3.0154 - 207ms/epoch - 207ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.9237 - val_loss: 3.0135 - 222ms/epoch - 222ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.9216 - val_loss: 3.0115 - 211ms/epoch - 211ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.9195 - val_loss: 3.0094 - 180ms/epoch - 180ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.9171 - val_loss: 3.0072 - 188ms/epoch - 188ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.9147 - val_loss: 3.0050 - 166ms/epoch - 166ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.9121 - val_loss: 3.0028 - 168ms/epoch - 168ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.9093 - val_loss: 3.0006 - 161ms/epoch - 161ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.9064 - val_loss: 2.9981 - 169ms/epoch - 169ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.9033 - val_loss: 2.9954 - 147ms/epoch - 147ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.9000 - val_loss: 2.9924 - 143ms/epoch - 143ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.8963 - val_loss: 2.9890 - 136ms/epoch - 136ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.8924 - val_loss: 2.9854 - 129ms/epoch - 129ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.8882 - val_loss: 2.9814 - 163ms/epoch - 163ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.8837 - val_loss: 2.9771 - 119ms/epoch - 119ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.8788 - val_loss: 2.9722 - 120ms/epoch - 120ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.8734 - val_loss: 2.9668 - 117ms/epoch - 117ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.8675 - val_loss: 2.9610 - 116ms/epoch - 116ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.8612 - val_loss: 2.9546 - 121ms/epoch - 121ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.8543 - val_loss: 2.9475 - 114ms/epoch - 114ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.8467 - val_loss: 2.9396 - 112ms/epoch - 112ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.8383 - val_loss: 2.9310 - 108ms/epoch - 108ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.8290 - val_loss: 2.9216 - 108ms/epoch - 108ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.8187 - val_loss: 2.9112 - 109ms/epoch - 109ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.8074 - val_loss: 2.8993 - 93ms/epoch - 93ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.7946 - val_loss: 2.8860 - 93ms/epoch - 93ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.7803 - val_loss: 2.8708 - 101ms/epoch - 101ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.7641 - val_loss: 2.8537 - 96ms/epoch - 96ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.7457 - val_loss: 2.8345 - 105ms/epoch - 105ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.7250 - val_loss: 2.8128 - 107ms/epoch - 107ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.7015 - val_loss: 2.7877 - 101ms/epoch - 101ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.6747 - val_loss: 2.7588 - 105ms/epoch - 105ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.6440 - val_loss: 2.7257 - 105ms/epoch - 105ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.6088 - val_loss: 2.6882 - 107ms/epoch - 107ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.5686 - val_loss: 2.6463 - 115ms/epoch - 115ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.5233 - val_loss: 2.5992 - 136ms/epoch - 136ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.4732 - val_loss: 2.5496 - 121ms/epoch - 121ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.4211 - val_loss: 2.5018 - 140ms/epoch - 140ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.3716 - val_loss: 2.4564 - 158ms/epoch - 158ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.3266 - val_loss: 2.4144 - 161ms/epoch - 161ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.2873 - val_loss: 2.3781 - 174ms/epoch - 174ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.2541 - val_loss: 2.3483 - 173ms/epoch - 173ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.2271 - val_loss: 2.3236 - 175ms/epoch - 175ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 2.2036 - val_loss: 2.3015 - 179ms/epoch - 179ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 2.1824 - val_loss: 2.2833 - 182ms/epoch - 182ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 2.1640 - val_loss: 2.2683 - 168ms/epoch - 168ms/step\n",
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 2.1487 - val_loss: 2.2545 - 162ms/epoch - 162ms/step\n",
      "Epoch 101/2000\n",
      "1/1 - 0s - loss: 2.1344 - val_loss: 2.2416 - 155ms/epoch - 155ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 2.1199 - val_loss: 2.2287 - 157ms/epoch - 157ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 2.1056 - val_loss: 2.2146 - 166ms/epoch - 166ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 2.0917 - val_loss: 2.2002 - 165ms/epoch - 165ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 2.0774 - val_loss: 2.1863 - 167ms/epoch - 167ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 2.0633 - val_loss: 2.1727 - 160ms/epoch - 160ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 2.0500 - val_loss: 2.1599 - 173ms/epoch - 173ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 2.0377 - val_loss: 2.1492 - 176ms/epoch - 176ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 2.0271 - val_loss: 2.1402 - 173ms/epoch - 173ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 2.0185 - val_loss: 2.1323 - 168ms/epoch - 168ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 2.0114 - val_loss: 2.1257 - 180ms/epoch - 180ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 2.0054 - val_loss: 2.1206 - 177ms/epoch - 177ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 2.0004 - val_loss: 2.1160 - 168ms/epoch - 168ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 1.9959 - val_loss: 2.1118 - 160ms/epoch - 160ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 1.9917 - val_loss: 2.1078 - 156ms/epoch - 156ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 1.9879 - val_loss: 2.1039 - 174ms/epoch - 174ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 1.9843 - val_loss: 2.0999 - 163ms/epoch - 163ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 1.9810 - val_loss: 2.0958 - 183ms/epoch - 183ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 1.9777 - val_loss: 2.0924 - 177ms/epoch - 177ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 1.9748 - val_loss: 2.0898 - 188ms/epoch - 188ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 1.9725 - val_loss: 2.0874 - 181ms/epoch - 181ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 1.9704 - val_loss: 2.0851 - 178ms/epoch - 178ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 1.9685 - val_loss: 2.0829 - 184ms/epoch - 184ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 1.9667 - val_loss: 2.0810 - 182ms/epoch - 182ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 1.9651 - val_loss: 2.0793 - 182ms/epoch - 182ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 1.9638 - val_loss: 2.0780 - 158ms/epoch - 158ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 1.9626 - val_loss: 2.0770 - 151ms/epoch - 151ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 1.9616 - val_loss: 2.0762 - 154ms/epoch - 154ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 1.9606 - val_loss: 2.0754 - 149ms/epoch - 149ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 1.9597 - val_loss: 2.0746 - 154ms/epoch - 154ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 1.9588 - val_loss: 2.0737 - 159ms/epoch - 159ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 1.9579 - val_loss: 2.0730 - 164ms/epoch - 164ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 1.9571 - val_loss: 2.0724 - 157ms/epoch - 157ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 1.9562 - val_loss: 2.0720 - 151ms/epoch - 151ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 1.9555 - val_loss: 2.0716 - 153ms/epoch - 153ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 1.9547 - val_loss: 2.0712 - 155ms/epoch - 155ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 1.9540 - val_loss: 2.0710 - 168ms/epoch - 168ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 1.9533 - val_loss: 2.0708 - 159ms/epoch - 159ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 1.9527 - val_loss: 2.0706 - 150ms/epoch - 150ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 1.9521 - val_loss: 2.0703 - 146ms/epoch - 146ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 1.9516 - val_loss: 2.0700 - 144ms/epoch - 144ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 1.9511 - val_loss: 2.0698 - 148ms/epoch - 148ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 1.9506 - val_loss: 2.0696 - 146ms/epoch - 146ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 1.9501 - val_loss: 2.0694 - 149ms/epoch - 149ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 1.9496 - val_loss: 2.0693 - 145ms/epoch - 145ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 1.9491 - val_loss: 2.0691 - 142ms/epoch - 142ms/step\n",
      "Epoch 147/2000\n",
      "1/1 - 0s - loss: 1.9487 - val_loss: 2.0689 - 139ms/epoch - 139ms/step\n",
      "Epoch 148/2000\n",
      "1/1 - 0s - loss: 1.9482 - val_loss: 2.0687 - 143ms/epoch - 143ms/step\n",
      "Epoch 149/2000\n",
      "1/1 - 0s - loss: 1.9478 - val_loss: 2.0684 - 141ms/epoch - 141ms/step\n",
      "Epoch 150/2000\n",
      "1/1 - 0s - loss: 1.9473 - val_loss: 2.0681 - 142ms/epoch - 142ms/step\n",
      "Epoch 151/2000\n",
      "1/1 - 0s - loss: 1.9469 - val_loss: 2.0678 - 140ms/epoch - 140ms/step\n",
      "Epoch 152/2000\n",
      "1/1 - 0s - loss: 1.9465 - val_loss: 2.0677 - 146ms/epoch - 146ms/step\n",
      "Epoch 153/2000\n",
      "1/1 - 0s - loss: 1.9461 - val_loss: 2.0677 - 181ms/epoch - 181ms/step\n",
      "Epoch 154/2000\n",
      "1/1 - 0s - loss: 1.9457 - val_loss: 2.0677 - 172ms/epoch - 172ms/step\n",
      "Epoch 155/2000\n",
      "1/1 - 0s - loss: 1.9453 - val_loss: 2.0677 - 167ms/epoch - 167ms/step\n",
      "Epoch 156/2000\n",
      "1/1 - 0s - loss: 1.9450 - val_loss: 2.0677 - 150ms/epoch - 150ms/step\n",
      "Epoch 157/2000\n",
      "1/1 - 0s - loss: 1.9446 - val_loss: 2.0677 - 135ms/epoch - 135ms/step\n",
      "Epoch 158/2000\n",
      "1/1 - 0s - loss: 1.9442 - val_loss: 2.0677 - 150ms/epoch - 150ms/step\n",
      "Epoch 159/2000\n",
      "1/1 - 0s - loss: 1.9439 - val_loss: 2.0677 - 136ms/epoch - 136ms/step\n",
      "Epoch 160/2000\n",
      "1/1 - 0s - loss: 1.9435 - val_loss: 2.0678 - 131ms/epoch - 131ms/step\n",
      "Epoch 161/2000\n",
      "1/1 - 0s - loss: 1.9431 - val_loss: 2.0678 - 152ms/epoch - 152ms/step\n",
      "Epoch 162/2000\n",
      "1/1 - 0s - loss: 1.9428 - val_loss: 2.0678 - 183ms/epoch - 183ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 100\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "sinusoidal_embeddings = get_sinusoidal_embeddings(len(x_masked_train[0]), embed_dim)\n",
    "encoder_output = word_embeddings + sinusoidal_embeddings\n",
    "\n",
    "for i in range(1):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "81e26fe5",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "267160c5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.0, 0.9771991)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f4cab17",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
