{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d1686d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "56c107e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "504a342e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_module(query, key, value, embed_dim, num_head, i):\n",
    "    \n",
    "    # Multi headed self-attention\n",
    "    attention_output = layers.MultiHeadAttention(\n",
    "        num_heads=num_head,\n",
    "        key_dim=embed_dim // num_head,\n",
    "        name=\"encoder_{}/multiheadattention\".format(i)\n",
    "    )(query, key, value, use_causal_mask=True)\n",
    "    \n",
    "    # Add & Normalize\n",
    "    attention_output = layers.Add()([query, attention_output])  # Skip Connection\n",
    "    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)\n",
    "    \n",
    "    #Feedforward network\n",
    "    ff_net = keras.models.Sequential([\n",
    "        layers.Dense(2 * embed_dim, activation='relu', name=\"encoder_{}/ffn_dense_1\".format(i)),\n",
    "        layers.Dense(embed_dim, name=\"encoder_{}/ffn_dense_2\".format(i)),\n",
    "    ])\n",
    "\n",
    "    # Apply Feedforward network\n",
    "    ffn_output = ff_net(attention_output)\n",
    "\n",
    "    # Add & Normalize\n",
    "    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection\n",
    "    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)\n",
    "    \n",
    "    return ffn_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fef2622a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sinusoidal_embeddings(sequence_length, embedding_dim):\n",
    "    position_enc = np.array([\n",
    "        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]\n",
    "        if pos != 0 else np.zeros(embedding_dim)\n",
    "        for pos in range(sequence_length)\n",
    "    ])\n",
    "    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i\n",
    "    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1\n",
    "    return tf.cast(position_enc, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3e5bfd69",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 20 # vocab_size\n",
    "\n",
    "vocabs = ['word_' + str(i) for i in range(N)]\n",
    "\n",
    "vocab_map = {}\n",
    "for i in range(len(vocabs)):\n",
    "    vocab_map[vocabs[i]] = i\n",
    "    \n",
    "pairs = []\n",
    "\n",
    "for i in vocabs:\n",
    "    for j in vocabs:\n",
    "        for k in vocabs:\n",
    "            if i != j and i != k and j != k:\n",
    "                pairs.append((i,j,k))\n",
    "            \n",
    "indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])\n",
    "\n",
    "pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]\n",
    "pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d43093fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences_train = []\n",
    "sentences_number_train = []\n",
    "sentences_test = []\n",
    "sentences_number_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    sentences_train.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    sentences_test.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = []\n",
    "y_masked_labels_train = []\n",
    "x_masked_test = []\n",
    "y_masked_labels_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    x_masked_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_train.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    x_masked_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_test.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = np.array(x_masked_train)\n",
    "y_masked_labels_train = np.array(y_masked_labels_train)\n",
    "x_masked_test = np.array(x_masked_test)\n",
    "y_masked_labels_test = np.array(y_masked_labels_test)\n",
    "\n",
    "perm = np.random.permutation(len(x_masked_train))\n",
    "x_masked_train = x_masked_train[perm]\n",
    "y_masked_labels_train = y_masked_labels_train[perm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "13b40f89",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 5s - loss: 3.2623 - val_loss: 3.2319 - 5s/epoch - 5s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.2253 - val_loss: 3.2067 - 82ms/epoch - 82ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.1964 - val_loss: 3.1845 - 84ms/epoch - 84ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.1718 - val_loss: 3.1658 - 81ms/epoch - 81ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.1520 - val_loss: 3.1498 - 83ms/epoch - 83ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.1358 - val_loss: 3.1357 - 85ms/epoch - 85ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.1221 - val_loss: 3.1232 - 82ms/epoch - 82ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.1100 - val_loss: 3.1122 - 83ms/epoch - 83ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.0993 - val_loss: 3.1026 - 79ms/epoch - 79ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 3.0899 - val_loss: 3.0941 - 84ms/epoch - 84ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 3.0815 - val_loss: 3.0866 - 82ms/epoch - 82ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 3.0742 - val_loss: 3.0800 - 82ms/epoch - 82ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 3.0676 - val_loss: 3.0738 - 83ms/epoch - 83ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 3.0615 - val_loss: 3.0683 - 79ms/epoch - 79ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 3.0560 - val_loss: 3.0634 - 83ms/epoch - 83ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 3.0509 - val_loss: 3.0588 - 82ms/epoch - 82ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 3.0462 - val_loss: 3.0547 - 83ms/epoch - 83ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 3.0419 - val_loss: 3.0509 - 84ms/epoch - 84ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 3.0379 - val_loss: 3.0473 - 81ms/epoch - 81ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 3.0340 - val_loss: 3.0438 - 87ms/epoch - 87ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 3.0303 - val_loss: 3.0403 - 84ms/epoch - 84ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 3.0266 - val_loss: 3.0369 - 85ms/epoch - 85ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 3.0229 - val_loss: 3.0336 - 83ms/epoch - 83ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 3.0193 - val_loss: 3.0303 - 81ms/epoch - 81ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 3.0156 - val_loss: 3.0271 - 84ms/epoch - 84ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 3.0120 - val_loss: 3.0241 - 79ms/epoch - 79ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 3.0084 - val_loss: 3.0211 - 83ms/epoch - 83ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 3.0049 - val_loss: 3.0183 - 83ms/epoch - 83ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 3.0015 - val_loss: 3.0155 - 81ms/epoch - 81ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.9981 - val_loss: 3.0129 - 85ms/epoch - 85ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.9948 - val_loss: 3.0102 - 80ms/epoch - 80ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.9914 - val_loss: 3.0077 - 83ms/epoch - 83ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.9881 - val_loss: 3.0052 - 81ms/epoch - 81ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.9847 - val_loss: 3.0028 - 79ms/epoch - 79ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.9814 - val_loss: 3.0005 - 83ms/epoch - 83ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.9780 - val_loss: 2.9982 - 80ms/epoch - 80ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.9748 - val_loss: 2.9960 - 82ms/epoch - 82ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.9717 - val_loss: 2.9937 - 82ms/epoch - 82ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.9685 - val_loss: 2.9912 - 81ms/epoch - 81ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.9653 - val_loss: 2.9884 - 85ms/epoch - 85ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.9620 - val_loss: 2.9854 - 80ms/epoch - 80ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.9588 - val_loss: 2.9822 - 83ms/epoch - 83ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.9555 - val_loss: 2.9790 - 90ms/epoch - 90ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.9522 - val_loss: 2.9757 - 86ms/epoch - 86ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.9488 - val_loss: 2.9724 - 84ms/epoch - 84ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.9454 - val_loss: 2.9691 - 80ms/epoch - 80ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.9421 - val_loss: 2.9660 - 84ms/epoch - 84ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.9387 - val_loss: 2.9630 - 80ms/epoch - 80ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.9355 - val_loss: 2.9600 - 84ms/epoch - 84ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.9324 - val_loss: 2.9570 - 84ms/epoch - 84ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.9293 - val_loss: 2.9538 - 81ms/epoch - 81ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.9262 - val_loss: 2.9507 - 84ms/epoch - 84ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.9232 - val_loss: 2.9479 - 81ms/epoch - 81ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.9203 - val_loss: 2.9452 - 82ms/epoch - 82ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.9174 - val_loss: 2.9427 - 83ms/epoch - 83ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.9146 - val_loss: 2.9403 - 80ms/epoch - 80ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.9117 - val_loss: 2.9379 - 83ms/epoch - 83ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.9088 - val_loss: 2.9355 - 80ms/epoch - 80ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.9057 - val_loss: 2.9333 - 83ms/epoch - 83ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.9027 - val_loss: 2.9312 - 82ms/epoch - 82ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.8995 - val_loss: 2.9290 - 80ms/epoch - 80ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.8963 - val_loss: 2.9268 - 84ms/epoch - 84ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.8929 - val_loss: 2.9245 - 80ms/epoch - 80ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.8896 - val_loss: 2.9221 - 83ms/epoch - 83ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.8862 - val_loss: 2.9197 - 84ms/epoch - 84ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.8829 - val_loss: 2.9173 - 80ms/epoch - 80ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.8794 - val_loss: 2.9150 - 84ms/epoch - 84ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.8759 - val_loss: 2.9127 - 80ms/epoch - 80ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.8722 - val_loss: 2.9106 - 83ms/epoch - 83ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.8687 - val_loss: 2.9086 - 82ms/epoch - 82ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.8653 - val_loss: 2.9064 - 80ms/epoch - 80ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.8619 - val_loss: 2.9041 - 83ms/epoch - 83ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.8585 - val_loss: 2.9021 - 79ms/epoch - 79ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.8549 - val_loss: 2.9000 - 83ms/epoch - 83ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.8514 - val_loss: 2.8979 - 82ms/epoch - 82ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.8478 - val_loss: 2.8959 - 80ms/epoch - 80ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.8444 - val_loss: 2.8934 - 84ms/epoch - 84ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.8408 - val_loss: 2.8909 - 79ms/epoch - 79ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.8373 - val_loss: 2.8887 - 83ms/epoch - 83ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.8339 - val_loss: 2.8866 - 84ms/epoch - 84ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.8303 - val_loss: 2.8845 - 80ms/epoch - 80ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.8271 - val_loss: 2.8819 - 83ms/epoch - 83ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.8236 - val_loss: 2.8786 - 79ms/epoch - 79ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.8201 - val_loss: 2.8758 - 83ms/epoch - 83ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.8170 - val_loss: 2.8748 - 82ms/epoch - 82ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.8148 - val_loss: 2.8708 - 80ms/epoch - 80ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.8106 - val_loss: 2.8691 - 83ms/epoch - 83ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.8080 - val_loss: 2.8671 - 80ms/epoch - 80ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.8045 - val_loss: 2.8641 - 82ms/epoch - 82ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.8019 - val_loss: 2.8614 - 82ms/epoch - 82ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.7975 - val_loss: 2.8610 - 80ms/epoch - 80ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.7953 - val_loss: 2.8577 - 83ms/epoch - 83ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.7921 - val_loss: 2.8558 - 80ms/epoch - 80ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.7877 - val_loss: 2.8549 - 82ms/epoch - 82ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.7856 - val_loss: 2.8509 - 82ms/epoch - 82ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.7818 - val_loss: 2.8487 - 79ms/epoch - 79ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 2.7783 - val_loss: 2.8481 - 83ms/epoch - 83ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 2.7760 - val_loss: 2.8442 - 80ms/epoch - 80ms/step\n",
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 2.7721 - val_loss: 2.8417 - 82ms/epoch - 82ms/step\n",
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 2.7690 - val_loss: 2.8401 - 82ms/epoch - 82ms/step\n",
      "Epoch 101/2000\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 - 0s - loss: 2.7664 - val_loss: 2.8363 - 80ms/epoch - 80ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 2.7630 - val_loss: 2.8344 - 91ms/epoch - 91ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 2.7597 - val_loss: 2.8331 - 80ms/epoch - 80ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 2.7571 - val_loss: 2.8292 - 83ms/epoch - 83ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 2.7538 - val_loss: 2.8263 - 82ms/epoch - 82ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 2.7502 - val_loss: 2.8243 - 80ms/epoch - 80ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 2.7474 - val_loss: 2.8217 - 83ms/epoch - 83ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 2.7446 - val_loss: 2.8208 - 80ms/epoch - 80ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 2.7413 - val_loss: 2.8178 - 83ms/epoch - 83ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 2.7384 - val_loss: 2.8154 - 81ms/epoch - 81ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 2.7350 - val_loss: 2.8138 - 80ms/epoch - 80ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 2.7324 - val_loss: 2.8105 - 84ms/epoch - 84ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 2.7302 - val_loss: 2.8104 - 80ms/epoch - 80ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 2.7280 - val_loss: 2.8064 - 83ms/epoch - 83ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 2.7253 - val_loss: 2.8050 - 81ms/epoch - 81ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 2.7211 - val_loss: 2.8010 - 80ms/epoch - 80ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 2.7173 - val_loss: 2.7978 - 83ms/epoch - 83ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 2.7155 - val_loss: 2.7975 - 81ms/epoch - 81ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 2.7136 - val_loss: 2.7944 - 85ms/epoch - 85ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 2.7102 - val_loss: 2.7932 - 80ms/epoch - 80ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 2.7070 - val_loss: 2.7906 - 81ms/epoch - 81ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 2.7042 - val_loss: 2.7870 - 83ms/epoch - 83ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 2.7023 - val_loss: 2.7856 - 79ms/epoch - 79ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 2.6994 - val_loss: 2.7819 - 83ms/epoch - 83ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 2.6965 - val_loss: 2.7796 - 81ms/epoch - 81ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 2.6935 - val_loss: 2.7782 - 81ms/epoch - 81ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 2.6915 - val_loss: 2.7763 - 83ms/epoch - 83ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 2.6887 - val_loss: 2.7778 - 72ms/epoch - 72ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 2.6879 - val_loss: 2.7741 - 83ms/epoch - 83ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 2.6880 - val_loss: 2.7756 - 75ms/epoch - 75ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 2.6850 - val_loss: 2.7689 - 80ms/epoch - 80ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 2.6794 - val_loss: 2.7671 - 83ms/epoch - 83ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 2.6783 - val_loss: 2.7729 - 73ms/epoch - 73ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 2.6820 - val_loss: 2.7668 - 83ms/epoch - 83ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 2.6800 - val_loss: 2.7639 - 83ms/epoch - 83ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 2.6719 - val_loss: 2.7683 - 73ms/epoch - 73ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 2.6748 - val_loss: 2.7593 - 85ms/epoch - 85ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 2.6682 - val_loss: 2.7566 - 81ms/epoch - 81ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 2.6680 - val_loss: 2.7591 - 75ms/epoch - 75ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 2.6651 - val_loss: 2.7562 - 83ms/epoch - 83ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 2.6611 - val_loss: 2.7537 - 80ms/epoch - 80ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 2.6614 - val_loss: 2.7509 - 83ms/epoch - 83ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 2.6561 - val_loss: 2.7537 - 75ms/epoch - 75ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 2.6571 - val_loss: 2.7479 - 79ms/epoch - 79ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 2.6528 - val_loss: 2.7472 - 83ms/epoch - 83ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 2.6512 - val_loss: 2.7491 - 74ms/epoch - 74ms/step\n",
      "Epoch 147/2000\n",
      "1/1 - 0s - loss: 2.6503 - val_loss: 2.7444 - 83ms/epoch - 83ms/step\n",
      "Epoch 148/2000\n",
      "1/1 - 0s - loss: 2.6458 - val_loss: 2.7420 - 82ms/epoch - 82ms/step\n",
      "Epoch 149/2000\n",
      "1/1 - 0s - loss: 2.6447 - val_loss: 2.7407 - 80ms/epoch - 80ms/step\n",
      "Epoch 150/2000\n",
      "1/1 - 0s - loss: 2.6418 - val_loss: 2.7395 - 89ms/epoch - 89ms/step\n",
      "Epoch 151/2000\n",
      "1/1 - 0s - loss: 2.6398 - val_loss: 2.7365 - 84ms/epoch - 84ms/step\n",
      "Epoch 152/2000\n",
      "1/1 - 0s - loss: 2.6375 - val_loss: 2.7359 - 84ms/epoch - 84ms/step\n",
      "Epoch 153/2000\n",
      "1/1 - 0s - loss: 2.6358 - val_loss: 2.7339 - 82ms/epoch - 82ms/step\n",
      "Epoch 154/2000\n",
      "1/1 - 0s - loss: 2.6330 - val_loss: 2.7318 - 80ms/epoch - 80ms/step\n",
      "Epoch 155/2000\n",
      "1/1 - 0s - loss: 2.6312 - val_loss: 2.7301 - 83ms/epoch - 83ms/step\n",
      "Epoch 156/2000\n",
      "1/1 - 0s - loss: 2.6286 - val_loss: 2.7284 - 79ms/epoch - 79ms/step\n",
      "Epoch 157/2000\n",
      "1/1 - 0s - loss: 2.6266 - val_loss: 2.7259 - 84ms/epoch - 84ms/step\n",
      "Epoch 158/2000\n",
      "1/1 - 0s - loss: 2.6242 - val_loss: 2.7272 - 74ms/epoch - 74ms/step\n",
      "Epoch 159/2000\n",
      "1/1 - 0s - loss: 2.6229 - val_loss: 2.7246 - 80ms/epoch - 80ms/step\n",
      "Epoch 160/2000\n",
      "1/1 - 0s - loss: 2.6214 - val_loss: 2.7251 - 77ms/epoch - 77ms/step\n",
      "Epoch 161/2000\n",
      "1/1 - 0s - loss: 2.6183 - val_loss: 2.7223 - 79ms/epoch - 79ms/step\n",
      "Epoch 162/2000\n",
      "1/1 - 0s - loss: 2.6163 - val_loss: 2.7193 - 83ms/epoch - 83ms/step\n",
      "Epoch 163/2000\n",
      "1/1 - 0s - loss: 2.6133 - val_loss: 2.7182 - 82ms/epoch - 82ms/step\n",
      "Epoch 164/2000\n",
      "1/1 - 0s - loss: 2.6115 - val_loss: 2.7163 - 80ms/epoch - 80ms/step\n",
      "Epoch 165/2000\n",
      "1/1 - 0s - loss: 2.6094 - val_loss: 2.7178 - 76ms/epoch - 76ms/step\n",
      "Epoch 166/2000\n",
      "1/1 - 0s - loss: 2.6086 - val_loss: 2.7135 - 80ms/epoch - 80ms/step\n",
      "Epoch 167/2000\n",
      "1/1 - 0s - loss: 2.6107 - val_loss: 2.7174 - 75ms/epoch - 75ms/step\n",
      "Epoch 168/2000\n",
      "1/1 - 0s - loss: 2.6081 - val_loss: 2.7125 - 83ms/epoch - 83ms/step\n",
      "Epoch 169/2000\n",
      "1/1 - 0s - loss: 2.6039 - val_loss: 2.7101 - 79ms/epoch - 79ms/step\n",
      "Epoch 170/2000\n",
      "1/1 - 0s - loss: 2.6009 - val_loss: 2.7103 - 77ms/epoch - 77ms/step\n",
      "Epoch 171/2000\n",
      "1/1 - 0s - loss: 2.6002 - val_loss: 2.7035 - 82ms/epoch - 82ms/step\n",
      "Epoch 172/2000\n",
      "1/1 - 0s - loss: 2.5975 - val_loss: 2.7065 - 74ms/epoch - 74ms/step\n",
      "Epoch 173/2000\n",
      "1/1 - 0s - loss: 2.5954 - val_loss: 2.7031 - 82ms/epoch - 82ms/step\n",
      "Epoch 174/2000\n",
      "1/1 - 0s - loss: 2.5923 - val_loss: 2.7028 - 91ms/epoch - 91ms/step\n",
      "Epoch 175/2000\n",
      "1/1 - 0s - loss: 2.5907 - val_loss: 2.7019 - 83ms/epoch - 83ms/step\n",
      "Epoch 176/2000\n",
      "1/1 - 0s - loss: 2.5872 - val_loss: 2.7005 - 81ms/epoch - 81ms/step\n",
      "Epoch 177/2000\n",
      "1/1 - 0s - loss: 2.5866 - val_loss: 2.6959 - 82ms/epoch - 82ms/step\n",
      "Epoch 178/2000\n",
      "1/1 - 0s - loss: 2.5842 - val_loss: 2.6945 - 83ms/epoch - 83ms/step\n",
      "Epoch 179/2000\n",
      "1/1 - 0s - loss: 2.5809 - val_loss: 2.6958 - 73ms/epoch - 73ms/step\n",
      "Epoch 180/2000\n",
      "1/1 - 0s - loss: 2.5798 - val_loss: 2.6939 - 83ms/epoch - 83ms/step\n",
      "Epoch 181/2000\n",
      "1/1 - 0s - loss: 2.5778 - val_loss: 2.6889 - 80ms/epoch - 80ms/step\n",
      "Epoch 182/2000\n",
      "1/1 - 0s - loss: 2.5764 - val_loss: 2.6909 - 74ms/epoch - 74ms/step\n",
      "Epoch 183/2000\n",
      "1/1 - 0s - loss: 2.5750 - val_loss: 2.6897 - 77ms/epoch - 77ms/step\n",
      "Epoch 184/2000\n",
      "1/1 - 0s - loss: 2.5739 - val_loss: 2.6957 - 73ms/epoch - 73ms/step\n",
      "Epoch 185/2000\n",
      "1/1 - 0s - loss: 2.5779 - val_loss: 2.6858 - 83ms/epoch - 83ms/step\n",
      "Epoch 186/2000\n",
      "1/1 - 0s - loss: 2.5775 - val_loss: 2.6845 - 81ms/epoch - 81ms/step\n",
      "Epoch 187/2000\n",
      "1/1 - 0s - loss: 2.5730 - val_loss: 2.6825 - 83ms/epoch - 83ms/step\n",
      "Epoch 188/2000\n",
      "1/1 - 0s - loss: 2.5656 - val_loss: 2.6870 - 81ms/epoch - 81ms/step\n",
      "Epoch 189/2000\n",
      "1/1 - 0s - loss: 2.5690 - val_loss: 2.6818 - 79ms/epoch - 79ms/step\n",
      "Epoch 190/2000\n",
      "1/1 - 0s - loss: 2.5613 - val_loss: 2.6808 - 83ms/epoch - 83ms/step\n",
      "Epoch 191/2000\n",
      "1/1 - 0s - loss: 2.5619 - val_loss: 2.6759 - 82ms/epoch - 82ms/step\n",
      "Epoch 192/2000\n",
      "1/1 - 0s - loss: 2.5581 - val_loss: 2.6757 - 81ms/epoch - 81ms/step\n",
      "Epoch 193/2000\n",
      "1/1 - 0s - loss: 2.5574 - val_loss: 2.6721 - 84ms/epoch - 84ms/step\n",
      "Epoch 194/2000\n",
      "1/1 - 0s - loss: 2.5525 - val_loss: 2.6726 - 72ms/epoch - 72ms/step\n",
      "Epoch 195/2000\n",
      "1/1 - 0s - loss: 2.5541 - val_loss: 2.6757 - 76ms/epoch - 76ms/step\n",
      "Epoch 196/2000\n",
      "1/1 - 0s - loss: 2.5527 - val_loss: 2.6706 - 82ms/epoch - 82ms/step\n",
      "Epoch 197/2000\n",
      "1/1 - 0s - loss: 2.5523 - val_loss: 2.6651 - 80ms/epoch - 80ms/step\n",
      "Epoch 198/2000\n",
      "1/1 - 0s - loss: 2.5457 - val_loss: 2.6667 - 76ms/epoch - 76ms/step\n",
      "Epoch 199/2000\n",
      "1/1 - 0s - loss: 2.5464 - val_loss: 2.6638 - 79ms/epoch - 79ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200/2000\n",
      "1/1 - 0s - loss: 2.5451 - val_loss: 2.6651 - 75ms/epoch - 75ms/step\n",
      "Epoch 201/2000\n",
      "1/1 - 0s - loss: 2.5411 - val_loss: 2.6617 - 83ms/epoch - 83ms/step\n",
      "Epoch 202/2000\n",
      "1/1 - 0s - loss: 2.5384 - val_loss: 2.6605 - 80ms/epoch - 80ms/step\n",
      "Epoch 203/2000\n",
      "1/1 - 0s - loss: 2.5382 - val_loss: 2.6616 - 76ms/epoch - 76ms/step\n",
      "Epoch 204/2000\n",
      "1/1 - 0s - loss: 2.5359 - val_loss: 2.6592 - 81ms/epoch - 81ms/step\n",
      "Epoch 205/2000\n",
      "1/1 - 0s - loss: 2.5343 - val_loss: 2.6572 - 81ms/epoch - 81ms/step\n",
      "Epoch 206/2000\n",
      "1/1 - 0s - loss: 2.5306 - val_loss: 2.6521 - 83ms/epoch - 83ms/step\n",
      "Epoch 207/2000\n",
      "1/1 - 0s - loss: 2.5281 - val_loss: 2.6497 - 79ms/epoch - 79ms/step\n",
      "Epoch 208/2000\n",
      "1/1 - 0s - loss: 2.5262 - val_loss: 2.6518 - 76ms/epoch - 76ms/step\n",
      "Epoch 209/2000\n",
      "1/1 - 0s - loss: 2.5249 - val_loss: 2.6490 - 82ms/epoch - 82ms/step\n",
      "Epoch 210/2000\n",
      "1/1 - 0s - loss: 2.5228 - val_loss: 2.6492 - 72ms/epoch - 72ms/step\n",
      "Epoch 211/2000\n",
      "1/1 - 0s - loss: 2.5210 - val_loss: 2.6462 - 83ms/epoch - 83ms/step\n",
      "Epoch 212/2000\n",
      "1/1 - 0s - loss: 2.5191 - val_loss: 2.6460 - 79ms/epoch - 79ms/step\n",
      "Epoch 213/2000\n",
      "1/1 - 0s - loss: 2.5160 - val_loss: 2.6430 - 82ms/epoch - 82ms/step\n",
      "Epoch 214/2000\n",
      "1/1 - 0s - loss: 2.5141 - val_loss: 2.6399 - 83ms/epoch - 83ms/step\n",
      "Epoch 215/2000\n",
      "1/1 - 0s - loss: 2.5110 - val_loss: 2.6394 - 79ms/epoch - 79ms/step\n",
      "Epoch 216/2000\n",
      "1/1 - 0s - loss: 2.5104 - val_loss: 2.6373 - 84ms/epoch - 84ms/step\n",
      "Epoch 217/2000\n",
      "1/1 - 0s - loss: 2.5093 - val_loss: 2.6430 - 73ms/epoch - 73ms/step\n",
      "Epoch 218/2000\n",
      "1/1 - 0s - loss: 2.5103 - val_loss: 2.6454 - 75ms/epoch - 75ms/step\n",
      "Epoch 219/2000\n",
      "1/1 - 0s - loss: 2.5193 - val_loss: 2.6517 - 78ms/epoch - 78ms/step\n",
      "Epoch 220/2000\n",
      "1/1 - 0s - loss: 2.5195 - val_loss: 2.6387 - 73ms/epoch - 73ms/step\n",
      "Epoch 221/2000\n",
      "1/1 - 0s - loss: 2.5173 - val_loss: 2.6294 - 83ms/epoch - 83ms/step\n",
      "Epoch 222/2000\n",
      "1/1 - 0s - loss: 2.5041 - val_loss: 2.6375 - 75ms/epoch - 75ms/step\n",
      "Epoch 223/2000\n",
      "1/1 - 0s - loss: 2.5076 - val_loss: 2.6391 - 73ms/epoch - 73ms/step\n",
      "Epoch 224/2000\n",
      "1/1 - 0s - loss: 2.5122 - val_loss: 2.6361 - 76ms/epoch - 76ms/step\n",
      "Epoch 225/2000\n",
      "1/1 - 0s - loss: 2.4997 - val_loss: 2.6316 - 73ms/epoch - 73ms/step\n",
      "Epoch 226/2000\n",
      "1/1 - 0s - loss: 2.4986 - val_loss: 2.6290 - 81ms/epoch - 81ms/step\n",
      "Epoch 227/2000\n",
      "1/1 - 0s - loss: 2.4999 - val_loss: 2.6247 - 83ms/epoch - 83ms/step\n",
      "Epoch 228/2000\n",
      "1/1 - 0s - loss: 2.4918 - val_loss: 2.6303 - 72ms/epoch - 72ms/step\n",
      "Epoch 229/2000\n",
      "1/1 - 0s - loss: 2.4947 - val_loss: 2.6271 - 76ms/epoch - 76ms/step\n",
      "Epoch 230/2000\n",
      "1/1 - 0s - loss: 2.4961 - val_loss: 2.6216 - 83ms/epoch - 83ms/step\n",
      "Epoch 231/2000\n",
      "1/1 - 0s - loss: 2.4868 - val_loss: 2.6257 - 73ms/epoch - 73ms/step\n",
      "Epoch 232/2000\n",
      "1/1 - 0s - loss: 2.4896 - val_loss: 2.6212 - 83ms/epoch - 83ms/step\n",
      "Epoch 233/2000\n",
      "1/1 - 0s - loss: 2.4899 - val_loss: 2.6143 - 80ms/epoch - 80ms/step\n",
      "Epoch 234/2000\n",
      "1/1 - 0s - loss: 2.4804 - val_loss: 2.6200 - 76ms/epoch - 76ms/step\n",
      "Epoch 235/2000\n",
      "1/1 - 0s - loss: 2.4838 - val_loss: 2.6176 - 76ms/epoch - 76ms/step\n",
      "Epoch 236/2000\n",
      "1/1 - 0s - loss: 2.4848 - val_loss: 2.6138 - 79ms/epoch - 79ms/step\n",
      "Epoch 237/2000\n",
      "1/1 - 0s - loss: 2.4763 - val_loss: 2.6129 - 83ms/epoch - 83ms/step\n",
      "Epoch 238/2000\n",
      "1/1 - 0s - loss: 2.4754 - val_loss: 2.6110 - 81ms/epoch - 81ms/step\n",
      "Epoch 239/2000\n",
      "1/1 - 0s - loss: 2.4755 - val_loss: 2.6104 - 82ms/epoch - 82ms/step\n",
      "Epoch 240/2000\n",
      "1/1 - 0s - loss: 2.4703 - val_loss: 2.6102 - 83ms/epoch - 83ms/step\n",
      "Epoch 241/2000\n",
      "1/1 - 0s - loss: 2.4693 - val_loss: 2.6058 - 79ms/epoch - 79ms/step\n",
      "Epoch 242/2000\n",
      "1/1 - 0s - loss: 2.4696 - val_loss: 2.6063 - 77ms/epoch - 77ms/step\n",
      "Epoch 243/2000\n",
      "1/1 - 0s - loss: 2.4668 - val_loss: 2.6038 - 81ms/epoch - 81ms/step\n",
      "Epoch 244/2000\n",
      "1/1 - 0s - loss: 2.4634 - val_loss: 2.6039 - 74ms/epoch - 74ms/step\n",
      "Epoch 245/2000\n",
      "1/1 - 0s - loss: 2.4615 - val_loss: 2.6015 - 83ms/epoch - 83ms/step\n",
      "Epoch 246/2000\n",
      "1/1 - 0s - loss: 2.4594 - val_loss: 2.5973 - 79ms/epoch - 79ms/step\n",
      "Epoch 247/2000\n",
      "1/1 - 0s - loss: 2.4578 - val_loss: 2.5988 - 76ms/epoch - 76ms/step\n",
      "Epoch 248/2000\n",
      "1/1 - 0s - loss: 2.4574 - val_loss: 2.5976 - 74ms/epoch - 74ms/step\n",
      "Epoch 249/2000\n",
      "1/1 - 0s - loss: 2.4605 - val_loss: 2.6103 - 84ms/epoch - 84ms/step\n",
      "Epoch 250/2000\n",
      "1/1 - 0s - loss: 2.4658 - val_loss: 2.6081 - 78ms/epoch - 78ms/step\n",
      "Epoch 251/2000\n",
      "1/1 - 0s - loss: 2.4735 - val_loss: 2.5952 - 80ms/epoch - 80ms/step\n",
      "Epoch 252/2000\n",
      "1/1 - 0s - loss: 2.4565 - val_loss: 2.6016 - 76ms/epoch - 76ms/step\n",
      "Epoch 253/2000\n",
      "1/1 - 0s - loss: 2.4597 - val_loss: 2.6108 - 75ms/epoch - 75ms/step\n",
      "Epoch 254/2000\n",
      "1/1 - 0s - loss: 2.4741 - val_loss: 2.5972 - 74ms/epoch - 74ms/step\n",
      "Epoch 255/2000\n",
      "1/1 - 0s - loss: 2.4511 - val_loss: 2.5985 - 76ms/epoch - 76ms/step\n",
      "Epoch 256/2000\n",
      "1/1 - 0s - loss: 2.4564 - val_loss: 2.5918 - 80ms/epoch - 80ms/step\n",
      "Epoch 257/2000\n",
      "1/1 - 0s - loss: 2.4532 - val_loss: 2.5921 - 74ms/epoch - 74ms/step\n",
      "Epoch 258/2000\n",
      "1/1 - 0s - loss: 2.4512 - val_loss: 2.5947 - 76ms/epoch - 76ms/step\n",
      "Epoch 259/2000\n",
      "1/1 - 0s - loss: 2.4480 - val_loss: 2.5888 - 81ms/epoch - 81ms/step\n",
      "Epoch 260/2000\n",
      "1/1 - 0s - loss: 2.4435 - val_loss: 2.5865 - 82ms/epoch - 82ms/step\n",
      "Epoch 261/2000\n",
      "1/1 - 0s - loss: 2.4424 - val_loss: 2.5901 - 75ms/epoch - 75ms/step\n",
      "Epoch 262/2000\n",
      "1/1 - 0s - loss: 2.4426 - val_loss: 2.5831 - 79ms/epoch - 79ms/step\n",
      "Epoch 263/2000\n",
      "1/1 - 0s - loss: 2.4377 - val_loss: 2.5857 - 76ms/epoch - 76ms/step\n",
      "Epoch 264/2000\n",
      "1/1 - 0s - loss: 2.4377 - val_loss: 2.5867 - 73ms/epoch - 73ms/step\n",
      "Epoch 265/2000\n",
      "1/1 - 0s - loss: 2.4351 - val_loss: 2.5789 - 82ms/epoch - 82ms/step\n",
      "Epoch 266/2000\n",
      "1/1 - 0s - loss: 2.4309 - val_loss: 2.5756 - 82ms/epoch - 82ms/step\n",
      "Epoch 267/2000\n",
      "1/1 - 0s - loss: 2.4308 - val_loss: 2.5736 - 79ms/epoch - 79ms/step\n",
      "Epoch 268/2000\n",
      "1/1 - 0s - loss: 2.4272 - val_loss: 2.5757 - 76ms/epoch - 76ms/step\n",
      "Epoch 269/2000\n",
      "1/1 - 0s - loss: 2.4263 - val_loss: 2.5785 - 74ms/epoch - 74ms/step\n",
      "Epoch 270/2000\n",
      "1/1 - 0s - loss: 2.4254 - val_loss: 2.5759 - 73ms/epoch - 73ms/step\n",
      "Epoch 271/2000\n",
      "1/1 - 0s - loss: 2.4227 - val_loss: 2.5734 - 83ms/epoch - 83ms/step\n",
      "Epoch 272/2000\n",
      "1/1 - 0s - loss: 2.4193 - val_loss: 2.5692 - 80ms/epoch - 80ms/step\n",
      "Epoch 273/2000\n",
      "1/1 - 0s - loss: 2.4185 - val_loss: 2.5663 - 83ms/epoch - 83ms/step\n",
      "Epoch 274/2000\n",
      "1/1 - 0s - loss: 2.4165 - val_loss: 2.5679 - 75ms/epoch - 75ms/step\n",
      "Epoch 275/2000\n",
      "1/1 - 0s - loss: 2.4156 - val_loss: 2.5686 - 72ms/epoch - 72ms/step\n",
      "Epoch 276/2000\n",
      "1/1 - 0s - loss: 2.4193 - val_loss: 2.5747 - 76ms/epoch - 76ms/step\n",
      "Epoch 277/2000\n",
      "1/1 - 0s - loss: 2.4207 - val_loss: 2.5774 - 74ms/epoch - 74ms/step\n",
      "Epoch 278/2000\n",
      "1/1 - 0s - loss: 2.4292 - val_loss: 2.5751 - 95ms/epoch - 95ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 10\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "encoder_output = word_embeddings\n",
    "\n",
    "for i in range(5):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "aa33aaff",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fe4d2103",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.882, 0.35875502)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0005b730",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 5s - loss: 3.7363 - val_loss: 3.2841 - 5s/epoch - 5s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.2838 - val_loss: 3.1190 - 307ms/epoch - 307ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.0946 - val_loss: 3.0484 - 299ms/epoch - 299ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.0104 - val_loss: 3.0134 - 296ms/epoch - 296ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 2.9706 - val_loss: 2.9909 - 275ms/epoch - 275ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 2.9489 - val_loss: 2.9801 - 284ms/epoch - 284ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 2.9388 - val_loss: 2.9659 - 270ms/epoch - 270ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 2.9226 - val_loss: 2.9514 - 274ms/epoch - 274ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 2.9036 - val_loss: 2.9434 - 268ms/epoch - 268ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 2.8898 - val_loss: 2.9388 - 273ms/epoch - 273ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 2.8788 - val_loss: 2.9307 - 279ms/epoch - 279ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 2.8632 - val_loss: 2.9181 - 266ms/epoch - 266ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 2.8425 - val_loss: 2.9044 - 270ms/epoch - 270ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 2.8198 - val_loss: 2.8912 - 271ms/epoch - 271ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 2.7971 - val_loss: 2.8789 - 275ms/epoch - 275ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 2.7741 - val_loss: 2.8672 - 269ms/epoch - 269ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 2.7507 - val_loss: 2.8566 - 277ms/epoch - 277ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 2.7279 - val_loss: 2.8471 - 264ms/epoch - 264ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 2.7065 - val_loss: 2.8378 - 256ms/epoch - 256ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 2.6865 - val_loss: 2.8278 - 261ms/epoch - 261ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 2.6674 - val_loss: 2.8166 - 274ms/epoch - 274ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 2.6486 - val_loss: 2.8035 - 261ms/epoch - 261ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 2.6291 - val_loss: 2.7894 - 263ms/epoch - 263ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 2.6095 - val_loss: 2.7743 - 261ms/epoch - 261ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 2.5892 - val_loss: 2.7585 - 264ms/epoch - 264ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 2.5683 - val_loss: 2.7437 - 260ms/epoch - 260ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 2.5483 - val_loss: 2.7300 - 289ms/epoch - 289ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.5293 - val_loss: 2.7164 - 270ms/epoch - 270ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.5099 - val_loss: 2.7033 - 263ms/epoch - 263ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.4902 - val_loss: 2.6911 - 267ms/epoch - 267ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.4706 - val_loss: 2.6799 - 255ms/epoch - 255ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.4513 - val_loss: 2.6696 - 268ms/epoch - 268ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.4322 - val_loss: 2.6602 - 261ms/epoch - 261ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.4135 - val_loss: 2.6513 - 263ms/epoch - 263ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.3951 - val_loss: 2.6428 - 260ms/epoch - 260ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.3767 - val_loss: 2.6340 - 266ms/epoch - 266ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.3583 - val_loss: 2.6245 - 263ms/epoch - 263ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.3395 - val_loss: 2.6152 - 269ms/epoch - 269ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.3211 - val_loss: 2.6057 - 263ms/epoch - 263ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.3029 - val_loss: 2.5966 - 260ms/epoch - 260ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.2849 - val_loss: 2.5883 - 273ms/epoch - 273ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.2671 - val_loss: 2.5806 - 264ms/epoch - 264ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.2488 - val_loss: 2.5735 - 264ms/epoch - 264ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.2303 - val_loss: 2.5670 - 262ms/epoch - 262ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.2118 - val_loss: 2.5604 - 260ms/epoch - 260ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.1930 - val_loss: 2.5538 - 258ms/epoch - 258ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.1740 - val_loss: 2.5473 - 262ms/epoch - 262ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.1547 - val_loss: 2.5407 - 270ms/epoch - 270ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.1349 - val_loss: 2.5342 - 270ms/epoch - 270ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.1150 - val_loss: 2.5267 - 260ms/epoch - 260ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.0947 - val_loss: 2.5201 - 265ms/epoch - 265ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.0749 - val_loss: 2.5194 - 267ms/epoch - 267ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.0634 - val_loss: 2.5454 - 256ms/epoch - 256ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.0799 - val_loss: 2.5442 - 257ms/epoch - 257ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.0727 - val_loss: 2.4991 - 253ms/epoch - 253ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.0066 - val_loss: 2.5249 - 260ms/epoch - 260ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.0224 - val_loss: 2.4882 - 270ms/epoch - 270ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 1.9677 - val_loss: 2.5053 - 261ms/epoch - 261ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 1.9758 - val_loss: 2.4804 - 263ms/epoch - 263ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 1.9340 - val_loss: 2.4888 - 261ms/epoch - 261ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 1.9280 - val_loss: 2.4763 - 266ms/epoch - 266ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 1.9025 - val_loss: 2.4706 - 260ms/epoch - 260ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 1.8854 - val_loss: 2.4713 - 263ms/epoch - 263ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 1.8721 - val_loss: 2.4702 - 266ms/epoch - 266ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 1.8513 - val_loss: 2.4833 - 255ms/epoch - 255ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 1.8523 - val_loss: 2.5165 - 257ms/epoch - 257ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 1.8712 - val_loss: 2.5022 - 256ms/epoch - 256ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 1.8572 - val_loss: 2.4698 - 270ms/epoch - 270ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 1.8038 - val_loss: 2.4925 - 253ms/epoch - 253ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 1.8112 - val_loss: 2.4769 - 256ms/epoch - 256ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 1.7825 - val_loss: 2.4794 - 254ms/epoch - 254ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 1.7751 - val_loss: 2.4809 - 255ms/epoch - 255ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 1.7559 - val_loss: 2.4931 - 278ms/epoch - 278ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 100\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "encoder_output = word_embeddings\n",
    "\n",
    "for i in range(5):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "96533a42",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[14, 19,  1, 14],\n",
       "       [ 4,  1,  0,  4],\n",
       "       [14,  1,  5, 14],\n",
       "       ...,\n",
       "       [15,  2, 12, 15],\n",
       "       [17,  4, 14, 17],\n",
       "       [ 5, 16,  1,  5]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_masked_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d6f921e3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[19,  1, 14],\n",
       "       [ 1,  0,  4],\n",
       "       [ 1,  5, 14],\n",
       "       ...,\n",
       "       [ 2, 12, 15],\n",
       "       [ 4, 14, 17],\n",
       "       [16,  1,  5]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_masked_labels_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "beb07cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "4ef61e71",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.967, 0.73480076)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "9fbf1e0b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 5,  0,  2,  5],\n",
       "       [11,  1,  3, 11],\n",
       "       [16, 11,  3, 16],\n",
       "       ...,\n",
       "       [11, 17,  7, 11],\n",
       "       [14, 10,  8, 14],\n",
       "       [ 7, 15,  6,  7]])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_test_subset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "331a0d43",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.5243081e-02, 1.6439389e-03, 1.9486870e-04, 1.3398707e-02,\n",
       "        1.5217075e-01, 7.3369056e-01, 4.5488998e-03, 3.8815632e-03,\n",
       "        3.1420204e-04, 1.2989054e-02, 2.4941228e-03, 1.7236115e-03,\n",
       "        2.4614645e-02, 5.1256698e-03, 8.7624453e-03, 2.0610983e-03,\n",
       "        1.5556569e-03, 3.9935755e-03, 5.7152491e-03, 5.8782739e-03]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array([5, 0, 2, 5]).reshape(1,len(sentence_number)))\n",
    "temp = temp[:,-1,:]\n",
    "temp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "399e0aa1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[8.7793839e-01, 3.9893654e-03, 3.1362392e-04, 2.5042465e-03,\n",
       "        3.1550489e-03, 1.3760317e-04, 4.9131678e-04, 2.6360264e-03,\n",
       "        1.9013104e-03, 3.9221297e-04, 1.6857702e-02, 4.8561595e-02,\n",
       "        4.8413621e-03, 1.7939642e-03, 4.3755607e-03, 1.3562806e-03,\n",
       "        6.2582488e-03, 2.1750858e-02, 1.2489436e-04, 6.2040123e-04]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array([0, 5, 2, 0]).reshape(1,len(sentence_number)))\n",
    "temp = temp[:,-1,:]\n",
    "temp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aea33c5b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "841310c9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b41c8a01",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
