{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d1686d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "56c107e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "504a342e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_module(query, key, value, embed_dim, num_head, i):\n",
    "    \n",
    "    # Multi headed self-attention\n",
    "    attention_output = layers.MultiHeadAttention(\n",
    "        num_heads=num_head,\n",
    "        key_dim=embed_dim // num_head,\n",
    "        name=\"encoder_{}/multiheadattention\".format(i)\n",
    "    )(query, key, value, use_causal_mask=True)\n",
    "    \n",
    "    # Add & Normalize\n",
    "    attention_output = layers.Add()([query, attention_output])  # Skip Connection\n",
    "    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)\n",
    "    \n",
    "    # Feedforward network\n",
    "    ff_net = keras.models.Sequential([\n",
    "        layers.Dense(2 * embed_dim, activation='relu', name=\"encoder_{}/ffn_dense_1\".format(i)),\n",
    "        layers.Dense(embed_dim, name=\"encoder_{}/ffn_dense_2\".format(i)),\n",
    "    ])\n",
    "\n",
    "    # Apply Feedforward network\n",
    "    ffn_output = ff_net(attention_output)\n",
    "\n",
    "    # Add & Normalize\n",
    "    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection\n",
    "    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)\n",
    "    \n",
    "    return ffn_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fef2622a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sinusoidal_embeddings(sequence_length, embedding_dim):\n",
    "    position_enc = np.array([\n",
    "        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]\n",
    "        if pos != 0 else np.zeros(embedding_dim)\n",
    "        for pos in range(sequence_length)\n",
    "    ])\n",
    "    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i\n",
    "    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1\n",
    "    return tf.cast(position_enc, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3e5bfd69",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 20 # vocab_size\n",
    "\n",
    "vocabs = ['word_' + str(i) for i in range(N)]\n",
    "\n",
    "vocab_map = {}\n",
    "for i in range(len(vocabs)):\n",
    "    vocab_map[vocabs[i]] = i\n",
    "    \n",
    "pairs = []\n",
    "\n",
    "for i in vocabs:\n",
    "    for j in vocabs:\n",
    "        for k in vocabs:\n",
    "            if i != j and i != k and j != k:\n",
    "                pairs.append((i,j,k))\n",
    "            \n",
    "#indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])\n",
    "\n",
    "# pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]\n",
    "# pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]\n",
    "\n",
    "pairs_train = [x for x in pairs if int(x[0].split('_')[-1]) <= 9]\n",
    "pairs_test = [x for x in pairs if int(x[0].split('_')[-1]) >= 10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d43093fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences_train = []\n",
    "sentences_number_train = []\n",
    "sentences_test = []\n",
    "sentences_number_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    sentences_train.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    sentences_test.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = []\n",
    "y_masked_labels_train = []\n",
    "x_masked_test = []\n",
    "y_masked_labels_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    x_masked_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_train.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    x_masked_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_test.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = np.array(x_masked_train)\n",
    "y_masked_labels_train = np.array(y_masked_labels_train)\n",
    "x_masked_test = np.array(x_masked_test)\n",
    "y_masked_labels_test = np.array(y_masked_labels_test)\n",
    "\n",
    "perm = np.random.permutation(len(x_masked_train))\n",
    "x_masked_train = x_masked_train[perm]\n",
    "y_masked_labels_train = y_masked_labels_train[perm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "13b40f89",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 1s - loss: 3.4102 - val_loss: 3.3886 - 1s/epoch - 1s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.3858 - val_loss: 3.3677 - 66ms/epoch - 66ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.3657 - val_loss: 3.3479 - 74ms/epoch - 74ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.3465 - val_loss: 3.3291 - 69ms/epoch - 69ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.3280 - val_loss: 3.3110 - 78ms/epoch - 78ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.3100 - val_loss: 3.2939 - 73ms/epoch - 73ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.2928 - val_loss: 3.2775 - 67ms/epoch - 67ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.2764 - val_loss: 3.2618 - 69ms/epoch - 69ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.2606 - val_loss: 3.2467 - 77ms/epoch - 77ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 3.2453 - val_loss: 3.2323 - 75ms/epoch - 75ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 3.2308 - val_loss: 3.2187 - 80ms/epoch - 80ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 3.2169 - val_loss: 3.2058 - 77ms/epoch - 77ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 3.2037 - val_loss: 3.1935 - 80ms/epoch - 80ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 3.1911 - val_loss: 3.1817 - 77ms/epoch - 77ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 3.1790 - val_loss: 3.1705 - 79ms/epoch - 79ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 3.1675 - val_loss: 3.1598 - 74ms/epoch - 74ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 3.1566 - val_loss: 3.1497 - 76ms/epoch - 76ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 3.1461 - val_loss: 3.1400 - 76ms/epoch - 76ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 3.1361 - val_loss: 3.1307 - 85ms/epoch - 85ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 3.1265 - val_loss: 3.1218 - 73ms/epoch - 73ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 3.1173 - val_loss: 3.1132 - 74ms/epoch - 74ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 3.1084 - val_loss: 3.1050 - 69ms/epoch - 69ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 3.0999 - val_loss: 3.0970 - 80ms/epoch - 80ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 3.0917 - val_loss: 3.0893 - 80ms/epoch - 80ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 3.0838 - val_loss: 3.0819 - 75ms/epoch - 75ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 3.0762 - val_loss: 3.0749 - 71ms/epoch - 71ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 3.0689 - val_loss: 3.0681 - 66ms/epoch - 66ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 3.0620 - val_loss: 3.0616 - 74ms/epoch - 74ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 3.0554 - val_loss: 3.0553 - 87ms/epoch - 87ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 3.0491 - val_loss: 3.0493 - 87ms/epoch - 87ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 3.0430 - val_loss: 3.0436 - 80ms/epoch - 80ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 3.0372 - val_loss: 3.0382 - 84ms/epoch - 84ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 3.0316 - val_loss: 3.0330 - 78ms/epoch - 78ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 3.0262 - val_loss: 3.0280 - 85ms/epoch - 85ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 3.0209 - val_loss: 3.0233 - 88ms/epoch - 88ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 3.0160 - val_loss: 3.0188 - 86ms/epoch - 86ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 3.0113 - val_loss: 3.0145 - 77ms/epoch - 77ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 3.0067 - val_loss: 3.0102 - 75ms/epoch - 75ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 3.0023 - val_loss: 3.0061 - 80ms/epoch - 80ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.9980 - val_loss: 3.0021 - 69ms/epoch - 69ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.9939 - val_loss: 2.9983 - 67ms/epoch - 67ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.9900 - val_loss: 2.9947 - 71ms/epoch - 71ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.9864 - val_loss: 2.9912 - 62ms/epoch - 62ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.9830 - val_loss: 2.9879 - 65ms/epoch - 65ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.9796 - val_loss: 2.9847 - 63ms/epoch - 63ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.9765 - val_loss: 2.9817 - 63ms/epoch - 63ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.9735 - val_loss: 2.9789 - 62ms/epoch - 62ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.9707 - val_loss: 2.9763 - 66ms/epoch - 66ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.9680 - val_loss: 2.9738 - 65ms/epoch - 65ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.9655 - val_loss: 2.9714 - 64ms/epoch - 64ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.9630 - val_loss: 2.9692 - 91ms/epoch - 91ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.9607 - val_loss: 2.9672 - 72ms/epoch - 72ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.9585 - val_loss: 2.9653 - 77ms/epoch - 77ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.9564 - val_loss: 2.9635 - 92ms/epoch - 92ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.9544 - val_loss: 2.9618 - 73ms/epoch - 73ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.9525 - val_loss: 2.9602 - 77ms/epoch - 77ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.9507 - val_loss: 2.9586 - 60ms/epoch - 60ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.9490 - val_loss: 2.9572 - 62ms/epoch - 62ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.9473 - val_loss: 2.9558 - 59ms/epoch - 59ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.9458 - val_loss: 2.9545 - 62ms/epoch - 62ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.9443 - val_loss: 2.9532 - 61ms/epoch - 61ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.9428 - val_loss: 2.9520 - 62ms/epoch - 62ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.9415 - val_loss: 2.9508 - 56ms/epoch - 56ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.9401 - val_loss: 2.9498 - 52ms/epoch - 52ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.9389 - val_loss: 2.9488 - 60ms/epoch - 60ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.9377 - val_loss: 2.9478 - 61ms/epoch - 61ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.9366 - val_loss: 2.9470 - 61ms/epoch - 61ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.9355 - val_loss: 2.9462 - 65ms/epoch - 65ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.9345 - val_loss: 2.9454 - 66ms/epoch - 66ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.9335 - val_loss: 2.9447 - 76ms/epoch - 76ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.9325 - val_loss: 2.9440 - 83ms/epoch - 83ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.9316 - val_loss: 2.9434 - 122ms/epoch - 122ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.9308 - val_loss: 2.9427 - 79ms/epoch - 79ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.9299 - val_loss: 2.9421 - 55ms/epoch - 55ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.9291 - val_loss: 2.9415 - 55ms/epoch - 55ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.9284 - val_loss: 2.9410 - 50ms/epoch - 50ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.9276 - val_loss: 2.9404 - 53ms/epoch - 53ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.9269 - val_loss: 2.9399 - 64ms/epoch - 64ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.9262 - val_loss: 2.9394 - 71ms/epoch - 71ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.9256 - val_loss: 2.9389 - 70ms/epoch - 70ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.9249 - val_loss: 2.9384 - 57ms/epoch - 57ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.9243 - val_loss: 2.9380 - 70ms/epoch - 70ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.9237 - val_loss: 2.9376 - 56ms/epoch - 56ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.9231 - val_loss: 2.9372 - 69ms/epoch - 69ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.9226 - val_loss: 2.9368 - 71ms/epoch - 71ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.9220 - val_loss: 2.9365 - 55ms/epoch - 55ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.9215 - val_loss: 2.9361 - 57ms/epoch - 57ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.9210 - val_loss: 2.9358 - 64ms/epoch - 64ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.9205 - val_loss: 2.9354 - 72ms/epoch - 72ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.9200 - val_loss: 2.9351 - 58ms/epoch - 58ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.9195 - val_loss: 2.9348 - 64ms/epoch - 64ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.9191 - val_loss: 2.9344 - 66ms/epoch - 66ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.9186 - val_loss: 2.9341 - 65ms/epoch - 65ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.9182 - val_loss: 2.9338 - 67ms/epoch - 67ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.9178 - val_loss: 2.9335 - 57ms/epoch - 57ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.9174 - val_loss: 2.9331 - 55ms/epoch - 55ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 2.9170 - val_loss: 2.9328 - 52ms/epoch - 52ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 2.9166 - val_loss: 2.9325 - 55ms/epoch - 55ms/step\n",
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 2.9162 - val_loss: 2.9323 - 71ms/epoch - 71ms/step\n",
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 2.9158 - val_loss: 2.9320 - 67ms/epoch - 67ms/step\n",
      "Epoch 101/2000\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 - 0s - loss: 2.9154 - val_loss: 2.9317 - 69ms/epoch - 69ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 2.9150 - val_loss: 2.9314 - 68ms/epoch - 68ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 2.9147 - val_loss: 2.9312 - 55ms/epoch - 55ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 2.9143 - val_loss: 2.9309 - 74ms/epoch - 74ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 2.9140 - val_loss: 2.9306 - 64ms/epoch - 64ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 2.9136 - val_loss: 2.9304 - 73ms/epoch - 73ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 2.9132 - val_loss: 2.9301 - 68ms/epoch - 68ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 2.9129 - val_loss: 2.9299 - 57ms/epoch - 57ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 2.9126 - val_loss: 2.9296 - 73ms/epoch - 73ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 2.9122 - val_loss: 2.9294 - 59ms/epoch - 59ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 2.9119 - val_loss: 2.9291 - 69ms/epoch - 69ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 2.9115 - val_loss: 2.9289 - 68ms/epoch - 68ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 2.9112 - val_loss: 2.9286 - 63ms/epoch - 63ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 2.9108 - val_loss: 2.9284 - 60ms/epoch - 60ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 2.9105 - val_loss: 2.9282 - 59ms/epoch - 59ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 2.9102 - val_loss: 2.9279 - 56ms/epoch - 56ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 2.9098 - val_loss: 2.9277 - 56ms/epoch - 56ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 2.9095 - val_loss: 2.9274 - 56ms/epoch - 56ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 2.9091 - val_loss: 2.9272 - 57ms/epoch - 57ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 2.9088 - val_loss: 2.9269 - 56ms/epoch - 56ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 2.9084 - val_loss: 2.9267 - 59ms/epoch - 59ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 2.9081 - val_loss: 2.9264 - 59ms/epoch - 59ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 2.9077 - val_loss: 2.9261 - 58ms/epoch - 58ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 2.9073 - val_loss: 2.9259 - 58ms/epoch - 58ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 2.9070 - val_loss: 2.9256 - 61ms/epoch - 61ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 2.9066 - val_loss: 2.9253 - 75ms/epoch - 75ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 2.9062 - val_loss: 2.9250 - 67ms/epoch - 67ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 2.9059 - val_loss: 2.9247 - 64ms/epoch - 64ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 2.9055 - val_loss: 2.9245 - 59ms/epoch - 59ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 2.9051 - val_loss: 2.9241 - 63ms/epoch - 63ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 2.9047 - val_loss: 2.9238 - 60ms/epoch - 60ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 2.9043 - val_loss: 2.9235 - 67ms/epoch - 67ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 2.9039 - val_loss: 2.9232 - 63ms/epoch - 63ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 2.9034 - val_loss: 2.9229 - 63ms/epoch - 63ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 2.9030 - val_loss: 2.9226 - 62ms/epoch - 62ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 2.9026 - val_loss: 2.9223 - 61ms/epoch - 61ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 2.9021 - val_loss: 2.9219 - 65ms/epoch - 65ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 2.9017 - val_loss: 2.9216 - 75ms/epoch - 75ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 2.9012 - val_loss: 2.9212 - 71ms/epoch - 71ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 2.9007 - val_loss: 2.9208 - 66ms/epoch - 66ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 2.9002 - val_loss: 2.9204 - 63ms/epoch - 63ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 2.8996 - val_loss: 2.9200 - 65ms/epoch - 65ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 2.8990 - val_loss: 2.9195 - 61ms/epoch - 61ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 2.8984 - val_loss: 2.9190 - 63ms/epoch - 63ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 2.8978 - val_loss: 2.9185 - 71ms/epoch - 71ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 2.8971 - val_loss: 2.9179 - 68ms/epoch - 68ms/step\n",
      "Epoch 147/2000\n",
      "1/1 - 0s - loss: 2.8963 - val_loss: 2.9173 - 69ms/epoch - 69ms/step\n",
      "Epoch 148/2000\n",
      "1/1 - 0s - loss: 2.8954 - val_loss: 2.9166 - 69ms/epoch - 69ms/step\n",
      "Epoch 149/2000\n",
      "1/1 - 0s - loss: 2.8945 - val_loss: 2.9159 - 69ms/epoch - 69ms/step\n",
      "Epoch 150/2000\n",
      "1/1 - 0s - loss: 2.8936 - val_loss: 2.9153 - 67ms/epoch - 67ms/step\n",
      "Epoch 151/2000\n",
      "1/1 - 0s - loss: 2.8928 - val_loss: 2.9147 - 75ms/epoch - 75ms/step\n",
      "Epoch 152/2000\n",
      "1/1 - 0s - loss: 2.8919 - val_loss: 2.9140 - 74ms/epoch - 74ms/step\n",
      "Epoch 153/2000\n",
      "1/1 - 0s - loss: 2.8910 - val_loss: 2.9133 - 75ms/epoch - 75ms/step\n",
      "Epoch 154/2000\n",
      "1/1 - 0s - loss: 2.8900 - val_loss: 2.9126 - 76ms/epoch - 76ms/step\n",
      "Epoch 155/2000\n",
      "1/1 - 0s - loss: 2.8892 - val_loss: 2.9118 - 76ms/epoch - 76ms/step\n",
      "Epoch 156/2000\n",
      "1/1 - 0s - loss: 2.8883 - val_loss: 2.9110 - 76ms/epoch - 76ms/step\n",
      "Epoch 157/2000\n",
      "1/1 - 0s - loss: 2.8873 - val_loss: 2.9101 - 69ms/epoch - 69ms/step\n",
      "Epoch 158/2000\n",
      "1/1 - 0s - loss: 2.8863 - val_loss: 2.9091 - 70ms/epoch - 70ms/step\n",
      "Epoch 159/2000\n",
      "1/1 - 0s - loss: 2.8852 - val_loss: 2.9080 - 69ms/epoch - 69ms/step\n",
      "Epoch 160/2000\n",
      "1/1 - 0s - loss: 2.8841 - val_loss: 2.9069 - 70ms/epoch - 70ms/step\n",
      "Epoch 161/2000\n",
      "1/1 - 0s - loss: 2.8830 - val_loss: 2.9057 - 71ms/epoch - 71ms/step\n",
      "Epoch 162/2000\n",
      "1/1 - 0s - loss: 2.8818 - val_loss: 2.9045 - 71ms/epoch - 71ms/step\n",
      "Epoch 163/2000\n",
      "1/1 - 0s - loss: 2.8805 - val_loss: 2.9032 - 71ms/epoch - 71ms/step\n",
      "Epoch 164/2000\n",
      "1/1 - 0s - loss: 2.8792 - val_loss: 2.9019 - 72ms/epoch - 72ms/step\n",
      "Epoch 165/2000\n",
      "1/1 - 0s - loss: 2.8778 - val_loss: 2.9006 - 72ms/epoch - 72ms/step\n",
      "Epoch 166/2000\n",
      "1/1 - 0s - loss: 2.8764 - val_loss: 2.8992 - 72ms/epoch - 72ms/step\n",
      "Epoch 167/2000\n",
      "1/1 - 0s - loss: 2.8749 - val_loss: 2.8977 - 67ms/epoch - 67ms/step\n",
      "Epoch 168/2000\n",
      "1/1 - 0s - loss: 2.8732 - val_loss: 2.8963 - 64ms/epoch - 64ms/step\n",
      "Epoch 169/2000\n",
      "1/1 - 0s - loss: 2.8715 - val_loss: 2.8947 - 70ms/epoch - 70ms/step\n",
      "Epoch 170/2000\n",
      "1/1 - 0s - loss: 2.8697 - val_loss: 2.8931 - 68ms/epoch - 68ms/step\n",
      "Epoch 171/2000\n",
      "1/1 - 0s - loss: 2.8677 - val_loss: 2.8914 - 62ms/epoch - 62ms/step\n",
      "Epoch 172/2000\n",
      "1/1 - 0s - loss: 2.8657 - val_loss: 2.8897 - 62ms/epoch - 62ms/step\n",
      "Epoch 173/2000\n",
      "1/1 - 0s - loss: 2.8636 - val_loss: 2.8878 - 69ms/epoch - 69ms/step\n",
      "Epoch 174/2000\n",
      "1/1 - 0s - loss: 2.8613 - val_loss: 2.8858 - 83ms/epoch - 83ms/step\n",
      "Epoch 175/2000\n",
      "1/1 - 0s - loss: 2.8589 - val_loss: 2.8837 - 68ms/epoch - 68ms/step\n",
      "Epoch 176/2000\n",
      "1/1 - 0s - loss: 2.8563 - val_loss: 2.8814 - 65ms/epoch - 65ms/step\n",
      "Epoch 177/2000\n",
      "1/1 - 0s - loss: 2.8537 - val_loss: 2.8788 - 69ms/epoch - 69ms/step\n",
      "Epoch 178/2000\n",
      "1/1 - 0s - loss: 2.8508 - val_loss: 2.8761 - 64ms/epoch - 64ms/step\n",
      "Epoch 179/2000\n",
      "1/1 - 0s - loss: 2.8478 - val_loss: 2.8733 - 68ms/epoch - 68ms/step\n",
      "Epoch 180/2000\n",
      "1/1 - 0s - loss: 2.8446 - val_loss: 2.8702 - 73ms/epoch - 73ms/step\n",
      "Epoch 181/2000\n",
      "1/1 - 0s - loss: 2.8413 - val_loss: 2.8669 - 78ms/epoch - 78ms/step\n",
      "Epoch 182/2000\n",
      "1/1 - 0s - loss: 2.8377 - val_loss: 2.8635 - 83ms/epoch - 83ms/step\n",
      "Epoch 183/2000\n",
      "1/1 - 0s - loss: 2.8339 - val_loss: 2.8599 - 69ms/epoch - 69ms/step\n",
      "Epoch 184/2000\n",
      "1/1 - 0s - loss: 2.8299 - val_loss: 2.8560 - 62ms/epoch - 62ms/step\n",
      "Epoch 185/2000\n",
      "1/1 - 0s - loss: 2.8257 - val_loss: 2.8521 - 65ms/epoch - 65ms/step\n",
      "Epoch 186/2000\n",
      "1/1 - 0s - loss: 2.8213 - val_loss: 2.8479 - 63ms/epoch - 63ms/step\n",
      "Epoch 187/2000\n",
      "1/1 - 0s - loss: 2.8166 - val_loss: 2.8435 - 74ms/epoch - 74ms/step\n",
      "Epoch 188/2000\n",
      "1/1 - 0s - loss: 2.8117 - val_loss: 2.8390 - 61ms/epoch - 61ms/step\n",
      "Epoch 189/2000\n",
      "1/1 - 0s - loss: 2.8066 - val_loss: 2.8343 - 67ms/epoch - 67ms/step\n",
      "Epoch 190/2000\n",
      "1/1 - 0s - loss: 2.8013 - val_loss: 2.8295 - 79ms/epoch - 79ms/step\n",
      "Epoch 191/2000\n",
      "1/1 - 0s - loss: 2.7959 - val_loss: 2.8246 - 67ms/epoch - 67ms/step\n",
      "Epoch 192/2000\n",
      "1/1 - 0s - loss: 2.7903 - val_loss: 2.8196 - 61ms/epoch - 61ms/step\n",
      "Epoch 193/2000\n",
      "1/1 - 0s - loss: 2.7846 - val_loss: 2.8146 - 70ms/epoch - 70ms/step\n",
      "Epoch 194/2000\n",
      "1/1 - 0s - loss: 2.7788 - val_loss: 2.8094 - 62ms/epoch - 62ms/step\n",
      "Epoch 195/2000\n",
      "1/1 - 0s - loss: 2.7728 - val_loss: 2.8041 - 64ms/epoch - 64ms/step\n",
      "Epoch 196/2000\n",
      "1/1 - 0s - loss: 2.7668 - val_loss: 2.7988 - 77ms/epoch - 77ms/step\n",
      "Epoch 197/2000\n",
      "1/1 - 0s - loss: 2.7608 - val_loss: 2.7954 - 61ms/epoch - 61ms/step\n",
      "Epoch 198/2000\n",
      "1/1 - 0s - loss: 2.7563 - val_loss: 2.8121 - 65ms/epoch - 65ms/step\n",
      "Epoch 199/2000\n",
      "1/1 - 0s - loss: 2.7746 - val_loss: 2.7929 - 74ms/epoch - 74ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200/2000\n",
      "1/1 - 0s - loss: 2.7533 - val_loss: 2.7865 - 65ms/epoch - 65ms/step\n",
      "Epoch 201/2000\n",
      "1/1 - 0s - loss: 2.7456 - val_loss: 2.7822 - 58ms/epoch - 58ms/step\n",
      "Epoch 202/2000\n",
      "1/1 - 0s - loss: 2.7410 - val_loss: 2.7761 - 56ms/epoch - 56ms/step\n",
      "Epoch 203/2000\n",
      "1/1 - 0s - loss: 2.7348 - val_loss: 2.7713 - 75ms/epoch - 75ms/step\n",
      "Epoch 204/2000\n",
      "1/1 - 0s - loss: 2.7302 - val_loss: 2.7666 - 59ms/epoch - 59ms/step\n",
      "Epoch 205/2000\n",
      "1/1 - 0s - loss: 2.7255 - val_loss: 2.7601 - 83ms/epoch - 83ms/step\n",
      "Epoch 206/2000\n",
      "1/1 - 0s - loss: 2.7190 - val_loss: 2.7541 - 61ms/epoch - 61ms/step\n",
      "Epoch 207/2000\n",
      "1/1 - 0s - loss: 2.7131 - val_loss: 2.7491 - 60ms/epoch - 60ms/step\n",
      "Epoch 208/2000\n",
      "1/1 - 0s - loss: 2.7083 - val_loss: 2.7433 - 61ms/epoch - 61ms/step\n",
      "Epoch 209/2000\n",
      "1/1 - 0s - loss: 2.7023 - val_loss: 2.7373 - 61ms/epoch - 61ms/step\n",
      "Epoch 210/2000\n",
      "1/1 - 0s - loss: 2.6958 - val_loss: 2.7329 - 60ms/epoch - 60ms/step\n",
      "Epoch 211/2000\n",
      "1/1 - 0s - loss: 2.6908 - val_loss: 2.7283 - 59ms/epoch - 59ms/step\n",
      "Epoch 212/2000\n",
      "1/1 - 0s - loss: 2.6861 - val_loss: 2.7227 - 60ms/epoch - 60ms/step\n",
      "Epoch 213/2000\n",
      "1/1 - 0s - loss: 2.6807 - val_loss: 2.7168 - 56ms/epoch - 56ms/step\n",
      "Epoch 214/2000\n",
      "1/1 - 0s - loss: 2.6757 - val_loss: 2.7104 - 67ms/epoch - 67ms/step\n",
      "Epoch 215/2000\n",
      "1/1 - 0s - loss: 2.6705 - val_loss: 2.7038 - 57ms/epoch - 57ms/step\n",
      "Epoch 216/2000\n",
      "1/1 - 0s - loss: 2.6649 - val_loss: 2.6985 - 55ms/epoch - 55ms/step\n",
      "Epoch 217/2000\n",
      "1/1 - 0s - loss: 2.6602 - val_loss: 2.6936 - 56ms/epoch - 56ms/step\n",
      "Epoch 218/2000\n",
      "1/1 - 0s - loss: 2.6555 - val_loss: 2.6879 - 55ms/epoch - 55ms/step\n",
      "Epoch 219/2000\n",
      "1/1 - 0s - loss: 2.6497 - val_loss: 2.6829 - 55ms/epoch - 55ms/step\n",
      "Epoch 220/2000\n",
      "1/1 - 0s - loss: 2.6446 - val_loss: 2.6782 - 56ms/epoch - 56ms/step\n",
      "Epoch 221/2000\n",
      "1/1 - 0s - loss: 2.6400 - val_loss: 2.6725 - 56ms/epoch - 56ms/step\n",
      "Epoch 222/2000\n",
      "1/1 - 0s - loss: 2.6346 - val_loss: 2.6664 - 56ms/epoch - 56ms/step\n",
      "Epoch 223/2000\n",
      "1/1 - 0s - loss: 2.6295 - val_loss: 2.6604 - 55ms/epoch - 55ms/step\n",
      "Epoch 224/2000\n",
      "1/1 - 0s - loss: 2.6244 - val_loss: 2.6544 - 64ms/epoch - 64ms/step\n",
      "Epoch 225/2000\n",
      "1/1 - 0s - loss: 2.6191 - val_loss: 2.6492 - 56ms/epoch - 56ms/step\n",
      "Epoch 226/2000\n",
      "1/1 - 0s - loss: 2.6142 - val_loss: 2.6443 - 55ms/epoch - 55ms/step\n",
      "Epoch 227/2000\n",
      "1/1 - 0s - loss: 2.6093 - val_loss: 2.6389 - 55ms/epoch - 55ms/step\n",
      "Epoch 228/2000\n",
      "1/1 - 0s - loss: 2.6044 - val_loss: 2.6329 - 50ms/epoch - 50ms/step\n",
      "Epoch 229/2000\n",
      "1/1 - 0s - loss: 2.5994 - val_loss: 2.6265 - 51ms/epoch - 51ms/step\n",
      "Epoch 230/2000\n",
      "1/1 - 0s - loss: 2.5941 - val_loss: 2.6207 - 47ms/epoch - 47ms/step\n",
      "Epoch 231/2000\n",
      "1/1 - 0s - loss: 2.5892 - val_loss: 2.6152 - 50ms/epoch - 50ms/step\n",
      "Epoch 232/2000\n",
      "1/1 - 0s - loss: 2.5839 - val_loss: 2.6101 - 62ms/epoch - 62ms/step\n",
      "Epoch 233/2000\n",
      "1/1 - 0s - loss: 2.5788 - val_loss: 2.6049 - 55ms/epoch - 55ms/step\n",
      "Epoch 234/2000\n",
      "1/1 - 0s - loss: 2.5740 - val_loss: 2.5991 - 81ms/epoch - 81ms/step\n",
      "Epoch 235/2000\n",
      "1/1 - 0s - loss: 2.5689 - val_loss: 2.5934 - 55ms/epoch - 55ms/step\n",
      "Epoch 236/2000\n",
      "1/1 - 0s - loss: 2.5640 - val_loss: 2.5879 - 54ms/epoch - 54ms/step\n",
      "Epoch 237/2000\n",
      "1/1 - 0s - loss: 2.5590 - val_loss: 2.5831 - 55ms/epoch - 55ms/step\n",
      "Epoch 238/2000\n",
      "1/1 - 0s - loss: 2.5542 - val_loss: 2.5784 - 55ms/epoch - 55ms/step\n",
      "Epoch 239/2000\n",
      "1/1 - 0s - loss: 2.5496 - val_loss: 2.5732 - 59ms/epoch - 59ms/step\n",
      "Epoch 240/2000\n",
      "1/1 - 0s - loss: 2.5449 - val_loss: 2.5678 - 57ms/epoch - 57ms/step\n",
      "Epoch 241/2000\n",
      "1/1 - 0s - loss: 2.5402 - val_loss: 2.5626 - 56ms/epoch - 56ms/step\n",
      "Epoch 242/2000\n",
      "1/1 - 0s - loss: 2.5356 - val_loss: 2.5577 - 55ms/epoch - 55ms/step\n",
      "Epoch 243/2000\n",
      "1/1 - 0s - loss: 2.5309 - val_loss: 2.5529 - 62ms/epoch - 62ms/step\n",
      "Epoch 244/2000\n",
      "1/1 - 0s - loss: 2.5263 - val_loss: 2.5480 - 54ms/epoch - 54ms/step\n",
      "Epoch 245/2000\n",
      "1/1 - 0s - loss: 2.5218 - val_loss: 2.5432 - 54ms/epoch - 54ms/step\n",
      "Epoch 246/2000\n",
      "1/1 - 0s - loss: 2.5173 - val_loss: 2.5385 - 52ms/epoch - 52ms/step\n",
      "Epoch 247/2000\n",
      "1/1 - 0s - loss: 2.5128 - val_loss: 2.5342 - 59ms/epoch - 59ms/step\n",
      "Epoch 248/2000\n",
      "1/1 - 0s - loss: 2.5084 - val_loss: 2.5299 - 59ms/epoch - 59ms/step\n",
      "Epoch 249/2000\n",
      "1/1 - 0s - loss: 2.5040 - val_loss: 2.5256 - 45ms/epoch - 45ms/step\n",
      "Epoch 250/2000\n",
      "1/1 - 0s - loss: 2.4996 - val_loss: 2.5211 - 49ms/epoch - 49ms/step\n",
      "Epoch 251/2000\n",
      "1/1 - 0s - loss: 2.4954 - val_loss: 2.5167 - 58ms/epoch - 58ms/step\n",
      "Epoch 252/2000\n",
      "1/1 - 0s - loss: 2.4911 - val_loss: 2.5126 - 51ms/epoch - 51ms/step\n",
      "Epoch 253/2000\n",
      "1/1 - 0s - loss: 2.4870 - val_loss: 2.5083 - 47ms/epoch - 47ms/step\n",
      "Epoch 254/2000\n",
      "1/1 - 0s - loss: 2.4828 - val_loss: 2.5039 - 44ms/epoch - 44ms/step\n",
      "Epoch 255/2000\n",
      "1/1 - 0s - loss: 2.4786 - val_loss: 2.4998 - 55ms/epoch - 55ms/step\n",
      "Epoch 256/2000\n",
      "1/1 - 0s - loss: 2.4745 - val_loss: 2.4958 - 54ms/epoch - 54ms/step\n",
      "Epoch 257/2000\n",
      "1/1 - 0s - loss: 2.4703 - val_loss: 2.4915 - 53ms/epoch - 53ms/step\n",
      "Epoch 258/2000\n",
      "1/1 - 0s - loss: 2.4661 - val_loss: 2.4875 - 55ms/epoch - 55ms/step\n",
      "Epoch 259/2000\n",
      "1/1 - 0s - loss: 2.4620 - val_loss: 2.4836 - 54ms/epoch - 54ms/step\n",
      "Epoch 260/2000\n",
      "1/1 - 0s - loss: 2.4579 - val_loss: 2.4796 - 53ms/epoch - 53ms/step\n",
      "Epoch 261/2000\n",
      "1/1 - 0s - loss: 2.4539 - val_loss: 2.4755 - 54ms/epoch - 54ms/step\n",
      "Epoch 262/2000\n",
      "1/1 - 0s - loss: 2.4498 - val_loss: 2.4712 - 60ms/epoch - 60ms/step\n",
      "Epoch 263/2000\n",
      "1/1 - 0s - loss: 2.4458 - val_loss: 2.4672 - 62ms/epoch - 62ms/step\n",
      "Epoch 264/2000\n",
      "1/1 - 0s - loss: 2.4417 - val_loss: 2.4631 - 61ms/epoch - 61ms/step\n",
      "Epoch 265/2000\n",
      "1/1 - 0s - loss: 2.4377 - val_loss: 2.4589 - 59ms/epoch - 59ms/step\n",
      "Epoch 266/2000\n",
      "1/1 - 0s - loss: 2.4337 - val_loss: 2.4548 - 55ms/epoch - 55ms/step\n",
      "Epoch 267/2000\n",
      "1/1 - 0s - loss: 2.4297 - val_loss: 2.4507 - 56ms/epoch - 56ms/step\n",
      "Epoch 268/2000\n",
      "1/1 - 0s - loss: 2.4257 - val_loss: 2.4466 - 60ms/epoch - 60ms/step\n",
      "Epoch 269/2000\n",
      "1/1 - 0s - loss: 2.4216 - val_loss: 2.4424 - 60ms/epoch - 60ms/step\n",
      "Epoch 270/2000\n",
      "1/1 - 0s - loss: 2.4177 - val_loss: 2.4386 - 73ms/epoch - 73ms/step\n",
      "Epoch 271/2000\n",
      "1/1 - 0s - loss: 2.4139 - val_loss: 2.4347 - 61ms/epoch - 61ms/step\n",
      "Epoch 272/2000\n",
      "1/1 - 0s - loss: 2.4106 - val_loss: 2.4315 - 73ms/epoch - 73ms/step\n",
      "Epoch 273/2000\n",
      "1/1 - 0s - loss: 2.4072 - val_loss: 2.4266 - 63ms/epoch - 63ms/step\n",
      "Epoch 274/2000\n",
      "1/1 - 0s - loss: 2.4032 - val_loss: 2.4219 - 63ms/epoch - 63ms/step\n",
      "Epoch 275/2000\n",
      "1/1 - 0s - loss: 2.3980 - val_loss: 2.4181 - 63ms/epoch - 63ms/step\n",
      "Epoch 276/2000\n",
      "1/1 - 0s - loss: 2.3946 - val_loss: 2.4146 - 67ms/epoch - 67ms/step\n",
      "Epoch 277/2000\n",
      "1/1 - 0s - loss: 2.3917 - val_loss: 2.4101 - 58ms/epoch - 58ms/step\n",
      "Epoch 278/2000\n",
      "1/1 - 0s - loss: 2.3870 - val_loss: 2.4058 - 64ms/epoch - 64ms/step\n",
      "Epoch 279/2000\n",
      "1/1 - 0s - loss: 2.3830 - val_loss: 2.4026 - 64ms/epoch - 64ms/step\n",
      "Epoch 280/2000\n",
      "1/1 - 0s - loss: 2.3800 - val_loss: 2.3983 - 66ms/epoch - 66ms/step\n",
      "Epoch 281/2000\n",
      "1/1 - 0s - loss: 2.3757 - val_loss: 2.3938 - 68ms/epoch - 68ms/step\n",
      "Epoch 282/2000\n",
      "1/1 - 0s - loss: 2.3715 - val_loss: 2.3903 - 64ms/epoch - 64ms/step\n",
      "Epoch 283/2000\n",
      "1/1 - 0s - loss: 2.3682 - val_loss: 2.3865 - 64ms/epoch - 64ms/step\n",
      "Epoch 284/2000\n",
      "1/1 - 0s - loss: 2.3646 - val_loss: 2.3823 - 66ms/epoch - 66ms/step\n",
      "Epoch 285/2000\n",
      "1/1 - 0s - loss: 2.3604 - val_loss: 2.3781 - 61ms/epoch - 61ms/step\n",
      "Epoch 286/2000\n",
      "1/1 - 0s - loss: 2.3565 - val_loss: 2.3746 - 60ms/epoch - 60ms/step\n",
      "Epoch 287/2000\n",
      "1/1 - 0s - loss: 2.3531 - val_loss: 2.3706 - 59ms/epoch - 59ms/step\n",
      "Epoch 288/2000\n",
      "1/1 - 0s - loss: 2.3495 - val_loss: 2.3667 - 60ms/epoch - 60ms/step\n",
      "Epoch 289/2000\n",
      "1/1 - 0s - loss: 2.3455 - val_loss: 2.3623 - 68ms/epoch - 68ms/step\n",
      "Epoch 290/2000\n",
      "1/1 - 0s - loss: 2.3416 - val_loss: 2.3593 - 67ms/epoch - 67ms/step\n",
      "Epoch 291/2000\n",
      "1/1 - 0s - loss: 2.3383 - val_loss: 2.3557 - 67ms/epoch - 67ms/step\n",
      "Epoch 292/2000\n",
      "1/1 - 0s - loss: 2.3352 - val_loss: 2.3523 - 69ms/epoch - 69ms/step\n",
      "Epoch 293/2000\n",
      "1/1 - 0s - loss: 2.3319 - val_loss: 2.3472 - 75ms/epoch - 75ms/step\n",
      "Epoch 294/2000\n",
      "1/1 - 0s - loss: 2.3269 - val_loss: 2.3429 - 71ms/epoch - 71ms/step\n",
      "Epoch 295/2000\n",
      "1/1 - 0s - loss: 2.3227 - val_loss: 2.3394 - 62ms/epoch - 62ms/step\n",
      "Epoch 296/2000\n",
      "1/1 - 0s - loss: 2.3197 - val_loss: 2.3361 - 65ms/epoch - 65ms/step\n",
      "Epoch 297/2000\n",
      "1/1 - 0s - loss: 2.3164 - val_loss: 2.3320 - 85ms/epoch - 85ms/step\n",
      "Epoch 298/2000\n",
      "1/1 - 0s - loss: 2.3125 - val_loss: 2.3272 - 84ms/epoch - 84ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 299/2000\n",
      "1/1 - 0s - loss: 2.3080 - val_loss: 2.3234 - 76ms/epoch - 76ms/step\n",
      "Epoch 300/2000\n",
      "1/1 - 0s - loss: 2.3043 - val_loss: 2.3203 - 74ms/epoch - 74ms/step\n",
      "Epoch 301/2000\n",
      "1/1 - 0s - loss: 2.3013 - val_loss: 2.3166 - 76ms/epoch - 76ms/step\n",
      "Epoch 302/2000\n",
      "1/1 - 0s - loss: 2.2981 - val_loss: 2.3130 - 82ms/epoch - 82ms/step\n",
      "Epoch 303/2000\n",
      "1/1 - 0s - loss: 2.2943 - val_loss: 2.3084 - 85ms/epoch - 85ms/step\n",
      "Epoch 304/2000\n",
      "1/1 - 0s - loss: 2.2900 - val_loss: 2.3048 - 81ms/epoch - 81ms/step\n",
      "Epoch 305/2000\n",
      "1/1 - 0s - loss: 2.2864 - val_loss: 2.3027 - 78ms/epoch - 78ms/step\n",
      "Epoch 306/2000\n",
      "1/1 - 0s - loss: 2.2846 - val_loss: 2.3037 - 79ms/epoch - 79ms/step\n",
      "Epoch 307/2000\n",
      "1/1 - 0s - loss: 2.2856 - val_loss: 2.3064 - 104ms/epoch - 104ms/step\n",
      "Epoch 308/2000\n",
      "1/1 - 0s - loss: 2.2894 - val_loss: 2.2949 - 69ms/epoch - 69ms/step\n",
      "Epoch 309/2000\n",
      "1/1 - 0s - loss: 2.2767 - val_loss: 2.2899 - 74ms/epoch - 74ms/step\n",
      "Epoch 310/2000\n",
      "1/1 - 0s - loss: 2.2726 - val_loss: 2.2908 - 68ms/epoch - 68ms/step\n",
      "Epoch 311/2000\n",
      "1/1 - 0s - loss: 2.2741 - val_loss: 2.2819 - 81ms/epoch - 81ms/step\n",
      "Epoch 312/2000\n",
      "1/1 - 0s - loss: 2.2645 - val_loss: 2.2827 - 67ms/epoch - 67ms/step\n",
      "Epoch 313/2000\n",
      "1/1 - 0s - loss: 2.2661 - val_loss: 2.2768 - 64ms/epoch - 64ms/step\n",
      "Epoch 314/2000\n",
      "1/1 - 0s - loss: 2.2607 - val_loss: 2.2740 - 65ms/epoch - 65ms/step\n",
      "Epoch 315/2000\n",
      "1/1 - 0s - loss: 2.2572 - val_loss: 2.2722 - 77ms/epoch - 77ms/step\n",
      "Epoch 316/2000\n",
      "1/1 - 0s - loss: 2.2559 - val_loss: 2.2656 - 73ms/epoch - 73ms/step\n",
      "Epoch 317/2000\n",
      "1/1 - 0s - loss: 2.2501 - val_loss: 2.2656 - 68ms/epoch - 68ms/step\n",
      "Epoch 318/2000\n",
      "1/1 - 0s - loss: 2.2494 - val_loss: 2.2611 - 72ms/epoch - 72ms/step\n",
      "Epoch 319/2000\n",
      "1/1 - 0s - loss: 2.2449 - val_loss: 2.2579 - 65ms/epoch - 65ms/step\n",
      "Epoch 320/2000\n",
      "1/1 - 0s - loss: 2.2428 - val_loss: 2.2551 - 69ms/epoch - 69ms/step\n",
      "Epoch 321/2000\n",
      "1/1 - 0s - loss: 2.2394 - val_loss: 2.2527 - 66ms/epoch - 66ms/step\n",
      "Epoch 322/2000\n",
      "1/1 - 0s - loss: 2.2366 - val_loss: 2.2494 - 66ms/epoch - 66ms/step\n",
      "Epoch 323/2000\n",
      "1/1 - 0s - loss: 2.2342 - val_loss: 2.2456 - 63ms/epoch - 63ms/step\n",
      "Epoch 324/2000\n",
      "1/1 - 0s - loss: 2.2303 - val_loss: 2.2449 - 63ms/epoch - 63ms/step\n",
      "Epoch 325/2000\n",
      "1/1 - 0s - loss: 2.2289 - val_loss: 2.2403 - 63ms/epoch - 63ms/step\n",
      "Epoch 326/2000\n",
      "1/1 - 0s - loss: 2.2249 - val_loss: 2.2378 - 70ms/epoch - 70ms/step\n",
      "Epoch 327/2000\n",
      "1/1 - 0s - loss: 2.2232 - val_loss: 2.2359 - 67ms/epoch - 67ms/step\n",
      "Epoch 328/2000\n",
      "1/1 - 0s - loss: 2.2203 - val_loss: 2.2325 - 82ms/epoch - 82ms/step\n",
      "Epoch 329/2000\n",
      "1/1 - 0s - loss: 2.2172 - val_loss: 2.2301 - 65ms/epoch - 65ms/step\n",
      "Epoch 330/2000\n",
      "1/1 - 0s - loss: 2.2157 - val_loss: 2.2269 - 62ms/epoch - 62ms/step\n",
      "Epoch 331/2000\n",
      "1/1 - 0s - loss: 2.2120 - val_loss: 2.2253 - 64ms/epoch - 64ms/step\n",
      "Epoch 332/2000\n",
      "1/1 - 0s - loss: 2.2103 - val_loss: 2.2220 - 64ms/epoch - 64ms/step\n",
      "Epoch 333/2000\n",
      "1/1 - 0s - loss: 2.2078 - val_loss: 2.2190 - 70ms/epoch - 70ms/step\n",
      "Epoch 334/2000\n",
      "1/1 - 0s - loss: 2.2046 - val_loss: 2.2180 - 64ms/epoch - 64ms/step\n",
      "Epoch 335/2000\n",
      "1/1 - 0s - loss: 2.2031 - val_loss: 2.2143 - 70ms/epoch - 70ms/step\n",
      "Epoch 336/2000\n",
      "1/1 - 0s - loss: 2.2003 - val_loss: 2.2119 - 67ms/epoch - 67ms/step\n",
      "Epoch 337/2000\n",
      "1/1 - 0s - loss: 2.1977 - val_loss: 2.2107 - 67ms/epoch - 67ms/step\n",
      "Epoch 338/2000\n",
      "1/1 - 0s - loss: 2.1962 - val_loss: 2.2081 - 66ms/epoch - 66ms/step\n",
      "Epoch 339/2000\n",
      "1/1 - 0s - loss: 2.1940 - val_loss: 2.2071 - 60ms/epoch - 60ms/step\n",
      "Epoch 340/2000\n",
      "1/1 - 0s - loss: 2.1934 - val_loss: 2.2091 - 61ms/epoch - 61ms/step\n",
      "Epoch 341/2000\n",
      "1/1 - 0s - loss: 2.1948 - val_loss: 2.2081 - 58ms/epoch - 58ms/step\n",
      "Epoch 342/2000\n",
      "1/1 - 0s - loss: 2.1945 - val_loss: 2.2036 - 65ms/epoch - 65ms/step\n",
      "Epoch 343/2000\n",
      "1/1 - 0s - loss: 2.1897 - val_loss: 2.1978 - 67ms/epoch - 67ms/step\n",
      "Epoch 344/2000\n",
      "1/1 - 0s - loss: 2.1834 - val_loss: 2.1973 - 66ms/epoch - 66ms/step\n",
      "Epoch 345/2000\n",
      "1/1 - 0s - loss: 2.1833 - val_loss: 2.1962 - 65ms/epoch - 65ms/step\n",
      "Epoch 346/2000\n",
      "1/1 - 0s - loss: 2.1824 - val_loss: 2.1909 - 67ms/epoch - 67ms/step\n",
      "Epoch 347/2000\n",
      "1/1 - 0s - loss: 2.1767 - val_loss: 2.1909 - 68ms/epoch - 68ms/step\n",
      "Epoch 348/2000\n",
      "1/1 - 0s - loss: 2.1767 - val_loss: 2.1888 - 72ms/epoch - 72ms/step\n",
      "Epoch 349/2000\n",
      "1/1 - 0s - loss: 2.1750 - val_loss: 2.1842 - 58ms/epoch - 58ms/step\n",
      "Epoch 350/2000\n",
      "1/1 - 0s - loss: 2.1705 - val_loss: 2.1845 - 61ms/epoch - 61ms/step\n",
      "Epoch 351/2000\n",
      "1/1 - 0s - loss: 2.1703 - val_loss: 2.1824 - 61ms/epoch - 61ms/step\n",
      "Epoch 352/2000\n",
      "1/1 - 0s - loss: 2.1684 - val_loss: 2.1790 - 61ms/epoch - 61ms/step\n",
      "Epoch 353/2000\n",
      "1/1 - 0s - loss: 2.1652 - val_loss: 2.1780 - 73ms/epoch - 73ms/step\n",
      "Epoch 354/2000\n",
      "1/1 - 0s - loss: 2.1635 - val_loss: 2.1770 - 62ms/epoch - 62ms/step\n",
      "Epoch 355/2000\n",
      "1/1 - 0s - loss: 2.1625 - val_loss: 2.1739 - 61ms/epoch - 61ms/step\n",
      "Epoch 356/2000\n",
      "1/1 - 0s - loss: 2.1600 - val_loss: 2.1716 - 55ms/epoch - 55ms/step\n",
      "Epoch 357/2000\n",
      "1/1 - 0s - loss: 2.1572 - val_loss: 2.1711 - 58ms/epoch - 58ms/step\n",
      "Epoch 358/2000\n",
      "1/1 - 0s - loss: 2.1567 - val_loss: 2.1684 - 60ms/epoch - 60ms/step\n",
      "Epoch 359/2000\n",
      "1/1 - 0s - loss: 2.1548 - val_loss: 2.1660 - 63ms/epoch - 63ms/step\n",
      "Epoch 360/2000\n",
      "1/1 - 0s - loss: 2.1519 - val_loss: 2.1651 - 63ms/epoch - 63ms/step\n",
      "Epoch 361/2000\n",
      "1/1 - 0s - loss: 2.1508 - val_loss: 2.1631 - 67ms/epoch - 67ms/step\n",
      "Epoch 362/2000\n",
      "1/1 - 0s - loss: 2.1495 - val_loss: 2.1612 - 57ms/epoch - 57ms/step\n",
      "Epoch 363/2000\n",
      "1/1 - 0s - loss: 2.1473 - val_loss: 2.1596 - 60ms/epoch - 60ms/step\n",
      "Epoch 364/2000\n",
      "1/1 - 0s - loss: 2.1454 - val_loss: 2.1574 - 67ms/epoch - 67ms/step\n",
      "Epoch 365/2000\n",
      "1/1 - 0s - loss: 2.1438 - val_loss: 2.1562 - 72ms/epoch - 72ms/step\n",
      "Epoch 366/2000\n",
      "1/1 - 0s - loss: 2.1424 - val_loss: 2.1549 - 59ms/epoch - 59ms/step\n",
      "Epoch 367/2000\n",
      "1/1 - 0s - loss: 2.1409 - val_loss: 2.1526 - 70ms/epoch - 70ms/step\n",
      "Epoch 368/2000\n",
      "1/1 - 0s - loss: 2.1389 - val_loss: 2.1507 - 71ms/epoch - 71ms/step\n",
      "Epoch 369/2000\n",
      "1/1 - 0s - loss: 2.1369 - val_loss: 2.1496 - 76ms/epoch - 76ms/step\n",
      "Epoch 370/2000\n",
      "1/1 - 0s - loss: 2.1357 - val_loss: 2.1481 - 61ms/epoch - 61ms/step\n",
      "Epoch 371/2000\n",
      "1/1 - 0s - loss: 2.1345 - val_loss: 2.1466 - 61ms/epoch - 61ms/step\n",
      "Epoch 372/2000\n",
      "1/1 - 0s - loss: 2.1329 - val_loss: 2.1453 - 61ms/epoch - 61ms/step\n",
      "Epoch 373/2000\n",
      "1/1 - 0s - loss: 2.1314 - val_loss: 2.1432 - 61ms/epoch - 61ms/step\n",
      "Epoch 374/2000\n",
      "1/1 - 0s - loss: 2.1297 - val_loss: 2.1417 - 59ms/epoch - 59ms/step\n",
      "Epoch 375/2000\n",
      "1/1 - 0s - loss: 2.1279 - val_loss: 2.1400 - 59ms/epoch - 59ms/step\n",
      "Epoch 376/2000\n",
      "1/1 - 0s - loss: 2.1263 - val_loss: 2.1385 - 62ms/epoch - 62ms/step\n",
      "Epoch 377/2000\n",
      "1/1 - 0s - loss: 2.1249 - val_loss: 2.1376 - 58ms/epoch - 58ms/step\n",
      "Epoch 378/2000\n",
      "1/1 - 0s - loss: 2.1237 - val_loss: 2.1365 - 58ms/epoch - 58ms/step\n",
      "Epoch 379/2000\n",
      "1/1 - 0s - loss: 2.1230 - val_loss: 2.1388 - 65ms/epoch - 65ms/step\n",
      "Epoch 380/2000\n",
      "1/1 - 0s - loss: 2.1250 - val_loss: 2.1528 - 51ms/epoch - 51ms/step\n",
      "Epoch 381/2000\n",
      "1/1 - 0s - loss: 2.1392 - val_loss: 2.1813 - 49ms/epoch - 49ms/step\n",
      "Epoch 382/2000\n",
      "1/1 - 0s - loss: 2.1681 - val_loss: 2.1905 - 50ms/epoch - 50ms/step\n",
      "Epoch 383/2000\n",
      "1/1 - 0s - loss: 2.1780 - val_loss: 2.1412 - 68ms/epoch - 68ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 10\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "sinusoidal_embeddings = get_sinusoidal_embeddings(len(x_masked_train[0]), embed_dim)\n",
    "encoder_output = word_embeddings + sinusoidal_embeddings\n",
    "\n",
    "for i in range(1):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "aa33aaff",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fe4d2103",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0, 0.0027104116)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0005b730",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 1s - loss: 3.7518 - val_loss: 3.3174 - 1s/epoch - 1s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.3064 - val_loss: 3.1280 - 181ms/epoch - 181ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.1152 - val_loss: 3.0649 - 187ms/epoch - 187ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.0495 - val_loss: 3.0413 - 172ms/epoch - 172ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.0222 - val_loss: 3.0265 - 175ms/epoch - 175ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.0036 - val_loss: 3.0128 - 196ms/epoch - 196ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 2.9867 - val_loss: 2.9943 - 206ms/epoch - 206ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 2.9660 - val_loss: 2.9721 - 190ms/epoch - 190ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 2.9425 - val_loss: 2.9519 - 161ms/epoch - 161ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 2.9212 - val_loss: 2.9382 - 123ms/epoch - 123ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 2.9065 - val_loss: 2.9314 - 137ms/epoch - 137ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 2.8984 - val_loss: 2.9286 - 136ms/epoch - 136ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 2.8942 - val_loss: 2.9267 - 168ms/epoch - 168ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 2.8910 - val_loss: 2.9231 - 134ms/epoch - 134ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 2.8862 - val_loss: 2.9170 - 129ms/epoch - 129ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 2.8788 - val_loss: 2.9094 - 131ms/epoch - 131ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 2.8700 - val_loss: 2.9017 - 136ms/epoch - 136ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 2.8614 - val_loss: 2.8951 - 115ms/epoch - 115ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 2.8542 - val_loss: 2.8898 - 154ms/epoch - 154ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 2.8486 - val_loss: 2.8856 - 118ms/epoch - 118ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 2.8443 - val_loss: 2.8815 - 157ms/epoch - 157ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 2.8403 - val_loss: 2.8770 - 117ms/epoch - 117ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 2.8359 - val_loss: 2.8719 - 150ms/epoch - 150ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 2.8309 - val_loss: 2.8664 - 119ms/epoch - 119ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 2.8252 - val_loss: 2.8608 - 119ms/epoch - 119ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 2.8190 - val_loss: 2.8552 - 130ms/epoch - 130ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 2.8127 - val_loss: 2.8500 - 128ms/epoch - 128ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.8065 - val_loss: 2.8449 - 117ms/epoch - 117ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.8004 - val_loss: 2.8396 - 121ms/epoch - 121ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.7942 - val_loss: 2.8343 - 128ms/epoch - 128ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.7881 - val_loss: 2.8292 - 117ms/epoch - 117ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.7819 - val_loss: 2.8242 - 109ms/epoch - 109ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.7756 - val_loss: 2.8194 - 121ms/epoch - 121ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.7694 - val_loss: 2.8142 - 121ms/epoch - 121ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.7633 - val_loss: 2.8090 - 112ms/epoch - 112ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.7574 - val_loss: 2.8040 - 128ms/epoch - 128ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.7518 - val_loss: 2.7992 - 120ms/epoch - 120ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.7466 - val_loss: 2.7938 - 118ms/epoch - 118ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.7413 - val_loss: 2.7884 - 128ms/epoch - 128ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.7363 - val_loss: 2.7830 - 127ms/epoch - 127ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.7313 - val_loss: 2.7774 - 118ms/epoch - 118ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.7264 - val_loss: 2.7725 - 124ms/epoch - 124ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.7219 - val_loss: 2.7680 - 124ms/epoch - 124ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.7176 - val_loss: 2.7639 - 139ms/epoch - 139ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.7133 - val_loss: 2.7601 - 136ms/epoch - 136ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.7092 - val_loss: 2.7566 - 147ms/epoch - 147ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.7049 - val_loss: 2.7522 - 130ms/epoch - 130ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.6997 - val_loss: 2.7474 - 127ms/epoch - 127ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.6939 - val_loss: 2.7422 - 128ms/epoch - 128ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.6878 - val_loss: 2.7361 - 139ms/epoch - 139ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.6810 - val_loss: 2.7291 - 133ms/epoch - 133ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.6734 - val_loss: 2.7212 - 141ms/epoch - 141ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.6640 - val_loss: 2.7125 - 148ms/epoch - 148ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.6542 - val_loss: 2.7027 - 127ms/epoch - 127ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.6439 - val_loss: 2.6918 - 123ms/epoch - 123ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.6311 - val_loss: 2.6768 - 126ms/epoch - 126ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.6158 - val_loss: 2.6617 - 137ms/epoch - 137ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.5997 - val_loss: 2.6466 - 128ms/epoch - 128ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.5836 - val_loss: 2.6315 - 125ms/epoch - 125ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.5691 - val_loss: 2.6191 - 153ms/epoch - 153ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.5547 - val_loss: 2.5955 - 128ms/epoch - 128ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.5323 - val_loss: 2.5740 - 126ms/epoch - 126ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.5099 - val_loss: 2.5647 - 164ms/epoch - 164ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.4989 - val_loss: 2.5416 - 195ms/epoch - 195ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.4768 - val_loss: 2.5181 - 220ms/epoch - 220ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.4525 - val_loss: 2.5067 - 173ms/epoch - 173ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.4390 - val_loss: 2.4820 - 196ms/epoch - 196ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.4151 - val_loss: 2.4559 - 173ms/epoch - 173ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.3893 - val_loss: 2.4383 - 182ms/epoch - 182ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.3704 - val_loss: 2.4108 - 165ms/epoch - 165ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.3436 - val_loss: 2.3812 - 151ms/epoch - 151ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.3149 - val_loss: 2.3595 - 184ms/epoch - 184ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.2930 - val_loss: 2.3324 - 141ms/epoch - 141ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.2665 - val_loss: 2.3046 - 155ms/epoch - 155ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.2388 - val_loss: 2.2827 - 127ms/epoch - 127ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.2165 - val_loss: 2.2590 - 123ms/epoch - 123ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.1928 - val_loss: 2.2350 - 124ms/epoch - 124ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.1684 - val_loss: 2.2160 - 151ms/epoch - 151ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.1488 - val_loss: 2.1973 - 148ms/epoch - 148ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.1290 - val_loss: 2.1799 - 156ms/epoch - 156ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.1107 - val_loss: 2.1666 - 150ms/epoch - 150ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.0967 - val_loss: 2.1543 - 151ms/epoch - 151ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.0835 - val_loss: 2.1442 - 153ms/epoch - 153ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.0734 - val_loss: 2.1330 - 170ms/epoch - 170ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.0636 - val_loss: 2.1237 - 151ms/epoch - 151ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.0553 - val_loss: 2.1151 - 150ms/epoch - 150ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.0470 - val_loss: 2.1075 - 163ms/epoch - 163ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.0401 - val_loss: 2.0992 - 147ms/epoch - 147ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.0334 - val_loss: 2.0927 - 159ms/epoch - 159ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.0280 - val_loss: 2.0874 - 154ms/epoch - 154ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.0229 - val_loss: 2.0822 - 152ms/epoch - 152ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.0185 - val_loss: 2.0767 - 146ms/epoch - 146ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.0144 - val_loss: 2.0719 - 146ms/epoch - 146ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.0103 - val_loss: 2.0680 - 141ms/epoch - 141ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.0068 - val_loss: 2.0636 - 135ms/epoch - 135ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.0030 - val_loss: 2.0601 - 141ms/epoch - 141ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 1.9998 - val_loss: 2.0571 - 177ms/epoch - 177ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 1.9966 - val_loss: 2.0545 - 176ms/epoch - 176ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 1.9941 - val_loss: 2.0517 - 176ms/epoch - 176ms/step\n",
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 1.9916 - val_loss: 2.0493 - 180ms/epoch - 180ms/step\n",
      "Epoch 101/2000\n",
      "1/1 - 0s - loss: 1.9896 - val_loss: 2.0473 - 130ms/epoch - 130ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 1.9877 - val_loss: 2.0452 - 143ms/epoch - 143ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 1.9861 - val_loss: 2.0430 - 140ms/epoch - 140ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 1.9846 - val_loss: 2.0411 - 129ms/epoch - 129ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 1.9833 - val_loss: 2.0397 - 138ms/epoch - 138ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 1.9821 - val_loss: 2.0383 - 150ms/epoch - 150ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 1.9810 - val_loss: 2.0372 - 159ms/epoch - 159ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 1.9801 - val_loss: 2.0364 - 128ms/epoch - 128ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 1.9792 - val_loss: 2.0358 - 131ms/epoch - 131ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 1.9784 - val_loss: 2.0350 - 130ms/epoch - 130ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 1.9777 - val_loss: 2.0344 - 136ms/epoch - 136ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 1.9770 - val_loss: 2.0338 - 129ms/epoch - 129ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 1.9764 - val_loss: 2.0331 - 126ms/epoch - 126ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 1.9758 - val_loss: 2.0326 - 129ms/epoch - 129ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 1.9753 - val_loss: 2.0321 - 132ms/epoch - 132ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 1.9748 - val_loss: 2.0315 - 130ms/epoch - 130ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 1.9743 - val_loss: 2.0311 - 124ms/epoch - 124ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 1.9739 - val_loss: 2.0308 - 130ms/epoch - 130ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 1.9735 - val_loss: 2.0304 - 178ms/epoch - 178ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 1.9730 - val_loss: 2.0301 - 171ms/epoch - 171ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 1.9727 - val_loss: 2.0298 - 134ms/epoch - 134ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 1.9723 - val_loss: 2.0295 - 142ms/epoch - 142ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 1.9719 - val_loss: 2.0293 - 128ms/epoch - 128ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 1.9716 - val_loss: 2.0290 - 182ms/epoch - 182ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 1.9712 - val_loss: 2.0287 - 139ms/epoch - 139ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 1.9709 - val_loss: 2.0284 - 189ms/epoch - 189ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 1.9705 - val_loss: 2.0282 - 145ms/epoch - 145ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 1.9702 - val_loss: 2.0280 - 179ms/epoch - 179ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 1.9699 - val_loss: 2.0279 - 180ms/epoch - 180ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 1.9696 - val_loss: 2.0278 - 181ms/epoch - 181ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 1.9693 - val_loss: 2.0277 - 216ms/epoch - 216ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 1.9690 - val_loss: 2.0276 - 222ms/epoch - 222ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 1.9687 - val_loss: 2.0276 - 249ms/epoch - 249ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 1.9684 - val_loss: 2.0275 - 214ms/epoch - 214ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 1.9681 - val_loss: 2.0274 - 311ms/epoch - 311ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 1.9679 - val_loss: 2.0274 - 298ms/epoch - 298ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 1.9676 - val_loss: 2.0273 - 158ms/epoch - 158ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 1.9674 - val_loss: 2.0272 - 168ms/epoch - 168ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 1.9671 - val_loss: 2.0272 - 226ms/epoch - 226ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 1.9669 - val_loss: 2.0272 - 242ms/epoch - 242ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 1.9666 - val_loss: 2.0272 - 279ms/epoch - 279ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 1.9664 - val_loss: 2.0272 - 317ms/epoch - 317ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 1.9661 - val_loss: 2.0273 - 269ms/epoch - 269ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 1.9659 - val_loss: 2.0274 - 259ms/epoch - 259ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 1.9657 - val_loss: 2.0274 - 296ms/epoch - 296ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 1.9654 - val_loss: 2.0275 - 359ms/epoch - 359ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 100\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "sinusoidal_embeddings = get_sinusoidal_embeddings(len(x_masked_train[0]), embed_dim)\n",
    "encoder_output = word_embeddings + sinusoidal_embeddings\n",
    "\n",
    "for i in range(1):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "beb07cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "dbdd4a69",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0, 0.003079601)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "54f31bac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[10,  0,  1, 10],\n",
       "       [10,  0,  2, 10],\n",
       "       [10,  0,  3, 10],\n",
       "       ...,\n",
       "       [19, 18, 15, 19],\n",
       "       [19, 18, 16, 19],\n",
       "       [19, 18, 17, 19]])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_masked_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "458ae205",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
