{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "674d5c38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5a29d33b",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d96c3193",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_module(query, key, value, embed_dim, num_head, i):\n",
    "    \n",
    "    # Multi headed self-attention\n",
    "    attention_output = layers.MultiHeadAttention(\n",
    "        num_heads=num_head,\n",
    "        key_dim=embed_dim // num_head,\n",
    "        name=\"encoder_{}/multiheadattention\".format(i)\n",
    "    )(query, key, value, use_causal_mask=True)\n",
    "    \n",
    "    # Add & Normalize\n",
    "    attention_output = layers.Add()([query, attention_output])  # Skip Connection\n",
    "    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)\n",
    "    \n",
    "    # Feedforward network\n",
    "    ff_net = keras.models.Sequential([\n",
    "        layers.Dense(2 * embed_dim, activation='relu', name=\"encoder_{}/ffn_dense_1\".format(i)),\n",
    "        layers.Dense(embed_dim, name=\"encoder_{}/ffn_dense_2\".format(i)),\n",
    "    ])\n",
    "\n",
    "    # Apply Feedforward network\n",
    "    ffn_output = ff_net(attention_output)\n",
    "\n",
    "    # Add & Normalize\n",
    "    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection\n",
    "    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)\n",
    "    \n",
    "    return ffn_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fd7afdea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sinusoidal_embeddings(sequence_length, embedding_dim):\n",
    "    position_enc = np.array([\n",
    "        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]\n",
    "        if pos != 0 else np.zeros(embedding_dim)\n",
    "        for pos in range(sequence_length)\n",
    "    ])\n",
    "    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i\n",
    "    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1\n",
    "    return tf.cast(position_enc, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "95bab1fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 20 # vocab_size\n",
    "\n",
    "vocabs = ['word_' + str(i) for i in range(N)]\n",
    "\n",
    "vocab_map = {}\n",
    "for i in range(len(vocabs)):\n",
    "    vocab_map[vocabs[i]] = i\n",
    "    \n",
    "pairs = []\n",
    "\n",
    "for i in vocabs:\n",
    "    for j in vocabs:\n",
    "        for k in vocabs:\n",
    "            if i != j and i != k and j != k:\n",
    "                pairs.append((i,j,k))\n",
    "            \n",
    "indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])\n",
    "\n",
    "pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]\n",
    "pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4f71cf58",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences_train = []\n",
    "sentences_number_train = []\n",
    "sentences_test = []\n",
    "sentences_number_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    sentences_train.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    sentences_test.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = []\n",
    "y_masked_labels_train = []\n",
    "x_masked_test = []\n",
    "y_masked_labels_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    x_masked_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_train.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    x_masked_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_test.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = np.array(x_masked_train)\n",
    "y_masked_labels_train = np.array(y_masked_labels_train)\n",
    "x_masked_test = np.array(x_masked_test)\n",
    "y_masked_labels_test = np.array(y_masked_labels_test)\n",
    "\n",
    "perm = np.random.permutation(len(x_masked_train))\n",
    "x_masked_train = x_masked_train[perm]\n",
    "y_masked_labels_train = y_masked_labels_train[perm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3e737b36",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-22 07:04:43.714790: W tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory\n",
      "2024-05-22 07:04:43.714826: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)\n",
      "2024-05-22 07:04:43.714846: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (gl3035.arc-ts.umich.edu): /proc/driver/nvidia/version does not exist\n",
      "2024-05-22 07:04:43.715024: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 4s - loss: 3.2037 - val_loss: 3.1891 - 4s/epoch - 4s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.1765 - val_loss: 3.1696 - 100ms/epoch - 100ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.1544 - val_loss: 3.1508 - 97ms/epoch - 97ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.1341 - val_loss: 3.1334 - 102ms/epoch - 102ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.1156 - val_loss: 3.1167 - 97ms/epoch - 97ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.0982 - val_loss: 3.0996 - 97ms/epoch - 97ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.0803 - val_loss: 3.0836 - 96ms/epoch - 96ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.0632 - val_loss: 3.0682 - 97ms/epoch - 97ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.0464 - val_loss: 3.0523 - 97ms/epoch - 97ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 3.0296 - val_loss: 3.0360 - 98ms/epoch - 98ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 3.0130 - val_loss: 3.0200 - 103ms/epoch - 103ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 2.9967 - val_loss: 3.0047 - 94ms/epoch - 94ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 2.9813 - val_loss: 2.9906 - 100ms/epoch - 100ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 2.9673 - val_loss: 2.9777 - 97ms/epoch - 97ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 2.9543 - val_loss: 2.9657 - 96ms/epoch - 96ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 2.9425 - val_loss: 2.9538 - 100ms/epoch - 100ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 2.9309 - val_loss: 2.9423 - 96ms/epoch - 96ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 2.9197 - val_loss: 2.9309 - 97ms/epoch - 97ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 2.9088 - val_loss: 2.9204 - 101ms/epoch - 101ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 2.8990 - val_loss: 2.9109 - 106ms/epoch - 106ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 2.8893 - val_loss: 2.9020 - 101ms/epoch - 101ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 2.8798 - val_loss: 2.8928 - 102ms/epoch - 102ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 2.8700 - val_loss: 2.8838 - 107ms/epoch - 107ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 2.8603 - val_loss: 2.8754 - 102ms/epoch - 102ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 2.8510 - val_loss: 2.8676 - 103ms/epoch - 103ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 2.8420 - val_loss: 2.8603 - 113ms/epoch - 113ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 2.8337 - val_loss: 2.8536 - 117ms/epoch - 117ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.8266 - val_loss: 2.8465 - 125ms/epoch - 125ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.8192 - val_loss: 2.8388 - 103ms/epoch - 103ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.8112 - val_loss: 2.8306 - 105ms/epoch - 105ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.8029 - val_loss: 2.8227 - 96ms/epoch - 96ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.7952 - val_loss: 2.8156 - 98ms/epoch - 98ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.7882 - val_loss: 2.8082 - 99ms/epoch - 99ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.7815 - val_loss: 2.8015 - 97ms/epoch - 97ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.7752 - val_loss: 2.7954 - 90ms/epoch - 90ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.7691 - val_loss: 2.7891 - 97ms/epoch - 97ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.7629 - val_loss: 2.7826 - 94ms/epoch - 94ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.7567 - val_loss: 2.7765 - 98ms/epoch - 98ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.7503 - val_loss: 2.7696 - 101ms/epoch - 101ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.7437 - val_loss: 2.7635 - 125ms/epoch - 125ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.7378 - val_loss: 2.7585 - 98ms/epoch - 98ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.7325 - val_loss: 2.7522 - 98ms/epoch - 98ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.7265 - val_loss: 2.7473 - 111ms/epoch - 111ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.7205 - val_loss: 2.7415 - 99ms/epoch - 99ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.7142 - val_loss: 2.7362 - 97ms/epoch - 97ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.7079 - val_loss: 2.7311 - 97ms/epoch - 97ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.7019 - val_loss: 2.7271 - 98ms/epoch - 98ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.6979 - val_loss: 2.7302 - 88ms/epoch - 88ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.7000 - val_loss: 2.7184 - 95ms/epoch - 95ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.6882 - val_loss: 2.7120 - 103ms/epoch - 103ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.6812 - val_loss: 2.7121 - 99ms/epoch - 99ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.6809 - val_loss: 2.7004 - 96ms/epoch - 96ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.6692 - val_loss: 2.7009 - 89ms/epoch - 89ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.6704 - val_loss: 2.6925 - 95ms/epoch - 95ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.6612 - val_loss: 2.6880 - 96ms/epoch - 96ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.6568 - val_loss: 2.6841 - 96ms/epoch - 96ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.6537 - val_loss: 2.6750 - 96ms/epoch - 96ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.6440 - val_loss: 2.6737 - 99ms/epoch - 99ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.6430 - val_loss: 2.6663 - 98ms/epoch - 98ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.6353 - val_loss: 2.6616 - 111ms/epoch - 111ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.6306 - val_loss: 2.6588 - 98ms/epoch - 98ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.6281 - val_loss: 2.6530 - 93ms/epoch - 93ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.6212 - val_loss: 2.6512 - 96ms/epoch - 96ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.6196 - val_loss: 2.6457 - 97ms/epoch - 97ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.6152 - val_loss: 2.6402 - 96ms/epoch - 96ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.6086 - val_loss: 2.6377 - 100ms/epoch - 100ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.6061 - val_loss: 2.6317 - 101ms/epoch - 101ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.6012 - val_loss: 2.6253 - 96ms/epoch - 96ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.5936 - val_loss: 2.6225 - 100ms/epoch - 100ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.5905 - val_loss: 2.6175 - 96ms/epoch - 96ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.5868 - val_loss: 2.6116 - 98ms/epoch - 98ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.5799 - val_loss: 2.6101 - 99ms/epoch - 99ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.5777 - val_loss: 2.6044 - 137ms/epoch - 137ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.5729 - val_loss: 2.5993 - 99ms/epoch - 99ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.5677 - val_loss: 2.5954 - 96ms/epoch - 96ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.5631 - val_loss: 2.5922 - 98ms/epoch - 98ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.5601 - val_loss: 2.5887 - 96ms/epoch - 96ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.5567 - val_loss: 2.5837 - 98ms/epoch - 98ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.5520 - val_loss: 2.5797 - 97ms/epoch - 97ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.5469 - val_loss: 2.5737 - 97ms/epoch - 97ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.5411 - val_loss: 2.5697 - 96ms/epoch - 96ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.5370 - val_loss: 2.5667 - 98ms/epoch - 98ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.5335 - val_loss: 2.5651 - 113ms/epoch - 113ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.5317 - val_loss: 2.5667 - 92ms/epoch - 92ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.5337 - val_loss: 2.5712 - 91ms/epoch - 91ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.5383 - val_loss: 2.5602 - 96ms/epoch - 96ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.5259 - val_loss: 2.5476 - 103ms/epoch - 103ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.5130 - val_loss: 2.5526 - 92ms/epoch - 92ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.5191 - val_loss: 2.5460 - 98ms/epoch - 98ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.5110 - val_loss: 2.5369 - 96ms/epoch - 96ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.5018 - val_loss: 2.5370 - 92ms/epoch - 92ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.5032 - val_loss: 2.5326 - 97ms/epoch - 97ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.4985 - val_loss: 2.5227 - 97ms/epoch - 97ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.4887 - val_loss: 2.5211 - 97ms/epoch - 97ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.4877 - val_loss: 2.5179 - 97ms/epoch - 97ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.4846 - val_loss: 2.5105 - 101ms/epoch - 101ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 2.4777 - val_loss: 2.5055 - 97ms/epoch - 97ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 2.4727 - val_loss: 2.5029 - 99ms/epoch - 99ms/step\n",
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 2.4705 - val_loss: 2.4996 - 97ms/epoch - 97ms/step\n",
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 2.4690 - val_loss: 2.4929 - 99ms/epoch - 99ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 101/2000\n",
      "1/1 - 0s - loss: 2.4623 - val_loss: 2.4888 - 105ms/epoch - 105ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 2.4577 - val_loss: 2.4849 - 97ms/epoch - 97ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 2.4536 - val_loss: 2.4819 - 109ms/epoch - 109ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 2.4506 - val_loss: 2.4780 - 101ms/epoch - 101ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 2.4477 - val_loss: 2.4751 - 100ms/epoch - 100ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 2.4446 - val_loss: 2.4706 - 100ms/epoch - 100ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 2.4412 - val_loss: 2.4687 - 103ms/epoch - 103ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 2.4390 - val_loss: 2.4658 - 104ms/epoch - 104ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 2.4366 - val_loss: 2.4648 - 99ms/epoch - 99ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 2.4357 - val_loss: 2.4613 - 96ms/epoch - 96ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 2.4327 - val_loss: 2.4565 - 97ms/epoch - 97ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 2.4269 - val_loss: 2.4494 - 98ms/epoch - 98ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 2.4201 - val_loss: 2.4452 - 89ms/epoch - 89ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 2.4159 - val_loss: 2.4437 - 93ms/epoch - 93ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 2.4145 - val_loss: 2.4428 - 96ms/epoch - 96ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 2.4142 - val_loss: 2.4427 - 136ms/epoch - 136ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 2.4134 - val_loss: 2.4381 - 95ms/epoch - 95ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 2.4099 - val_loss: 2.4325 - 96ms/epoch - 96ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 2.4034 - val_loss: 2.4267 - 95ms/epoch - 95ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 2.3980 - val_loss: 2.4241 - 99ms/epoch - 99ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 2.3956 - val_loss: 2.4240 - 95ms/epoch - 95ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 2.3952 - val_loss: 2.4214 - 96ms/epoch - 96ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 2.3937 - val_loss: 2.4190 - 93ms/epoch - 93ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 2.3908 - val_loss: 2.4138 - 96ms/epoch - 96ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 2.3864 - val_loss: 2.4097 - 96ms/epoch - 96ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 2.3818 - val_loss: 2.4056 - 96ms/epoch - 96ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 2.3779 - val_loss: 2.4029 - 98ms/epoch - 98ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 2.3756 - val_loss: 2.4019 - 97ms/epoch - 97ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 2.3744 - val_loss: 2.4006 - 98ms/epoch - 98ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 2.3741 - val_loss: 2.4041 - 90ms/epoch - 90ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 2.3771 - val_loss: 2.4063 - 89ms/epoch - 89ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 2.3813 - val_loss: 2.4064 - 90ms/epoch - 90ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 2.3794 - val_loss: 2.3913 - 96ms/epoch - 96ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 2.3652 - val_loss: 2.3854 - 97ms/epoch - 97ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 2.3587 - val_loss: 2.3917 - 89ms/epoch - 89ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 2.3645 - val_loss: 2.3871 - 90ms/epoch - 90ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 2.3613 - val_loss: 2.3779 - 95ms/epoch - 95ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 2.3507 - val_loss: 2.3774 - 96ms/epoch - 96ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 2.3501 - val_loss: 2.3781 - 90ms/epoch - 90ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 2.3522 - val_loss: 2.3734 - 101ms/epoch - 101ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 2.3462 - val_loss: 2.3678 - 104ms/epoch - 104ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 2.3406 - val_loss: 2.3685 - 105ms/epoch - 105ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 2.3425 - val_loss: 2.3670 - 112ms/epoch - 112ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 2.3400 - val_loss: 2.3598 - 95ms/epoch - 95ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 2.3333 - val_loss: 2.3598 - 96ms/epoch - 96ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 2.3336 - val_loss: 2.3607 - 91ms/epoch - 91ms/step\n",
      "Epoch 147/2000\n",
      "1/1 - 0s - loss: 2.3336 - val_loss: 2.3546 - 94ms/epoch - 94ms/step\n",
      "Epoch 148/2000\n",
      "1/1 - 0s - loss: 2.3282 - val_loss: 2.3512 - 99ms/epoch - 99ms/step\n",
      "Epoch 149/2000\n",
      "1/1 - 0s - loss: 2.3244 - val_loss: 2.3526 - 102ms/epoch - 102ms/step\n",
      "Epoch 150/2000\n",
      "1/1 - 0s - loss: 2.3253 - val_loss: 2.3485 - 96ms/epoch - 96ms/step\n",
      "Epoch 151/2000\n",
      "1/1 - 0s - loss: 2.3224 - val_loss: 2.3442 - 96ms/epoch - 96ms/step\n",
      "Epoch 152/2000\n",
      "1/1 - 0s - loss: 2.3174 - val_loss: 2.3433 - 98ms/epoch - 98ms/step\n",
      "Epoch 153/2000\n",
      "1/1 - 0s - loss: 2.3162 - val_loss: 2.3417 - 95ms/epoch - 95ms/step\n",
      "Epoch 154/2000\n",
      "1/1 - 0s - loss: 2.3158 - val_loss: 2.3405 - 95ms/epoch - 95ms/step\n",
      "Epoch 155/2000\n",
      "1/1 - 0s - loss: 2.3138 - val_loss: 2.3356 - 95ms/epoch - 95ms/step\n",
      "Epoch 156/2000\n",
      "1/1 - 0s - loss: 2.3091 - val_loss: 2.3331 - 95ms/epoch - 95ms/step\n",
      "Epoch 157/2000\n",
      "1/1 - 0s - loss: 2.3066 - val_loss: 2.3330 - 98ms/epoch - 98ms/step\n",
      "Epoch 158/2000\n",
      "1/1 - 0s - loss: 2.3063 - val_loss: 2.3298 - 96ms/epoch - 96ms/step\n",
      "Epoch 159/2000\n",
      "1/1 - 0s - loss: 2.3040 - val_loss: 2.3275 - 97ms/epoch - 97ms/step\n",
      "Epoch 160/2000\n",
      "1/1 - 0s - loss: 2.3008 - val_loss: 2.3238 - 96ms/epoch - 96ms/step\n",
      "Epoch 161/2000\n",
      "1/1 - 0s - loss: 2.2977 - val_loss: 2.3219 - 100ms/epoch - 100ms/step\n",
      "Epoch 162/2000\n",
      "1/1 - 0s - loss: 2.2961 - val_loss: 2.3216 - 97ms/epoch - 97ms/step\n",
      "Epoch 163/2000\n",
      "1/1 - 0s - loss: 2.2951 - val_loss: 2.3186 - 96ms/epoch - 96ms/step\n",
      "Epoch 164/2000\n",
      "1/1 - 0s - loss: 2.2933 - val_loss: 2.3168 - 97ms/epoch - 97ms/step\n",
      "Epoch 165/2000\n",
      "1/1 - 0s - loss: 2.2907 - val_loss: 2.3136 - 98ms/epoch - 98ms/step\n",
      "Epoch 166/2000\n",
      "1/1 - 0s - loss: 2.2878 - val_loss: 2.3111 - 102ms/epoch - 102ms/step\n",
      "Epoch 167/2000\n",
      "1/1 - 0s - loss: 2.2853 - val_loss: 2.3095 - 95ms/epoch - 95ms/step\n",
      "Epoch 168/2000\n",
      "1/1 - 0s - loss: 2.2837 - val_loss: 2.3077 - 93ms/epoch - 93ms/step\n",
      "Epoch 169/2000\n",
      "1/1 - 0s - loss: 2.2823 - val_loss: 2.3070 - 94ms/epoch - 94ms/step\n",
      "Epoch 170/2000\n",
      "1/1 - 0s - loss: 2.2810 - val_loss: 2.3044 - 94ms/epoch - 94ms/step\n",
      "Epoch 171/2000\n",
      "1/1 - 0s - loss: 2.2797 - val_loss: 2.3041 - 94ms/epoch - 94ms/step\n",
      "Epoch 172/2000\n",
      "1/1 - 0s - loss: 2.2782 - val_loss: 2.3011 - 93ms/epoch - 93ms/step\n",
      "Epoch 173/2000\n",
      "1/1 - 0s - loss: 2.2764 - val_loss: 2.3000 - 94ms/epoch - 94ms/step\n",
      "Epoch 174/2000\n",
      "1/1 - 0s - loss: 2.2742 - val_loss: 2.2961 - 95ms/epoch - 95ms/step\n",
      "Epoch 175/2000\n",
      "1/1 - 0s - loss: 2.2712 - val_loss: 2.2942 - 93ms/epoch - 93ms/step\n",
      "Epoch 176/2000\n",
      "1/1 - 0s - loss: 2.2685 - val_loss: 2.2912 - 94ms/epoch - 94ms/step\n",
      "Epoch 177/2000\n",
      "1/1 - 0s - loss: 2.2662 - val_loss: 2.2893 - 104ms/epoch - 104ms/step\n",
      "Epoch 178/2000\n",
      "1/1 - 0s - loss: 2.2640 - val_loss: 2.2872 - 99ms/epoch - 99ms/step\n",
      "Epoch 179/2000\n",
      "1/1 - 0s - loss: 2.2620 - val_loss: 2.2850 - 96ms/epoch - 96ms/step\n",
      "Epoch 180/2000\n",
      "1/1 - 0s - loss: 2.2603 - val_loss: 2.2837 - 95ms/epoch - 95ms/step\n",
      "Epoch 181/2000\n",
      "1/1 - 0s - loss: 2.2586 - val_loss: 2.2817 - 89ms/epoch - 89ms/step\n",
      "Epoch 182/2000\n",
      "1/1 - 0s - loss: 2.2571 - val_loss: 2.2816 - 81ms/epoch - 81ms/step\n",
      "Epoch 183/2000\n",
      "1/1 - 0s - loss: 2.2564 - val_loss: 2.2811 - 79ms/epoch - 79ms/step\n",
      "Epoch 184/2000\n",
      "1/1 - 0s - loss: 2.2575 - val_loss: 2.2891 - 74ms/epoch - 74ms/step\n",
      "Epoch 185/2000\n",
      "1/1 - 0s - loss: 2.2636 - val_loss: 2.2981 - 75ms/epoch - 75ms/step\n",
      "Epoch 186/2000\n",
      "1/1 - 0s - loss: 2.2766 - val_loss: 2.3119 - 73ms/epoch - 73ms/step\n",
      "Epoch 187/2000\n",
      "1/1 - 0s - loss: 2.2853 - val_loss: 2.2820 - 76ms/epoch - 76ms/step\n",
      "Epoch 188/2000\n",
      "1/1 - 0s - loss: 2.2579 - val_loss: 2.2781 - 81ms/epoch - 81ms/step\n",
      "Epoch 189/2000\n",
      "1/1 - 0s - loss: 2.2531 - val_loss: 2.2927 - 74ms/epoch - 74ms/step\n",
      "Epoch 190/2000\n",
      "1/1 - 0s - loss: 2.2656 - val_loss: 2.2741 - 87ms/epoch - 87ms/step\n",
      "Epoch 191/2000\n",
      "1/1 - 0s - loss: 2.2490 - val_loss: 2.2774 - 78ms/epoch - 78ms/step\n",
      "Epoch 192/2000\n",
      "1/1 - 0s - loss: 2.2525 - val_loss: 2.2842 - 71ms/epoch - 71ms/step\n",
      "Epoch 193/2000\n",
      "1/1 - 0s - loss: 2.2564 - val_loss: 2.2664 - 80ms/epoch - 80ms/step\n",
      "Epoch 194/2000\n",
      "1/1 - 0s - loss: 2.2403 - val_loss: 2.2744 - 76ms/epoch - 76ms/step\n",
      "Epoch 195/2000\n",
      "1/1 - 0s - loss: 2.2510 - val_loss: 2.2666 - 78ms/epoch - 78ms/step\n",
      "Epoch 196/2000\n",
      "1/1 - 0s - loss: 2.2397 - val_loss: 2.2693 - 76ms/epoch - 76ms/step\n",
      "Epoch 197/2000\n",
      "1/1 - 0s - loss: 2.2417 - val_loss: 2.2624 - 82ms/epoch - 82ms/step\n",
      "Epoch 198/2000\n",
      "1/1 - 0s - loss: 2.2367 - val_loss: 2.2601 - 82ms/epoch - 82ms/step\n",
      "Epoch 199/2000\n",
      "1/1 - 0s - loss: 2.2344 - val_loss: 2.2618 - 77ms/epoch - 77ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200/2000\n",
      "1/1 - 0s - loss: 2.2347 - val_loss: 2.2549 - 81ms/epoch - 81ms/step\n",
      "Epoch 201/2000\n",
      "1/1 - 0s - loss: 2.2282 - val_loss: 2.2557 - 73ms/epoch - 73ms/step\n",
      "Epoch 202/2000\n",
      "1/1 - 0s - loss: 2.2313 - val_loss: 2.2497 - 83ms/epoch - 83ms/step\n",
      "Epoch 203/2000\n",
      "1/1 - 0s - loss: 2.2243 - val_loss: 2.2540 - 73ms/epoch - 73ms/step\n",
      "Epoch 204/2000\n",
      "1/1 - 0s - loss: 2.2269 - val_loss: 2.2474 - 82ms/epoch - 82ms/step\n",
      "Epoch 205/2000\n",
      "1/1 - 0s - loss: 2.2210 - val_loss: 2.2471 - 80ms/epoch - 80ms/step\n",
      "Epoch 206/2000\n",
      "1/1 - 0s - loss: 2.2221 - val_loss: 2.2441 - 79ms/epoch - 79ms/step\n",
      "Epoch 207/2000\n",
      "1/1 - 0s - loss: 2.2183 - val_loss: 2.2442 - 87ms/epoch - 87ms/step\n",
      "Epoch 208/2000\n",
      "1/1 - 0s - loss: 2.2180 - val_loss: 2.2400 - 89ms/epoch - 89ms/step\n",
      "Epoch 209/2000\n",
      "1/1 - 0s - loss: 2.2148 - val_loss: 2.2388 - 81ms/epoch - 81ms/step\n",
      "Epoch 210/2000\n",
      "1/1 - 0s - loss: 2.2143 - val_loss: 2.2375 - 80ms/epoch - 80ms/step\n",
      "Epoch 211/2000\n",
      "1/1 - 0s - loss: 2.2119 - val_loss: 2.2366 - 82ms/epoch - 82ms/step\n",
      "Epoch 212/2000\n",
      "1/1 - 0s - loss: 2.2104 - val_loss: 2.2341 - 84ms/epoch - 84ms/step\n",
      "Epoch 213/2000\n",
      "1/1 - 0s - loss: 2.2089 - val_loss: 2.2320 - 81ms/epoch - 81ms/step\n",
      "Epoch 214/2000\n",
      "1/1 - 0s - loss: 2.2068 - val_loss: 2.2315 - 80ms/epoch - 80ms/step\n",
      "Epoch 215/2000\n",
      "1/1 - 0s - loss: 2.2057 - val_loss: 2.2291 - 81ms/epoch - 81ms/step\n",
      "Epoch 216/2000\n",
      "1/1 - 0s - loss: 2.2035 - val_loss: 2.2274 - 82ms/epoch - 82ms/step\n",
      "Epoch 217/2000\n",
      "1/1 - 0s - loss: 2.2026 - val_loss: 2.2254 - 81ms/epoch - 81ms/step\n",
      "Epoch 218/2000\n",
      "1/1 - 0s - loss: 2.2005 - val_loss: 2.2248 - 81ms/epoch - 81ms/step\n",
      "Epoch 219/2000\n",
      "1/1 - 0s - loss: 2.1993 - val_loss: 2.2230 - 84ms/epoch - 84ms/step\n",
      "Epoch 220/2000\n",
      "1/1 - 0s - loss: 2.1977 - val_loss: 2.2211 - 83ms/epoch - 83ms/step\n",
      "Epoch 221/2000\n",
      "1/1 - 0s - loss: 2.1961 - val_loss: 2.2201 - 90ms/epoch - 90ms/step\n",
      "Epoch 222/2000\n",
      "1/1 - 0s - loss: 2.1948 - val_loss: 2.2182 - 88ms/epoch - 88ms/step\n",
      "Epoch 223/2000\n",
      "1/1 - 0s - loss: 2.1931 - val_loss: 2.2161 - 82ms/epoch - 82ms/step\n",
      "Epoch 224/2000\n",
      "1/1 - 0s - loss: 2.1918 - val_loss: 2.2145 - 83ms/epoch - 83ms/step\n",
      "Epoch 225/2000\n",
      "1/1 - 0s - loss: 2.1903 - val_loss: 2.2137 - 81ms/epoch - 81ms/step\n",
      "Epoch 226/2000\n",
      "1/1 - 0s - loss: 2.1889 - val_loss: 2.2123 - 81ms/epoch - 81ms/step\n",
      "Epoch 227/2000\n",
      "1/1 - 0s - loss: 2.1875 - val_loss: 2.2107 - 80ms/epoch - 80ms/step\n",
      "Epoch 228/2000\n",
      "1/1 - 0s - loss: 2.1860 - val_loss: 2.2094 - 83ms/epoch - 83ms/step\n",
      "Epoch 229/2000\n",
      "1/1 - 0s - loss: 2.1847 - val_loss: 2.2075 - 85ms/epoch - 85ms/step\n",
      "Epoch 230/2000\n",
      "1/1 - 0s - loss: 2.1832 - val_loss: 2.2060 - 80ms/epoch - 80ms/step\n",
      "Epoch 231/2000\n",
      "1/1 - 0s - loss: 2.1819 - val_loss: 2.2049 - 82ms/epoch - 82ms/step\n",
      "Epoch 232/2000\n",
      "1/1 - 0s - loss: 2.1805 - val_loss: 2.2037 - 80ms/epoch - 80ms/step\n",
      "Epoch 233/2000\n",
      "1/1 - 0s - loss: 2.1792 - val_loss: 2.2024 - 81ms/epoch - 81ms/step\n",
      "Epoch 234/2000\n",
      "1/1 - 0s - loss: 2.1779 - val_loss: 2.2008 - 80ms/epoch - 80ms/step\n",
      "Epoch 235/2000\n",
      "1/1 - 0s - loss: 2.1764 - val_loss: 2.1995 - 80ms/epoch - 80ms/step\n",
      "Epoch 236/2000\n",
      "1/1 - 0s - loss: 2.1752 - val_loss: 2.1980 - 80ms/epoch - 80ms/step\n",
      "Epoch 237/2000\n",
      "1/1 - 0s - loss: 2.1739 - val_loss: 2.1968 - 84ms/epoch - 84ms/step\n",
      "Epoch 238/2000\n",
      "1/1 - 0s - loss: 2.1725 - val_loss: 2.1958 - 87ms/epoch - 87ms/step\n",
      "Epoch 239/2000\n",
      "1/1 - 0s - loss: 2.1713 - val_loss: 2.1942 - 81ms/epoch - 81ms/step\n",
      "Epoch 240/2000\n",
      "1/1 - 0s - loss: 2.1700 - val_loss: 2.1930 - 84ms/epoch - 84ms/step\n",
      "Epoch 241/2000\n",
      "1/1 - 0s - loss: 2.1687 - val_loss: 2.1918 - 80ms/epoch - 80ms/step\n",
      "Epoch 242/2000\n",
      "1/1 - 0s - loss: 2.1675 - val_loss: 2.1905 - 84ms/epoch - 84ms/step\n",
      "Epoch 243/2000\n",
      "1/1 - 0s - loss: 2.1662 - val_loss: 2.1895 - 84ms/epoch - 84ms/step\n",
      "Epoch 244/2000\n",
      "1/1 - 0s - loss: 2.1649 - val_loss: 2.1882 - 82ms/epoch - 82ms/step\n",
      "Epoch 245/2000\n",
      "1/1 - 0s - loss: 2.1637 - val_loss: 2.1866 - 96ms/epoch - 96ms/step\n",
      "Epoch 246/2000\n",
      "1/1 - 0s - loss: 2.1625 - val_loss: 2.1857 - 85ms/epoch - 85ms/step\n",
      "Epoch 247/2000\n",
      "1/1 - 0s - loss: 2.1613 - val_loss: 2.1844 - 85ms/epoch - 85ms/step\n",
      "Epoch 248/2000\n",
      "1/1 - 0s - loss: 2.1600 - val_loss: 2.1831 - 82ms/epoch - 82ms/step\n",
      "Epoch 249/2000\n",
      "1/1 - 0s - loss: 2.1589 - val_loss: 2.1821 - 80ms/epoch - 80ms/step\n",
      "Epoch 250/2000\n",
      "1/1 - 0s - loss: 2.1577 - val_loss: 2.1806 - 81ms/epoch - 81ms/step\n",
      "Epoch 251/2000\n",
      "1/1 - 0s - loss: 2.1564 - val_loss: 2.1795 - 81ms/epoch - 81ms/step\n",
      "Epoch 252/2000\n",
      "1/1 - 0s - loss: 2.1553 - val_loss: 2.1785 - 86ms/epoch - 86ms/step\n",
      "Epoch 253/2000\n",
      "1/1 - 0s - loss: 2.1541 - val_loss: 2.1772 - 87ms/epoch - 87ms/step\n",
      "Epoch 254/2000\n",
      "1/1 - 0s - loss: 2.1529 - val_loss: 2.1761 - 80ms/epoch - 80ms/step\n",
      "Epoch 255/2000\n",
      "1/1 - 0s - loss: 2.1518 - val_loss: 2.1749 - 82ms/epoch - 82ms/step\n",
      "Epoch 256/2000\n",
      "1/1 - 0s - loss: 2.1506 - val_loss: 2.1738 - 90ms/epoch - 90ms/step\n",
      "Epoch 257/2000\n",
      "1/1 - 0s - loss: 2.1495 - val_loss: 2.1727 - 87ms/epoch - 87ms/step\n",
      "Epoch 258/2000\n",
      "1/1 - 0s - loss: 2.1484 - val_loss: 2.1715 - 83ms/epoch - 83ms/step\n",
      "Epoch 259/2000\n",
      "1/1 - 0s - loss: 2.1472 - val_loss: 2.1705 - 81ms/epoch - 81ms/step\n",
      "Epoch 260/2000\n",
      "1/1 - 0s - loss: 2.1461 - val_loss: 2.1693 - 88ms/epoch - 88ms/step\n",
      "Epoch 261/2000\n",
      "1/1 - 0s - loss: 2.1450 - val_loss: 2.1683 - 99ms/epoch - 99ms/step\n",
      "Epoch 262/2000\n",
      "1/1 - 0s - loss: 2.1439 - val_loss: 2.1671 - 85ms/epoch - 85ms/step\n",
      "Epoch 263/2000\n",
      "1/1 - 0s - loss: 2.1428 - val_loss: 2.1661 - 80ms/epoch - 80ms/step\n",
      "Epoch 264/2000\n",
      "1/1 - 0s - loss: 2.1417 - val_loss: 2.1651 - 81ms/epoch - 81ms/step\n",
      "Epoch 265/2000\n",
      "1/1 - 0s - loss: 2.1406 - val_loss: 2.1640 - 102ms/epoch - 102ms/step\n",
      "Epoch 266/2000\n",
      "1/1 - 0s - loss: 2.1395 - val_loss: 2.1630 - 81ms/epoch - 81ms/step\n",
      "Epoch 267/2000\n",
      "1/1 - 0s - loss: 2.1385 - val_loss: 2.1618 - 95ms/epoch - 95ms/step\n",
      "Epoch 268/2000\n",
      "1/1 - 0s - loss: 2.1374 - val_loss: 2.1610 - 81ms/epoch - 81ms/step\n",
      "Epoch 269/2000\n",
      "1/1 - 0s - loss: 2.1364 - val_loss: 2.1598 - 83ms/epoch - 83ms/step\n",
      "Epoch 270/2000\n",
      "1/1 - 0s - loss: 2.1354 - val_loss: 2.1592 - 83ms/epoch - 83ms/step\n",
      "Epoch 271/2000\n",
      "1/1 - 0s - loss: 2.1344 - val_loss: 2.1578 - 82ms/epoch - 82ms/step\n",
      "Epoch 272/2000\n",
      "1/1 - 0s - loss: 2.1335 - val_loss: 2.1578 - 75ms/epoch - 75ms/step\n",
      "Epoch 273/2000\n",
      "1/1 - 0s - loss: 2.1326 - val_loss: 2.1560 - 81ms/epoch - 81ms/step\n",
      "Epoch 274/2000\n",
      "1/1 - 0s - loss: 2.1319 - val_loss: 2.1566 - 85ms/epoch - 85ms/step\n",
      "Epoch 275/2000\n",
      "1/1 - 0s - loss: 2.1312 - val_loss: 2.1545 - 86ms/epoch - 86ms/step\n",
      "Epoch 276/2000\n",
      "1/1 - 0s - loss: 2.1305 - val_loss: 2.1554 - 92ms/epoch - 92ms/step\n",
      "Epoch 277/2000\n",
      "1/1 - 0s - loss: 2.1297 - val_loss: 2.1527 - 98ms/epoch - 98ms/step\n",
      "Epoch 278/2000\n",
      "1/1 - 0s - loss: 2.1285 - val_loss: 2.1527 - 86ms/epoch - 86ms/step\n",
      "Epoch 279/2000\n",
      "1/1 - 0s - loss: 2.1272 - val_loss: 2.1502 - 92ms/epoch - 92ms/step\n",
      "Epoch 280/2000\n",
      "1/1 - 0s - loss: 2.1257 - val_loss: 2.1496 - 84ms/epoch - 84ms/step\n",
      "Epoch 281/2000\n",
      "1/1 - 0s - loss: 2.1244 - val_loss: 2.1483 - 84ms/epoch - 84ms/step\n",
      "Epoch 282/2000\n",
      "1/1 - 0s - loss: 2.1234 - val_loss: 2.1472 - 81ms/epoch - 81ms/step\n",
      "Epoch 283/2000\n",
      "1/1 - 0s - loss: 2.1225 - val_loss: 2.1469 - 84ms/epoch - 84ms/step\n",
      "Epoch 284/2000\n",
      "1/1 - 0s - loss: 2.1217 - val_loss: 2.1454 - 85ms/epoch - 85ms/step\n",
      "Epoch 285/2000\n",
      "1/1 - 0s - loss: 2.1210 - val_loss: 2.1460 - 75ms/epoch - 75ms/step\n",
      "Epoch 286/2000\n",
      "1/1 - 0s - loss: 2.1204 - val_loss: 2.1442 - 82ms/epoch - 82ms/step\n",
      "Epoch 287/2000\n",
      "1/1 - 0s - loss: 2.1199 - val_loss: 2.1452 - 79ms/epoch - 79ms/step\n",
      "Epoch 288/2000\n",
      "1/1 - 0s - loss: 2.1193 - val_loss: 2.1427 - 86ms/epoch - 86ms/step\n",
      "Epoch 289/2000\n",
      "1/1 - 0s - loss: 2.1183 - val_loss: 2.1431 - 75ms/epoch - 75ms/step\n",
      "Epoch 290/2000\n",
      "1/1 - 0s - loss: 2.1172 - val_loss: 2.1404 - 82ms/epoch - 82ms/step\n",
      "Epoch 291/2000\n",
      "1/1 - 0s - loss: 2.1158 - val_loss: 2.1399 - 84ms/epoch - 84ms/step\n",
      "Epoch 292/2000\n",
      "1/1 - 0s - loss: 2.1144 - val_loss: 2.1385 - 86ms/epoch - 86ms/step\n",
      "Epoch 293/2000\n",
      "1/1 - 0s - loss: 2.1133 - val_loss: 2.1375 - 83ms/epoch - 83ms/step\n",
      "Epoch 294/2000\n",
      "1/1 - 0s - loss: 2.1126 - val_loss: 2.1377 - 75ms/epoch - 75ms/step\n",
      "Epoch 295/2000\n",
      "1/1 - 0s - loss: 2.1121 - val_loss: 2.1363 - 93ms/epoch - 93ms/step\n",
      "Epoch 296/2000\n",
      "1/1 - 0s - loss: 2.1117 - val_loss: 2.1377 - 75ms/epoch - 75ms/step\n",
      "Epoch 297/2000\n",
      "1/1 - 0s - loss: 2.1116 - val_loss: 2.1355 - 84ms/epoch - 84ms/step\n",
      "Epoch 298/2000\n",
      "1/1 - 0s - loss: 2.1111 - val_loss: 2.1365 - 75ms/epoch - 75ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 299/2000\n",
      "1/1 - 0s - loss: 2.1103 - val_loss: 2.1334 - 81ms/epoch - 81ms/step\n",
      "Epoch 300/2000\n",
      "1/1 - 0s - loss: 2.1087 - val_loss: 2.1328 - 88ms/epoch - 88ms/step\n",
      "Epoch 301/2000\n",
      "1/1 - 0s - loss: 2.1069 - val_loss: 2.1312 - 88ms/epoch - 88ms/step\n",
      "Epoch 302/2000\n",
      "1/1 - 0s - loss: 2.1057 - val_loss: 2.1303 - 85ms/epoch - 85ms/step\n",
      "Epoch 303/2000\n",
      "1/1 - 0s - loss: 2.1053 - val_loss: 2.1315 - 74ms/epoch - 74ms/step\n",
      "Epoch 304/2000\n",
      "1/1 - 0s - loss: 2.1054 - val_loss: 2.1300 - 81ms/epoch - 81ms/step\n",
      "Epoch 305/2000\n",
      "1/1 - 0s - loss: 2.1053 - val_loss: 2.1319 - 79ms/epoch - 79ms/step\n",
      "Epoch 306/2000\n",
      "1/1 - 0s - loss: 2.1051 - val_loss: 2.1284 - 91ms/epoch - 91ms/step\n",
      "Epoch 307/2000\n",
      "1/1 - 0s - loss: 2.1034 - val_loss: 2.1276 - 82ms/epoch - 82ms/step\n",
      "Epoch 308/2000\n",
      "1/1 - 0s - loss: 2.1015 - val_loss: 2.1260 - 82ms/epoch - 82ms/step\n",
      "Epoch 309/2000\n",
      "1/1 - 0s - loss: 2.1002 - val_loss: 2.1251 - 83ms/epoch - 83ms/step\n",
      "Epoch 310/2000\n",
      "1/1 - 0s - loss: 2.0997 - val_loss: 2.1263 - 76ms/epoch - 76ms/step\n",
      "Epoch 311/2000\n",
      "1/1 - 0s - loss: 2.0999 - val_loss: 2.1246 - 82ms/epoch - 82ms/step\n",
      "Epoch 312/2000\n",
      "1/1 - 0s - loss: 2.0995 - val_loss: 2.1255 - 78ms/epoch - 78ms/step\n",
      "Epoch 313/2000\n",
      "1/1 - 0s - loss: 2.0989 - val_loss: 2.1224 - 80ms/epoch - 80ms/step\n",
      "Epoch 314/2000\n",
      "1/1 - 0s - loss: 2.0972 - val_loss: 2.1218 - 82ms/epoch - 82ms/step\n",
      "Epoch 315/2000\n",
      "1/1 - 0s - loss: 2.0956 - val_loss: 2.1211 - 79ms/epoch - 79ms/step\n",
      "Epoch 316/2000\n",
      "1/1 - 0s - loss: 2.0949 - val_loss: 2.1201 - 83ms/epoch - 83ms/step\n",
      "Epoch 317/2000\n",
      "1/1 - 0s - loss: 2.0949 - val_loss: 2.1212 - 75ms/epoch - 75ms/step\n",
      "Epoch 318/2000\n",
      "1/1 - 0s - loss: 2.0946 - val_loss: 2.1189 - 84ms/epoch - 84ms/step\n",
      "Epoch 319/2000\n",
      "1/1 - 0s - loss: 2.0935 - val_loss: 2.1186 - 83ms/epoch - 83ms/step\n",
      "Epoch 320/2000\n",
      "1/1 - 0s - loss: 2.0922 - val_loss: 2.1172 - 82ms/epoch - 82ms/step\n",
      "Epoch 321/2000\n",
      "1/1 - 0s - loss: 2.0911 - val_loss: 2.1163 - 81ms/epoch - 81ms/step\n",
      "Epoch 322/2000\n",
      "1/1 - 0s - loss: 2.0907 - val_loss: 2.1170 - 80ms/epoch - 80ms/step\n",
      "Epoch 323/2000\n",
      "1/1 - 0s - loss: 2.0906 - val_loss: 2.1152 - 82ms/epoch - 82ms/step\n",
      "Epoch 324/2000\n",
      "1/1 - 0s - loss: 2.0899 - val_loss: 2.1152 - 82ms/epoch - 82ms/step\n",
      "Epoch 325/2000\n",
      "1/1 - 0s - loss: 2.0888 - val_loss: 2.1135 - 81ms/epoch - 81ms/step\n",
      "Epoch 326/2000\n",
      "1/1 - 0s - loss: 2.0876 - val_loss: 2.1129 - 81ms/epoch - 81ms/step\n",
      "Epoch 327/2000\n",
      "1/1 - 0s - loss: 2.0869 - val_loss: 2.1131 - 76ms/epoch - 76ms/step\n",
      "Epoch 328/2000\n",
      "1/1 - 0s - loss: 2.0865 - val_loss: 2.1117 - 82ms/epoch - 82ms/step\n",
      "Epoch 329/2000\n",
      "1/1 - 0s - loss: 2.0861 - val_loss: 2.1122 - 91ms/epoch - 91ms/step\n",
      "Epoch 330/2000\n",
      "1/1 - 0s - loss: 2.0855 - val_loss: 2.1101 - 103ms/epoch - 103ms/step\n",
      "Epoch 331/2000\n",
      "1/1 - 0s - loss: 2.0845 - val_loss: 2.1100 - 95ms/epoch - 95ms/step\n",
      "Epoch 332/2000\n",
      "1/1 - 0s - loss: 2.0835 - val_loss: 2.1090 - 91ms/epoch - 91ms/step\n",
      "Epoch 333/2000\n",
      "1/1 - 0s - loss: 2.0827 - val_loss: 2.1081 - 80ms/epoch - 80ms/step\n",
      "Epoch 334/2000\n",
      "1/1 - 0s - loss: 2.0822 - val_loss: 2.1085 - 77ms/epoch - 77ms/step\n",
      "Epoch 335/2000\n",
      "1/1 - 0s - loss: 2.0819 - val_loss: 2.1071 - 80ms/epoch - 80ms/step\n",
      "Epoch 336/2000\n",
      "1/1 - 0s - loss: 2.0814 - val_loss: 2.1077 - 75ms/epoch - 75ms/step\n",
      "Epoch 337/2000\n",
      "1/1 - 0s - loss: 2.0810 - val_loss: 2.1059 - 81ms/epoch - 81ms/step\n",
      "Epoch 338/2000\n",
      "1/1 - 0s - loss: 2.0802 - val_loss: 2.1060 - 82ms/epoch - 82ms/step\n",
      "Epoch 339/2000\n",
      "1/1 - 0s - loss: 2.0792 - val_loss: 2.1044 - 80ms/epoch - 80ms/step\n",
      "Epoch 340/2000\n",
      "1/1 - 0s - loss: 2.0783 - val_loss: 2.1040 - 82ms/epoch - 82ms/step\n",
      "Epoch 341/2000\n",
      "1/1 - 0s - loss: 2.0776 - val_loss: 2.1036 - 83ms/epoch - 83ms/step\n",
      "Epoch 342/2000\n",
      "1/1 - 0s - loss: 2.0771 - val_loss: 2.1026 - 82ms/epoch - 82ms/step\n",
      "Epoch 343/2000\n",
      "1/1 - 0s - loss: 2.0767 - val_loss: 2.1033 - 76ms/epoch - 76ms/step\n",
      "Epoch 344/2000\n",
      "1/1 - 0s - loss: 2.0763 - val_loss: 2.1016 - 82ms/epoch - 82ms/step\n",
      "Epoch 345/2000\n",
      "1/1 - 0s - loss: 2.0757 - val_loss: 2.1021 - 77ms/epoch - 77ms/step\n",
      "Epoch 346/2000\n",
      "1/1 - 0s - loss: 2.0750 - val_loss: 2.1004 - 83ms/epoch - 83ms/step\n",
      "Epoch 347/2000\n",
      "1/1 - 0s - loss: 2.0742 - val_loss: 2.1002 - 84ms/epoch - 84ms/step\n",
      "Epoch 348/2000\n",
      "1/1 - 0s - loss: 2.0734 - val_loss: 2.0994 - 81ms/epoch - 81ms/step\n",
      "Epoch 349/2000\n",
      "1/1 - 0s - loss: 2.0727 - val_loss: 2.0987 - 83ms/epoch - 83ms/step\n",
      "Epoch 350/2000\n",
      "1/1 - 0s - loss: 2.0722 - val_loss: 2.0988 - 73ms/epoch - 73ms/step\n",
      "Epoch 351/2000\n",
      "1/1 - 0s - loss: 2.0718 - val_loss: 2.0977 - 81ms/epoch - 81ms/step\n",
      "Epoch 352/2000\n",
      "1/1 - 0s - loss: 2.0714 - val_loss: 2.0988 - 76ms/epoch - 76ms/step\n",
      "Epoch 353/2000\n",
      "1/1 - 0s - loss: 2.0712 - val_loss: 2.0969 - 86ms/epoch - 86ms/step\n",
      "Epoch 354/2000\n",
      "1/1 - 0s - loss: 2.0709 - val_loss: 2.0984 - 77ms/epoch - 77ms/step\n",
      "Epoch 355/2000\n",
      "1/1 - 0s - loss: 2.0706 - val_loss: 2.0958 - 82ms/epoch - 82ms/step\n",
      "Epoch 356/2000\n",
      "1/1 - 0s - loss: 2.0698 - val_loss: 2.0962 - 76ms/epoch - 76ms/step\n",
      "Epoch 357/2000\n",
      "1/1 - 0s - loss: 2.0688 - val_loss: 2.0944 - 81ms/epoch - 81ms/step\n",
      "Epoch 358/2000\n",
      "1/1 - 0s - loss: 2.0677 - val_loss: 2.0939 - 80ms/epoch - 80ms/step\n",
      "Epoch 359/2000\n",
      "1/1 - 0s - loss: 2.0670 - val_loss: 2.0940 - 74ms/epoch - 74ms/step\n",
      "Epoch 360/2000\n",
      "1/1 - 0s - loss: 2.0667 - val_loss: 2.0929 - 80ms/epoch - 80ms/step\n",
      "Epoch 361/2000\n",
      "1/1 - 0s - loss: 2.0665 - val_loss: 2.0941 - 77ms/epoch - 77ms/step\n",
      "Epoch 362/2000\n",
      "1/1 - 0s - loss: 2.0664 - val_loss: 2.0922 - 82ms/epoch - 82ms/step\n",
      "Epoch 363/2000\n",
      "1/1 - 0s - loss: 2.0660 - val_loss: 2.0933 - 83ms/epoch - 83ms/step\n",
      "Epoch 364/2000\n",
      "1/1 - 0s - loss: 2.0654 - val_loss: 2.0910 - 88ms/epoch - 88ms/step\n",
      "Epoch 365/2000\n",
      "1/1 - 0s - loss: 2.0643 - val_loss: 2.0909 - 82ms/epoch - 82ms/step\n",
      "Epoch 366/2000\n",
      "1/1 - 0s - loss: 2.0634 - val_loss: 2.0902 - 82ms/epoch - 82ms/step\n",
      "Epoch 367/2000\n",
      "1/1 - 0s - loss: 2.0628 - val_loss: 2.0894 - 81ms/epoch - 81ms/step\n",
      "Epoch 368/2000\n",
      "1/1 - 0s - loss: 2.0625 - val_loss: 2.0903 - 76ms/epoch - 76ms/step\n",
      "Epoch 369/2000\n",
      "1/1 - 0s - loss: 2.0624 - val_loss: 2.0888 - 81ms/epoch - 81ms/step\n",
      "Epoch 370/2000\n",
      "1/1 - 0s - loss: 2.0622 - val_loss: 2.0901 - 76ms/epoch - 76ms/step\n",
      "Epoch 371/2000\n",
      "1/1 - 0s - loss: 2.0620 - val_loss: 2.0879 - 81ms/epoch - 81ms/step\n",
      "Epoch 372/2000\n",
      "1/1 - 0s - loss: 2.0612 - val_loss: 2.0883 - 73ms/epoch - 73ms/step\n",
      "Epoch 373/2000\n",
      "1/1 - 0s - loss: 2.0603 - val_loss: 2.0867 - 83ms/epoch - 83ms/step\n",
      "Epoch 374/2000\n",
      "1/1 - 0s - loss: 2.0594 - val_loss: 2.0862 - 81ms/epoch - 81ms/step\n",
      "Epoch 375/2000\n",
      "1/1 - 0s - loss: 2.0588 - val_loss: 2.0865 - 75ms/epoch - 75ms/step\n",
      "Epoch 376/2000\n",
      "1/1 - 0s - loss: 2.0585 - val_loss: 2.0855 - 80ms/epoch - 80ms/step\n",
      "Epoch 377/2000\n",
      "1/1 - 0s - loss: 2.0584 - val_loss: 2.0866 - 76ms/epoch - 76ms/step\n",
      "Epoch 378/2000\n",
      "1/1 - 0s - loss: 2.0583 - val_loss: 2.0847 - 83ms/epoch - 83ms/step\n",
      "Epoch 379/2000\n",
      "1/1 - 0s - loss: 2.0578 - val_loss: 2.0857 - 73ms/epoch - 73ms/step\n",
      "Epoch 380/2000\n",
      "1/1 - 0s - loss: 2.0572 - val_loss: 2.0835 - 83ms/epoch - 83ms/step\n",
      "Epoch 381/2000\n",
      "1/1 - 0s - loss: 2.0563 - val_loss: 2.0835 - 73ms/epoch - 73ms/step\n",
      "Epoch 382/2000\n",
      "1/1 - 0s - loss: 2.0556 - val_loss: 2.0829 - 81ms/epoch - 81ms/step\n",
      "Epoch 383/2000\n",
      "1/1 - 0s - loss: 2.0550 - val_loss: 2.0823 - 83ms/epoch - 83ms/step\n",
      "Epoch 384/2000\n",
      "1/1 - 0s - loss: 2.0547 - val_loss: 2.0830 - 74ms/epoch - 74ms/step\n",
      "Epoch 385/2000\n",
      "1/1 - 0s - loss: 2.0545 - val_loss: 2.0816 - 82ms/epoch - 82ms/step\n",
      "Epoch 386/2000\n",
      "1/1 - 0s - loss: 2.0542 - val_loss: 2.0823 - 73ms/epoch - 73ms/step\n",
      "Epoch 387/2000\n",
      "1/1 - 0s - loss: 2.0538 - val_loss: 2.0809 - 81ms/epoch - 81ms/step\n",
      "Epoch 388/2000\n",
      "1/1 - 0s - loss: 2.0533 - val_loss: 2.0813 - 76ms/epoch - 76ms/step\n",
      "Epoch 389/2000\n",
      "1/1 - 0s - loss: 2.0527 - val_loss: 2.0799 - 81ms/epoch - 81ms/step\n",
      "Epoch 390/2000\n",
      "1/1 - 0s - loss: 2.0521 - val_loss: 2.0800 - 76ms/epoch - 76ms/step\n",
      "Epoch 391/2000\n",
      "1/1 - 0s - loss: 2.0516 - val_loss: 2.0792 - 81ms/epoch - 81ms/step\n",
      "Epoch 392/2000\n",
      "1/1 - 0s - loss: 2.0511 - val_loss: 2.0789 - 78ms/epoch - 78ms/step\n",
      "Epoch 393/2000\n",
      "1/1 - 0s - loss: 2.0507 - val_loss: 2.0789 - 82ms/epoch - 82ms/step\n",
      "Epoch 394/2000\n",
      "1/1 - 0s - loss: 2.0503 - val_loss: 2.0779 - 81ms/epoch - 81ms/step\n",
      "Epoch 395/2000\n",
      "1/1 - 0s - loss: 2.0500 - val_loss: 2.0787 - 75ms/epoch - 75ms/step\n",
      "Epoch 396/2000\n",
      "1/1 - 0s - loss: 2.0498 - val_loss: 2.0773 - 88ms/epoch - 88ms/step\n",
      "Epoch 397/2000\n",
      "1/1 - 0s - loss: 2.0496 - val_loss: 2.0788 - 78ms/epoch - 78ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 398/2000\n",
      "1/1 - 0s - loss: 2.0495 - val_loss: 2.0769 - 84ms/epoch - 84ms/step\n",
      "Epoch 399/2000\n",
      "1/1 - 0s - loss: 2.0494 - val_loss: 2.0785 - 75ms/epoch - 75ms/step\n",
      "Epoch 400/2000\n",
      "1/1 - 0s - loss: 2.0490 - val_loss: 2.0762 - 83ms/epoch - 83ms/step\n",
      "Epoch 401/2000\n",
      "1/1 - 0s - loss: 2.0484 - val_loss: 2.0769 - 75ms/epoch - 75ms/step\n",
      "Epoch 402/2000\n",
      "1/1 - 0s - loss: 2.0475 - val_loss: 2.0751 - 81ms/epoch - 81ms/step\n",
      "Epoch 403/2000\n",
      "1/1 - 0s - loss: 2.0467 - val_loss: 2.0750 - 81ms/epoch - 81ms/step\n",
      "Epoch 404/2000\n",
      "1/1 - 0s - loss: 2.0462 - val_loss: 2.0750 - 81ms/epoch - 81ms/step\n",
      "Epoch 405/2000\n",
      "1/1 - 0s - loss: 2.0459 - val_loss: 2.0740 - 82ms/epoch - 82ms/step\n",
      "Epoch 406/2000\n",
      "1/1 - 0s - loss: 2.0458 - val_loss: 2.0752 - 76ms/epoch - 76ms/step\n",
      "Epoch 407/2000\n",
      "1/1 - 0s - loss: 2.0457 - val_loss: 2.0736 - 81ms/epoch - 81ms/step\n",
      "Epoch 408/2000\n",
      "1/1 - 0s - loss: 2.0456 - val_loss: 2.0750 - 74ms/epoch - 74ms/step\n",
      "Epoch 409/2000\n",
      "1/1 - 0s - loss: 2.0453 - val_loss: 2.0729 - 82ms/epoch - 82ms/step\n",
      "Epoch 410/2000\n",
      "1/1 - 0s - loss: 2.0447 - val_loss: 2.0735 - 75ms/epoch - 75ms/step\n",
      "Epoch 411/2000\n",
      "1/1 - 0s - loss: 2.0439 - val_loss: 2.0720 - 82ms/epoch - 82ms/step\n",
      "Epoch 412/2000\n",
      "1/1 - 0s - loss: 2.0433 - val_loss: 2.0718 - 82ms/epoch - 82ms/step\n",
      "Epoch 413/2000\n",
      "1/1 - 0s - loss: 2.0429 - val_loss: 2.0721 - 75ms/epoch - 75ms/step\n",
      "Epoch 414/2000\n",
      "1/1 - 0s - loss: 2.0427 - val_loss: 2.0711 - 81ms/epoch - 81ms/step\n",
      "Epoch 415/2000\n",
      "1/1 - 0s - loss: 2.0426 - val_loss: 2.0723 - 74ms/epoch - 74ms/step\n",
      "Epoch 416/2000\n",
      "1/1 - 0s - loss: 2.0424 - val_loss: 2.0705 - 81ms/epoch - 81ms/step\n",
      "Epoch 417/2000\n",
      "1/1 - 0s - loss: 2.0421 - val_loss: 2.0716 - 75ms/epoch - 75ms/step\n",
      "Epoch 418/2000\n",
      "1/1 - 0s - loss: 2.0416 - val_loss: 2.0698 - 82ms/epoch - 82ms/step\n",
      "Epoch 419/2000\n",
      "1/1 - 0s - loss: 2.0410 - val_loss: 2.0701 - 77ms/epoch - 77ms/step\n",
      "Epoch 420/2000\n",
      "1/1 - 0s - loss: 2.0405 - val_loss: 2.0695 - 85ms/epoch - 85ms/step\n",
      "Epoch 421/2000\n",
      "1/1 - 0s - loss: 2.0400 - val_loss: 2.0690 - 81ms/epoch - 81ms/step\n",
      "Epoch 422/2000\n",
      "1/1 - 0s - loss: 2.0397 - val_loss: 2.0695 - 78ms/epoch - 78ms/step\n",
      "Epoch 423/2000\n",
      "1/1 - 0s - loss: 2.0395 - val_loss: 2.0684 - 83ms/epoch - 83ms/step\n",
      "Epoch 424/2000\n",
      "1/1 - 0s - loss: 2.0395 - val_loss: 2.0697 - 77ms/epoch - 77ms/step\n",
      "Epoch 425/2000\n",
      "1/1 - 0s - loss: 2.0394 - val_loss: 2.0680 - 81ms/epoch - 81ms/step\n",
      "Epoch 426/2000\n",
      "1/1 - 0s - loss: 2.0393 - val_loss: 2.0699 - 76ms/epoch - 76ms/step\n",
      "Epoch 427/2000\n",
      "1/1 - 0s - loss: 2.0392 - val_loss: 2.0675 - 80ms/epoch - 80ms/step\n",
      "Epoch 428/2000\n",
      "1/1 - 0s - loss: 2.0390 - val_loss: 2.0689 - 74ms/epoch - 74ms/step\n",
      "Epoch 429/2000\n",
      "1/1 - 0s - loss: 2.0383 - val_loss: 2.0666 - 79ms/epoch - 79ms/step\n",
      "Epoch 430/2000\n",
      "1/1 - 0s - loss: 2.0375 - val_loss: 2.0667 - 74ms/epoch - 74ms/step\n",
      "Epoch 431/2000\n",
      "1/1 - 0s - loss: 2.0368 - val_loss: 2.0667 - 73ms/epoch - 73ms/step\n",
      "Epoch 432/2000\n",
      "1/1 - 0s - loss: 2.0365 - val_loss: 2.0656 - 81ms/epoch - 81ms/step\n",
      "Epoch 433/2000\n",
      "1/1 - 0s - loss: 2.0366 - val_loss: 2.0672 - 88ms/epoch - 88ms/step\n",
      "Epoch 434/2000\n",
      "1/1 - 0s - loss: 2.0366 - val_loss: 2.0655 - 81ms/epoch - 81ms/step\n",
      "Epoch 435/2000\n",
      "1/1 - 0s - loss: 2.0366 - val_loss: 2.0668 - 74ms/epoch - 74ms/step\n",
      "Epoch 436/2000\n",
      "1/1 - 0s - loss: 2.0360 - val_loss: 2.0647 - 92ms/epoch - 92ms/step\n",
      "Epoch 437/2000\n",
      "1/1 - 0s - loss: 2.0352 - val_loss: 2.0648 - 78ms/epoch - 78ms/step\n",
      "Epoch 438/2000\n",
      "1/1 - 0s - loss: 2.0345 - val_loss: 2.0649 - 73ms/epoch - 73ms/step\n",
      "Epoch 439/2000\n",
      "1/1 - 0s - loss: 2.0343 - val_loss: 2.0640 - 82ms/epoch - 82ms/step\n",
      "Epoch 440/2000\n",
      "1/1 - 0s - loss: 2.0345 - val_loss: 2.0655 - 75ms/epoch - 75ms/step\n",
      "Epoch 441/2000\n",
      "1/1 - 0s - loss: 2.0345 - val_loss: 2.0634 - 80ms/epoch - 80ms/step\n",
      "Epoch 442/2000\n",
      "1/1 - 0s - loss: 2.0342 - val_loss: 2.0644 - 75ms/epoch - 75ms/step\n",
      "Epoch 443/2000\n",
      "1/1 - 0s - loss: 2.0335 - val_loss: 2.0628 - 83ms/epoch - 83ms/step\n",
      "Epoch 444/2000\n",
      "1/1 - 0s - loss: 2.0328 - val_loss: 2.0627 - 84ms/epoch - 84ms/step\n",
      "Epoch 445/2000\n",
      "1/1 - 0s - loss: 2.0325 - val_loss: 2.0635 - 76ms/epoch - 76ms/step\n",
      "Epoch 446/2000\n",
      "1/1 - 0s - loss: 2.0325 - val_loss: 2.0621 - 82ms/epoch - 82ms/step\n",
      "Epoch 447/2000\n",
      "1/1 - 0s - loss: 2.0326 - val_loss: 2.0636 - 77ms/epoch - 77ms/step\n",
      "Epoch 448/2000\n",
      "1/1 - 0s - loss: 2.0323 - val_loss: 2.0615 - 81ms/epoch - 81ms/step\n",
      "Epoch 449/2000\n",
      "1/1 - 0s - loss: 2.0318 - val_loss: 2.0621 - 76ms/epoch - 76ms/step\n",
      "Epoch 450/2000\n",
      "1/1 - 0s - loss: 2.0311 - val_loss: 2.0613 - 81ms/epoch - 81ms/step\n",
      "Epoch 451/2000\n",
      "1/1 - 0s - loss: 2.0306 - val_loss: 2.0607 - 80ms/epoch - 80ms/step\n",
      "Epoch 452/2000\n",
      "1/1 - 0s - loss: 2.0305 - val_loss: 2.0619 - 75ms/epoch - 75ms/step\n",
      "Epoch 453/2000\n",
      "1/1 - 0s - loss: 2.0305 - val_loss: 2.0604 - 80ms/epoch - 80ms/step\n",
      "Epoch 454/2000\n",
      "1/1 - 0s - loss: 2.0304 - val_loss: 2.0614 - 73ms/epoch - 73ms/step\n",
      "Epoch 455/2000\n",
      "1/1 - 0s - loss: 2.0300 - val_loss: 2.0598 - 79ms/epoch - 79ms/step\n",
      "Epoch 456/2000\n",
      "1/1 - 0s - loss: 2.0295 - val_loss: 2.0601 - 74ms/epoch - 74ms/step\n",
      "Epoch 457/2000\n",
      "1/1 - 0s - loss: 2.0290 - val_loss: 2.0597 - 79ms/epoch - 79ms/step\n",
      "Epoch 458/2000\n",
      "1/1 - 0s - loss: 2.0287 - val_loss: 2.0592 - 78ms/epoch - 78ms/step\n",
      "Epoch 459/2000\n",
      "1/1 - 0s - loss: 2.0286 - val_loss: 2.0600 - 73ms/epoch - 73ms/step\n",
      "Epoch 460/2000\n",
      "1/1 - 0s - loss: 2.0285 - val_loss: 2.0586 - 81ms/epoch - 81ms/step\n",
      "Epoch 461/2000\n",
      "1/1 - 0s - loss: 2.0284 - val_loss: 2.0600 - 75ms/epoch - 75ms/step\n",
      "Epoch 462/2000\n",
      "1/1 - 0s - loss: 2.0282 - val_loss: 2.0582 - 80ms/epoch - 80ms/step\n",
      "Epoch 463/2000\n",
      "1/1 - 0s - loss: 2.0278 - val_loss: 2.0591 - 75ms/epoch - 75ms/step\n",
      "Epoch 464/2000\n",
      "1/1 - 0s - loss: 2.0274 - val_loss: 2.0580 - 82ms/epoch - 82ms/step\n",
      "Epoch 465/2000\n",
      "1/1 - 0s - loss: 2.0269 - val_loss: 2.0580 - 80ms/epoch - 80ms/step\n",
      "Epoch 466/2000\n",
      "1/1 - 0s - loss: 2.0266 - val_loss: 2.0578 - 84ms/epoch - 84ms/step\n",
      "Epoch 467/2000\n",
      "1/1 - 0s - loss: 2.0263 - val_loss: 2.0573 - 81ms/epoch - 81ms/step\n",
      "Epoch 468/2000\n",
      "1/1 - 0s - loss: 2.0261 - val_loss: 2.0577 - 76ms/epoch - 76ms/step\n",
      "Epoch 469/2000\n",
      "1/1 - 0s - loss: 2.0259 - val_loss: 2.0569 - 87ms/epoch - 87ms/step\n",
      "Epoch 470/2000\n",
      "1/1 - 0s - loss: 2.0258 - val_loss: 2.0578 - 76ms/epoch - 76ms/step\n",
      "Epoch 471/2000\n",
      "1/1 - 0s - loss: 2.0257 - val_loss: 2.0564 - 80ms/epoch - 80ms/step\n",
      "Epoch 472/2000\n",
      "1/1 - 0s - loss: 2.0256 - val_loss: 2.0580 - 76ms/epoch - 76ms/step\n",
      "Epoch 473/2000\n",
      "1/1 - 0s - loss: 2.0255 - val_loss: 2.0561 - 81ms/epoch - 81ms/step\n",
      "Epoch 474/2000\n",
      "1/1 - 0s - loss: 2.0253 - val_loss: 2.0577 - 77ms/epoch - 77ms/step\n",
      "Epoch 475/2000\n",
      "1/1 - 0s - loss: 2.0250 - val_loss: 2.0558 - 83ms/epoch - 83ms/step\n",
      "Epoch 476/2000\n",
      "1/1 - 0s - loss: 2.0247 - val_loss: 2.0566 - 84ms/epoch - 84ms/step\n",
      "Epoch 477/2000\n",
      "1/1 - 0s - loss: 2.0242 - val_loss: 2.0554 - 82ms/epoch - 82ms/step\n",
      "Epoch 478/2000\n",
      "1/1 - 0s - loss: 2.0237 - val_loss: 2.0555 - 77ms/epoch - 77ms/step\n",
      "Epoch 479/2000\n",
      "1/1 - 0s - loss: 2.0233 - val_loss: 2.0553 - 83ms/epoch - 83ms/step\n",
      "Epoch 480/2000\n",
      "1/1 - 0s - loss: 2.0231 - val_loss: 2.0548 - 82ms/epoch - 82ms/step\n",
      "Epoch 481/2000\n",
      "1/1 - 0s - loss: 2.0230 - val_loss: 2.0556 - 73ms/epoch - 73ms/step\n",
      "Epoch 482/2000\n",
      "1/1 - 0s - loss: 2.0229 - val_loss: 2.0543 - 81ms/epoch - 81ms/step\n",
      "Epoch 483/2000\n",
      "1/1 - 0s - loss: 2.0229 - val_loss: 2.0561 - 73ms/epoch - 73ms/step\n",
      "Epoch 484/2000\n",
      "1/1 - 0s - loss: 2.0229 - val_loss: 2.0540 - 83ms/epoch - 83ms/step\n",
      "Epoch 485/2000\n",
      "1/1 - 0s - loss: 2.0228 - val_loss: 2.0558 - 74ms/epoch - 74ms/step\n",
      "Epoch 486/2000\n",
      "1/1 - 0s - loss: 2.0225 - val_loss: 2.0539 - 79ms/epoch - 79ms/step\n",
      "Epoch 487/2000\n",
      "1/1 - 0s - loss: 2.0221 - val_loss: 2.0546 - 74ms/epoch - 74ms/step\n",
      "Epoch 488/2000\n",
      "1/1 - 0s - loss: 2.0217 - val_loss: 2.0535 - 81ms/epoch - 81ms/step\n",
      "Epoch 489/2000\n",
      "1/1 - 0s - loss: 2.0212 - val_loss: 2.0535 - 77ms/epoch - 77ms/step\n",
      "Epoch 490/2000\n",
      "1/1 - 0s - loss: 2.0208 - val_loss: 2.0534 - 83ms/epoch - 83ms/step\n",
      "Epoch 491/2000\n",
      "1/1 - 0s - loss: 2.0206 - val_loss: 2.0529 - 83ms/epoch - 83ms/step\n",
      "Epoch 492/2000\n",
      "1/1 - 0s - loss: 2.0206 - val_loss: 2.0538 - 76ms/epoch - 76ms/step\n",
      "Epoch 493/2000\n",
      "1/1 - 0s - loss: 2.0206 - val_loss: 2.0525 - 87ms/epoch - 87ms/step\n",
      "Epoch 494/2000\n",
      "1/1 - 0s - loss: 2.0206 - val_loss: 2.0546 - 83ms/epoch - 83ms/step\n",
      "Epoch 495/2000\n",
      "1/1 - 0s - loss: 2.0207 - val_loss: 2.0522 - 90ms/epoch - 90ms/step\n",
      "Epoch 496/2000\n",
      "1/1 - 0s - loss: 2.0208 - val_loss: 2.0547 - 85ms/epoch - 85ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 497/2000\n",
      "1/1 - 0s - loss: 2.0206 - val_loss: 2.0520 - 94ms/epoch - 94ms/step\n",
      "Epoch 498/2000\n",
      "1/1 - 0s - loss: 2.0201 - val_loss: 2.0528 - 85ms/epoch - 85ms/step\n",
      "Epoch 499/2000\n",
      "1/1 - 0s - loss: 2.0194 - val_loss: 2.0520 - 102ms/epoch - 102ms/step\n",
      "Epoch 500/2000\n",
      "1/1 - 0s - loss: 2.0189 - val_loss: 2.0514 - 91ms/epoch - 91ms/step\n",
      "Epoch 501/2000\n",
      "1/1 - 0s - loss: 2.0186 - val_loss: 2.0521 - 81ms/epoch - 81ms/step\n",
      "Epoch 502/2000\n",
      "1/1 - 0s - loss: 2.0186 - val_loss: 2.0511 - 91ms/epoch - 91ms/step\n",
      "Epoch 503/2000\n",
      "1/1 - 0s - loss: 2.0186 - val_loss: 2.0522 - 82ms/epoch - 82ms/step\n",
      "Epoch 504/2000\n",
      "1/1 - 0s - loss: 2.0185 - val_loss: 2.0508 - 87ms/epoch - 87ms/step\n",
      "Epoch 505/2000\n",
      "1/1 - 0s - loss: 2.0181 - val_loss: 2.0514 - 80ms/epoch - 80ms/step\n",
      "Epoch 506/2000\n",
      "1/1 - 0s - loss: 2.0177 - val_loss: 2.0504 - 91ms/epoch - 91ms/step\n",
      "Epoch 507/2000\n",
      "1/1 - 0s - loss: 2.0174 - val_loss: 2.0508 - 81ms/epoch - 81ms/step\n",
      "Epoch 508/2000\n",
      "1/1 - 0s - loss: 2.0172 - val_loss: 2.0508 - 82ms/epoch - 82ms/step\n",
      "Epoch 509/2000\n",
      "1/1 - 0s - loss: 2.0170 - val_loss: 2.0498 - 90ms/epoch - 90ms/step\n",
      "Epoch 510/2000\n",
      "1/1 - 0s - loss: 2.0169 - val_loss: 2.0511 - 89ms/epoch - 89ms/step\n",
      "Epoch 511/2000\n",
      "1/1 - 0s - loss: 2.0168 - val_loss: 2.0495 - 91ms/epoch - 91ms/step\n",
      "Epoch 512/2000\n",
      "1/1 - 0s - loss: 2.0166 - val_loss: 2.0509 - 86ms/epoch - 86ms/step\n",
      "Epoch 513/2000\n",
      "1/1 - 0s - loss: 2.0164 - val_loss: 2.0494 - 90ms/epoch - 90ms/step\n",
      "Epoch 514/2000\n",
      "1/1 - 0s - loss: 2.0161 - val_loss: 2.0501 - 79ms/epoch - 79ms/step\n",
      "Epoch 515/2000\n",
      "1/1 - 0s - loss: 2.0158 - val_loss: 2.0496 - 93ms/epoch - 93ms/step\n",
      "Epoch 516/2000\n",
      "1/1 - 0s - loss: 2.0156 - val_loss: 2.0494 - 105ms/epoch - 105ms/step\n",
      "Epoch 517/2000\n",
      "1/1 - 0s - loss: 2.0154 - val_loss: 2.0492 - 89ms/epoch - 89ms/step\n",
      "Epoch 518/2000\n",
      "1/1 - 0s - loss: 2.0151 - val_loss: 2.0490 - 103ms/epoch - 103ms/step\n",
      "Epoch 519/2000\n",
      "1/1 - 0s - loss: 2.0149 - val_loss: 2.0489 - 94ms/epoch - 94ms/step\n",
      "Epoch 520/2000\n",
      "1/1 - 0s - loss: 2.0147 - val_loss: 2.0488 - 103ms/epoch - 103ms/step\n",
      "Epoch 521/2000\n",
      "1/1 - 0s - loss: 2.0146 - val_loss: 2.0489 - 105ms/epoch - 105ms/step\n",
      "Epoch 522/2000\n",
      "1/1 - 0s - loss: 2.0145 - val_loss: 2.0483 - 102ms/epoch - 102ms/step\n",
      "Epoch 523/2000\n",
      "1/1 - 0s - loss: 2.0144 - val_loss: 2.0500 - 83ms/epoch - 83ms/step\n",
      "Epoch 524/2000\n",
      "1/1 - 0s - loss: 2.0146 - val_loss: 2.0483 - 81ms/epoch - 81ms/step\n",
      "Epoch 525/2000\n",
      "1/1 - 0s - loss: 2.0152 - val_loss: 2.0522 - 79ms/epoch - 79ms/step\n",
      "Epoch 526/2000\n",
      "1/1 - 0s - loss: 2.0162 - val_loss: 2.0494 - 80ms/epoch - 80ms/step\n",
      "Epoch 527/2000\n",
      "1/1 - 0s - loss: 2.0170 - val_loss: 2.0527 - 103ms/epoch - 103ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 10\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "sinusoidal_embeddings = get_sinusoidal_embeddings(len(x_masked_train[0]), embed_dim)\n",
    "encoder_output = word_embeddings + sinusoidal_embeddings\n",
    "\n",
    "for i in range(5):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dfa1251c",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a0d91607",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.0, 0.9048236)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "bb91f983",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 5s - loss: 3.7975 - val_loss: 3.5023 - 5s/epoch - 5s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.4826 - val_loss: 3.3289 - 477ms/epoch - 477ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 1s - loss: 3.2997 - val_loss: 3.2003 - 560ms/epoch - 560ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 1s - loss: 3.1685 - val_loss: 3.1501 - 596ms/epoch - 596ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 1s - loss: 3.1184 - val_loss: 3.1181 - 558ms/epoch - 558ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 1s - loss: 3.0889 - val_loss: 3.0813 - 597ms/epoch - 597ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 1s - loss: 3.0556 - val_loss: 3.0442 - 582ms/epoch - 582ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 1s - loss: 3.0219 - val_loss: 3.0110 - 600ms/epoch - 600ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 1s - loss: 2.9908 - val_loss: 2.9813 - 629ms/epoch - 629ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 2.9611 - val_loss: 2.9534 - 478ms/epoch - 478ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 1s - loss: 2.9311 - val_loss: 2.9291 - 552ms/epoch - 552ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 2.9033 - val_loss: 2.9102 - 464ms/epoch - 464ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 2.8800 - val_loss: 2.8941 - 448ms/epoch - 448ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 1s - loss: 2.8587 - val_loss: 2.8718 - 531ms/epoch - 531ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 1s - loss: 2.8308 - val_loss: 2.8378 - 577ms/epoch - 577ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 2.7908 - val_loss: 2.7953 - 483ms/epoch - 483ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 2.7412 - val_loss: 2.7492 - 445ms/epoch - 445ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 1s - loss: 2.6870 - val_loss: 2.7030 - 524ms/epoch - 524ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 1s - loss: 2.6323 - val_loss: 2.6620 - 527ms/epoch - 527ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 1s - loss: 2.5830 - val_loss: 2.6309 - 565ms/epoch - 565ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 1s - loss: 2.5439 - val_loss: 2.6094 - 617ms/epoch - 617ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 1s - loss: 2.5158 - val_loss: 2.5952 - 564ms/epoch - 564ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 2.4964 - val_loss: 2.5812 - 448ms/epoch - 448ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 1s - loss: 2.4787 - val_loss: 2.5632 - 533ms/epoch - 533ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 2.4583 - val_loss: 2.5429 - 436ms/epoch - 436ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 1s - loss: 2.4376 - val_loss: 2.5241 - 624ms/epoch - 624ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 1s - loss: 2.4205 - val_loss: 2.5032 - 564ms/epoch - 564ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 1s - loss: 2.4023 - val_loss: 2.4778 - 605ms/epoch - 605ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 1s - loss: 2.3789 - val_loss: 2.4537 - 613ms/epoch - 613ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.3558 - val_loss: 2.4314 - 435ms/epoch - 435ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.3336 - val_loss: 2.4018 - 436ms/epoch - 436ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.3041 - val_loss: 2.3690 - 489ms/epoch - 489ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.2717 - val_loss: 2.3347 - 460ms/epoch - 460ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 1s - loss: 2.2378 - val_loss: 2.2964 - 516ms/epoch - 516ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 1s - loss: 2.1991 - val_loss: 2.2553 - 562ms/epoch - 562ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.1590 - val_loss: 2.2107 - 475ms/epoch - 475ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 1s - loss: 2.1181 - val_loss: 2.1728 - 594ms/epoch - 594ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 1s - loss: 2.0844 - val_loss: 2.1410 - 507ms/epoch - 507ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 1s - loss: 2.0553 - val_loss: 2.1171 - 530ms/epoch - 530ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 1s - loss: 2.0353 - val_loss: 2.0938 - 581ms/epoch - 581ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.0177 - val_loss: 2.0769 - 465ms/epoch - 465ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.0056 - val_loss: 2.0652 - 465ms/epoch - 465ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 1s - loss: 1.9975 - val_loss: 2.0533 - 537ms/epoch - 537ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 1s - loss: 1.9890 - val_loss: 2.0495 - 645ms/epoch - 645ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 1.9876 - val_loss: 2.0468 - 440ms/epoch - 440ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 1s - loss: 1.9856 - val_loss: 2.0431 - 668ms/epoch - 668ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 1.9818 - val_loss: 2.0415 - 485ms/epoch - 485ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 1.9793 - val_loss: 2.0412 - 455ms/epoch - 455ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 1.9774 - val_loss: 2.0412 - 464ms/epoch - 464ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 1s - loss: 1.9752 - val_loss: 2.0429 - 586ms/epoch - 586ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 1s - loss: 1.9746 - val_loss: 2.0421 - 571ms/epoch - 571ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 1s - loss: 1.9720 - val_loss: 2.0420 - 509ms/epoch - 509ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 1s - loss: 1.9703 - val_loss: 2.0430 - 546ms/epoch - 546ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 100\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "sinusoidal_embeddings = get_sinusoidal_embeddings(len(x_masked_train[0]), embed_dim)\n",
    "encoder_output = word_embeddings + sinusoidal_embeddings\n",
    "\n",
    "for i in range(5):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "81e26fe5",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "267160c5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.0, 0.98095894)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f4cab17",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
