{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d1686d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "56c107e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "504a342e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_module(query, key, value, embed_dim, num_head, i):\n",
    "    \n",
    "    # Multi headed self-attention\n",
    "    attention_output = layers.MultiHeadAttention(\n",
    "        num_heads=num_head,\n",
    "        key_dim=embed_dim // num_head,\n",
    "        name=\"encoder_{}/multiheadattention\".format(i)\n",
    "    )(query, key, value, use_causal_mask=True)\n",
    "    \n",
    "    # Add & Normalize\n",
    "    attention_output = layers.Add()([query, attention_output])  # Skip Connection\n",
    "    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)\n",
    "    \n",
    "    # Feedforward network\n",
    "    ff_net = keras.models.Sequential([\n",
    "        layers.Dense(2 * embed_dim, activation='relu', name=\"encoder_{}/ffn_dense_1\".format(i)),\n",
    "        layers.Dense(embed_dim, name=\"encoder_{}/ffn_dense_2\".format(i)),\n",
    "    ])\n",
    "\n",
    "    # Apply Feedforward network\n",
    "    ffn_output = ff_net(attention_output)\n",
    "\n",
    "    # Add & Normalize\n",
    "    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection\n",
    "    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)\n",
    "    \n",
    "    return ffn_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fef2622a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sinusoidal_embeddings(sequence_length, embedding_dim):\n",
    "    position_enc = np.array([\n",
    "        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]\n",
    "        if pos != 0 else np.zeros(embedding_dim)\n",
    "        for pos in range(sequence_length)\n",
    "    ])\n",
    "    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i\n",
    "    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1\n",
    "    return tf.cast(position_enc, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3e5bfd69",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 20 # vocab_size\n",
    "\n",
    "vocabs = ['word_' + str(i) for i in range(N)]\n",
    "\n",
    "vocab_map = {}\n",
    "for i in range(len(vocabs)):\n",
    "    vocab_map[vocabs[i]] = i\n",
    "    \n",
    "pairs = []\n",
    "\n",
    "for i in vocabs:\n",
    "    for j in vocabs:\n",
    "        for k in vocabs:\n",
    "            if i != j and i != k and j != k:\n",
    "                pairs.append((i,j,k))\n",
    "            \n",
    "#indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])\n",
    "\n",
    "# pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]\n",
    "# pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]\n",
    "\n",
    "pairs_train = [x for x in pairs if int(x[0].split('_')[-1]) <= 9]\n",
    "pairs_test = [x for x in pairs if int(x[0].split('_')[-1]) >= 10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d43093fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences_train = []\n",
    "sentences_number_train = []\n",
    "sentences_test = []\n",
    "sentences_number_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    sentences_train.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    sentences_test.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = []\n",
    "y_masked_labels_train = []\n",
    "x_masked_test = []\n",
    "y_masked_labels_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    x_masked_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_train.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    x_masked_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_test.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = np.array(x_masked_train)\n",
    "y_masked_labels_train = np.array(y_masked_labels_train)\n",
    "x_masked_test = np.array(x_masked_test)\n",
    "y_masked_labels_test = np.array(y_masked_labels_test)\n",
    "\n",
    "perm = np.random.permutation(len(x_masked_train))\n",
    "x_masked_train = x_masked_train[perm]\n",
    "y_masked_labels_train = y_masked_labels_train[perm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "13b40f89",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-22 07:07:27.968308: W tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory\n",
      "2024-05-22 07:07:27.968357: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)\n",
      "2024-05-22 07:07:27.968376: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (gl3035.arc-ts.umich.edu): /proc/driver/nvidia/version does not exist\n",
      "2024-05-22 07:07:27.968579: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 4s - loss: 3.1853 - val_loss: 3.1523 - 4s/epoch - 4s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.1355 - val_loss: 3.1253 - 155ms/epoch - 155ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.1051 - val_loss: 3.1015 - 183ms/epoch - 183ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.0807 - val_loss: 3.0812 - 158ms/epoch - 158ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.0590 - val_loss: 3.0637 - 160ms/epoch - 160ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.0404 - val_loss: 3.0481 - 146ms/epoch - 146ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.0248 - val_loss: 3.0354 - 139ms/epoch - 139ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.0119 - val_loss: 3.0251 - 150ms/epoch - 150ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.0010 - val_loss: 3.0164 - 156ms/epoch - 156ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 2.9915 - val_loss: 3.0081 - 166ms/epoch - 166ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 2.9822 - val_loss: 2.9999 - 153ms/epoch - 153ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 2.9730 - val_loss: 2.9920 - 146ms/epoch - 146ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 2.9643 - val_loss: 2.9847 - 122ms/epoch - 122ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 2.9566 - val_loss: 2.9783 - 161ms/epoch - 161ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 2.9502 - val_loss: 2.9725 - 159ms/epoch - 159ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 2.9445 - val_loss: 2.9666 - 149ms/epoch - 149ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 2.9389 - val_loss: 2.9605 - 154ms/epoch - 154ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 2.9332 - val_loss: 2.9542 - 177ms/epoch - 177ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 2.9274 - val_loss: 2.9482 - 156ms/epoch - 156ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 2.9221 - val_loss: 2.9427 - 153ms/epoch - 153ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 2.9170 - val_loss: 2.9373 - 136ms/epoch - 136ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 2.9120 - val_loss: 2.9318 - 161ms/epoch - 161ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 2.9066 - val_loss: 2.9263 - 139ms/epoch - 139ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 2.9010 - val_loss: 2.9209 - 142ms/epoch - 142ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 2.8954 - val_loss: 2.9160 - 127ms/epoch - 127ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 2.8903 - val_loss: 2.9113 - 158ms/epoch - 158ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 2.8854 - val_loss: 2.9064 - 172ms/epoch - 172ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.8802 - val_loss: 2.9015 - 155ms/epoch - 155ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.8749 - val_loss: 2.8967 - 155ms/epoch - 155ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.8697 - val_loss: 2.8920 - 162ms/epoch - 162ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.8646 - val_loss: 2.8875 - 157ms/epoch - 157ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.8596 - val_loss: 2.8831 - 142ms/epoch - 142ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.8546 - val_loss: 2.8788 - 155ms/epoch - 155ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.8498 - val_loss: 2.8745 - 134ms/epoch - 134ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.8451 - val_loss: 2.8701 - 142ms/epoch - 142ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.8403 - val_loss: 2.8655 - 130ms/epoch - 130ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.8353 - val_loss: 2.8610 - 146ms/epoch - 146ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.8304 - val_loss: 2.8566 - 143ms/epoch - 143ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.8258 - val_loss: 2.8523 - 134ms/epoch - 134ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.8212 - val_loss: 2.8481 - 150ms/epoch - 150ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.8167 - val_loss: 2.8439 - 146ms/epoch - 146ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.8123 - val_loss: 2.8401 - 134ms/epoch - 134ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.8081 - val_loss: 2.8365 - 149ms/epoch - 149ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.8042 - val_loss: 2.8329 - 102ms/epoch - 102ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.8003 - val_loss: 2.8293 - 89ms/epoch - 89ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.7963 - val_loss: 2.8258 - 89ms/epoch - 89ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.7924 - val_loss: 2.8225 - 92ms/epoch - 92ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.7886 - val_loss: 2.8194 - 97ms/epoch - 97ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.7849 - val_loss: 2.8163 - 91ms/epoch - 91ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.7813 - val_loss: 2.8131 - 92ms/epoch - 92ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.7777 - val_loss: 2.8101 - 116ms/epoch - 116ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.7745 - val_loss: 2.8070 - 98ms/epoch - 98ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.7713 - val_loss: 2.8039 - 92ms/epoch - 92ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.7679 - val_loss: 2.8007 - 93ms/epoch - 93ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.7647 - val_loss: 2.7971 - 89ms/epoch - 89ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.7614 - val_loss: 2.7937 - 110ms/epoch - 110ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.7583 - val_loss: 2.7905 - 91ms/epoch - 91ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.7550 - val_loss: 2.7875 - 96ms/epoch - 96ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.7520 - val_loss: 2.7843 - 90ms/epoch - 90ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.7488 - val_loss: 2.7813 - 94ms/epoch - 94ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.7458 - val_loss: 2.7785 - 92ms/epoch - 92ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.7428 - val_loss: 2.7754 - 114ms/epoch - 114ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.7398 - val_loss: 2.7721 - 104ms/epoch - 104ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.7367 - val_loss: 2.7689 - 102ms/epoch - 102ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.7337 - val_loss: 2.7659 - 96ms/epoch - 96ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.7308 - val_loss: 2.7629 - 99ms/epoch - 99ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.7280 - val_loss: 2.7599 - 112ms/epoch - 112ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.7251 - val_loss: 2.7570 - 99ms/epoch - 99ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.7224 - val_loss: 2.7539 - 105ms/epoch - 105ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.7195 - val_loss: 2.7510 - 95ms/epoch - 95ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.7167 - val_loss: 2.7478 - 82ms/epoch - 82ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.7138 - val_loss: 2.7450 - 83ms/epoch - 83ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.7110 - val_loss: 2.7423 - 84ms/epoch - 84ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.7082 - val_loss: 2.7393 - 92ms/epoch - 92ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.7053 - val_loss: 2.7364 - 100ms/epoch - 100ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.7025 - val_loss: 2.7336 - 97ms/epoch - 97ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.6996 - val_loss: 2.7308 - 93ms/epoch - 93ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.6967 - val_loss: 2.7280 - 82ms/epoch - 82ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.6938 - val_loss: 2.7255 - 90ms/epoch - 90ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.6909 - val_loss: 2.7227 - 83ms/epoch - 83ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.6880 - val_loss: 2.7200 - 95ms/epoch - 95ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.6850 - val_loss: 2.7171 - 84ms/epoch - 84ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.6820 - val_loss: 2.7140 - 88ms/epoch - 88ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.6791 - val_loss: 2.7109 - 84ms/epoch - 84ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.6761 - val_loss: 2.7079 - 83ms/epoch - 83ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.6731 - val_loss: 2.7049 - 82ms/epoch - 82ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.6701 - val_loss: 2.7020 - 83ms/epoch - 83ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.6670 - val_loss: 2.6988 - 85ms/epoch - 85ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.6639 - val_loss: 2.6962 - 84ms/epoch - 84ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.6609 - val_loss: 2.6929 - 82ms/epoch - 82ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.6580 - val_loss: 2.6906 - 82ms/epoch - 82ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.6549 - val_loss: 2.6869 - 83ms/epoch - 83ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.6518 - val_loss: 2.6849 - 84ms/epoch - 84ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.6488 - val_loss: 2.6811 - 86ms/epoch - 86ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.6460 - val_loss: 2.6795 - 82ms/epoch - 82ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.6430 - val_loss: 2.6753 - 83ms/epoch - 83ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 2.6401 - val_loss: 2.6733 - 83ms/epoch - 83ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 2.6367 - val_loss: 2.6683 - 84ms/epoch - 84ms/step\n",
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 2.6328 - val_loss: 2.6652 - 82ms/epoch - 82ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 2.6294 - val_loss: 2.6625 - 83ms/epoch - 83ms/step\n",
      "Epoch 101/2000\n",
      "1/1 - 0s - loss: 2.6264 - val_loss: 2.6590 - 82ms/epoch - 82ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 2.6237 - val_loss: 2.6589 - 85ms/epoch - 85ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 2.6217 - val_loss: 2.6551 - 81ms/epoch - 81ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 2.6202 - val_loss: 2.6541 - 100ms/epoch - 100ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 2.6167 - val_loss: 2.6474 - 87ms/epoch - 87ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 2.6118 - val_loss: 2.6439 - 84ms/epoch - 84ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 2.6074 - val_loss: 2.6438 - 84ms/epoch - 84ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 2.6062 - val_loss: 2.6407 - 83ms/epoch - 83ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 2.6045 - val_loss: 2.6372 - 83ms/epoch - 83ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 2.5994 - val_loss: 2.6327 - 81ms/epoch - 81ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 2.5953 - val_loss: 2.6305 - 83ms/epoch - 83ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 2.5935 - val_loss: 2.6303 - 82ms/epoch - 82ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 2.5916 - val_loss: 2.6253 - 95ms/epoch - 95ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 2.5880 - val_loss: 2.6209 - 84ms/epoch - 84ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 2.5833 - val_loss: 2.6186 - 83ms/epoch - 83ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 2.5808 - val_loss: 2.6168 - 84ms/epoch - 84ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 2.5795 - val_loss: 2.6160 - 82ms/epoch - 82ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 2.5774 - val_loss: 2.6099 - 82ms/epoch - 82ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 2.5725 - val_loss: 2.6070 - 86ms/epoch - 86ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 2.5688 - val_loss: 2.6061 - 88ms/epoch - 88ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 2.5670 - val_loss: 2.6032 - 82ms/epoch - 82ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 2.5656 - val_loss: 2.6007 - 81ms/epoch - 81ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 2.5613 - val_loss: 2.5952 - 81ms/epoch - 81ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 2.5569 - val_loss: 2.5922 - 88ms/epoch - 88ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 2.5545 - val_loss: 2.5927 - 75ms/epoch - 75ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 2.5529 - val_loss: 2.5901 - 82ms/epoch - 82ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 2.5516 - val_loss: 2.5941 - 75ms/epoch - 75ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 2.5535 - val_loss: 2.5861 - 94ms/epoch - 94ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 2.5474 - val_loss: 2.5813 - 83ms/epoch - 83ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 2.5419 - val_loss: 2.5817 - 78ms/epoch - 78ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 2.5419 - val_loss: 2.5785 - 82ms/epoch - 82ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 2.5397 - val_loss: 2.5749 - 84ms/epoch - 84ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 2.5343 - val_loss: 2.5731 - 82ms/epoch - 82ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 2.5322 - val_loss: 2.5684 - 79ms/epoch - 79ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 2.5288 - val_loss: 2.5663 - 80ms/epoch - 80ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 2.5252 - val_loss: 2.5647 - 83ms/epoch - 83ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 2.5231 - val_loss: 2.5615 - 81ms/epoch - 81ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 2.5218 - val_loss: 2.5614 - 85ms/epoch - 85ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 2.5196 - val_loss: 2.5575 - 83ms/epoch - 83ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 2.5154 - val_loss: 2.5547 - 85ms/epoch - 85ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 2.5127 - val_loss: 2.5548 - 89ms/epoch - 89ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 2.5115 - val_loss: 2.5490 - 86ms/epoch - 86ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 2.5077 - val_loss: 2.5464 - 92ms/epoch - 92ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 2.5041 - val_loss: 2.5451 - 105ms/epoch - 105ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 2.5025 - val_loss: 2.5415 - 96ms/epoch - 96ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 2.4993 - val_loss: 2.5405 - 97ms/epoch - 97ms/step\n",
      "Epoch 147/2000\n",
      "1/1 - 0s - loss: 2.4966 - val_loss: 2.5376 - 96ms/epoch - 96ms/step\n",
      "Epoch 148/2000\n",
      "1/1 - 0s - loss: 2.4938 - val_loss: 2.5345 - 97ms/epoch - 97ms/step\n",
      "Epoch 149/2000\n",
      "1/1 - 0s - loss: 2.4907 - val_loss: 2.5330 - 97ms/epoch - 97ms/step\n",
      "Epoch 150/2000\n",
      "1/1 - 0s - loss: 2.4886 - val_loss: 2.5299 - 102ms/epoch - 102ms/step\n",
      "Epoch 151/2000\n",
      "1/1 - 0s - loss: 2.4857 - val_loss: 2.5275 - 99ms/epoch - 99ms/step\n",
      "Epoch 152/2000\n",
      "1/1 - 0s - loss: 2.4830 - val_loss: 2.5249 - 96ms/epoch - 96ms/step\n",
      "Epoch 153/2000\n",
      "1/1 - 0s - loss: 2.4808 - val_loss: 2.5234 - 94ms/epoch - 94ms/step\n",
      "Epoch 154/2000\n",
      "1/1 - 0s - loss: 2.4782 - val_loss: 2.5214 - 91ms/epoch - 91ms/step\n",
      "Epoch 155/2000\n",
      "1/1 - 0s - loss: 2.4770 - val_loss: 2.5212 - 85ms/epoch - 85ms/step\n",
      "Epoch 156/2000\n",
      "1/1 - 0s - loss: 2.4747 - val_loss: 2.5177 - 85ms/epoch - 85ms/step\n",
      "Epoch 157/2000\n",
      "1/1 - 0s - loss: 2.4734 - val_loss: 2.5221 - 78ms/epoch - 78ms/step\n",
      "Epoch 158/2000\n",
      "1/1 - 0s - loss: 2.4742 - val_loss: 2.5183 - 76ms/epoch - 76ms/step\n",
      "Epoch 159/2000\n",
      "1/1 - 0s - loss: 2.4743 - val_loss: 2.5145 - 85ms/epoch - 85ms/step\n",
      "Epoch 160/2000\n",
      "1/1 - 0s - loss: 2.4675 - val_loss: 2.5065 - 84ms/epoch - 84ms/step\n",
      "Epoch 161/2000\n",
      "1/1 - 0s - loss: 2.4604 - val_loss: 2.5098 - 76ms/epoch - 76ms/step\n",
      "Epoch 162/2000\n",
      "1/1 - 0s - loss: 2.4650 - val_loss: 2.5182 - 75ms/epoch - 75ms/step\n",
      "Epoch 163/2000\n",
      "1/1 - 0s - loss: 2.4686 - val_loss: 2.5046 - 83ms/epoch - 83ms/step\n",
      "Epoch 164/2000\n",
      "1/1 - 0s - loss: 2.4592 - val_loss: 2.5024 - 83ms/epoch - 83ms/step\n",
      "Epoch 165/2000\n",
      "1/1 - 0s - loss: 2.4575 - val_loss: 2.5013 - 81ms/epoch - 81ms/step\n",
      "Epoch 166/2000\n",
      "1/1 - 0s - loss: 2.4540 - val_loss: 2.4968 - 83ms/epoch - 83ms/step\n",
      "Epoch 167/2000\n",
      "1/1 - 0s - loss: 2.4495 - val_loss: 2.4946 - 83ms/epoch - 83ms/step\n",
      "Epoch 168/2000\n",
      "1/1 - 0s - loss: 2.4498 - val_loss: 2.4917 - 97ms/epoch - 97ms/step\n",
      "Epoch 169/2000\n",
      "1/1 - 0s - loss: 2.4444 - val_loss: 2.4900 - 103ms/epoch - 103ms/step\n",
      "Epoch 170/2000\n",
      "1/1 - 0s - loss: 2.4429 - val_loss: 2.4859 - 97ms/epoch - 97ms/step\n",
      "Epoch 171/2000\n",
      "1/1 - 0s - loss: 2.4402 - val_loss: 2.4830 - 97ms/epoch - 97ms/step\n",
      "Epoch 172/2000\n",
      "1/1 - 0s - loss: 2.4367 - val_loss: 2.4838 - 89ms/epoch - 89ms/step\n",
      "Epoch 173/2000\n",
      "1/1 - 0s - loss: 2.4344 - val_loss: 2.4803 - 96ms/epoch - 96ms/step\n",
      "Epoch 174/2000\n",
      "1/1 - 0s - loss: 2.4314 - val_loss: 2.4754 - 97ms/epoch - 97ms/step\n",
      "Epoch 175/2000\n",
      "1/1 - 0s - loss: 2.4285 - val_loss: 2.4745 - 95ms/epoch - 95ms/step\n",
      "Epoch 176/2000\n",
      "1/1 - 0s - loss: 2.4263 - val_loss: 2.4714 - 94ms/epoch - 94ms/step\n",
      "Epoch 177/2000\n",
      "1/1 - 0s - loss: 2.4235 - val_loss: 2.4679 - 100ms/epoch - 100ms/step\n",
      "Epoch 178/2000\n",
      "1/1 - 0s - loss: 2.4206 - val_loss: 2.4669 - 99ms/epoch - 99ms/step\n",
      "Epoch 179/2000\n",
      "1/1 - 0s - loss: 2.4180 - val_loss: 2.4643 - 81ms/epoch - 81ms/step\n",
      "Epoch 180/2000\n",
      "1/1 - 0s - loss: 2.4146 - val_loss: 2.4617 - 84ms/epoch - 84ms/step\n",
      "Epoch 181/2000\n",
      "1/1 - 0s - loss: 2.4130 - val_loss: 2.4594 - 90ms/epoch - 90ms/step\n",
      "Epoch 182/2000\n",
      "1/1 - 0s - loss: 2.4100 - val_loss: 2.4552 - 83ms/epoch - 83ms/step\n",
      "Epoch 183/2000\n",
      "1/1 - 0s - loss: 2.4065 - val_loss: 2.4530 - 82ms/epoch - 82ms/step\n",
      "Epoch 184/2000\n",
      "1/1 - 0s - loss: 2.4043 - val_loss: 2.4527 - 81ms/epoch - 81ms/step\n",
      "Epoch 185/2000\n",
      "1/1 - 0s - loss: 2.4025 - val_loss: 2.4480 - 81ms/epoch - 81ms/step\n",
      "Epoch 186/2000\n",
      "1/1 - 0s - loss: 2.3999 - val_loss: 2.4460 - 95ms/epoch - 95ms/step\n",
      "Epoch 187/2000\n",
      "1/1 - 0s - loss: 2.3969 - val_loss: 2.4427 - 82ms/epoch - 82ms/step\n",
      "Epoch 188/2000\n",
      "1/1 - 0s - loss: 2.3940 - val_loss: 2.4404 - 86ms/epoch - 86ms/step\n",
      "Epoch 189/2000\n",
      "1/1 - 0s - loss: 2.3913 - val_loss: 2.4410 - 77ms/epoch - 77ms/step\n",
      "Epoch 190/2000\n",
      "1/1 - 0s - loss: 2.3900 - val_loss: 2.4355 - 86ms/epoch - 86ms/step\n",
      "Epoch 191/2000\n",
      "1/1 - 0s - loss: 2.3870 - val_loss: 2.4361 - 77ms/epoch - 77ms/step\n",
      "Epoch 192/2000\n",
      "1/1 - 0s - loss: 2.3855 - val_loss: 2.4305 - 89ms/epoch - 89ms/step\n",
      "Epoch 193/2000\n",
      "1/1 - 0s - loss: 2.3824 - val_loss: 2.4314 - 85ms/epoch - 85ms/step\n",
      "Epoch 194/2000\n",
      "1/1 - 0s - loss: 2.3805 - val_loss: 2.4264 - 85ms/epoch - 85ms/step\n",
      "Epoch 195/2000\n",
      "1/1 - 0s - loss: 2.3771 - val_loss: 2.4241 - 84ms/epoch - 84ms/step\n",
      "Epoch 196/2000\n",
      "1/1 - 0s - loss: 2.3736 - val_loss: 2.4202 - 82ms/epoch - 82ms/step\n",
      "Epoch 197/2000\n",
      "1/1 - 0s - loss: 2.3702 - val_loss: 2.4181 - 83ms/epoch - 83ms/step\n",
      "Epoch 198/2000\n",
      "1/1 - 0s - loss: 2.3681 - val_loss: 2.4177 - 86ms/epoch - 86ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199/2000\n",
      "1/1 - 0s - loss: 2.3660 - val_loss: 2.4139 - 83ms/epoch - 83ms/step\n",
      "Epoch 200/2000\n",
      "1/1 - 0s - loss: 2.3633 - val_loss: 2.4145 - 75ms/epoch - 75ms/step\n",
      "Epoch 201/2000\n",
      "1/1 - 0s - loss: 2.3625 - val_loss: 2.4145 - 77ms/epoch - 77ms/step\n",
      "Epoch 202/2000\n",
      "1/1 - 0s - loss: 2.3671 - val_loss: 2.4329 - 76ms/epoch - 76ms/step\n",
      "Epoch 203/2000\n",
      "1/1 - 0s - loss: 2.3777 - val_loss: 2.4325 - 76ms/epoch - 76ms/step\n",
      "Epoch 204/2000\n",
      "1/1 - 0s - loss: 2.3853 - val_loss: 2.4368 - 97ms/epoch - 97ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 10\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "encoder_output = word_embeddings\n",
    "\n",
    "for i in range(5):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "aa33aaff",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fe4d2103",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0, 0.0171412)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0005b730",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 4s - loss: 3.8087 - val_loss: 3.2956 - 4s/epoch - 4s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.2599 - val_loss: 3.1145 - 306ms/epoch - 306ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.0763 - val_loss: 3.0461 - 282ms/epoch - 282ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.0114 - val_loss: 2.9794 - 286ms/epoch - 286ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 2.9468 - val_loss: 2.9573 - 264ms/epoch - 264ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 2.9252 - val_loss: 2.9383 - 272ms/epoch - 272ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 2.9075 - val_loss: 2.9163 - 275ms/epoch - 275ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 2.8870 - val_loss: 2.9022 - 274ms/epoch - 274ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 2.8732 - val_loss: 2.8965 - 254ms/epoch - 254ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 2.8666 - val_loss: 2.8921 - 247ms/epoch - 247ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 2.8602 - val_loss: 2.8821 - 265ms/epoch - 265ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 2.8474 - val_loss: 2.8655 - 277ms/epoch - 277ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 2.8273 - val_loss: 2.8453 - 291ms/epoch - 291ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 2.8026 - val_loss: 2.8237 - 285ms/epoch - 285ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 2.7754 - val_loss: 2.8021 - 267ms/epoch - 267ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 2.7467 - val_loss: 2.7813 - 265ms/epoch - 265ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 2.7173 - val_loss: 2.7613 - 263ms/epoch - 263ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 2.6884 - val_loss: 2.7443 - 246ms/epoch - 246ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 2.6630 - val_loss: 2.7313 - 292ms/epoch - 292ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 2.6433 - val_loss: 2.7167 - 265ms/epoch - 265ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 2.6244 - val_loss: 2.6977 - 279ms/epoch - 279ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 2.6028 - val_loss: 2.6770 - 308ms/epoch - 308ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 2.5805 - val_loss: 2.6573 - 263ms/epoch - 263ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 2.5601 - val_loss: 2.6405 - 316ms/epoch - 316ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 2.5429 - val_loss: 2.6251 - 259ms/epoch - 259ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 2.5264 - val_loss: 2.6092 - 248ms/epoch - 248ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 2.5087 - val_loss: 2.5943 - 258ms/epoch - 258ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.4912 - val_loss: 2.5812 - 248ms/epoch - 248ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.4758 - val_loss: 2.5691 - 275ms/epoch - 275ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.4615 - val_loss: 2.5573 - 307ms/epoch - 307ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.4478 - val_loss: 2.5462 - 279ms/epoch - 279ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.4347 - val_loss: 2.5364 - 271ms/epoch - 271ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.4222 - val_loss: 2.5287 - 267ms/epoch - 267ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.4107 - val_loss: 2.5218 - 274ms/epoch - 274ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.3997 - val_loss: 2.5151 - 278ms/epoch - 278ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.3885 - val_loss: 2.5091 - 279ms/epoch - 279ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.3778 - val_loss: 2.5039 - 273ms/epoch - 273ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.3675 - val_loss: 2.4984 - 301ms/epoch - 301ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.3565 - val_loss: 2.4927 - 261ms/epoch - 261ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.3449 - val_loss: 2.4867 - 264ms/epoch - 264ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.3328 - val_loss: 2.4802 - 253ms/epoch - 253ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.3201 - val_loss: 2.4732 - 256ms/epoch - 256ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.3066 - val_loss: 2.4658 - 247ms/epoch - 247ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.2924 - val_loss: 2.4575 - 253ms/epoch - 253ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.2772 - val_loss: 2.4475 - 249ms/epoch - 249ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.2605 - val_loss: 2.4358 - 249ms/epoch - 249ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.2426 - val_loss: 2.4230 - 255ms/epoch - 255ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.2236 - val_loss: 2.4095 - 262ms/epoch - 262ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.2039 - val_loss: 2.3947 - 254ms/epoch - 254ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.1833 - val_loss: 2.3792 - 249ms/epoch - 249ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.1621 - val_loss: 2.3652 - 251ms/epoch - 251ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.1425 - val_loss: 2.3647 - 267ms/epoch - 267ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.1386 - val_loss: 2.3810 - 249ms/epoch - 249ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.1505 - val_loss: 2.3532 - 258ms/epoch - 258ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.1155 - val_loss: 2.3167 - 254ms/epoch - 254ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.0732 - val_loss: 2.3242 - 250ms/epoch - 250ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.0759 - val_loss: 2.2921 - 256ms/epoch - 256ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.0378 - val_loss: 2.2992 - 244ms/epoch - 244ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.0372 - val_loss: 2.2775 - 254ms/epoch - 254ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.0081 - val_loss: 2.2781 - 235ms/epoch - 235ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.0031 - val_loss: 2.2595 - 321ms/epoch - 321ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 1.9777 - val_loss: 2.2649 - 356ms/epoch - 356ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 1.9752 - val_loss: 2.2516 - 266ms/epoch - 266ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 1.9530 - val_loss: 2.2541 - 261ms/epoch - 261ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 1.9482 - val_loss: 2.2433 - 255ms/epoch - 255ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 1.9317 - val_loss: 2.2427 - 272ms/epoch - 272ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 1.9245 - val_loss: 2.2421 - 257ms/epoch - 257ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 1.9133 - val_loss: 2.2403 - 256ms/epoch - 256ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 1.9042 - val_loss: 2.2389 - 275ms/epoch - 275ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 1.8962 - val_loss: 2.2421 - 267ms/epoch - 267ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 1.8891 - val_loss: 2.2473 - 250ms/epoch - 250ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 1.8811 - val_loss: 2.2506 - 241ms/epoch - 241ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 1.8749 - val_loss: 2.2489 - 261ms/epoch - 261ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 1.8693 - val_loss: 2.2540 - 279ms/epoch - 279ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 100\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "encoder_output = word_embeddings\n",
    "\n",
    "for i in range(5):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "beb07cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "dbdd4a69",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0, 0.0019487024)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "54f31bac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[10,  0,  1, 10],\n",
       "       [10,  0,  2, 10],\n",
       "       [10,  0,  3, 10],\n",
       "       ...,\n",
       "       [19, 18, 15, 19],\n",
       "       [19, 18, 16, 19],\n",
       "       [19, 18, 17, 19]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_masked_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "458ae205",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
