{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d1686d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "56c107e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "504a342e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_module(query, key, value, embed_dim, num_head, i):\n",
    "    \n",
    "    # Multi headed self-attention\n",
    "    attention_output = layers.MultiHeadAttention(\n",
    "        num_heads=num_head,\n",
    "        key_dim=embed_dim // num_head,\n",
    "        name=\"encoder_{}/multiheadattention\".format(i)\n",
    "    )(query, key, value, use_causal_mask=True)\n",
    "    \n",
    "    # Add & Normalize\n",
    "    attention_output = layers.Add()([query, attention_output])  # Skip Connection\n",
    "    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)\n",
    "    \n",
    "    # Feedforward network\n",
    "    ff_net = keras.models.Sequential([\n",
    "        layers.Dense(2 * embed_dim, activation='relu', name=\"encoder_{}/ffn_dense_1\".format(i)),\n",
    "        layers.Dense(embed_dim, name=\"encoder_{}/ffn_dense_2\".format(i)),\n",
    "    ])\n",
    "\n",
    "    # Apply Feedforward network\n",
    "    ffn_output = ff_net(attention_output)\n",
    "\n",
    "    # Add & Normalize\n",
    "    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection\n",
    "    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)\n",
    "    \n",
    "    return ffn_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fef2622a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sinusoidal_embeddings(sequence_length, embedding_dim):\n",
    "    position_enc = np.array([\n",
    "        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]\n",
    "        if pos != 0 else np.zeros(embedding_dim)\n",
    "        for pos in range(sequence_length)\n",
    "    ])\n",
    "    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i\n",
    "    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1\n",
    "    return tf.cast(position_enc, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3e5bfd69",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 20 # vocab_size\n",
    "\n",
    "vocabs = ['word_' + str(i) for i in range(N)]\n",
    "\n",
    "vocab_map = {}\n",
    "for i in range(len(vocabs)):\n",
    "    vocab_map[vocabs[i]] = i\n",
    "    \n",
    "pairs = []\n",
    "\n",
    "for i in vocabs:\n",
    "    for j in vocabs:\n",
    "        for k in vocabs:\n",
    "            if i != j and i != k and j != k:\n",
    "                pairs.append((i,j,k))\n",
    "            \n",
    "indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])\n",
    "\n",
    "pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]\n",
    "pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d43093fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences_train = []\n",
    "sentences_number_train = []\n",
    "sentences_test = []\n",
    "sentences_number_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    sentences_train.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    sentences_test.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = []\n",
    "y_masked_labels_train = []\n",
    "x_masked_test = []\n",
    "y_masked_labels_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    x_masked_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_train.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    x_masked_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_test.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = np.array(x_masked_train)\n",
    "y_masked_labels_train = np.array(y_masked_labels_train)\n",
    "x_masked_test = np.array(x_masked_test)\n",
    "y_masked_labels_test = np.array(y_masked_labels_test)\n",
    "\n",
    "perm = np.random.permutation(len(x_masked_train))\n",
    "x_masked_train = x_masked_train[perm]\n",
    "y_masked_labels_train = y_masked_labels_train[perm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "13b40f89",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-02 17:57:37.181665: W tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory\n",
      "2024-05-02 17:57:37.181755: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)\n",
      "2024-05-02 17:57:37.181780: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (gl3384.arc-ts.umich.edu): /proc/driver/nvidia/version does not exist\n",
      "2024-05-02 17:57:37.182089: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 5s - loss: 3.2596 - val_loss: 3.2190 - 5s/epoch - 5s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.2195 - val_loss: 3.1927 - 147ms/epoch - 147ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.1926 - val_loss: 3.1738 - 149ms/epoch - 149ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.1731 - val_loss: 3.1591 - 135ms/epoch - 135ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.1580 - val_loss: 3.1468 - 138ms/epoch - 138ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.1449 - val_loss: 3.1362 - 160ms/epoch - 160ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.1334 - val_loss: 3.1273 - 141ms/epoch - 141ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.1236 - val_loss: 3.1202 - 144ms/epoch - 144ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.1155 - val_loss: 3.1145 - 144ms/epoch - 144ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 3.1089 - val_loss: 3.1101 - 137ms/epoch - 137ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 3.1036 - val_loss: 3.1062 - 135ms/epoch - 135ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 3.0988 - val_loss: 3.1026 - 143ms/epoch - 143ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 3.0943 - val_loss: 3.0987 - 141ms/epoch - 141ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 3.0897 - val_loss: 3.0945 - 132ms/epoch - 132ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 3.0850 - val_loss: 3.0903 - 131ms/epoch - 131ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 3.0804 - val_loss: 3.0862 - 125ms/epoch - 125ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 3.0759 - val_loss: 3.0824 - 147ms/epoch - 147ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 3.0719 - val_loss: 3.0789 - 162ms/epoch - 162ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 3.0682 - val_loss: 3.0757 - 134ms/epoch - 134ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 3.0648 - val_loss: 3.0727 - 159ms/epoch - 159ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 3.0617 - val_loss: 3.0698 - 125ms/epoch - 125ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 3.0587 - val_loss: 3.0670 - 124ms/epoch - 124ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 3.0559 - val_loss: 3.0644 - 123ms/epoch - 123ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 3.0531 - val_loss: 3.0618 - 135ms/epoch - 135ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 3.0504 - val_loss: 3.0592 - 124ms/epoch - 124ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 3.0478 - val_loss: 3.0568 - 127ms/epoch - 127ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 3.0453 - val_loss: 3.0545 - 122ms/epoch - 122ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 3.0430 - val_loss: 3.0522 - 123ms/epoch - 123ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 3.0406 - val_loss: 3.0499 - 132ms/epoch - 132ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 3.0383 - val_loss: 3.0476 - 151ms/epoch - 151ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 3.0361 - val_loss: 3.0454 - 132ms/epoch - 132ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 3.0339 - val_loss: 3.0432 - 135ms/epoch - 135ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 3.0317 - val_loss: 3.0410 - 122ms/epoch - 122ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 3.0296 - val_loss: 3.0389 - 126ms/epoch - 126ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 3.0276 - val_loss: 3.0369 - 121ms/epoch - 121ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 3.0256 - val_loss: 3.0349 - 140ms/epoch - 140ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 3.0237 - val_loss: 3.0329 - 129ms/epoch - 129ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 3.0217 - val_loss: 3.0309 - 122ms/epoch - 122ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 3.0198 - val_loss: 3.0290 - 156ms/epoch - 156ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 3.0179 - val_loss: 3.0270 - 151ms/epoch - 151ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 3.0160 - val_loss: 3.0251 - 154ms/epoch - 154ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 3.0141 - val_loss: 3.0231 - 164ms/epoch - 164ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 3.0123 - val_loss: 3.0212 - 122ms/epoch - 122ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 3.0104 - val_loss: 3.0193 - 160ms/epoch - 160ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 3.0086 - val_loss: 3.0173 - 139ms/epoch - 139ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 3.0067 - val_loss: 3.0153 - 138ms/epoch - 138ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 3.0048 - val_loss: 3.0131 - 145ms/epoch - 145ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 3.0028 - val_loss: 3.0109 - 139ms/epoch - 139ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 3.0007 - val_loss: 3.0086 - 125ms/epoch - 125ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.9985 - val_loss: 3.0062 - 146ms/epoch - 146ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.9961 - val_loss: 3.0035 - 140ms/epoch - 140ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.9936 - val_loss: 3.0007 - 164ms/epoch - 164ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.9909 - val_loss: 2.9977 - 124ms/epoch - 124ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.9880 - val_loss: 2.9945 - 149ms/epoch - 149ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.9849 - val_loss: 2.9910 - 159ms/epoch - 159ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.9815 - val_loss: 2.9872 - 144ms/epoch - 144ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.9778 - val_loss: 2.9831 - 197ms/epoch - 197ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.9739 - val_loss: 2.9787 - 165ms/epoch - 165ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.9696 - val_loss: 2.9740 - 175ms/epoch - 175ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.9651 - val_loss: 2.9690 - 155ms/epoch - 155ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.9604 - val_loss: 2.9639 - 160ms/epoch - 160ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.9554 - val_loss: 2.9585 - 142ms/epoch - 142ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.9504 - val_loss: 2.9531 - 176ms/epoch - 176ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.9454 - val_loss: 2.9476 - 133ms/epoch - 133ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.9404 - val_loss: 2.9424 - 130ms/epoch - 130ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.9356 - val_loss: 2.9373 - 133ms/epoch - 133ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.9310 - val_loss: 2.9324 - 139ms/epoch - 139ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.9264 - val_loss: 2.9272 - 142ms/epoch - 142ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.9214 - val_loss: 2.9217 - 141ms/epoch - 141ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.9159 - val_loss: 2.9157 - 141ms/epoch - 141ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.9098 - val_loss: 2.9094 - 143ms/epoch - 143ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.9034 - val_loss: 2.9032 - 151ms/epoch - 151ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.8969 - val_loss: 2.8970 - 148ms/epoch - 148ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.8906 - val_loss: 2.8909 - 149ms/epoch - 149ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.8844 - val_loss: 2.8849 - 153ms/epoch - 153ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.8783 - val_loss: 2.8788 - 152ms/epoch - 152ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.8721 - val_loss: 2.8727 - 170ms/epoch - 170ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.8658 - val_loss: 2.8664 - 177ms/epoch - 177ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.8593 - val_loss: 2.8601 - 163ms/epoch - 163ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.8526 - val_loss: 2.8538 - 169ms/epoch - 169ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.8458 - val_loss: 2.8474 - 163ms/epoch - 163ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.8388 - val_loss: 2.8410 - 161ms/epoch - 161ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.8317 - val_loss: 2.8348 - 154ms/epoch - 154ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.8247 - val_loss: 2.8290 - 152ms/epoch - 152ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.8181 - val_loss: 2.8240 - 154ms/epoch - 154ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.8123 - val_loss: 2.8193 - 186ms/epoch - 186ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.8072 - val_loss: 2.8145 - 193ms/epoch - 193ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.8020 - val_loss: 2.8093 - 189ms/epoch - 189ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.7966 - val_loss: 2.8040 - 176ms/epoch - 176ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.7912 - val_loss: 2.7988 - 115ms/epoch - 115ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.7861 - val_loss: 2.7935 - 132ms/epoch - 132ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.7812 - val_loss: 2.7899 - 132ms/epoch - 132ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.7777 - val_loss: 2.7864 - 170ms/epoch - 170ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.7741 - val_loss: 2.7783 - 189ms/epoch - 189ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.7665 - val_loss: 2.7732 - 187ms/epoch - 187ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.7613 - val_loss: 2.7712 - 203ms/epoch - 203ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 2.7584 - val_loss: 2.7627 - 162ms/epoch - 162ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 2.7504 - val_loss: 2.7571 - 174ms/epoch - 174ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 2.7444 - val_loss: 2.7546 - 154ms/epoch - 154ms/step\n",
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 2.7405 - val_loss: 2.7467 - 152ms/epoch - 152ms/step\n",
      "Epoch 101/2000\n",
      "1/1 - 0s - loss: 2.7330 - val_loss: 2.7402 - 151ms/epoch - 151ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 2.7257 - val_loss: 2.7379 - 163ms/epoch - 163ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 2.7224 - val_loss: 2.7326 - 174ms/epoch - 174ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 2.7173 - val_loss: 2.7229 - 153ms/epoch - 153ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 2.7067 - val_loss: 2.7200 - 153ms/epoch - 153ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 2.7025 - val_loss: 2.7160 - 157ms/epoch - 157ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 2.6992 - val_loss: 2.7059 - 144ms/epoch - 144ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 2.6871 - val_loss: 2.7001 - 140ms/epoch - 140ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 2.6807 - val_loss: 2.6956 - 139ms/epoch - 139ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 2.6774 - val_loss: 2.6866 - 144ms/epoch - 144ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 2.6664 - val_loss: 2.6805 - 143ms/epoch - 143ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 2.6599 - val_loss: 2.6746 - 172ms/epoch - 172ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 2.6555 - val_loss: 2.6661 - 150ms/epoch - 150ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 2.6452 - val_loss: 2.6603 - 144ms/epoch - 144ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 2.6389 - val_loss: 2.6546 - 142ms/epoch - 142ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 2.6343 - val_loss: 2.6461 - 158ms/epoch - 158ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 2.6247 - val_loss: 2.6412 - 217ms/epoch - 217ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 2.6193 - val_loss: 2.6353 - 203ms/epoch - 203ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 2.6144 - val_loss: 2.6278 - 197ms/epoch - 197ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 2.6057 - val_loss: 2.6222 - 137ms/epoch - 137ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 2.5999 - val_loss: 2.6166 - 182ms/epoch - 182ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 2.5955 - val_loss: 2.6104 - 175ms/epoch - 175ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 2.5881 - val_loss: 2.6036 - 182ms/epoch - 182ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 2.5817 - val_loss: 2.5982 - 177ms/epoch - 177ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 2.5772 - val_loss: 2.5923 - 164ms/epoch - 164ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 2.5703 - val_loss: 2.5866 - 183ms/epoch - 183ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 2.5646 - val_loss: 2.5813 - 159ms/epoch - 159ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 2.5601 - val_loss: 2.5756 - 166ms/epoch - 166ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 2.5536 - val_loss: 2.5707 - 162ms/epoch - 162ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 2.5485 - val_loss: 2.5650 - 164ms/epoch - 164ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 2.5436 - val_loss: 2.5595 - 164ms/epoch - 164ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 2.5376 - val_loss: 2.5555 - 164ms/epoch - 164ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 2.5331 - val_loss: 2.5497 - 165ms/epoch - 165ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 2.5278 - val_loss: 2.5447 - 171ms/epoch - 171ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 2.5225 - val_loss: 2.5409 - 171ms/epoch - 171ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 2.5181 - val_loss: 2.5354 - 168ms/epoch - 168ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 2.5129 - val_loss: 2.5310 - 199ms/epoch - 199ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 2.5084 - val_loss: 2.5265 - 162ms/epoch - 162ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 2.5036 - val_loss: 2.5219 - 159ms/epoch - 159ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 2.4989 - val_loss: 2.5174 - 160ms/epoch - 160ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 2.4946 - val_loss: 2.5129 - 165ms/epoch - 165ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 2.4898 - val_loss: 2.5088 - 167ms/epoch - 167ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 2.4855 - val_loss: 2.5039 - 159ms/epoch - 159ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 2.4809 - val_loss: 2.4994 - 152ms/epoch - 152ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 2.4762 - val_loss: 2.4954 - 148ms/epoch - 148ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 2.4720 - val_loss: 2.4906 - 145ms/epoch - 145ms/step\n",
      "Epoch 147/2000\n",
      "1/1 - 0s - loss: 2.4675 - val_loss: 2.4861 - 142ms/epoch - 142ms/step\n",
      "Epoch 148/2000\n",
      "1/1 - 0s - loss: 2.4631 - val_loss: 2.4819 - 141ms/epoch - 141ms/step\n",
      "Epoch 149/2000\n",
      "1/1 - 0s - loss: 2.4588 - val_loss: 2.4770 - 143ms/epoch - 143ms/step\n",
      "Epoch 150/2000\n",
      "1/1 - 0s - loss: 2.4544 - val_loss: 2.4726 - 140ms/epoch - 140ms/step\n",
      "Epoch 151/2000\n",
      "1/1 - 0s - loss: 2.4502 - val_loss: 2.4683 - 145ms/epoch - 145ms/step\n",
      "Epoch 152/2000\n",
      "1/1 - 0s - loss: 2.4460 - val_loss: 2.4639 - 180ms/epoch - 180ms/step\n",
      "Epoch 153/2000\n",
      "1/1 - 0s - loss: 2.4418 - val_loss: 2.4598 - 141ms/epoch - 141ms/step\n",
      "Epoch 154/2000\n",
      "1/1 - 0s - loss: 2.4379 - val_loss: 2.4558 - 136ms/epoch - 136ms/step\n",
      "Epoch 155/2000\n",
      "1/1 - 0s - loss: 2.4338 - val_loss: 2.4516 - 136ms/epoch - 136ms/step\n",
      "Epoch 156/2000\n",
      "1/1 - 0s - loss: 2.4297 - val_loss: 2.4476 - 137ms/epoch - 137ms/step\n",
      "Epoch 157/2000\n",
      "1/1 - 0s - loss: 2.4258 - val_loss: 2.4437 - 140ms/epoch - 140ms/step\n",
      "Epoch 158/2000\n",
      "1/1 - 0s - loss: 2.4219 - val_loss: 2.4397 - 143ms/epoch - 143ms/step\n",
      "Epoch 159/2000\n",
      "1/1 - 0s - loss: 2.4180 - val_loss: 2.4358 - 147ms/epoch - 147ms/step\n",
      "Epoch 160/2000\n",
      "1/1 - 0s - loss: 2.4142 - val_loss: 2.4321 - 151ms/epoch - 151ms/step\n",
      "Epoch 161/2000\n",
      "1/1 - 0s - loss: 2.4105 - val_loss: 2.4281 - 146ms/epoch - 146ms/step\n",
      "Epoch 162/2000\n",
      "1/1 - 0s - loss: 2.4068 - val_loss: 2.4243 - 124ms/epoch - 124ms/step\n",
      "Epoch 163/2000\n",
      "1/1 - 0s - loss: 2.4030 - val_loss: 2.4205 - 166ms/epoch - 166ms/step\n",
      "Epoch 164/2000\n",
      "1/1 - 0s - loss: 2.3993 - val_loss: 2.4167 - 121ms/epoch - 121ms/step\n",
      "Epoch 165/2000\n",
      "1/1 - 0s - loss: 2.3957 - val_loss: 2.4132 - 148ms/epoch - 148ms/step\n",
      "Epoch 166/2000\n",
      "1/1 - 0s - loss: 2.3921 - val_loss: 2.4093 - 153ms/epoch - 153ms/step\n",
      "Epoch 167/2000\n",
      "1/1 - 0s - loss: 2.3885 - val_loss: 2.4057 - 152ms/epoch - 152ms/step\n",
      "Epoch 168/2000\n",
      "1/1 - 0s - loss: 2.3848 - val_loss: 2.4020 - 149ms/epoch - 149ms/step\n",
      "Epoch 169/2000\n",
      "1/1 - 0s - loss: 2.3812 - val_loss: 2.3984 - 147ms/epoch - 147ms/step\n",
      "Epoch 170/2000\n",
      "1/1 - 0s - loss: 2.3777 - val_loss: 2.3949 - 141ms/epoch - 141ms/step\n",
      "Epoch 171/2000\n",
      "1/1 - 0s - loss: 2.3741 - val_loss: 2.3912 - 147ms/epoch - 147ms/step\n",
      "Epoch 172/2000\n",
      "1/1 - 0s - loss: 2.3706 - val_loss: 2.3879 - 144ms/epoch - 144ms/step\n",
      "Epoch 173/2000\n",
      "1/1 - 0s - loss: 2.3672 - val_loss: 2.3842 - 169ms/epoch - 169ms/step\n",
      "Epoch 174/2000\n",
      "1/1 - 0s - loss: 2.3638 - val_loss: 2.3812 - 170ms/epoch - 170ms/step\n",
      "Epoch 175/2000\n",
      "1/1 - 0s - loss: 2.3605 - val_loss: 2.3776 - 143ms/epoch - 143ms/step\n",
      "Epoch 176/2000\n",
      "1/1 - 0s - loss: 2.3573 - val_loss: 2.3752 - 176ms/epoch - 176ms/step\n",
      "Epoch 177/2000\n",
      "1/1 - 0s - loss: 2.3544 - val_loss: 2.3721 - 171ms/epoch - 171ms/step\n",
      "Epoch 178/2000\n",
      "1/1 - 0s - loss: 2.3519 - val_loss: 2.3707 - 147ms/epoch - 147ms/step\n",
      "Epoch 179/2000\n",
      "1/1 - 0s - loss: 2.3496 - val_loss: 2.3673 - 178ms/epoch - 178ms/step\n",
      "Epoch 180/2000\n",
      "1/1 - 0s - loss: 2.3473 - val_loss: 2.3641 - 188ms/epoch - 188ms/step\n",
      "Epoch 181/2000\n",
      "1/1 - 0s - loss: 2.3431 - val_loss: 2.3584 - 167ms/epoch - 167ms/step\n",
      "Epoch 182/2000\n",
      "1/1 - 0s - loss: 2.3382 - val_loss: 2.3546 - 167ms/epoch - 167ms/step\n",
      "Epoch 183/2000\n",
      "1/1 - 0s - loss: 2.3343 - val_loss: 2.3527 - 196ms/epoch - 196ms/step\n",
      "Epoch 184/2000\n",
      "1/1 - 0s - loss: 2.3320 - val_loss: 2.3501 - 187ms/epoch - 187ms/step\n",
      "Epoch 185/2000\n",
      "1/1 - 0s - loss: 2.3301 - val_loss: 2.3477 - 201ms/epoch - 201ms/step\n",
      "Epoch 186/2000\n",
      "1/1 - 0s - loss: 2.3269 - val_loss: 2.3430 - 171ms/epoch - 171ms/step\n",
      "Epoch 187/2000\n",
      "1/1 - 0s - loss: 2.3229 - val_loss: 2.3394 - 196ms/epoch - 196ms/step\n",
      "Epoch 188/2000\n",
      "1/1 - 0s - loss: 2.3190 - val_loss: 2.3366 - 197ms/epoch - 197ms/step\n",
      "Epoch 189/2000\n",
      "1/1 - 0s - loss: 2.3162 - val_loss: 2.3340 - 195ms/epoch - 195ms/step\n",
      "Epoch 190/2000\n",
      "1/1 - 0s - loss: 2.3140 - val_loss: 2.3322 - 191ms/epoch - 191ms/step\n",
      "Epoch 191/2000\n",
      "1/1 - 0s - loss: 2.3116 - val_loss: 2.3285 - 162ms/epoch - 162ms/step\n",
      "Epoch 192/2000\n",
      "1/1 - 0s - loss: 2.3087 - val_loss: 2.3254 - 189ms/epoch - 189ms/step\n",
      "Epoch 193/2000\n",
      "1/1 - 0s - loss: 2.3050 - val_loss: 2.3216 - 160ms/epoch - 160ms/step\n",
      "Epoch 194/2000\n",
      "1/1 - 0s - loss: 2.3017 - val_loss: 2.3188 - 162ms/epoch - 162ms/step\n",
      "Epoch 195/2000\n",
      "1/1 - 0s - loss: 2.2990 - val_loss: 2.3169 - 163ms/epoch - 163ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 196/2000\n",
      "1/1 - 0s - loss: 2.2968 - val_loss: 2.3140 - 190ms/epoch - 190ms/step\n",
      "Epoch 197/2000\n",
      "1/1 - 0s - loss: 2.2945 - val_loss: 2.3117 - 165ms/epoch - 165ms/step\n",
      "Epoch 198/2000\n",
      "1/1 - 0s - loss: 2.2916 - val_loss: 2.3082 - 154ms/epoch - 154ms/step\n",
      "Epoch 199/2000\n",
      "1/1 - 0s - loss: 2.2886 - val_loss: 2.3055 - 138ms/epoch - 138ms/step\n",
      "Epoch 200/2000\n",
      "1/1 - 0s - loss: 2.2856 - val_loss: 2.3028 - 163ms/epoch - 163ms/step\n",
      "Epoch 201/2000\n",
      "1/1 - 0s - loss: 2.2829 - val_loss: 2.3003 - 129ms/epoch - 129ms/step\n",
      "Epoch 202/2000\n",
      "1/1 - 0s - loss: 2.2806 - val_loss: 2.2985 - 135ms/epoch - 135ms/step\n",
      "Epoch 203/2000\n",
      "1/1 - 0s - loss: 2.2784 - val_loss: 2.2963 - 132ms/epoch - 132ms/step\n",
      "Epoch 204/2000\n",
      "1/1 - 0s - loss: 2.2767 - val_loss: 2.2959 - 133ms/epoch - 133ms/step\n",
      "Epoch 205/2000\n",
      "1/1 - 0s - loss: 2.2754 - val_loss: 2.2950 - 123ms/epoch - 123ms/step\n",
      "Epoch 206/2000\n",
      "1/1 - 0s - loss: 2.2754 - val_loss: 2.2945 - 125ms/epoch - 125ms/step\n",
      "Epoch 207/2000\n",
      "1/1 - 0s - loss: 2.2737 - val_loss: 2.2914 - 125ms/epoch - 125ms/step\n",
      "Epoch 208/2000\n",
      "1/1 - 0s - loss: 2.2716 - val_loss: 2.2857 - 127ms/epoch - 127ms/step\n",
      "Epoch 209/2000\n",
      "1/1 - 0s - loss: 2.2653 - val_loss: 2.2831 - 139ms/epoch - 139ms/step\n",
      "Epoch 210/2000\n",
      "1/1 - 0s - loss: 2.2625 - val_loss: 2.2827 - 157ms/epoch - 157ms/step\n",
      "Epoch 211/2000\n",
      "1/1 - 0s - loss: 2.2627 - val_loss: 2.2803 - 148ms/epoch - 148ms/step\n",
      "Epoch 212/2000\n",
      "1/1 - 0s - loss: 2.2596 - val_loss: 2.2766 - 124ms/epoch - 124ms/step\n",
      "Epoch 213/2000\n",
      "1/1 - 0s - loss: 2.2562 - val_loss: 2.2743 - 130ms/epoch - 130ms/step\n",
      "Epoch 214/2000\n",
      "1/1 - 0s - loss: 2.2540 - val_loss: 2.2735 - 133ms/epoch - 133ms/step\n",
      "Epoch 215/2000\n",
      "1/1 - 0s - loss: 2.2527 - val_loss: 2.2713 - 160ms/epoch - 160ms/step\n",
      "Epoch 216/2000\n",
      "1/1 - 0s - loss: 2.2508 - val_loss: 2.2681 - 155ms/epoch - 155ms/step\n",
      "Epoch 217/2000\n",
      "1/1 - 0s - loss: 2.2475 - val_loss: 2.2660 - 153ms/epoch - 153ms/step\n",
      "Epoch 218/2000\n",
      "1/1 - 0s - loss: 2.2452 - val_loss: 2.2650 - 126ms/epoch - 126ms/step\n",
      "Epoch 219/2000\n",
      "1/1 - 0s - loss: 2.2444 - val_loss: 2.2626 - 158ms/epoch - 158ms/step\n",
      "Epoch 220/2000\n",
      "1/1 - 0s - loss: 2.2418 - val_loss: 2.2598 - 123ms/epoch - 123ms/step\n",
      "Epoch 221/2000\n",
      "1/1 - 0s - loss: 2.2391 - val_loss: 2.2586 - 126ms/epoch - 126ms/step\n",
      "Epoch 222/2000\n",
      "1/1 - 0s - loss: 2.2377 - val_loss: 2.2568 - 118ms/epoch - 118ms/step\n",
      "Epoch 223/2000\n",
      "1/1 - 0s - loss: 2.2358 - val_loss: 2.2545 - 122ms/epoch - 122ms/step\n",
      "Epoch 224/2000\n",
      "1/1 - 0s - loss: 2.2337 - val_loss: 2.2527 - 150ms/epoch - 150ms/step\n",
      "Epoch 225/2000\n",
      "1/1 - 0s - loss: 2.2317 - val_loss: 2.2508 - 153ms/epoch - 153ms/step\n",
      "Epoch 226/2000\n",
      "1/1 - 0s - loss: 2.2297 - val_loss: 2.2490 - 163ms/epoch - 163ms/step\n",
      "Epoch 227/2000\n",
      "1/1 - 0s - loss: 2.2282 - val_loss: 2.2474 - 157ms/epoch - 157ms/step\n",
      "Epoch 228/2000\n",
      "1/1 - 0s - loss: 2.2262 - val_loss: 2.2450 - 115ms/epoch - 115ms/step\n",
      "Epoch 229/2000\n",
      "1/1 - 0s - loss: 2.2240 - val_loss: 2.2433 - 115ms/epoch - 115ms/step\n",
      "Epoch 230/2000\n",
      "1/1 - 0s - loss: 2.2224 - val_loss: 2.2419 - 124ms/epoch - 124ms/step\n",
      "Epoch 231/2000\n",
      "1/1 - 0s - loss: 2.2207 - val_loss: 2.2399 - 133ms/epoch - 133ms/step\n",
      "Epoch 232/2000\n",
      "1/1 - 0s - loss: 2.2188 - val_loss: 2.2382 - 109ms/epoch - 109ms/step\n",
      "Epoch 233/2000\n",
      "1/1 - 0s - loss: 2.2171 - val_loss: 2.2364 - 110ms/epoch - 110ms/step\n",
      "Epoch 234/2000\n",
      "1/1 - 0s - loss: 2.2151 - val_loss: 2.2347 - 146ms/epoch - 146ms/step\n",
      "Epoch 235/2000\n",
      "1/1 - 0s - loss: 2.2135 - val_loss: 2.2333 - 150ms/epoch - 150ms/step\n",
      "Epoch 236/2000\n",
      "1/1 - 0s - loss: 2.2119 - val_loss: 2.2315 - 151ms/epoch - 151ms/step\n",
      "Epoch 237/2000\n",
      "1/1 - 0s - loss: 2.2102 - val_loss: 2.2298 - 142ms/epoch - 142ms/step\n",
      "Epoch 238/2000\n",
      "1/1 - 0s - loss: 2.2084 - val_loss: 2.2281 - 140ms/epoch - 140ms/step\n",
      "Epoch 239/2000\n",
      "1/1 - 0s - loss: 2.2067 - val_loss: 2.2264 - 142ms/epoch - 142ms/step\n",
      "Epoch 240/2000\n",
      "1/1 - 0s - loss: 2.2050 - val_loss: 2.2250 - 142ms/epoch - 142ms/step\n",
      "Epoch 241/2000\n",
      "1/1 - 0s - loss: 2.2034 - val_loss: 2.2233 - 169ms/epoch - 169ms/step\n",
      "Epoch 242/2000\n",
      "1/1 - 0s - loss: 2.2019 - val_loss: 2.2219 - 140ms/epoch - 140ms/step\n",
      "Epoch 243/2000\n",
      "1/1 - 0s - loss: 2.2003 - val_loss: 2.2203 - 147ms/epoch - 147ms/step\n",
      "Epoch 244/2000\n",
      "1/1 - 0s - loss: 2.1987 - val_loss: 2.2189 - 158ms/epoch - 158ms/step\n",
      "Epoch 245/2000\n",
      "1/1 - 0s - loss: 2.1972 - val_loss: 2.2172 - 195ms/epoch - 195ms/step\n",
      "Epoch 246/2000\n",
      "1/1 - 0s - loss: 2.1956 - val_loss: 2.2159 - 170ms/epoch - 170ms/step\n",
      "Epoch 247/2000\n",
      "1/1 - 0s - loss: 2.1941 - val_loss: 2.2144 - 162ms/epoch - 162ms/step\n",
      "Epoch 248/2000\n",
      "1/1 - 0s - loss: 2.1926 - val_loss: 2.2132 - 165ms/epoch - 165ms/step\n",
      "Epoch 249/2000\n",
      "1/1 - 0s - loss: 2.1912 - val_loss: 2.2119 - 162ms/epoch - 162ms/step\n",
      "Epoch 250/2000\n",
      "1/1 - 0s - loss: 2.1900 - val_loss: 2.2113 - 165ms/epoch - 165ms/step\n",
      "Epoch 251/2000\n",
      "1/1 - 0s - loss: 2.1890 - val_loss: 2.2105 - 171ms/epoch - 171ms/step\n",
      "Epoch 252/2000\n",
      "1/1 - 0s - loss: 2.1886 - val_loss: 2.2114 - 160ms/epoch - 160ms/step\n",
      "Epoch 253/2000\n",
      "1/1 - 0s - loss: 2.1888 - val_loss: 2.2116 - 158ms/epoch - 158ms/step\n",
      "Epoch 254/2000\n",
      "1/1 - 0s - loss: 2.1895 - val_loss: 2.2116 - 156ms/epoch - 156ms/step\n",
      "Epoch 255/2000\n",
      "1/1 - 0s - loss: 2.1887 - val_loss: 2.2068 - 167ms/epoch - 167ms/step\n",
      "Epoch 256/2000\n",
      "1/1 - 0s - loss: 2.1843 - val_loss: 2.2020 - 167ms/epoch - 167ms/step\n",
      "Epoch 257/2000\n",
      "1/1 - 0s - loss: 2.1794 - val_loss: 2.2030 - 159ms/epoch - 159ms/step\n",
      "Epoch 258/2000\n",
      "1/1 - 0s - loss: 2.1800 - val_loss: 2.2030 - 163ms/epoch - 163ms/step\n",
      "Epoch 259/2000\n",
      "1/1 - 0s - loss: 2.1803 - val_loss: 2.1990 - 169ms/epoch - 169ms/step\n",
      "Epoch 260/2000\n",
      "1/1 - 0s - loss: 2.1759 - val_loss: 2.1972 - 187ms/epoch - 187ms/step\n",
      "Epoch 261/2000\n",
      "1/1 - 0s - loss: 2.1739 - val_loss: 2.1979 - 164ms/epoch - 164ms/step\n",
      "Epoch 262/2000\n",
      "1/1 - 0s - loss: 2.1748 - val_loss: 2.1960 - 169ms/epoch - 169ms/step\n",
      "Epoch 263/2000\n",
      "1/1 - 0s - loss: 2.1725 - val_loss: 2.1928 - 166ms/epoch - 166ms/step\n",
      "Epoch 264/2000\n",
      "1/1 - 0s - loss: 2.1695 - val_loss: 2.1924 - 174ms/epoch - 174ms/step\n",
      "Epoch 265/2000\n",
      "1/1 - 0s - loss: 2.1690 - val_loss: 2.1921 - 170ms/epoch - 170ms/step\n",
      "Epoch 266/2000\n",
      "1/1 - 0s - loss: 2.1683 - val_loss: 2.1893 - 168ms/epoch - 168ms/step\n",
      "Epoch 267/2000\n",
      "1/1 - 0s - loss: 2.1657 - val_loss: 2.1881 - 169ms/epoch - 169ms/step\n",
      "Epoch 268/2000\n",
      "1/1 - 0s - loss: 2.1644 - val_loss: 2.1881 - 166ms/epoch - 166ms/step\n",
      "Epoch 269/2000\n",
      "1/1 - 0s - loss: 2.1639 - val_loss: 2.1857 - 155ms/epoch - 155ms/step\n",
      "Epoch 270/2000\n",
      "1/1 - 0s - loss: 2.1618 - val_loss: 2.1843 - 149ms/epoch - 149ms/step\n",
      "Epoch 271/2000\n",
      "1/1 - 0s - loss: 2.1603 - val_loss: 2.1840 - 148ms/epoch - 148ms/step\n",
      "Epoch 272/2000\n",
      "1/1 - 0s - loss: 2.1597 - val_loss: 2.1823 - 149ms/epoch - 149ms/step\n",
      "Epoch 273/2000\n",
      "1/1 - 0s - loss: 2.1582 - val_loss: 2.1807 - 149ms/epoch - 149ms/step\n",
      "Epoch 274/2000\n",
      "1/1 - 0s - loss: 2.1564 - val_loss: 2.1800 - 147ms/epoch - 147ms/step\n",
      "Epoch 275/2000\n",
      "1/1 - 0s - loss: 2.1555 - val_loss: 2.1788 - 152ms/epoch - 152ms/step\n",
      "Epoch 276/2000\n",
      "1/1 - 0s - loss: 2.1546 - val_loss: 2.1773 - 156ms/epoch - 156ms/step\n",
      "Epoch 277/2000\n",
      "1/1 - 0s - loss: 2.1529 - val_loss: 2.1761 - 156ms/epoch - 156ms/step\n",
      "Epoch 278/2000\n",
      "1/1 - 0s - loss: 2.1516 - val_loss: 2.1752 - 160ms/epoch - 160ms/step\n",
      "Epoch 279/2000\n",
      "1/1 - 0s - loss: 2.1508 - val_loss: 2.1741 - 125ms/epoch - 125ms/step\n",
      "Epoch 280/2000\n",
      "1/1 - 0s - loss: 2.1494 - val_loss: 2.1726 - 124ms/epoch - 124ms/step\n",
      "Epoch 281/2000\n",
      "1/1 - 0s - loss: 2.1480 - val_loss: 2.1716 - 128ms/epoch - 128ms/step\n",
      "Epoch 282/2000\n",
      "1/1 - 0s - loss: 2.1470 - val_loss: 2.1708 - 124ms/epoch - 124ms/step\n",
      "Epoch 283/2000\n",
      "1/1 - 0s - loss: 2.1460 - val_loss: 2.1694 - 127ms/epoch - 127ms/step\n",
      "Epoch 284/2000\n",
      "1/1 - 0s - loss: 2.1446 - val_loss: 2.1682 - 129ms/epoch - 129ms/step\n",
      "Epoch 285/2000\n",
      "1/1 - 0s - loss: 2.1434 - val_loss: 2.1674 - 124ms/epoch - 124ms/step\n",
      "Epoch 286/2000\n",
      "1/1 - 0s - loss: 2.1424 - val_loss: 2.1661 - 125ms/epoch - 125ms/step\n",
      "Epoch 287/2000\n",
      "1/1 - 0s - loss: 2.1413 - val_loss: 2.1650 - 126ms/epoch - 126ms/step\n",
      "Epoch 288/2000\n",
      "1/1 - 0s - loss: 2.1401 - val_loss: 2.1639 - 125ms/epoch - 125ms/step\n",
      "Epoch 289/2000\n",
      "1/1 - 0s - loss: 2.1389 - val_loss: 2.1629 - 126ms/epoch - 126ms/step\n",
      "Epoch 290/2000\n",
      "1/1 - 0s - loss: 2.1379 - val_loss: 2.1619 - 128ms/epoch - 128ms/step\n",
      "Epoch 291/2000\n",
      "1/1 - 0s - loss: 2.1369 - val_loss: 2.1607 - 122ms/epoch - 122ms/step\n",
      "Epoch 292/2000\n",
      "1/1 - 0s - loss: 2.1357 - val_loss: 2.1596 - 126ms/epoch - 126ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 293/2000\n",
      "1/1 - 0s - loss: 2.1346 - val_loss: 2.1587 - 125ms/epoch - 125ms/step\n",
      "Epoch 294/2000\n",
      "1/1 - 0s - loss: 2.1336 - val_loss: 2.1576 - 120ms/epoch - 120ms/step\n",
      "Epoch 295/2000\n",
      "1/1 - 0s - loss: 2.1326 - val_loss: 2.1566 - 122ms/epoch - 122ms/step\n",
      "Epoch 296/2000\n",
      "1/1 - 0s - loss: 2.1314 - val_loss: 2.1555 - 121ms/epoch - 121ms/step\n",
      "Epoch 297/2000\n",
      "1/1 - 0s - loss: 2.1304 - val_loss: 2.1545 - 119ms/epoch - 119ms/step\n",
      "Epoch 298/2000\n",
      "1/1 - 0s - loss: 2.1294 - val_loss: 2.1537 - 154ms/epoch - 154ms/step\n",
      "Epoch 299/2000\n",
      "1/1 - 0s - loss: 2.1284 - val_loss: 2.1526 - 115ms/epoch - 115ms/step\n",
      "Epoch 300/2000\n",
      "1/1 - 0s - loss: 2.1274 - val_loss: 2.1517 - 118ms/epoch - 118ms/step\n",
      "Epoch 301/2000\n",
      "1/1 - 0s - loss: 2.1263 - val_loss: 2.1506 - 114ms/epoch - 114ms/step\n",
      "Epoch 302/2000\n",
      "1/1 - 0s - loss: 2.1253 - val_loss: 2.1497 - 112ms/epoch - 112ms/step\n",
      "Epoch 303/2000\n",
      "1/1 - 0s - loss: 2.1243 - val_loss: 2.1488 - 168ms/epoch - 168ms/step\n",
      "Epoch 304/2000\n",
      "1/1 - 0s - loss: 2.1233 - val_loss: 2.1478 - 152ms/epoch - 152ms/step\n",
      "Epoch 305/2000\n",
      "1/1 - 0s - loss: 2.1223 - val_loss: 2.1469 - 154ms/epoch - 154ms/step\n",
      "Epoch 306/2000\n",
      "1/1 - 0s - loss: 2.1213 - val_loss: 2.1459 - 118ms/epoch - 118ms/step\n",
      "Epoch 307/2000\n",
      "1/1 - 0s - loss: 2.1203 - val_loss: 2.1450 - 122ms/epoch - 122ms/step\n",
      "Epoch 308/2000\n",
      "1/1 - 0s - loss: 2.1194 - val_loss: 2.1441 - 121ms/epoch - 121ms/step\n",
      "Epoch 309/2000\n",
      "1/1 - 0s - loss: 2.1184 - val_loss: 2.1432 - 123ms/epoch - 123ms/step\n",
      "Epoch 310/2000\n",
      "1/1 - 0s - loss: 2.1175 - val_loss: 2.1424 - 150ms/epoch - 150ms/step\n",
      "Epoch 311/2000\n",
      "1/1 - 0s - loss: 2.1165 - val_loss: 2.1414 - 144ms/epoch - 144ms/step\n",
      "Epoch 312/2000\n",
      "1/1 - 0s - loss: 2.1156 - val_loss: 2.1406 - 122ms/epoch - 122ms/step\n",
      "Epoch 313/2000\n",
      "1/1 - 0s - loss: 2.1147 - val_loss: 2.1397 - 125ms/epoch - 125ms/step\n",
      "Epoch 314/2000\n",
      "1/1 - 0s - loss: 2.1137 - val_loss: 2.1389 - 151ms/epoch - 151ms/step\n",
      "Epoch 315/2000\n",
      "1/1 - 0s - loss: 2.1128 - val_loss: 2.1379 - 127ms/epoch - 127ms/step\n",
      "Epoch 316/2000\n",
      "1/1 - 0s - loss: 2.1119 - val_loss: 2.1371 - 119ms/epoch - 119ms/step\n",
      "Epoch 317/2000\n",
      "1/1 - 0s - loss: 2.1109 - val_loss: 2.1362 - 124ms/epoch - 124ms/step\n",
      "Epoch 318/2000\n",
      "1/1 - 0s - loss: 2.1100 - val_loss: 2.1354 - 129ms/epoch - 129ms/step\n",
      "Epoch 319/2000\n",
      "1/1 - 0s - loss: 2.1091 - val_loss: 2.1346 - 128ms/epoch - 128ms/step\n",
      "Epoch 320/2000\n",
      "1/1 - 0s - loss: 2.1082 - val_loss: 2.1337 - 137ms/epoch - 137ms/step\n",
      "Epoch 321/2000\n",
      "1/1 - 0s - loss: 2.1074 - val_loss: 2.1330 - 158ms/epoch - 158ms/step\n",
      "Epoch 322/2000\n",
      "1/1 - 0s - loss: 2.1065 - val_loss: 2.1321 - 144ms/epoch - 144ms/step\n",
      "Epoch 323/2000\n",
      "1/1 - 0s - loss: 2.1056 - val_loss: 2.1315 - 146ms/epoch - 146ms/step\n",
      "Epoch 324/2000\n",
      "1/1 - 0s - loss: 2.1048 - val_loss: 2.1306 - 171ms/epoch - 171ms/step\n",
      "Epoch 325/2000\n",
      "1/1 - 0s - loss: 2.1040 - val_loss: 2.1301 - 157ms/epoch - 157ms/step\n",
      "Epoch 326/2000\n",
      "1/1 - 0s - loss: 2.1032 - val_loss: 2.1291 - 138ms/epoch - 138ms/step\n",
      "Epoch 327/2000\n",
      "1/1 - 0s - loss: 2.1024 - val_loss: 2.1289 - 124ms/epoch - 124ms/step\n",
      "Epoch 328/2000\n",
      "1/1 - 0s - loss: 2.1017 - val_loss: 2.1279 - 128ms/epoch - 128ms/step\n",
      "Epoch 329/2000\n",
      "1/1 - 0s - loss: 2.1011 - val_loss: 2.1281 - 128ms/epoch - 128ms/step\n",
      "Epoch 330/2000\n",
      "1/1 - 0s - loss: 2.1007 - val_loss: 2.1269 - 133ms/epoch - 133ms/step\n",
      "Epoch 331/2000\n",
      "1/1 - 0s - loss: 2.1001 - val_loss: 2.1267 - 125ms/epoch - 125ms/step\n",
      "Epoch 332/2000\n",
      "1/1 - 0s - loss: 2.0993 - val_loss: 2.1250 - 123ms/epoch - 123ms/step\n",
      "Epoch 333/2000\n",
      "1/1 - 0s - loss: 2.0979 - val_loss: 2.1240 - 163ms/epoch - 163ms/step\n",
      "Epoch 334/2000\n",
      "1/1 - 0s - loss: 2.0966 - val_loss: 2.1230 - 125ms/epoch - 125ms/step\n",
      "Epoch 335/2000\n",
      "1/1 - 0s - loss: 2.0956 - val_loss: 2.1224 - 125ms/epoch - 125ms/step\n",
      "Epoch 336/2000\n",
      "1/1 - 0s - loss: 2.0951 - val_loss: 2.1225 - 124ms/epoch - 124ms/step\n",
      "Epoch 337/2000\n",
      "1/1 - 0s - loss: 2.0947 - val_loss: 2.1217 - 132ms/epoch - 132ms/step\n",
      "Epoch 338/2000\n",
      "1/1 - 0s - loss: 2.0944 - val_loss: 2.1223 - 122ms/epoch - 122ms/step\n",
      "Epoch 339/2000\n",
      "1/1 - 0s - loss: 2.0941 - val_loss: 2.1203 - 138ms/epoch - 138ms/step\n",
      "Epoch 340/2000\n",
      "1/1 - 0s - loss: 2.0929 - val_loss: 2.1192 - 150ms/epoch - 150ms/step\n",
      "Epoch 341/2000\n",
      "1/1 - 0s - loss: 2.0913 - val_loss: 2.1182 - 148ms/epoch - 148ms/step\n",
      "Epoch 342/2000\n",
      "1/1 - 0s - loss: 2.0902 - val_loss: 2.1176 - 152ms/epoch - 152ms/step\n",
      "Epoch 343/2000\n",
      "1/1 - 0s - loss: 2.0898 - val_loss: 2.1180 - 142ms/epoch - 142ms/step\n",
      "Epoch 344/2000\n",
      "1/1 - 0s - loss: 2.0896 - val_loss: 2.1167 - 139ms/epoch - 139ms/step\n",
      "Epoch 345/2000\n",
      "1/1 - 0s - loss: 2.0889 - val_loss: 2.1159 - 142ms/epoch - 142ms/step\n",
      "Epoch 346/2000\n",
      "1/1 - 0s - loss: 2.0875 - val_loss: 2.1146 - 160ms/epoch - 160ms/step\n",
      "Epoch 347/2000\n",
      "1/1 - 0s - loss: 2.0864 - val_loss: 2.1141 - 145ms/epoch - 145ms/step\n",
      "Epoch 348/2000\n",
      "1/1 - 0s - loss: 2.0859 - val_loss: 2.1140 - 140ms/epoch - 140ms/step\n",
      "Epoch 349/2000\n",
      "1/1 - 0s - loss: 2.0854 - val_loss: 2.1130 - 146ms/epoch - 146ms/step\n",
      "Epoch 350/2000\n",
      "1/1 - 0s - loss: 2.0848 - val_loss: 2.1128 - 137ms/epoch - 137ms/step\n",
      "Epoch 351/2000\n",
      "1/1 - 0s - loss: 2.0840 - val_loss: 2.1114 - 150ms/epoch - 150ms/step\n",
      "Epoch 352/2000\n",
      "1/1 - 0s - loss: 2.0828 - val_loss: 2.1107 - 142ms/epoch - 142ms/step\n",
      "Epoch 353/2000\n",
      "1/1 - 0s - loss: 2.0822 - val_loss: 2.1106 - 125ms/epoch - 125ms/step\n",
      "Epoch 354/2000\n",
      "1/1 - 0s - loss: 2.0817 - val_loss: 2.1096 - 139ms/epoch - 139ms/step\n",
      "Epoch 355/2000\n",
      "1/1 - 0s - loss: 2.0811 - val_loss: 2.1094 - 131ms/epoch - 131ms/step\n",
      "Epoch 356/2000\n",
      "1/1 - 0s - loss: 2.0804 - val_loss: 2.1083 - 139ms/epoch - 139ms/step\n",
      "Epoch 357/2000\n",
      "1/1 - 0s - loss: 2.0795 - val_loss: 2.1076 - 146ms/epoch - 146ms/step\n",
      "Epoch 358/2000\n",
      "1/1 - 0s - loss: 2.0786 - val_loss: 2.1072 - 133ms/epoch - 133ms/step\n",
      "Epoch 359/2000\n",
      "1/1 - 0s - loss: 2.0780 - val_loss: 2.1066 - 152ms/epoch - 152ms/step\n",
      "Epoch 360/2000\n",
      "1/1 - 0s - loss: 2.0775 - val_loss: 2.1064 - 149ms/epoch - 149ms/step\n",
      "Epoch 361/2000\n",
      "1/1 - 0s - loss: 2.0769 - val_loss: 2.1053 - 152ms/epoch - 152ms/step\n",
      "Epoch 362/2000\n",
      "1/1 - 0s - loss: 2.0761 - val_loss: 2.1048 - 148ms/epoch - 148ms/step\n",
      "Epoch 363/2000\n",
      "1/1 - 0s - loss: 2.0753 - val_loss: 2.1041 - 181ms/epoch - 181ms/step\n",
      "Epoch 364/2000\n",
      "1/1 - 0s - loss: 2.0746 - val_loss: 2.1034 - 144ms/epoch - 144ms/step\n",
      "Epoch 365/2000\n",
      "1/1 - 0s - loss: 2.0740 - val_loss: 2.1032 - 146ms/epoch - 146ms/step\n",
      "Epoch 366/2000\n",
      "1/1 - 0s - loss: 2.0735 - val_loss: 2.1024 - 151ms/epoch - 151ms/step\n",
      "Epoch 367/2000\n",
      "1/1 - 0s - loss: 2.0729 - val_loss: 2.1020 - 143ms/epoch - 143ms/step\n",
      "Epoch 368/2000\n",
      "1/1 - 0s - loss: 2.0722 - val_loss: 2.1012 - 139ms/epoch - 139ms/step\n",
      "Epoch 369/2000\n",
      "1/1 - 0s - loss: 2.0714 - val_loss: 2.1007 - 143ms/epoch - 143ms/step\n",
      "Epoch 370/2000\n",
      "1/1 - 0s - loss: 2.0708 - val_loss: 2.1001 - 146ms/epoch - 146ms/step\n",
      "Epoch 371/2000\n",
      "1/1 - 0s - loss: 2.0702 - val_loss: 2.0994 - 146ms/epoch - 146ms/step\n",
      "Epoch 372/2000\n",
      "1/1 - 0s - loss: 2.0696 - val_loss: 2.0992 - 131ms/epoch - 131ms/step\n",
      "Epoch 373/2000\n",
      "1/1 - 0s - loss: 2.0691 - val_loss: 2.0984 - 141ms/epoch - 141ms/step\n",
      "Epoch 374/2000\n",
      "1/1 - 0s - loss: 2.0685 - val_loss: 2.0981 - 153ms/epoch - 153ms/step\n",
      "Epoch 375/2000\n",
      "1/1 - 0s - loss: 2.0678 - val_loss: 2.0974 - 131ms/epoch - 131ms/step\n",
      "Epoch 376/2000\n",
      "1/1 - 0s - loss: 2.0672 - val_loss: 2.0970 - 131ms/epoch - 131ms/step\n",
      "Epoch 377/2000\n",
      "1/1 - 0s - loss: 2.0665 - val_loss: 2.0963 - 132ms/epoch - 132ms/step\n",
      "Epoch 378/2000\n",
      "1/1 - 0s - loss: 2.0659 - val_loss: 2.0959 - 133ms/epoch - 133ms/step\n",
      "Epoch 379/2000\n",
      "1/1 - 0s - loss: 2.0653 - val_loss: 2.0954 - 134ms/epoch - 134ms/step\n",
      "Epoch 380/2000\n",
      "1/1 - 0s - loss: 2.0647 - val_loss: 2.0948 - 135ms/epoch - 135ms/step\n",
      "Epoch 381/2000\n",
      "1/1 - 0s - loss: 2.0642 - val_loss: 2.0944 - 138ms/epoch - 138ms/step\n",
      "Epoch 382/2000\n",
      "1/1 - 0s - loss: 2.0636 - val_loss: 2.0938 - 138ms/epoch - 138ms/step\n",
      "Epoch 383/2000\n",
      "1/1 - 0s - loss: 2.0631 - val_loss: 2.0935 - 126ms/epoch - 126ms/step\n",
      "Epoch 384/2000\n",
      "1/1 - 0s - loss: 2.0625 - val_loss: 2.0929 - 128ms/epoch - 128ms/step\n",
      "Epoch 385/2000\n",
      "1/1 - 0s - loss: 2.0620 - val_loss: 2.0926 - 136ms/epoch - 136ms/step\n",
      "Epoch 386/2000\n",
      "1/1 - 0s - loss: 2.0614 - val_loss: 2.0919 - 130ms/epoch - 130ms/step\n",
      "Epoch 387/2000\n",
      "1/1 - 0s - loss: 2.0608 - val_loss: 2.0917 - 130ms/epoch - 130ms/step\n",
      "Epoch 388/2000\n",
      "1/1 - 0s - loss: 2.0603 - val_loss: 2.0910 - 126ms/epoch - 126ms/step\n",
      "Epoch 389/2000\n",
      "1/1 - 0s - loss: 2.0597 - val_loss: 2.0906 - 130ms/epoch - 130ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 390/2000\n",
      "1/1 - 0s - loss: 2.0591 - val_loss: 2.0901 - 130ms/epoch - 130ms/step\n",
      "Epoch 391/2000\n",
      "1/1 - 0s - loss: 2.0586 - val_loss: 2.0896 - 127ms/epoch - 127ms/step\n",
      "Epoch 392/2000\n",
      "1/1 - 0s - loss: 2.0580 - val_loss: 2.0892 - 140ms/epoch - 140ms/step\n",
      "Epoch 393/2000\n",
      "1/1 - 0s - loss: 2.0575 - val_loss: 2.0887 - 131ms/epoch - 131ms/step\n",
      "Epoch 394/2000\n",
      "1/1 - 0s - loss: 2.0570 - val_loss: 2.0883 - 121ms/epoch - 121ms/step\n",
      "Epoch 395/2000\n",
      "1/1 - 0s - loss: 2.0565 - val_loss: 2.0878 - 126ms/epoch - 126ms/step\n",
      "Epoch 396/2000\n",
      "1/1 - 0s - loss: 2.0560 - val_loss: 2.0875 - 132ms/epoch - 132ms/step\n",
      "Epoch 397/2000\n",
      "1/1 - 0s - loss: 2.0555 - val_loss: 2.0869 - 131ms/epoch - 131ms/step\n",
      "Epoch 398/2000\n",
      "1/1 - 0s - loss: 2.0550 - val_loss: 2.0868 - 128ms/epoch - 128ms/step\n",
      "Epoch 399/2000\n",
      "1/1 - 0s - loss: 2.0545 - val_loss: 2.0862 - 125ms/epoch - 125ms/step\n",
      "Epoch 400/2000\n",
      "1/1 - 0s - loss: 2.0541 - val_loss: 2.0861 - 136ms/epoch - 136ms/step\n",
      "Epoch 401/2000\n",
      "1/1 - 0s - loss: 2.0536 - val_loss: 2.0854 - 151ms/epoch - 151ms/step\n",
      "Epoch 402/2000\n",
      "1/1 - 0s - loss: 2.0532 - val_loss: 2.0856 - 125ms/epoch - 125ms/step\n",
      "Epoch 403/2000\n",
      "1/1 - 0s - loss: 2.0528 - val_loss: 2.0846 - 127ms/epoch - 127ms/step\n",
      "Epoch 404/2000\n",
      "1/1 - 0s - loss: 2.0523 - val_loss: 2.0846 - 125ms/epoch - 125ms/step\n",
      "Epoch 405/2000\n",
      "1/1 - 0s - loss: 2.0518 - val_loss: 2.0837 - 128ms/epoch - 128ms/step\n",
      "Epoch 406/2000\n",
      "1/1 - 0s - loss: 2.0511 - val_loss: 2.0834 - 118ms/epoch - 118ms/step\n",
      "Epoch 407/2000\n",
      "1/1 - 0s - loss: 2.0505 - val_loss: 2.0828 - 119ms/epoch - 119ms/step\n",
      "Epoch 408/2000\n",
      "1/1 - 0s - loss: 2.0500 - val_loss: 2.0825 - 113ms/epoch - 113ms/step\n",
      "Epoch 409/2000\n",
      "1/1 - 0s - loss: 2.0495 - val_loss: 2.0823 - 113ms/epoch - 113ms/step\n",
      "Epoch 410/2000\n",
      "1/1 - 0s - loss: 2.0491 - val_loss: 2.0818 - 111ms/epoch - 111ms/step\n",
      "Epoch 411/2000\n",
      "1/1 - 0s - loss: 2.0487 - val_loss: 2.0819 - 106ms/epoch - 106ms/step\n",
      "Epoch 412/2000\n",
      "1/1 - 0s - loss: 2.0484 - val_loss: 2.0811 - 119ms/epoch - 119ms/step\n",
      "Epoch 413/2000\n",
      "1/1 - 0s - loss: 2.0480 - val_loss: 2.0812 - 122ms/epoch - 122ms/step\n",
      "Epoch 414/2000\n",
      "1/1 - 0s - loss: 2.0475 - val_loss: 2.0803 - 134ms/epoch - 134ms/step\n",
      "Epoch 415/2000\n",
      "1/1 - 0s - loss: 2.0469 - val_loss: 2.0800 - 135ms/epoch - 135ms/step\n",
      "Epoch 416/2000\n",
      "1/1 - 0s - loss: 2.0463 - val_loss: 2.0796 - 141ms/epoch - 141ms/step\n",
      "Epoch 417/2000\n",
      "1/1 - 0s - loss: 2.0458 - val_loss: 2.0791 - 129ms/epoch - 129ms/step\n",
      "Epoch 418/2000\n",
      "1/1 - 0s - loss: 2.0454 - val_loss: 2.0790 - 128ms/epoch - 128ms/step\n",
      "Epoch 419/2000\n",
      "1/1 - 0s - loss: 2.0451 - val_loss: 2.0785 - 140ms/epoch - 140ms/step\n",
      "Epoch 420/2000\n",
      "1/1 - 0s - loss: 2.0447 - val_loss: 2.0785 - 131ms/epoch - 131ms/step\n",
      "Epoch 421/2000\n",
      "1/1 - 0s - loss: 2.0443 - val_loss: 2.0778 - 125ms/epoch - 125ms/step\n",
      "Epoch 422/2000\n",
      "1/1 - 0s - loss: 2.0438 - val_loss: 2.0777 - 119ms/epoch - 119ms/step\n",
      "Epoch 423/2000\n",
      "1/1 - 0s - loss: 2.0433 - val_loss: 2.0771 - 123ms/epoch - 123ms/step\n",
      "Epoch 424/2000\n",
      "1/1 - 0s - loss: 2.0428 - val_loss: 2.0768 - 130ms/epoch - 130ms/step\n",
      "Epoch 425/2000\n",
      "1/1 - 0s - loss: 2.0423 - val_loss: 2.0765 - 137ms/epoch - 137ms/step\n",
      "Epoch 426/2000\n",
      "1/1 - 0s - loss: 2.0419 - val_loss: 2.0761 - 135ms/epoch - 135ms/step\n",
      "Epoch 427/2000\n",
      "1/1 - 0s - loss: 2.0415 - val_loss: 2.0760 - 136ms/epoch - 136ms/step\n",
      "Epoch 428/2000\n",
      "1/1 - 0s - loss: 2.0412 - val_loss: 2.0755 - 137ms/epoch - 137ms/step\n",
      "Epoch 429/2000\n",
      "1/1 - 0s - loss: 2.0408 - val_loss: 2.0754 - 135ms/epoch - 135ms/step\n",
      "Epoch 430/2000\n",
      "1/1 - 0s - loss: 2.0404 - val_loss: 2.0748 - 142ms/epoch - 142ms/step\n",
      "Epoch 431/2000\n",
      "1/1 - 0s - loss: 2.0400 - val_loss: 2.0748 - 136ms/epoch - 136ms/step\n",
      "Epoch 432/2000\n",
      "1/1 - 0s - loss: 2.0396 - val_loss: 2.0741 - 144ms/epoch - 144ms/step\n",
      "Epoch 433/2000\n",
      "1/1 - 0s - loss: 2.0391 - val_loss: 2.0739 - 145ms/epoch - 145ms/step\n",
      "Epoch 434/2000\n",
      "1/1 - 0s - loss: 2.0387 - val_loss: 2.0736 - 129ms/epoch - 129ms/step\n",
      "Epoch 435/2000\n",
      "1/1 - 0s - loss: 2.0383 - val_loss: 2.0732 - 140ms/epoch - 140ms/step\n",
      "Epoch 436/2000\n",
      "1/1 - 0s - loss: 2.0379 - val_loss: 2.0731 - 138ms/epoch - 138ms/step\n",
      "Epoch 437/2000\n",
      "1/1 - 0s - loss: 2.0375 - val_loss: 2.0727 - 153ms/epoch - 153ms/step\n",
      "Epoch 438/2000\n",
      "1/1 - 0s - loss: 2.0372 - val_loss: 2.0726 - 134ms/epoch - 134ms/step\n",
      "Epoch 439/2000\n",
      "1/1 - 0s - loss: 2.0369 - val_loss: 2.0721 - 138ms/epoch - 138ms/step\n",
      "Epoch 440/2000\n",
      "1/1 - 0s - loss: 2.0365 - val_loss: 2.0722 - 145ms/epoch - 145ms/step\n",
      "Epoch 441/2000\n",
      "1/1 - 0s - loss: 2.0362 - val_loss: 2.0715 - 159ms/epoch - 159ms/step\n",
      "Epoch 442/2000\n",
      "1/1 - 0s - loss: 2.0358 - val_loss: 2.0716 - 140ms/epoch - 140ms/step\n",
      "Epoch 443/2000\n",
      "1/1 - 0s - loss: 2.0354 - val_loss: 2.0710 - 144ms/epoch - 144ms/step\n",
      "Epoch 444/2000\n",
      "1/1 - 0s - loss: 2.0350 - val_loss: 2.0707 - 180ms/epoch - 180ms/step\n",
      "Epoch 445/2000\n",
      "1/1 - 0s - loss: 2.0345 - val_loss: 2.0704 - 157ms/epoch - 157ms/step\n",
      "Epoch 446/2000\n",
      "1/1 - 0s - loss: 2.0341 - val_loss: 2.0702 - 130ms/epoch - 130ms/step\n",
      "Epoch 447/2000\n",
      "1/1 - 0s - loss: 2.0337 - val_loss: 2.0700 - 143ms/epoch - 143ms/step\n",
      "Epoch 448/2000\n",
      "1/1 - 0s - loss: 2.0335 - val_loss: 2.0697 - 132ms/epoch - 132ms/step\n",
      "Epoch 449/2000\n",
      "1/1 - 0s - loss: 2.0332 - val_loss: 2.0698 - 127ms/epoch - 127ms/step\n",
      "Epoch 450/2000\n",
      "1/1 - 0s - loss: 2.0329 - val_loss: 2.0692 - 135ms/epoch - 135ms/step\n",
      "Epoch 451/2000\n",
      "1/1 - 0s - loss: 2.0326 - val_loss: 2.0695 - 164ms/epoch - 164ms/step\n",
      "Epoch 452/2000\n",
      "1/1 - 0s - loss: 2.0323 - val_loss: 2.0687 - 146ms/epoch - 146ms/step\n",
      "Epoch 453/2000\n",
      "1/1 - 0s - loss: 2.0319 - val_loss: 2.0685 - 152ms/epoch - 152ms/step\n",
      "Epoch 454/2000\n",
      "1/1 - 0s - loss: 2.0314 - val_loss: 2.0681 - 159ms/epoch - 159ms/step\n",
      "Epoch 455/2000\n",
      "1/1 - 0s - loss: 2.0310 - val_loss: 2.0678 - 151ms/epoch - 151ms/step\n",
      "Epoch 456/2000\n",
      "1/1 - 0s - loss: 2.0306 - val_loss: 2.0676 - 155ms/epoch - 155ms/step\n",
      "Epoch 457/2000\n",
      "1/1 - 0s - loss: 2.0303 - val_loss: 2.0674 - 161ms/epoch - 161ms/step\n",
      "Epoch 458/2000\n",
      "1/1 - 0s - loss: 2.0300 - val_loss: 2.0674 - 151ms/epoch - 151ms/step\n",
      "Epoch 459/2000\n",
      "1/1 - 0s - loss: 2.0298 - val_loss: 2.0669 - 150ms/epoch - 150ms/step\n",
      "Epoch 460/2000\n",
      "1/1 - 0s - loss: 2.0295 - val_loss: 2.0673 - 149ms/epoch - 149ms/step\n",
      "Epoch 461/2000\n",
      "1/1 - 0s - loss: 2.0293 - val_loss: 2.0664 - 157ms/epoch - 157ms/step\n",
      "Epoch 462/2000\n",
      "1/1 - 0s - loss: 2.0288 - val_loss: 2.0662 - 138ms/epoch - 138ms/step\n",
      "Epoch 463/2000\n",
      "1/1 - 0s - loss: 2.0284 - val_loss: 2.0659 - 133ms/epoch - 133ms/step\n",
      "Epoch 464/2000\n",
      "1/1 - 0s - loss: 2.0280 - val_loss: 2.0655 - 129ms/epoch - 129ms/step\n",
      "Epoch 465/2000\n",
      "1/1 - 0s - loss: 2.0276 - val_loss: 2.0655 - 116ms/epoch - 116ms/step\n",
      "Epoch 466/2000\n",
      "1/1 - 0s - loss: 2.0274 - val_loss: 2.0652 - 121ms/epoch - 121ms/step\n",
      "Epoch 467/2000\n",
      "1/1 - 0s - loss: 2.0271 - val_loss: 2.0653 - 117ms/epoch - 117ms/step\n",
      "Epoch 468/2000\n",
      "1/1 - 0s - loss: 2.0269 - val_loss: 2.0648 - 136ms/epoch - 136ms/step\n",
      "Epoch 469/2000\n",
      "1/1 - 0s - loss: 2.0265 - val_loss: 2.0649 - 124ms/epoch - 124ms/step\n",
      "Epoch 470/2000\n",
      "1/1 - 0s - loss: 2.0262 - val_loss: 2.0643 - 135ms/epoch - 135ms/step\n",
      "Epoch 471/2000\n",
      "1/1 - 0s - loss: 2.0258 - val_loss: 2.0641 - 148ms/epoch - 148ms/step\n",
      "Epoch 472/2000\n",
      "1/1 - 0s - loss: 2.0254 - val_loss: 2.0639 - 132ms/epoch - 132ms/step\n",
      "Epoch 473/2000\n",
      "1/1 - 0s - loss: 2.0251 - val_loss: 2.0635 - 137ms/epoch - 137ms/step\n",
      "Epoch 474/2000\n",
      "1/1 - 0s - loss: 2.0249 - val_loss: 2.0636 - 129ms/epoch - 129ms/step\n",
      "Epoch 475/2000\n",
      "1/1 - 0s - loss: 2.0247 - val_loss: 2.0632 - 160ms/epoch - 160ms/step\n",
      "Epoch 476/2000\n",
      "1/1 - 0s - loss: 2.0243 - val_loss: 2.0630 - 133ms/epoch - 133ms/step\n",
      "Epoch 477/2000\n",
      "1/1 - 0s - loss: 2.0240 - val_loss: 2.0627 - 162ms/epoch - 162ms/step\n",
      "Epoch 478/2000\n",
      "1/1 - 0s - loss: 2.0236 - val_loss: 2.0625 - 132ms/epoch - 132ms/step\n",
      "Epoch 479/2000\n",
      "1/1 - 0s - loss: 2.0233 - val_loss: 2.0625 - 116ms/epoch - 116ms/step\n",
      "Epoch 480/2000\n",
      "1/1 - 0s - loss: 2.0231 - val_loss: 2.0623 - 126ms/epoch - 126ms/step\n",
      "Epoch 481/2000\n",
      "1/1 - 0s - loss: 2.0228 - val_loss: 2.0624 - 113ms/epoch - 113ms/step\n",
      "Epoch 482/2000\n",
      "1/1 - 0s - loss: 2.0226 - val_loss: 2.0619 - 123ms/epoch - 123ms/step\n",
      "Epoch 483/2000\n",
      "1/1 - 0s - loss: 2.0222 - val_loss: 2.0618 - 109ms/epoch - 109ms/step\n",
      "Epoch 484/2000\n",
      "1/1 - 0s - loss: 2.0219 - val_loss: 2.0614 - 108ms/epoch - 108ms/step\n",
      "Epoch 485/2000\n",
      "1/1 - 0s - loss: 2.0216 - val_loss: 2.0611 - 122ms/epoch - 122ms/step\n",
      "Epoch 486/2000\n",
      "1/1 - 0s - loss: 2.0214 - val_loss: 2.0613 - 110ms/epoch - 110ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 487/2000\n",
      "1/1 - 0s - loss: 2.0211 - val_loss: 2.0608 - 122ms/epoch - 122ms/step\n",
      "Epoch 488/2000\n",
      "1/1 - 0s - loss: 2.0209 - val_loss: 2.0609 - 103ms/epoch - 103ms/step\n",
      "Epoch 489/2000\n",
      "1/1 - 0s - loss: 2.0206 - val_loss: 2.0606 - 114ms/epoch - 114ms/step\n",
      "Epoch 490/2000\n",
      "1/1 - 0s - loss: 2.0203 - val_loss: 2.0604 - 134ms/epoch - 134ms/step\n",
      "Epoch 491/2000\n",
      "1/1 - 0s - loss: 2.0200 - val_loss: 2.0602 - 122ms/epoch - 122ms/step\n",
      "Epoch 492/2000\n",
      "1/1 - 0s - loss: 2.0197 - val_loss: 2.0600 - 125ms/epoch - 125ms/step\n",
      "Epoch 493/2000\n",
      "1/1 - 0s - loss: 2.0194 - val_loss: 2.0599 - 147ms/epoch - 147ms/step\n",
      "Epoch 494/2000\n",
      "1/1 - 0s - loss: 2.0192 - val_loss: 2.0597 - 126ms/epoch - 126ms/step\n",
      "Epoch 495/2000\n",
      "1/1 - 0s - loss: 2.0189 - val_loss: 2.0597 - 116ms/epoch - 116ms/step\n",
      "Epoch 496/2000\n",
      "1/1 - 0s - loss: 2.0187 - val_loss: 2.0593 - 119ms/epoch - 119ms/step\n",
      "Epoch 497/2000\n",
      "1/1 - 0s - loss: 2.0184 - val_loss: 2.0593 - 147ms/epoch - 147ms/step\n",
      "Epoch 498/2000\n",
      "1/1 - 0s - loss: 2.0181 - val_loss: 2.0590 - 114ms/epoch - 114ms/step\n",
      "Epoch 499/2000\n",
      "1/1 - 0s - loss: 2.0178 - val_loss: 2.0588 - 141ms/epoch - 141ms/step\n",
      "Epoch 500/2000\n",
      "1/1 - 0s - loss: 2.0176 - val_loss: 2.0587 - 122ms/epoch - 122ms/step\n",
      "Epoch 501/2000\n",
      "1/1 - 0s - loss: 2.0173 - val_loss: 2.0584 - 129ms/epoch - 129ms/step\n",
      "Epoch 502/2000\n",
      "1/1 - 0s - loss: 2.0171 - val_loss: 2.0584 - 119ms/epoch - 119ms/step\n",
      "Epoch 503/2000\n",
      "1/1 - 0s - loss: 2.0168 - val_loss: 2.0581 - 124ms/epoch - 124ms/step\n",
      "Epoch 504/2000\n",
      "1/1 - 0s - loss: 2.0165 - val_loss: 2.0580 - 163ms/epoch - 163ms/step\n",
      "Epoch 505/2000\n",
      "1/1 - 0s - loss: 2.0163 - val_loss: 2.0578 - 165ms/epoch - 165ms/step\n",
      "Epoch 506/2000\n",
      "1/1 - 0s - loss: 2.0160 - val_loss: 2.0576 - 163ms/epoch - 163ms/step\n",
      "Epoch 507/2000\n",
      "1/1 - 0s - loss: 2.0158 - val_loss: 2.0575 - 162ms/epoch - 162ms/step\n",
      "Epoch 508/2000\n",
      "1/1 - 0s - loss: 2.0155 - val_loss: 2.0573 - 169ms/epoch - 169ms/step\n",
      "Epoch 509/2000\n",
      "1/1 - 0s - loss: 2.0153 - val_loss: 2.0571 - 162ms/epoch - 162ms/step\n",
      "Epoch 510/2000\n",
      "1/1 - 0s - loss: 2.0150 - val_loss: 2.0570 - 157ms/epoch - 157ms/step\n",
      "Epoch 511/2000\n",
      "1/1 - 0s - loss: 2.0148 - val_loss: 2.0569 - 158ms/epoch - 158ms/step\n",
      "Epoch 512/2000\n",
      "1/1 - 0s - loss: 2.0146 - val_loss: 2.0566 - 164ms/epoch - 164ms/step\n",
      "Epoch 513/2000\n",
      "1/1 - 0s - loss: 2.0143 - val_loss: 2.0568 - 148ms/epoch - 148ms/step\n",
      "Epoch 514/2000\n",
      "1/1 - 0s - loss: 2.0141 - val_loss: 2.0564 - 153ms/epoch - 153ms/step\n",
      "Epoch 515/2000\n",
      "1/1 - 0s - loss: 2.0139 - val_loss: 2.0564 - 153ms/epoch - 153ms/step\n",
      "Epoch 516/2000\n",
      "1/1 - 0s - loss: 2.0136 - val_loss: 2.0562 - 124ms/epoch - 124ms/step\n",
      "Epoch 517/2000\n",
      "1/1 - 0s - loss: 2.0134 - val_loss: 2.0560 - 120ms/epoch - 120ms/step\n",
      "Epoch 518/2000\n",
      "1/1 - 0s - loss: 2.0131 - val_loss: 2.0559 - 120ms/epoch - 120ms/step\n",
      "Epoch 519/2000\n",
      "1/1 - 0s - loss: 2.0129 - val_loss: 2.0557 - 122ms/epoch - 122ms/step\n",
      "Epoch 520/2000\n",
      "1/1 - 0s - loss: 2.0126 - val_loss: 2.0556 - 116ms/epoch - 116ms/step\n",
      "Epoch 521/2000\n",
      "1/1 - 0s - loss: 2.0124 - val_loss: 2.0555 - 129ms/epoch - 129ms/step\n",
      "Epoch 522/2000\n",
      "1/1 - 0s - loss: 2.0122 - val_loss: 2.0555 - 132ms/epoch - 132ms/step\n",
      "Epoch 523/2000\n",
      "1/1 - 0s - loss: 2.0120 - val_loss: 2.0554 - 133ms/epoch - 133ms/step\n",
      "Epoch 524/2000\n",
      "1/1 - 0s - loss: 2.0118 - val_loss: 2.0554 - 125ms/epoch - 125ms/step\n",
      "Epoch 525/2000\n",
      "1/1 - 0s - loss: 2.0115 - val_loss: 2.0552 - 127ms/epoch - 127ms/step\n",
      "Epoch 526/2000\n",
      "1/1 - 0s - loss: 2.0113 - val_loss: 2.0552 - 122ms/epoch - 122ms/step\n",
      "Epoch 527/2000\n",
      "1/1 - 0s - loss: 2.0111 - val_loss: 2.0550 - 133ms/epoch - 133ms/step\n",
      "Epoch 528/2000\n",
      "1/1 - 0s - loss: 2.0109 - val_loss: 2.0548 - 131ms/epoch - 131ms/step\n",
      "Epoch 529/2000\n",
      "1/1 - 0s - loss: 2.0106 - val_loss: 2.0547 - 144ms/epoch - 144ms/step\n",
      "Epoch 530/2000\n",
      "1/1 - 0s - loss: 2.0104 - val_loss: 2.0547 - 147ms/epoch - 147ms/step\n",
      "Epoch 531/2000\n",
      "1/1 - 0s - loss: 2.0102 - val_loss: 2.0545 - 186ms/epoch - 186ms/step\n",
      "Epoch 532/2000\n",
      "1/1 - 0s - loss: 2.0100 - val_loss: 2.0546 - 134ms/epoch - 134ms/step\n",
      "Epoch 533/2000\n",
      "1/1 - 0s - loss: 2.0098 - val_loss: 2.0544 - 138ms/epoch - 138ms/step\n",
      "Epoch 534/2000\n",
      "1/1 - 0s - loss: 2.0096 - val_loss: 2.0543 - 164ms/epoch - 164ms/step\n",
      "Epoch 535/2000\n",
      "1/1 - 0s - loss: 2.0093 - val_loss: 2.0543 - 149ms/epoch - 149ms/step\n",
      "Epoch 536/2000\n",
      "1/1 - 0s - loss: 2.0092 - val_loss: 2.0540 - 134ms/epoch - 134ms/step\n",
      "Epoch 537/2000\n",
      "1/1 - 0s - loss: 2.0089 - val_loss: 2.0540 - 127ms/epoch - 127ms/step\n",
      "Epoch 538/2000\n",
      "1/1 - 0s - loss: 2.0087 - val_loss: 2.0538 - 127ms/epoch - 127ms/step\n",
      "Epoch 539/2000\n",
      "1/1 - 0s - loss: 2.0085 - val_loss: 2.0537 - 161ms/epoch - 161ms/step\n",
      "Epoch 540/2000\n",
      "1/1 - 0s - loss: 2.0083 - val_loss: 2.0537 - 119ms/epoch - 119ms/step\n",
      "Epoch 541/2000\n",
      "1/1 - 0s - loss: 2.0081 - val_loss: 2.0533 - 159ms/epoch - 159ms/step\n",
      "Epoch 542/2000\n",
      "1/1 - 0s - loss: 2.0079 - val_loss: 2.0533 - 124ms/epoch - 124ms/step\n",
      "Epoch 543/2000\n",
      "1/1 - 0s - loss: 2.0077 - val_loss: 2.0532 - 162ms/epoch - 162ms/step\n",
      "Epoch 544/2000\n",
      "1/1 - 0s - loss: 2.0075 - val_loss: 2.0531 - 122ms/epoch - 122ms/step\n",
      "Epoch 545/2000\n",
      "1/1 - 0s - loss: 2.0073 - val_loss: 2.0530 - 165ms/epoch - 165ms/step\n",
      "Epoch 546/2000\n",
      "1/1 - 0s - loss: 2.0070 - val_loss: 2.0530 - 136ms/epoch - 136ms/step\n",
      "Epoch 547/2000\n",
      "1/1 - 0s - loss: 2.0069 - val_loss: 2.0529 - 136ms/epoch - 136ms/step\n",
      "Epoch 548/2000\n",
      "1/1 - 0s - loss: 2.0067 - val_loss: 2.0527 - 119ms/epoch - 119ms/step\n",
      "Epoch 549/2000\n",
      "1/1 - 0s - loss: 2.0065 - val_loss: 2.0526 - 128ms/epoch - 128ms/step\n",
      "Epoch 550/2000\n",
      "1/1 - 0s - loss: 2.0062 - val_loss: 2.0524 - 124ms/epoch - 124ms/step\n",
      "Epoch 551/2000\n",
      "1/1 - 0s - loss: 2.0060 - val_loss: 2.0525 - 114ms/epoch - 114ms/step\n",
      "Epoch 552/2000\n",
      "1/1 - 0s - loss: 2.0059 - val_loss: 2.0523 - 126ms/epoch - 126ms/step\n",
      "Epoch 553/2000\n",
      "1/1 - 0s - loss: 2.0057 - val_loss: 2.0524 - 154ms/epoch - 154ms/step\n",
      "Epoch 554/2000\n",
      "1/1 - 0s - loss: 2.0055 - val_loss: 2.0523 - 115ms/epoch - 115ms/step\n",
      "Epoch 555/2000\n",
      "1/1 - 0s - loss: 2.0053 - val_loss: 2.0522 - 125ms/epoch - 125ms/step\n",
      "Epoch 556/2000\n",
      "1/1 - 0s - loss: 2.0051 - val_loss: 2.0522 - 123ms/epoch - 123ms/step\n",
      "Epoch 557/2000\n",
      "1/1 - 0s - loss: 2.0049 - val_loss: 2.0520 - 125ms/epoch - 125ms/step\n",
      "Epoch 558/2000\n",
      "1/1 - 0s - loss: 2.0047 - val_loss: 2.0518 - 132ms/epoch - 132ms/step\n",
      "Epoch 559/2000\n",
      "1/1 - 0s - loss: 2.0045 - val_loss: 2.0519 - 106ms/epoch - 106ms/step\n",
      "Epoch 560/2000\n",
      "1/1 - 0s - loss: 2.0043 - val_loss: 2.0518 - 117ms/epoch - 117ms/step\n",
      "Epoch 561/2000\n",
      "1/1 - 0s - loss: 2.0041 - val_loss: 2.0518 - 117ms/epoch - 117ms/step\n",
      "Epoch 562/2000\n",
      "1/1 - 0s - loss: 2.0039 - val_loss: 2.0518 - 112ms/epoch - 112ms/step\n",
      "Epoch 563/2000\n",
      "1/1 - 0s - loss: 2.0038 - val_loss: 2.0516 - 115ms/epoch - 115ms/step\n",
      "Epoch 564/2000\n",
      "1/1 - 0s - loss: 2.0036 - val_loss: 2.0516 - 113ms/epoch - 113ms/step\n",
      "Epoch 565/2000\n",
      "1/1 - 0s - loss: 2.0034 - val_loss: 2.0516 - 145ms/epoch - 145ms/step\n",
      "Epoch 566/2000\n",
      "1/1 - 0s - loss: 2.0032 - val_loss: 2.0514 - 119ms/epoch - 119ms/step\n",
      "Epoch 567/2000\n",
      "1/1 - 0s - loss: 2.0031 - val_loss: 2.0515 - 110ms/epoch - 110ms/step\n",
      "Epoch 568/2000\n",
      "1/1 - 0s - loss: 2.0029 - val_loss: 2.0514 - 128ms/epoch - 128ms/step\n",
      "Epoch 569/2000\n",
      "1/1 - 0s - loss: 2.0027 - val_loss: 2.0513 - 131ms/epoch - 131ms/step\n",
      "Epoch 570/2000\n",
      "1/1 - 0s - loss: 2.0025 - val_loss: 2.0514 - 114ms/epoch - 114ms/step\n",
      "Epoch 571/2000\n",
      "1/1 - 0s - loss: 2.0023 - val_loss: 2.0512 - 121ms/epoch - 121ms/step\n",
      "Epoch 572/2000\n",
      "1/1 - 0s - loss: 2.0022 - val_loss: 2.0512 - 112ms/epoch - 112ms/step\n",
      "Epoch 573/2000\n",
      "1/1 - 0s - loss: 2.0020 - val_loss: 2.0510 - 114ms/epoch - 114ms/step\n",
      "Epoch 574/2000\n",
      "1/1 - 0s - loss: 2.0018 - val_loss: 2.0510 - 116ms/epoch - 116ms/step\n",
      "Epoch 575/2000\n",
      "1/1 - 0s - loss: 2.0016 - val_loss: 2.0511 - 107ms/epoch - 107ms/step\n",
      "Epoch 576/2000\n",
      "1/1 - 0s - loss: 2.0014 - val_loss: 2.0509 - 123ms/epoch - 123ms/step\n",
      "Epoch 577/2000\n",
      "1/1 - 0s - loss: 2.0013 - val_loss: 2.0511 - 110ms/epoch - 110ms/step\n",
      "Epoch 578/2000\n",
      "1/1 - 0s - loss: 2.0011 - val_loss: 2.0508 - 118ms/epoch - 118ms/step\n",
      "Epoch 579/2000\n",
      "1/1 - 0s - loss: 2.0009 - val_loss: 2.0508 - 135ms/epoch - 135ms/step\n",
      "Epoch 580/2000\n",
      "1/1 - 0s - loss: 2.0007 - val_loss: 2.0508 - 115ms/epoch - 115ms/step\n",
      "Epoch 581/2000\n",
      "1/1 - 0s - loss: 2.0005 - val_loss: 2.0508 - 108ms/epoch - 108ms/step\n",
      "Epoch 582/2000\n",
      "1/1 - 0s - loss: 2.0005 - val_loss: 2.0509 - 107ms/epoch - 107ms/step\n",
      "Epoch 583/2000\n",
      "1/1 - 0s - loss: 2.0002 - val_loss: 2.0508 - 113ms/epoch - 113ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 584/2000\n",
      "1/1 - 0s - loss: 2.0000 - val_loss: 2.0509 - 108ms/epoch - 108ms/step\n",
      "Epoch 585/2000\n",
      "1/1 - 0s - loss: 1.9998 - val_loss: 2.0510 - 104ms/epoch - 104ms/step\n",
      "Epoch 586/2000\n",
      "1/1 - 0s - loss: 1.9997 - val_loss: 2.0508 - 114ms/epoch - 114ms/step\n",
      "Epoch 587/2000\n",
      "1/1 - 0s - loss: 1.9995 - val_loss: 2.0508 - 104ms/epoch - 104ms/step\n",
      "Epoch 588/2000\n",
      "1/1 - 0s - loss: 1.9993 - val_loss: 2.0504 - 124ms/epoch - 124ms/step\n",
      "Epoch 589/2000\n",
      "1/1 - 0s - loss: 1.9991 - val_loss: 2.0504 - 115ms/epoch - 115ms/step\n",
      "Epoch 590/2000\n",
      "1/1 - 0s - loss: 1.9990 - val_loss: 2.0506 - 112ms/epoch - 112ms/step\n",
      "Epoch 591/2000\n",
      "1/1 - 0s - loss: 1.9988 - val_loss: 2.0504 - 129ms/epoch - 129ms/step\n",
      "Epoch 592/2000\n",
      "1/1 - 0s - loss: 1.9988 - val_loss: 2.0508 - 125ms/epoch - 125ms/step\n",
      "Epoch 593/2000\n",
      "1/1 - 0s - loss: 1.9987 - val_loss: 2.0505 - 135ms/epoch - 135ms/step\n",
      "Epoch 594/2000\n",
      "1/1 - 0s - loss: 1.9984 - val_loss: 2.0503 - 141ms/epoch - 141ms/step\n",
      "Epoch 595/2000\n",
      "1/1 - 0s - loss: 1.9981 - val_loss: 2.0506 - 162ms/epoch - 162ms/step\n",
      "Epoch 596/2000\n",
      "1/1 - 0s - loss: 1.9980 - val_loss: 2.0505 - 132ms/epoch - 132ms/step\n",
      "Epoch 597/2000\n",
      "1/1 - 0s - loss: 1.9980 - val_loss: 2.0508 - 136ms/epoch - 136ms/step\n",
      "Epoch 598/2000\n",
      "1/1 - 0s - loss: 1.9978 - val_loss: 2.0505 - 135ms/epoch - 135ms/step\n",
      "Epoch 599/2000\n",
      "1/1 - 0s - loss: 1.9975 - val_loss: 2.0503 - 169ms/epoch - 169ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 10\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "position_embeddings = PositionEmbedding(sequence_length=len(x_masked_train[0]))(word_embeddings)\n",
    "encoder_output = word_embeddings + position_embeddings\n",
    "\n",
    "for i in range(1):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "aa33aaff",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "fe4d2103",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.0, 0.92229855)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "0005b730",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 1s - loss: 3.7903 - val_loss: 3.3893 - 1s/epoch - 1s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.3982 - val_loss: 3.2298 - 76ms/epoch - 76ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.2282 - val_loss: 3.1684 - 78ms/epoch - 78ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.1586 - val_loss: 3.1380 - 81ms/epoch - 81ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.1217 - val_loss: 3.1117 - 74ms/epoch - 74ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.0913 - val_loss: 3.0886 - 75ms/epoch - 75ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.0656 - val_loss: 3.0736 - 78ms/epoch - 78ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.0494 - val_loss: 3.0661 - 72ms/epoch - 72ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.0413 - val_loss: 3.0600 - 80ms/epoch - 80ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 3.0347 - val_loss: 3.0518 - 79ms/epoch - 79ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 3.0261 - val_loss: 3.0427 - 71ms/epoch - 71ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 3.0164 - val_loss: 3.0350 - 75ms/epoch - 75ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 3.0079 - val_loss: 3.0295 - 75ms/epoch - 75ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 3.0013 - val_loss: 3.0261 - 71ms/epoch - 71ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 2.9963 - val_loss: 3.0240 - 85ms/epoch - 85ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 2.9923 - val_loss: 3.0226 - 73ms/epoch - 73ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 2.9886 - val_loss: 3.0210 - 72ms/epoch - 72ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 2.9844 - val_loss: 3.0188 - 77ms/epoch - 77ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 2.9796 - val_loss: 3.0156 - 79ms/epoch - 79ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 2.9740 - val_loss: 3.0120 - 76ms/epoch - 76ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 2.9684 - val_loss: 3.0089 - 79ms/epoch - 79ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 2.9635 - val_loss: 3.0068 - 76ms/epoch - 76ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 2.9600 - val_loss: 3.0059 - 75ms/epoch - 75ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 2.9580 - val_loss: 3.0056 - 77ms/epoch - 77ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 2.9567 - val_loss: 3.0050 - 74ms/epoch - 74ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 2.9553 - val_loss: 3.0037 - 74ms/epoch - 74ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 2.9532 - val_loss: 3.0019 - 76ms/epoch - 76ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.9505 - val_loss: 2.9998 - 71ms/epoch - 71ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.9475 - val_loss: 2.9978 - 76ms/epoch - 76ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.9446 - val_loss: 2.9962 - 84ms/epoch - 84ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.9419 - val_loss: 2.9950 - 77ms/epoch - 77ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.9395 - val_loss: 2.9941 - 78ms/epoch - 78ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.9372 - val_loss: 2.9932 - 77ms/epoch - 77ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.9348 - val_loss: 2.9921 - 71ms/epoch - 71ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.9320 - val_loss: 2.9906 - 75ms/epoch - 75ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.9288 - val_loss: 2.9888 - 74ms/epoch - 74ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.9251 - val_loss: 2.9868 - 72ms/epoch - 72ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.9212 - val_loss: 2.9847 - 77ms/epoch - 77ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.9171 - val_loss: 2.9825 - 78ms/epoch - 78ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.9129 - val_loss: 2.9798 - 74ms/epoch - 74ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.9084 - val_loss: 2.9767 - 80ms/epoch - 80ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.9035 - val_loss: 2.9730 - 75ms/epoch - 75ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.8979 - val_loss: 2.9686 - 72ms/epoch - 72ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.8916 - val_loss: 2.9635 - 76ms/epoch - 76ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.8845 - val_loss: 2.9575 - 75ms/epoch - 75ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.8763 - val_loss: 2.9505 - 73ms/epoch - 73ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.8669 - val_loss: 2.9422 - 78ms/epoch - 78ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.8561 - val_loss: 2.9323 - 73ms/epoch - 73ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.8435 - val_loss: 2.9205 - 74ms/epoch - 74ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.8289 - val_loss: 2.9065 - 78ms/epoch - 78ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.8119 - val_loss: 2.8896 - 74ms/epoch - 74ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.7920 - val_loss: 2.8694 - 73ms/epoch - 73ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.7687 - val_loss: 2.8452 - 77ms/epoch - 77ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.7412 - val_loss: 2.8161 - 70ms/epoch - 70ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.7090 - val_loss: 2.7816 - 76ms/epoch - 76ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.6714 - val_loss: 2.7409 - 79ms/epoch - 79ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.6278 - val_loss: 2.6934 - 72ms/epoch - 72ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.5776 - val_loss: 2.6387 - 76ms/epoch - 76ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.5204 - val_loss: 2.5776 - 76ms/epoch - 76ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.4574 - val_loss: 2.5112 - 75ms/epoch - 75ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.3903 - val_loss: 2.4419 - 76ms/epoch - 76ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.3220 - val_loss: 2.3732 - 76ms/epoch - 76ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.2562 - val_loss: 2.3094 - 71ms/epoch - 71ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.1967 - val_loss: 2.2545 - 76ms/epoch - 76ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.1465 - val_loss: 2.2103 - 78ms/epoch - 78ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.1073 - val_loss: 2.1766 - 72ms/epoch - 72ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.0784 - val_loss: 2.1522 - 91ms/epoch - 91ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.0581 - val_loss: 2.1346 - 74ms/epoch - 74ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.0437 - val_loss: 2.1214 - 74ms/epoch - 74ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.0333 - val_loss: 2.1111 - 81ms/epoch - 81ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.0253 - val_loss: 2.1026 - 75ms/epoch - 75ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.0186 - val_loss: 2.0954 - 73ms/epoch - 73ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.0129 - val_loss: 2.0893 - 78ms/epoch - 78ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.0079 - val_loss: 2.0838 - 73ms/epoch - 73ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.0034 - val_loss: 2.0789 - 75ms/epoch - 75ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 1.9993 - val_loss: 2.0745 - 80ms/epoch - 80ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 1.9956 - val_loss: 2.0705 - 73ms/epoch - 73ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 1.9921 - val_loss: 2.0669 - 76ms/epoch - 76ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 1.9889 - val_loss: 2.0637 - 79ms/epoch - 79ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 1.9860 - val_loss: 2.0609 - 74ms/epoch - 74ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 1.9832 - val_loss: 2.0585 - 76ms/epoch - 76ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 1.9807 - val_loss: 2.0564 - 75ms/epoch - 75ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 1.9783 - val_loss: 2.0544 - 73ms/epoch - 73ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 1.9760 - val_loss: 2.0527 - 77ms/epoch - 77ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 1.9738 - val_loss: 2.0512 - 74ms/epoch - 74ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 1.9717 - val_loss: 2.0499 - 71ms/epoch - 71ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 1.9697 - val_loss: 2.0488 - 78ms/epoch - 78ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 1.9677 - val_loss: 2.0478 - 76ms/epoch - 76ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 1.9659 - val_loss: 2.0471 - 69ms/epoch - 69ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 1.9641 - val_loss: 2.0465 - 77ms/epoch - 77ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 1.9623 - val_loss: 2.0461 - 74ms/epoch - 74ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 1.9606 - val_loss: 2.0458 - 74ms/epoch - 74ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 1.9590 - val_loss: 2.0458 - 77ms/epoch - 77ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 1.9574 - val_loss: 2.0460 - 71ms/epoch - 71ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 1.9559 - val_loss: 2.0465 - 70ms/epoch - 70ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 1.9544 - val_loss: 2.0472 - 74ms/epoch - 74ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 1.9529 - val_loss: 2.0480 - 74ms/epoch - 74ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 1.9515 - val_loss: 2.0487 - 73ms/epoch - 73ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 100\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "position_embeddings = PositionEmbedding(sequence_length=len(x_masked_train[0]))(word_embeddings)\n",
    "encoder_output = word_embeddings + position_embeddings\n",
    "\n",
    "for i in range(1):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "beb07cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "dbdd4a69",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.0, 0.97649634)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54f31bac",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
