{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d1686d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "56c107e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "504a342e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_module(query, key, value, embed_dim, num_head, i):\n",
    "    \n",
    "    # Multi headed self-attention\n",
    "    attention_output = layers.MultiHeadAttention(\n",
    "        num_heads=num_head,\n",
    "        key_dim=embed_dim // num_head,\n",
    "        name=\"encoder_{}/multiheadattention\".format(i)\n",
    "    )(query, key, value, use_causal_mask=True)\n",
    "    \n",
    "    # Add & Normalize\n",
    "    attention_output = layers.Add()([query, attention_output])  # Skip Connection\n",
    "    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)\n",
    "    \n",
    "    #Feedforward network\n",
    "    ff_net = keras.models.Sequential([\n",
    "        layers.Dense(2 * embed_dim, activation='relu', name=\"encoder_{}/ffn_dense_1\".format(i)),\n",
    "        layers.Dense(embed_dim, name=\"encoder_{}/ffn_dense_2\".format(i)),\n",
    "    ])\n",
    "\n",
    "    # Apply Feedforward network\n",
    "    ffn_output = ff_net(attention_output)\n",
    "\n",
    "    # Add & Normalize\n",
    "    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection\n",
    "    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)\n",
    "    \n",
    "    return ffn_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "fef2622a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sinusoidal_embeddings(sequence_length, embedding_dim):\n",
    "    position_enc = np.array([\n",
    "        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]\n",
    "        if pos != 0 else np.zeros(embedding_dim)\n",
    "        for pos in range(sequence_length)\n",
    "    ])\n",
    "    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i\n",
    "    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1\n",
    "    return tf.cast(position_enc, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "3e5bfd69",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 20 # vocab_size\n",
    "\n",
    "vocabs = ['word_' + str(i) for i in range(N)]\n",
    "\n",
    "vocab_map = {}\n",
    "for i in range(len(vocabs)):\n",
    "    vocab_map[vocabs[i]] = i\n",
    "    \n",
    "pairs = []\n",
    "\n",
    "for i in vocabs:\n",
    "    for j in vocabs:\n",
    "        for k in vocabs:\n",
    "            if i != j and i != k and j != k:\n",
    "                pairs.append((i,j,k))\n",
    "            \n",
    "indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])\n",
    "\n",
    "pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]\n",
    "pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "d43093fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences_train = []\n",
    "sentences_number_train = []\n",
    "sentences_test = []\n",
    "sentences_number_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    sentences_train.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    sentences_test.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = []\n",
    "y_masked_labels_train = []\n",
    "x_masked_test = []\n",
    "y_masked_labels_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    x_masked_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_train.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    x_masked_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_test.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = np.array(x_masked_train)\n",
    "y_masked_labels_train = np.array(y_masked_labels_train)\n",
    "x_masked_test = np.array(x_masked_test)\n",
    "y_masked_labels_test = np.array(y_masked_labels_test)\n",
    "\n",
    "perm = np.random.permutation(len(x_masked_train))\n",
    "x_masked_train = x_masked_train[perm]\n",
    "y_masked_labels_train = y_masked_labels_train[perm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "13b40f89",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 1s - loss: 3.2522 - val_loss: 3.2352 - 1s/epoch - 1s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.2310 - val_loss: 3.2201 - 37ms/epoch - 37ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.2127 - val_loss: 3.2063 - 36ms/epoch - 36ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.1960 - val_loss: 3.1936 - 38ms/epoch - 38ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.1808 - val_loss: 3.1818 - 40ms/epoch - 40ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.1669 - val_loss: 3.1708 - 40ms/epoch - 40ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.1541 - val_loss: 3.1606 - 37ms/epoch - 37ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.1423 - val_loss: 3.1512 - 37ms/epoch - 37ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.1314 - val_loss: 3.1425 - 37ms/epoch - 37ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 3.1211 - val_loss: 3.1343 - 39ms/epoch - 39ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 3.1116 - val_loss: 3.1267 - 39ms/epoch - 39ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 3.1027 - val_loss: 3.1197 - 38ms/epoch - 38ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 3.0945 - val_loss: 3.1133 - 36ms/epoch - 36ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 3.0869 - val_loss: 3.1073 - 36ms/epoch - 36ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 3.0799 - val_loss: 3.1017 - 39ms/epoch - 39ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 3.0733 - val_loss: 3.0965 - 40ms/epoch - 40ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 3.0672 - val_loss: 3.0916 - 38ms/epoch - 38ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 3.0616 - val_loss: 3.0869 - 36ms/epoch - 36ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 3.0563 - val_loss: 3.0825 - 36ms/epoch - 36ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 3.0515 - val_loss: 3.0783 - 39ms/epoch - 39ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 3.0469 - val_loss: 3.0743 - 39ms/epoch - 39ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 3.0428 - val_loss: 3.0706 - 38ms/epoch - 38ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 3.0389 - val_loss: 3.0671 - 36ms/epoch - 36ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 3.0353 - val_loss: 3.0638 - 36ms/epoch - 36ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 3.0320 - val_loss: 3.0608 - 38ms/epoch - 38ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 3.0289 - val_loss: 3.0578 - 39ms/epoch - 39ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 3.0259 - val_loss: 3.0551 - 39ms/epoch - 39ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 3.0231 - val_loss: 3.0524 - 36ms/epoch - 36ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 3.0204 - val_loss: 3.0499 - 36ms/epoch - 36ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 3.0178 - val_loss: 3.0474 - 38ms/epoch - 38ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 3.0152 - val_loss: 3.0451 - 39ms/epoch - 39ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 3.0128 - val_loss: 3.0428 - 39ms/epoch - 39ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 3.0104 - val_loss: 3.0406 - 37ms/epoch - 37ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 3.0080 - val_loss: 3.0384 - 36ms/epoch - 36ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 3.0057 - val_loss: 3.0363 - 36ms/epoch - 36ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 3.0035 - val_loss: 3.0343 - 39ms/epoch - 39ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 3.0013 - val_loss: 3.0323 - 41ms/epoch - 41ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.9992 - val_loss: 3.0304 - 37ms/epoch - 37ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.9971 - val_loss: 3.0286 - 36ms/epoch - 36ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.9950 - val_loss: 3.0268 - 36ms/epoch - 36ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.9930 - val_loss: 3.0251 - 39ms/epoch - 39ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.9910 - val_loss: 3.0235 - 39ms/epoch - 39ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.9891 - val_loss: 3.0219 - 37ms/epoch - 37ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.9873 - val_loss: 3.0204 - 36ms/epoch - 36ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.9854 - val_loss: 3.0189 - 36ms/epoch - 36ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.9836 - val_loss: 3.0175 - 39ms/epoch - 39ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.9819 - val_loss: 3.0162 - 39ms/epoch - 39ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.9801 - val_loss: 3.0150 - 38ms/epoch - 38ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.9784 - val_loss: 3.0138 - 37ms/epoch - 37ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.9767 - val_loss: 3.0127 - 37ms/epoch - 37ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.9751 - val_loss: 3.0115 - 39ms/epoch - 39ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.9734 - val_loss: 3.0103 - 40ms/epoch - 40ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.9719 - val_loss: 3.0091 - 39ms/epoch - 39ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.9703 - val_loss: 3.0079 - 36ms/epoch - 36ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.9687 - val_loss: 3.0067 - 36ms/epoch - 36ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.9672 - val_loss: 3.0055 - 38ms/epoch - 38ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.9657 - val_loss: 3.0043 - 40ms/epoch - 40ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.9642 - val_loss: 3.0030 - 38ms/epoch - 38ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.9627 - val_loss: 3.0017 - 37ms/epoch - 37ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.9612 - val_loss: 3.0004 - 36ms/epoch - 36ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.9597 - val_loss: 2.9991 - 38ms/epoch - 38ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.9583 - val_loss: 2.9978 - 40ms/epoch - 40ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.9569 - val_loss: 2.9965 - 39ms/epoch - 39ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.9554 - val_loss: 2.9953 - 37ms/epoch - 37ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.9540 - val_loss: 2.9941 - 36ms/epoch - 36ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.9526 - val_loss: 2.9928 - 37ms/epoch - 37ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.9513 - val_loss: 2.9916 - 40ms/epoch - 40ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.9499 - val_loss: 2.9903 - 39ms/epoch - 39ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.9486 - val_loss: 2.9890 - 37ms/epoch - 37ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.9472 - val_loss: 2.9878 - 36ms/epoch - 36ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.9459 - val_loss: 2.9867 - 37ms/epoch - 37ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.9445 - val_loss: 2.9856 - 39ms/epoch - 39ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.9432 - val_loss: 2.9845 - 41ms/epoch - 41ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.9419 - val_loss: 2.9833 - 37ms/epoch - 37ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.9406 - val_loss: 2.9822 - 36ms/epoch - 36ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.9392 - val_loss: 2.9810 - 36ms/epoch - 36ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.9379 - val_loss: 2.9799 - 41ms/epoch - 41ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.9366 - val_loss: 2.9789 - 39ms/epoch - 39ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.9353 - val_loss: 2.9778 - 37ms/epoch - 37ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.9340 - val_loss: 2.9766 - 36ms/epoch - 36ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.9327 - val_loss: 2.9755 - 36ms/epoch - 36ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.9314 - val_loss: 2.9745 - 39ms/epoch - 39ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.9302 - val_loss: 2.9736 - 40ms/epoch - 40ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.9289 - val_loss: 2.9727 - 38ms/epoch - 38ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.9277 - val_loss: 2.9718 - 36ms/epoch - 36ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.9264 - val_loss: 2.9708 - 37ms/epoch - 37ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.9252 - val_loss: 2.9699 - 39ms/epoch - 39ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.9240 - val_loss: 2.9691 - 39ms/epoch - 39ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.9228 - val_loss: 2.9683 - 38ms/epoch - 38ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.9216 - val_loss: 2.9675 - 37ms/epoch - 37ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.9204 - val_loss: 2.9666 - 36ms/epoch - 36ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.9192 - val_loss: 2.9658 - 39ms/epoch - 39ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.9180 - val_loss: 2.9649 - 39ms/epoch - 39ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.9168 - val_loss: 2.9641 - 39ms/epoch - 39ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.9157 - val_loss: 2.9633 - 36ms/epoch - 36ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.9145 - val_loss: 2.9625 - 37ms/epoch - 37ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 2.9134 - val_loss: 2.9618 - 38ms/epoch - 38ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 2.9123 - val_loss: 2.9611 - 39ms/epoch - 39ms/step\n",
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 2.9112 - val_loss: 2.9604 - 39ms/epoch - 39ms/step\n",
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 2.9101 - val_loss: 2.9597 - 36ms/epoch - 36ms/step\n",
      "Epoch 101/2000\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 - 0s - loss: 2.9090 - val_loss: 2.9591 - 36ms/epoch - 36ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 2.9079 - val_loss: 2.9584 - 37ms/epoch - 37ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 2.9068 - val_loss: 2.9576 - 39ms/epoch - 39ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 2.9057 - val_loss: 2.9568 - 41ms/epoch - 41ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 2.9046 - val_loss: 2.9561 - 37ms/epoch - 37ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 2.9035 - val_loss: 2.9554 - 36ms/epoch - 36ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 2.9025 - val_loss: 2.9546 - 37ms/epoch - 37ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 2.9014 - val_loss: 2.9539 - 41ms/epoch - 41ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 2.9004 - val_loss: 2.9532 - 39ms/epoch - 39ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 2.8993 - val_loss: 2.9526 - 38ms/epoch - 38ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 2.8983 - val_loss: 2.9519 - 37ms/epoch - 37ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 2.8973 - val_loss: 2.9512 - 36ms/epoch - 36ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 2.8963 - val_loss: 2.9506 - 39ms/epoch - 39ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 2.8954 - val_loss: 2.9500 - 40ms/epoch - 40ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 2.8944 - val_loss: 2.9493 - 38ms/epoch - 38ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 2.8935 - val_loss: 2.9486 - 36ms/epoch - 36ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 2.8925 - val_loss: 2.9480 - 36ms/epoch - 36ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 2.8916 - val_loss: 2.9474 - 39ms/epoch - 39ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 2.8907 - val_loss: 2.9469 - 39ms/epoch - 39ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 2.8898 - val_loss: 2.9465 - 39ms/epoch - 39ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 2.8889 - val_loss: 2.9461 - 36ms/epoch - 36ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 2.8881 - val_loss: 2.9457 - 36ms/epoch - 36ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 2.8872 - val_loss: 2.9452 - 37ms/epoch - 37ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 2.8864 - val_loss: 2.9447 - 39ms/epoch - 39ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 2.8855 - val_loss: 2.9442 - 39ms/epoch - 39ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 2.8847 - val_loss: 2.9437 - 36ms/epoch - 36ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 2.8840 - val_loss: 2.9432 - 36ms/epoch - 36ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 2.8832 - val_loss: 2.9428 - 37ms/epoch - 37ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 2.8824 - val_loss: 2.9424 - 39ms/epoch - 39ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 2.8817 - val_loss: 2.9421 - 39ms/epoch - 39ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 2.8810 - val_loss: 2.9417 - 37ms/epoch - 37ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 2.8803 - val_loss: 2.9412 - 36ms/epoch - 36ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 2.8797 - val_loss: 2.9408 - 37ms/epoch - 37ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 2.8790 - val_loss: 2.9404 - 40ms/epoch - 40ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 2.8784 - val_loss: 2.9398 - 39ms/epoch - 39ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 2.8778 - val_loss: 2.9400 - 35ms/epoch - 35ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 2.8772 - val_loss: 2.9390 - 38ms/epoch - 38ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 2.8767 - val_loss: 2.9391 - 34ms/epoch - 34ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 2.8759 - val_loss: 2.9386 - 39ms/epoch - 39ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 2.8752 - val_loss: 2.9379 - 39ms/epoch - 39ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 2.8748 - val_loss: 2.9381 - 36ms/epoch - 36ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 2.8741 - val_loss: 2.9376 - 36ms/epoch - 36ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 2.8733 - val_loss: 2.9371 - 36ms/epoch - 36ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 2.8728 - val_loss: 2.9374 - 36ms/epoch - 36ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 2.8722 - val_loss: 2.9369 - 39ms/epoch - 39ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 2.8714 - val_loss: 2.9364 - 39ms/epoch - 39ms/step\n",
      "Epoch 147/2000\n",
      "1/1 - 0s - loss: 2.8709 - val_loss: 2.9367 - 34ms/epoch - 34ms/step\n",
      "Epoch 148/2000\n",
      "1/1 - 0s - loss: 2.8703 - val_loss: 2.9362 - 36ms/epoch - 36ms/step\n",
      "Epoch 149/2000\n",
      "1/1 - 0s - loss: 2.8696 - val_loss: 2.9358 - 37ms/epoch - 37ms/step\n",
      "Epoch 150/2000\n",
      "1/1 - 0s - loss: 2.8690 - val_loss: 2.9360 - 37ms/epoch - 37ms/step\n",
      "Epoch 151/2000\n",
      "1/1 - 0s - loss: 2.8684 - val_loss: 2.9354 - 39ms/epoch - 39ms/step\n",
      "Epoch 152/2000\n",
      "1/1 - 0s - loss: 2.8678 - val_loss: 2.9352 - 37ms/epoch - 37ms/step\n",
      "Epoch 153/2000\n",
      "1/1 - 0s - loss: 2.8672 - val_loss: 2.9353 - 34ms/epoch - 34ms/step\n",
      "Epoch 154/2000\n",
      "1/1 - 0s - loss: 2.8666 - val_loss: 2.9349 - 36ms/epoch - 36ms/step\n",
      "Epoch 155/2000\n",
      "1/1 - 0s - loss: 2.8660 - val_loss: 2.9348 - 39ms/epoch - 39ms/step\n",
      "Epoch 156/2000\n",
      "1/1 - 0s - loss: 2.8654 - val_loss: 2.9348 - 39ms/epoch - 39ms/step\n",
      "Epoch 157/2000\n",
      "1/1 - 0s - loss: 2.8648 - val_loss: 2.9345 - 38ms/epoch - 38ms/step\n",
      "Epoch 158/2000\n",
      "1/1 - 0s - loss: 2.8643 - val_loss: 2.9348 - 34ms/epoch - 34ms/step\n",
      "Epoch 159/2000\n",
      "1/1 - 0s - loss: 2.8638 - val_loss: 2.9346 - 34ms/epoch - 34ms/step\n",
      "Epoch 160/2000\n",
      "1/1 - 0s - loss: 2.8633 - val_loss: 2.9347 - 36ms/epoch - 36ms/step\n",
      "Epoch 161/2000\n",
      "1/1 - 0s - loss: 2.8627 - val_loss: 2.9341 - 39ms/epoch - 39ms/step\n",
      "Epoch 162/2000\n",
      "1/1 - 0s - loss: 2.8622 - val_loss: 2.9343 - 38ms/epoch - 38ms/step\n",
      "Epoch 163/2000\n",
      "1/1 - 0s - loss: 2.8616 - val_loss: 2.9339 - 36ms/epoch - 36ms/step\n",
      "Epoch 164/2000\n",
      "1/1 - 0s - loss: 2.8611 - val_loss: 2.9339 - 34ms/epoch - 34ms/step\n",
      "Epoch 165/2000\n",
      "1/1 - 0s - loss: 2.8606 - val_loss: 2.9337 - 37ms/epoch - 37ms/step\n",
      "Epoch 166/2000\n",
      "1/1 - 0s - loss: 2.8601 - val_loss: 2.9334 - 40ms/epoch - 40ms/step\n",
      "Epoch 167/2000\n",
      "1/1 - 0s - loss: 2.8597 - val_loss: 2.9338 - 37ms/epoch - 37ms/step\n",
      "Epoch 168/2000\n",
      "1/1 - 0s - loss: 2.8592 - val_loss: 2.9331 - 37ms/epoch - 37ms/step\n",
      "Epoch 169/2000\n",
      "1/1 - 0s - loss: 2.8587 - val_loss: 2.9338 - 34ms/epoch - 34ms/step\n",
      "Epoch 170/2000\n",
      "1/1 - 0s - loss: 2.8583 - val_loss: 2.9330 - 36ms/epoch - 36ms/step\n",
      "Epoch 171/2000\n",
      "1/1 - 0s - loss: 2.8579 - val_loss: 2.9341 - 37ms/epoch - 37ms/step\n",
      "Epoch 172/2000\n",
      "1/1 - 0s - loss: 2.8577 - val_loss: 2.9331 - 39ms/epoch - 39ms/step\n",
      "Epoch 173/2000\n",
      "1/1 - 0s - loss: 2.8571 - val_loss: 2.9331 - 38ms/epoch - 38ms/step\n",
      "Epoch 174/2000\n",
      "1/1 - 0s - loss: 2.8565 - val_loss: 2.9332 - 34ms/epoch - 34ms/step\n",
      "Epoch 175/2000\n",
      "1/1 - 0s - loss: 2.8558 - val_loss: 2.9321 - 36ms/epoch - 36ms/step\n",
      "Epoch 176/2000\n",
      "1/1 - 0s - loss: 2.8555 - val_loss: 2.9337 - 36ms/epoch - 36ms/step\n",
      "Epoch 177/2000\n",
      "1/1 - 0s - loss: 2.8553 - val_loss: 2.9326 - 37ms/epoch - 37ms/step\n",
      "Epoch 178/2000\n",
      "1/1 - 0s - loss: 2.8546 - val_loss: 2.9323 - 37ms/epoch - 37ms/step\n",
      "Epoch 179/2000\n",
      "1/1 - 0s - loss: 2.8540 - val_loss: 2.9333 - 35ms/epoch - 35ms/step\n",
      "Epoch 180/2000\n",
      "1/1 - 0s - loss: 2.8537 - val_loss: 2.9317 - 36ms/epoch - 36ms/step\n",
      "Epoch 181/2000\n",
      "1/1 - 0s - loss: 2.8533 - val_loss: 2.9328 - 35ms/epoch - 35ms/step\n",
      "Epoch 182/2000\n",
      "1/1 - 0s - loss: 2.8527 - val_loss: 2.9318 - 37ms/epoch - 37ms/step\n",
      "Epoch 183/2000\n",
      "1/1 - 0s - loss: 2.8520 - val_loss: 2.9312 - 39ms/epoch - 39ms/step\n",
      "Epoch 184/2000\n",
      "1/1 - 0s - loss: 2.8517 - val_loss: 2.9326 - 36ms/epoch - 36ms/step\n",
      "Epoch 185/2000\n",
      "1/1 - 0s - loss: 2.8514 - val_loss: 2.9310 - 37ms/epoch - 37ms/step\n",
      "Epoch 186/2000\n",
      "1/1 - 0s - loss: 2.8507 - val_loss: 2.9310 - 36ms/epoch - 36ms/step\n",
      "Epoch 187/2000\n",
      "1/1 - 0s - loss: 2.8503 - val_loss: 2.9320 - 37ms/epoch - 37ms/step\n",
      "Epoch 188/2000\n",
      "1/1 - 0s - loss: 2.8500 - val_loss: 2.9309 - 39ms/epoch - 39ms/step\n",
      "Epoch 189/2000\n",
      "1/1 - 0s - loss: 2.8495 - val_loss: 2.9309 - 38ms/epoch - 38ms/step\n",
      "Epoch 190/2000\n",
      "1/1 - 0s - loss: 2.8489 - val_loss: 2.9313 - 34ms/epoch - 34ms/step\n",
      "Epoch 191/2000\n",
      "1/1 - 0s - loss: 2.8485 - val_loss: 2.9302 - 36ms/epoch - 36ms/step\n",
      "Epoch 192/2000\n",
      "1/1 - 0s - loss: 2.8482 - val_loss: 2.9315 - 36ms/epoch - 36ms/step\n",
      "Epoch 193/2000\n",
      "1/1 - 0s - loss: 2.8479 - val_loss: 2.9304 - 37ms/epoch - 37ms/step\n",
      "Epoch 194/2000\n",
      "1/1 - 0s - loss: 2.8475 - val_loss: 2.9306 - 37ms/epoch - 37ms/step\n",
      "Epoch 195/2000\n",
      "1/1 - 0s - loss: 2.8474 - val_loss: 2.9311 - 35ms/epoch - 35ms/step\n",
      "Epoch 196/2000\n",
      "1/1 - 0s - loss: 2.8476 - val_loss: 2.9304 - 37ms/epoch - 37ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 10\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "encoder_output = word_embeddings\n",
    "\n",
    "for i in range(1):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "aa33aaff",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "fe4d2103",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.295, 0.097396016)"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "0005b730",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 1s - loss: 3.6705 - val_loss: 3.4535 - 1s/epoch - 1s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.4047 - val_loss: 3.3180 - 80ms/epoch - 80ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.2608 - val_loss: 3.2204 - 76ms/epoch - 76ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.1572 - val_loss: 3.1491 - 72ms/epoch - 72ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.0823 - val_loss: 3.0953 - 79ms/epoch - 79ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.0266 - val_loss: 3.0547 - 76ms/epoch - 76ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 2.9849 - val_loss: 3.0265 - 74ms/epoch - 74ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 2.9558 - val_loss: 3.0070 - 74ms/epoch - 74ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 2.9353 - val_loss: 2.9918 - 82ms/epoch - 82ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 2.9187 - val_loss: 2.9788 - 73ms/epoch - 73ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 2.9038 - val_loss: 2.9677 - 75ms/epoch - 75ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 2.8905 - val_loss: 2.9585 - 75ms/epoch - 75ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 2.8789 - val_loss: 2.9506 - 70ms/epoch - 70ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 2.8686 - val_loss: 2.9429 - 74ms/epoch - 74ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 2.8587 - val_loss: 2.9352 - 75ms/epoch - 75ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 2.8488 - val_loss: 2.9276 - 72ms/epoch - 72ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 2.8390 - val_loss: 2.9201 - 77ms/epoch - 77ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 2.8293 - val_loss: 2.9124 - 73ms/epoch - 73ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 2.8193 - val_loss: 2.9042 - 70ms/epoch - 70ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 2.8087 - val_loss: 2.8960 - 78ms/epoch - 78ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 2.7979 - val_loss: 2.8883 - 73ms/epoch - 73ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 2.7873 - val_loss: 2.8812 - 78ms/epoch - 78ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 2.7772 - val_loss: 2.8742 - 76ms/epoch - 76ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 2.7672 - val_loss: 2.8671 - 78ms/epoch - 78ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 2.7572 - val_loss: 2.8601 - 69ms/epoch - 69ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 2.7474 - val_loss: 2.8532 - 80ms/epoch - 80ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 2.7380 - val_loss: 2.8463 - 74ms/epoch - 74ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.7291 - val_loss: 2.8396 - 71ms/epoch - 71ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.7204 - val_loss: 2.8333 - 76ms/epoch - 76ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.7122 - val_loss: 2.8275 - 74ms/epoch - 74ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.7045 - val_loss: 2.8223 - 72ms/epoch - 72ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.6971 - val_loss: 2.8177 - 77ms/epoch - 77ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.6900 - val_loss: 2.8137 - 73ms/epoch - 73ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.6834 - val_loss: 2.8100 - 74ms/epoch - 74ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.6771 - val_loss: 2.8065 - 76ms/epoch - 76ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.6711 - val_loss: 2.8033 - 72ms/epoch - 72ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.6655 - val_loss: 2.8006 - 73ms/epoch - 73ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.6603 - val_loss: 2.7983 - 76ms/epoch - 76ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.6553 - val_loss: 2.7965 - 71ms/epoch - 71ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.6505 - val_loss: 2.7950 - 73ms/epoch - 73ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.6458 - val_loss: 2.7939 - 78ms/epoch - 78ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.6413 - val_loss: 2.7928 - 72ms/epoch - 72ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.6369 - val_loss: 2.7917 - 71ms/epoch - 71ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.6326 - val_loss: 2.7905 - 76ms/epoch - 76ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.6284 - val_loss: 2.7893 - 73ms/epoch - 73ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.6244 - val_loss: 2.7883 - 76ms/epoch - 76ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.6207 - val_loss: 2.7876 - 77ms/epoch - 77ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.6170 - val_loss: 2.7870 - 73ms/epoch - 73ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.6134 - val_loss: 2.7865 - 74ms/epoch - 74ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.6099 - val_loss: 2.7861 - 77ms/epoch - 77ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.6066 - val_loss: 2.7856 - 72ms/epoch - 72ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.6033 - val_loss: 2.7852 - 76ms/epoch - 76ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.6001 - val_loss: 2.7850 - 79ms/epoch - 79ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.5970 - val_loss: 2.7849 - 72ms/epoch - 72ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.5940 - val_loss: 2.7850 - 72ms/epoch - 72ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.5911 - val_loss: 2.7853 - 74ms/epoch - 74ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.5882 - val_loss: 2.7857 - 70ms/epoch - 70ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.5853 - val_loss: 2.7859 - 72ms/epoch - 72ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.5824 - val_loss: 2.7859 - 76ms/epoch - 76ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 100\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "encoder_output = word_embeddings\n",
    "\n",
    "for i in range(1):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "96533a42",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0, 10,  2,  0],\n",
       "       [18, 16, 19, 18],\n",
       "       [16, 14, 13, 16],\n",
       "       ...,\n",
       "       [ 1, 11,  0,  1],\n",
       "       [18, 11, 16, 18],\n",
       "       [ 9,  1, 19,  9]])"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_masked_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "d6f921e3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[10,  2,  0],\n",
       "       [16, 19, 18],\n",
       "       [14, 13, 16],\n",
       "       ...,\n",
       "       [11,  0,  1],\n",
       "       [11, 16, 18],\n",
       "       [ 1, 19,  9]])"
      ]
     },
     "execution_count": 94,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_masked_labels_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "beb07cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "4ef61e71",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.481, 0.17966855)"
      ]
     },
     "execution_count": 96,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "9fbf1e0b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[15,  0, 13, 15],\n",
       "       [17, 11,  9, 17],\n",
       "       [10,  0,  3, 10],\n",
       "       ...,\n",
       "       [12,  2,  7, 12],\n",
       "       [ 9, 16,  3,  9],\n",
       "       [12,  6, 18, 12]])"
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_test_subset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "331a0d43",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.18322888, 0.04183327, 0.00042343, 0.04988674, 0.02165989,\n",
       "        0.23305799, 0.051962  , 0.03700745, 0.05091651, 0.0318071 ,\n",
       "        0.02159526, 0.02821101, 0.02173642, 0.02830644, 0.02710385,\n",
       "        0.05953328, 0.02446126, 0.02872639, 0.04173755, 0.01680526]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array([5, 0, 2, 5]).reshape(1,len(sentence_number)))\n",
    "temp = temp[:,-1,:]\n",
    "temp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "399e0aa1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.18322888, 0.04183327, 0.00042343, 0.04988674, 0.02165989,\n",
       "        0.23305799, 0.051962  , 0.03700745, 0.05091651, 0.0318071 ,\n",
       "        0.02159526, 0.02821101, 0.02173642, 0.02830644, 0.02710385,\n",
       "        0.05953328, 0.02446126, 0.02872639, 0.04173755, 0.01680526]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array([0, 5, 2, 0]).reshape(1,len(sentence_number)))\n",
    "temp = temp[:,-1,:]\n",
    "temp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aea33c5b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "841310c9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b41c8a01",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
