{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d1686d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "56c107e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "504a342e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_module(query, key, value, embed_dim, num_head, i):\n",
    "    \n",
    "    # Multi headed self-attention\n",
    "    attention_output = layers.MultiHeadAttention(\n",
    "        num_heads=num_head,\n",
    "        key_dim=embed_dim // num_head,\n",
    "        name=\"encoder_{}/multiheadattention\".format(i)\n",
    "    )(query, key, value, use_causal_mask=True)\n",
    "    \n",
    "    # Add & Normalize\n",
    "    attention_output = layers.Add()([query, attention_output])  # Skip Connection\n",
    "    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)\n",
    "    \n",
    "    # Feedforward network\n",
    "    ff_net = keras.models.Sequential([\n",
    "        layers.Dense(2 * embed_dim, activation='relu', name=\"encoder_{}/ffn_dense_1\".format(i)),\n",
    "        layers.Dense(embed_dim, name=\"encoder_{}/ffn_dense_2\".format(i)),\n",
    "    ])\n",
    "\n",
    "    # Apply Feedforward network\n",
    "    ffn_output = ff_net(attention_output)\n",
    "\n",
    "    # Add & Normalize\n",
    "    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection\n",
    "    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)\n",
    "    \n",
    "    return ffn_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fef2622a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sinusoidal_embeddings(sequence_length, embedding_dim):\n",
    "    position_enc = np.array([\n",
    "        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]\n",
    "        if pos != 0 else np.zeros(embedding_dim)\n",
    "        for pos in range(sequence_length)\n",
    "    ])\n",
    "    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i\n",
    "    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1\n",
    "    return tf.cast(position_enc, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3e5bfd69",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 20 # vocab_size\n",
    "\n",
    "vocabs = ['word_' + str(i) for i in range(N)]\n",
    "\n",
    "vocab_map = {}\n",
    "for i in range(len(vocabs)):\n",
    "    vocab_map[vocabs[i]] = i\n",
    "    \n",
    "pairs = []\n",
    "\n",
    "for i in vocabs:\n",
    "    for j in vocabs:\n",
    "        for k in vocabs:\n",
    "            if i != j and i != k and j != k:\n",
    "                pairs.append((i,j,k))\n",
    "            \n",
    "indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])\n",
    "\n",
    "pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]\n",
    "pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d43093fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences_train = []\n",
    "sentences_number_train = []\n",
    "sentences_test = []\n",
    "sentences_number_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    sentences_train.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    sentences_test.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = []\n",
    "y_masked_labels_train = []\n",
    "x_masked_test = []\n",
    "y_masked_labels_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    x_masked_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_train.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    x_masked_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_test.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = np.array(x_masked_train)\n",
    "y_masked_labels_train = np.array(y_masked_labels_train)\n",
    "x_masked_test = np.array(x_masked_test)\n",
    "y_masked_labels_test = np.array(y_masked_labels_test)\n",
    "\n",
    "perm = np.random.permutation(len(x_masked_train))\n",
    "x_masked_train = x_masked_train[perm]\n",
    "y_masked_labels_train = y_masked_labels_train[perm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "13b40f89",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-22 07:04:06.397947: W tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory\n",
      "2024-05-22 07:04:06.397986: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)\n",
      "2024-05-22 07:04:06.398006: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (gl3035.arc-ts.umich.edu): /proc/driver/nvidia/version does not exist\n",
      "2024-05-22 07:04:06.398190: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 4s - loss: 3.2596 - val_loss: 3.2190 - 4s/epoch - 4s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.2195 - val_loss: 3.1927 - 89ms/epoch - 89ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.1926 - val_loss: 3.1738 - 84ms/epoch - 84ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.1731 - val_loss: 3.1591 - 82ms/epoch - 82ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.1580 - val_loss: 3.1468 - 97ms/epoch - 97ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.1449 - val_loss: 3.1362 - 94ms/epoch - 94ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.1334 - val_loss: 3.1273 - 84ms/epoch - 84ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.1236 - val_loss: 3.1202 - 82ms/epoch - 82ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.1155 - val_loss: 3.1145 - 81ms/epoch - 81ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 3.1089 - val_loss: 3.1101 - 81ms/epoch - 81ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 3.1036 - val_loss: 3.1062 - 82ms/epoch - 82ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 3.0988 - val_loss: 3.1026 - 81ms/epoch - 81ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 3.0943 - val_loss: 3.0987 - 81ms/epoch - 81ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 3.0897 - val_loss: 3.0945 - 81ms/epoch - 81ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 3.0850 - val_loss: 3.0903 - 82ms/epoch - 82ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 3.0804 - val_loss: 3.0862 - 82ms/epoch - 82ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 3.0759 - val_loss: 3.0824 - 83ms/epoch - 83ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 3.0719 - val_loss: 3.0789 - 82ms/epoch - 82ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 3.0682 - val_loss: 3.0757 - 83ms/epoch - 83ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 3.0648 - val_loss: 3.0727 - 83ms/epoch - 83ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 3.0617 - val_loss: 3.0698 - 81ms/epoch - 81ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 3.0587 - val_loss: 3.0670 - 83ms/epoch - 83ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 3.0559 - val_loss: 3.0644 - 83ms/epoch - 83ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 3.0531 - val_loss: 3.0618 - 84ms/epoch - 84ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 3.0504 - val_loss: 3.0592 - 83ms/epoch - 83ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 3.0478 - val_loss: 3.0568 - 82ms/epoch - 82ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 3.0453 - val_loss: 3.0545 - 94ms/epoch - 94ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 3.0430 - val_loss: 3.0522 - 84ms/epoch - 84ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 3.0406 - val_loss: 3.0499 - 80ms/epoch - 80ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 3.0383 - val_loss: 3.0476 - 82ms/epoch - 82ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 3.0361 - val_loss: 3.0454 - 81ms/epoch - 81ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 3.0339 - val_loss: 3.0432 - 83ms/epoch - 83ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 3.0317 - val_loss: 3.0410 - 82ms/epoch - 82ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 3.0296 - val_loss: 3.0389 - 82ms/epoch - 82ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 3.0276 - val_loss: 3.0369 - 83ms/epoch - 83ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 3.0256 - val_loss: 3.0349 - 82ms/epoch - 82ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 3.0237 - val_loss: 3.0329 - 88ms/epoch - 88ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 3.0217 - val_loss: 3.0309 - 81ms/epoch - 81ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 3.0198 - val_loss: 3.0290 - 81ms/epoch - 81ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 3.0179 - val_loss: 3.0270 - 82ms/epoch - 82ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 3.0160 - val_loss: 3.0251 - 82ms/epoch - 82ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 3.0141 - val_loss: 3.0231 - 82ms/epoch - 82ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 3.0123 - val_loss: 3.0212 - 81ms/epoch - 81ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 3.0104 - val_loss: 3.0193 - 84ms/epoch - 84ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 3.0086 - val_loss: 3.0173 - 84ms/epoch - 84ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 3.0067 - val_loss: 3.0153 - 82ms/epoch - 82ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 3.0048 - val_loss: 3.0131 - 81ms/epoch - 81ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 3.0028 - val_loss: 3.0109 - 82ms/epoch - 82ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 3.0007 - val_loss: 3.0086 - 83ms/epoch - 83ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.9985 - val_loss: 3.0062 - 82ms/epoch - 82ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.9961 - val_loss: 3.0035 - 94ms/epoch - 94ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.9936 - val_loss: 3.0007 - 94ms/epoch - 94ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.9909 - val_loss: 2.9977 - 84ms/epoch - 84ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.9880 - val_loss: 2.9945 - 83ms/epoch - 83ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.9849 - val_loss: 2.9910 - 82ms/epoch - 82ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.9815 - val_loss: 2.9872 - 81ms/epoch - 81ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.9778 - val_loss: 2.9831 - 81ms/epoch - 81ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.9739 - val_loss: 2.9787 - 95ms/epoch - 95ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.9696 - val_loss: 2.9740 - 83ms/epoch - 83ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.9651 - val_loss: 2.9690 - 82ms/epoch - 82ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.9604 - val_loss: 2.9639 - 84ms/epoch - 84ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.9554 - val_loss: 2.9585 - 79ms/epoch - 79ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.9504 - val_loss: 2.9531 - 81ms/epoch - 81ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.9454 - val_loss: 2.9476 - 86ms/epoch - 86ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.9404 - val_loss: 2.9423 - 97ms/epoch - 97ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.9356 - val_loss: 2.9373 - 85ms/epoch - 85ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.9310 - val_loss: 2.9324 - 81ms/epoch - 81ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.9264 - val_loss: 2.9272 - 88ms/epoch - 88ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.9214 - val_loss: 2.9216 - 102ms/epoch - 102ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.9159 - val_loss: 2.9157 - 84ms/epoch - 84ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.9098 - val_loss: 2.9094 - 84ms/epoch - 84ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.9034 - val_loss: 2.9032 - 83ms/epoch - 83ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.8969 - val_loss: 2.8970 - 81ms/epoch - 81ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.8906 - val_loss: 2.8909 - 82ms/epoch - 82ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.8844 - val_loss: 2.8849 - 80ms/epoch - 80ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.8784 - val_loss: 2.8789 - 80ms/epoch - 80ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.8722 - val_loss: 2.8728 - 81ms/epoch - 81ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.8660 - val_loss: 2.8666 - 94ms/epoch - 94ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.8595 - val_loss: 2.8603 - 80ms/epoch - 80ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.8529 - val_loss: 2.8540 - 82ms/epoch - 82ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.8461 - val_loss: 2.8476 - 81ms/epoch - 81ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.8392 - val_loss: 2.8413 - 79ms/epoch - 79ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.8320 - val_loss: 2.8350 - 82ms/epoch - 82ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.8250 - val_loss: 2.8293 - 93ms/epoch - 93ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.8184 - val_loss: 2.8242 - 82ms/epoch - 82ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.8125 - val_loss: 2.8195 - 82ms/epoch - 82ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.8073 - val_loss: 2.8147 - 83ms/epoch - 83ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.8021 - val_loss: 2.8095 - 82ms/epoch - 82ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.7968 - val_loss: 2.8043 - 84ms/epoch - 84ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.7914 - val_loss: 2.7990 - 82ms/epoch - 82ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.7862 - val_loss: 2.7937 - 83ms/epoch - 83ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.7811 - val_loss: 2.7884 - 82ms/epoch - 82ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.7761 - val_loss: 2.7849 - 80ms/epoch - 80ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.7728 - val_loss: 2.7845 - 80ms/epoch - 80ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.7717 - val_loss: 2.7741 - 82ms/epoch - 82ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.7623 - val_loss: 2.7692 - 81ms/epoch - 81ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 2.7572 - val_loss: 2.7671 - 83ms/epoch - 83ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 2.7538 - val_loss: 2.7568 - 82ms/epoch - 82ms/step\n",
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 2.7438 - val_loss: 2.7547 - 81ms/epoch - 81ms/step\n",
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 2.7416 - val_loss: 2.7477 - 80ms/epoch - 80ms/step\n",
      "Epoch 101/2000\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 - 0s - loss: 2.7334 - val_loss: 2.7417 - 81ms/epoch - 81ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 2.7270 - val_loss: 2.7382 - 82ms/epoch - 82ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 2.7234 - val_loss: 2.7299 - 82ms/epoch - 82ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 2.7143 - val_loss: 2.7253 - 79ms/epoch - 79ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 2.7090 - val_loss: 2.7208 - 82ms/epoch - 82ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 2.7046 - val_loss: 2.7129 - 83ms/epoch - 83ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 2.6950 - val_loss: 2.7055 - 84ms/epoch - 84ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 2.6871 - val_loss: 2.7003 - 81ms/epoch - 81ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 2.6823 - val_loss: 2.6959 - 81ms/epoch - 81ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 2.6757 - val_loss: 2.6860 - 80ms/epoch - 80ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 2.6670 - val_loss: 2.6787 - 83ms/epoch - 83ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 2.6590 - val_loss: 2.6742 - 83ms/epoch - 83ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 2.6534 - val_loss: 2.6673 - 81ms/epoch - 81ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 2.6475 - val_loss: 2.6605 - 82ms/epoch - 82ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 2.6392 - val_loss: 2.6523 - 81ms/epoch - 81ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 2.6313 - val_loss: 2.6462 - 82ms/epoch - 82ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 2.6252 - val_loss: 2.6419 - 87ms/epoch - 87ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 2.6198 - val_loss: 2.6343 - 79ms/epoch - 79ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 2.6131 - val_loss: 2.6279 - 81ms/epoch - 81ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 2.6056 - val_loss: 2.6208 - 93ms/epoch - 93ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 2.5988 - val_loss: 2.6146 - 94ms/epoch - 94ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 2.5931 - val_loss: 2.6103 - 80ms/epoch - 80ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 2.5880 - val_loss: 2.6031 - 82ms/epoch - 82ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 2.5821 - val_loss: 2.5974 - 83ms/epoch - 83ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 2.5754 - val_loss: 2.5909 - 80ms/epoch - 80ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 2.5691 - val_loss: 2.5855 - 84ms/epoch - 84ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 2.5640 - val_loss: 2.5812 - 82ms/epoch - 82ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 2.5588 - val_loss: 2.5744 - 83ms/epoch - 83ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 2.5528 - val_loss: 2.5688 - 82ms/epoch - 82ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 2.5468 - val_loss: 2.5640 - 80ms/epoch - 80ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 2.5417 - val_loss: 2.5584 - 81ms/epoch - 81ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 2.5367 - val_loss: 2.5534 - 80ms/epoch - 80ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 2.5311 - val_loss: 2.5483 - 81ms/epoch - 81ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 2.5259 - val_loss: 2.5435 - 83ms/epoch - 83ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 2.5211 - val_loss: 2.5388 - 83ms/epoch - 83ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 2.5161 - val_loss: 2.5339 - 81ms/epoch - 81ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 2.5111 - val_loss: 2.5292 - 78ms/epoch - 78ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 2.5065 - val_loss: 2.5248 - 80ms/epoch - 80ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 2.5017 - val_loss: 2.5202 - 80ms/epoch - 80ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 2.4970 - val_loss: 2.5157 - 80ms/epoch - 80ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 2.4925 - val_loss: 2.5115 - 81ms/epoch - 81ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 2.4880 - val_loss: 2.5067 - 82ms/epoch - 82ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 2.4832 - val_loss: 2.5022 - 83ms/epoch - 83ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 2.4787 - val_loss: 2.4981 - 81ms/epoch - 81ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 2.4743 - val_loss: 2.4935 - 82ms/epoch - 82ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 2.4699 - val_loss: 2.4890 - 80ms/epoch - 80ms/step\n",
      "Epoch 147/2000\n",
      "1/1 - 0s - loss: 2.4654 - val_loss: 2.4845 - 80ms/epoch - 80ms/step\n",
      "Epoch 148/2000\n",
      "1/1 - 0s - loss: 2.4610 - val_loss: 2.4798 - 80ms/epoch - 80ms/step\n",
      "Epoch 149/2000\n",
      "1/1 - 0s - loss: 2.4566 - val_loss: 2.4753 - 79ms/epoch - 79ms/step\n",
      "Epoch 150/2000\n",
      "1/1 - 0s - loss: 2.4523 - val_loss: 2.4707 - 79ms/epoch - 79ms/step\n",
      "Epoch 151/2000\n",
      "1/1 - 0s - loss: 2.4480 - val_loss: 2.4663 - 79ms/epoch - 79ms/step\n",
      "Epoch 152/2000\n",
      "1/1 - 0s - loss: 2.4438 - val_loss: 2.4622 - 81ms/epoch - 81ms/step\n",
      "Epoch 153/2000\n",
      "1/1 - 0s - loss: 2.4397 - val_loss: 2.4580 - 82ms/epoch - 82ms/step\n",
      "Epoch 154/2000\n",
      "1/1 - 0s - loss: 2.4356 - val_loss: 2.4540 - 83ms/epoch - 83ms/step\n",
      "Epoch 155/2000\n",
      "1/1 - 0s - loss: 2.4316 - val_loss: 2.4500 - 80ms/epoch - 80ms/step\n",
      "Epoch 156/2000\n",
      "1/1 - 0s - loss: 2.4276 - val_loss: 2.4460 - 81ms/epoch - 81ms/step\n",
      "Epoch 157/2000\n",
      "1/1 - 0s - loss: 2.4237 - val_loss: 2.4422 - 80ms/epoch - 80ms/step\n",
      "Epoch 158/2000\n",
      "1/1 - 0s - loss: 2.4198 - val_loss: 2.4383 - 92ms/epoch - 92ms/step\n",
      "Epoch 159/2000\n",
      "1/1 - 0s - loss: 2.4161 - val_loss: 2.4346 - 82ms/epoch - 82ms/step\n",
      "Epoch 160/2000\n",
      "1/1 - 0s - loss: 2.4124 - val_loss: 2.4306 - 81ms/epoch - 81ms/step\n",
      "Epoch 161/2000\n",
      "1/1 - 0s - loss: 2.4087 - val_loss: 2.4269 - 81ms/epoch - 81ms/step\n",
      "Epoch 162/2000\n",
      "1/1 - 0s - loss: 2.4049 - val_loss: 2.4230 - 79ms/epoch - 79ms/step\n",
      "Epoch 163/2000\n",
      "1/1 - 0s - loss: 2.4013 - val_loss: 2.4195 - 80ms/epoch - 80ms/step\n",
      "Epoch 164/2000\n",
      "1/1 - 0s - loss: 2.3977 - val_loss: 2.4155 - 81ms/epoch - 81ms/step\n",
      "Epoch 165/2000\n",
      "1/1 - 0s - loss: 2.3940 - val_loss: 2.4121 - 81ms/epoch - 81ms/step\n",
      "Epoch 166/2000\n",
      "1/1 - 0s - loss: 2.3905 - val_loss: 2.4083 - 81ms/epoch - 81ms/step\n",
      "Epoch 167/2000\n",
      "1/1 - 0s - loss: 2.3871 - val_loss: 2.4058 - 80ms/epoch - 80ms/step\n",
      "Epoch 168/2000\n",
      "1/1 - 0s - loss: 2.3841 - val_loss: 2.4029 - 80ms/epoch - 80ms/step\n",
      "Epoch 169/2000\n",
      "1/1 - 0s - loss: 2.3819 - val_loss: 2.4026 - 93ms/epoch - 93ms/step\n",
      "Epoch 170/2000\n",
      "1/1 - 0s - loss: 2.3805 - val_loss: 2.3979 - 80ms/epoch - 80ms/step\n",
      "Epoch 171/2000\n",
      "1/1 - 0s - loss: 2.3769 - val_loss: 2.3922 - 82ms/epoch - 82ms/step\n",
      "Epoch 172/2000\n",
      "1/1 - 0s - loss: 2.3706 - val_loss: 2.3874 - 80ms/epoch - 80ms/step\n",
      "Epoch 173/2000\n",
      "1/1 - 0s - loss: 2.3659 - val_loss: 2.3857 - 82ms/epoch - 82ms/step\n",
      "Epoch 174/2000\n",
      "1/1 - 0s - loss: 2.3646 - val_loss: 2.3839 - 80ms/epoch - 80ms/step\n",
      "Epoch 175/2000\n",
      "1/1 - 0s - loss: 2.3619 - val_loss: 2.3781 - 80ms/epoch - 80ms/step\n",
      "Epoch 176/2000\n",
      "1/1 - 0s - loss: 2.3568 - val_loss: 2.3740 - 85ms/epoch - 85ms/step\n",
      "Epoch 177/2000\n",
      "1/1 - 0s - loss: 2.3527 - val_loss: 2.3728 - 82ms/epoch - 82ms/step\n",
      "Epoch 178/2000\n",
      "1/1 - 0s - loss: 2.3511 - val_loss: 2.3693 - 82ms/epoch - 82ms/step\n",
      "Epoch 179/2000\n",
      "1/1 - 0s - loss: 2.3481 - val_loss: 2.3646 - 81ms/epoch - 81ms/step\n",
      "Epoch 180/2000\n",
      "1/1 - 0s - loss: 2.3432 - val_loss: 2.3612 - 81ms/epoch - 81ms/step\n",
      "Epoch 181/2000\n",
      "1/1 - 0s - loss: 2.3400 - val_loss: 2.3588 - 80ms/epoch - 80ms/step\n",
      "Epoch 182/2000\n",
      "1/1 - 0s - loss: 2.3380 - val_loss: 2.3556 - 83ms/epoch - 83ms/step\n",
      "Epoch 183/2000\n",
      "1/1 - 0s - loss: 2.3344 - val_loss: 2.3512 - 81ms/epoch - 81ms/step\n",
      "Epoch 184/2000\n",
      "1/1 - 0s - loss: 2.3305 - val_loss: 2.3483 - 82ms/epoch - 82ms/step\n",
      "Epoch 185/2000\n",
      "1/1 - 0s - loss: 2.3276 - val_loss: 2.3461 - 90ms/epoch - 90ms/step\n",
      "Epoch 186/2000\n",
      "1/1 - 0s - loss: 2.3251 - val_loss: 2.3425 - 79ms/epoch - 79ms/step\n",
      "Epoch 187/2000\n",
      "1/1 - 0s - loss: 2.3220 - val_loss: 2.3391 - 79ms/epoch - 79ms/step\n",
      "Epoch 188/2000\n",
      "1/1 - 0s - loss: 2.3183 - val_loss: 2.3361 - 79ms/epoch - 79ms/step\n",
      "Epoch 189/2000\n",
      "1/1 - 0s - loss: 2.3154 - val_loss: 2.3335 - 81ms/epoch - 81ms/step\n",
      "Epoch 190/2000\n",
      "1/1 - 0s - loss: 2.3132 - val_loss: 2.3312 - 84ms/epoch - 84ms/step\n",
      "Epoch 191/2000\n",
      "1/1 - 0s - loss: 2.3104 - val_loss: 2.3273 - 83ms/epoch - 83ms/step\n",
      "Epoch 192/2000\n",
      "1/1 - 0s - loss: 2.3071 - val_loss: 2.3242 - 83ms/epoch - 83ms/step\n",
      "Epoch 193/2000\n",
      "1/1 - 0s - loss: 2.3039 - val_loss: 2.3214 - 98ms/epoch - 98ms/step\n",
      "Epoch 194/2000\n",
      "1/1 - 0s - loss: 2.3011 - val_loss: 2.3187 - 99ms/epoch - 99ms/step\n",
      "Epoch 195/2000\n",
      "1/1 - 0s - loss: 2.2987 - val_loss: 2.3167 - 95ms/epoch - 95ms/step\n",
      "Epoch 196/2000\n",
      "1/1 - 0s - loss: 2.2964 - val_loss: 2.3137 - 95ms/epoch - 95ms/step\n",
      "Epoch 197/2000\n",
      "1/1 - 0s - loss: 2.2937 - val_loss: 2.3109 - 90ms/epoch - 90ms/step\n",
      "Epoch 198/2000\n",
      "1/1 - 0s - loss: 2.2907 - val_loss: 2.3077 - 86ms/epoch - 86ms/step\n",
      "Epoch 199/2000\n",
      "1/1 - 0s - loss: 2.2877 - val_loss: 2.3051 - 90ms/epoch - 90ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200/2000\n",
      "1/1 - 0s - loss: 2.2850 - val_loss: 2.3026 - 84ms/epoch - 84ms/step\n",
      "Epoch 201/2000\n",
      "1/1 - 0s - loss: 2.2825 - val_loss: 2.3002 - 83ms/epoch - 83ms/step\n",
      "Epoch 202/2000\n",
      "1/1 - 0s - loss: 2.2802 - val_loss: 2.2985 - 85ms/epoch - 85ms/step\n",
      "Epoch 203/2000\n",
      "1/1 - 0s - loss: 2.2782 - val_loss: 2.2961 - 82ms/epoch - 82ms/step\n",
      "Epoch 204/2000\n",
      "1/1 - 0s - loss: 2.2762 - val_loss: 2.2951 - 81ms/epoch - 81ms/step\n",
      "Epoch 205/2000\n",
      "1/1 - 0s - loss: 2.2745 - val_loss: 2.2935 - 79ms/epoch - 79ms/step\n",
      "Epoch 206/2000\n",
      "1/1 - 0s - loss: 2.2735 - val_loss: 2.2928 - 80ms/epoch - 80ms/step\n",
      "Epoch 207/2000\n",
      "1/1 - 0s - loss: 2.2720 - val_loss: 2.2902 - 91ms/epoch - 91ms/step\n",
      "Epoch 208/2000\n",
      "1/1 - 0s - loss: 2.2700 - val_loss: 2.2861 - 87ms/epoch - 87ms/step\n",
      "Epoch 209/2000\n",
      "1/1 - 0s - loss: 2.2653 - val_loss: 2.2819 - 82ms/epoch - 82ms/step\n",
      "Epoch 210/2000\n",
      "1/1 - 0s - loss: 2.2614 - val_loss: 2.2804 - 83ms/epoch - 83ms/step\n",
      "Epoch 211/2000\n",
      "1/1 - 0s - loss: 2.2598 - val_loss: 2.2800 - 84ms/epoch - 84ms/step\n",
      "Epoch 212/2000\n",
      "1/1 - 0s - loss: 2.2591 - val_loss: 2.2775 - 96ms/epoch - 96ms/step\n",
      "Epoch 213/2000\n",
      "1/1 - 0s - loss: 2.2570 - val_loss: 2.2743 - 86ms/epoch - 86ms/step\n",
      "Epoch 214/2000\n",
      "1/1 - 0s - loss: 2.2534 - val_loss: 2.2714 - 84ms/epoch - 84ms/step\n",
      "Epoch 215/2000\n",
      "1/1 - 0s - loss: 2.2506 - val_loss: 2.2701 - 81ms/epoch - 81ms/step\n",
      "Epoch 216/2000\n",
      "1/1 - 0s - loss: 2.2495 - val_loss: 2.2694 - 82ms/epoch - 82ms/step\n",
      "Epoch 217/2000\n",
      "1/1 - 0s - loss: 2.2482 - val_loss: 2.2668 - 94ms/epoch - 94ms/step\n",
      "Epoch 218/2000\n",
      "1/1 - 0s - loss: 2.2459 - val_loss: 2.2640 - 81ms/epoch - 81ms/step\n",
      "Epoch 219/2000\n",
      "1/1 - 0s - loss: 2.2428 - val_loss: 2.2617 - 81ms/epoch - 81ms/step\n",
      "Epoch 220/2000\n",
      "1/1 - 0s - loss: 2.2406 - val_loss: 2.2604 - 82ms/epoch - 82ms/step\n",
      "Epoch 221/2000\n",
      "1/1 - 0s - loss: 2.2393 - val_loss: 2.2592 - 82ms/epoch - 82ms/step\n",
      "Epoch 222/2000\n",
      "1/1 - 0s - loss: 2.2377 - val_loss: 2.2566 - 84ms/epoch - 84ms/step\n",
      "Epoch 223/2000\n",
      "1/1 - 0s - loss: 2.2352 - val_loss: 2.2542 - 82ms/epoch - 82ms/step\n",
      "Epoch 224/2000\n",
      "1/1 - 0s - loss: 2.2327 - val_loss: 2.2527 - 81ms/epoch - 81ms/step\n",
      "Epoch 225/2000\n",
      "1/1 - 0s - loss: 2.2311 - val_loss: 2.2512 - 81ms/epoch - 81ms/step\n",
      "Epoch 226/2000\n",
      "1/1 - 0s - loss: 2.2297 - val_loss: 2.2498 - 90ms/epoch - 90ms/step\n",
      "Epoch 227/2000\n",
      "1/1 - 0s - loss: 2.2280 - val_loss: 2.2473 - 82ms/epoch - 82ms/step\n",
      "Epoch 228/2000\n",
      "1/1 - 0s - loss: 2.2257 - val_loss: 2.2453 - 82ms/epoch - 82ms/step\n",
      "Epoch 229/2000\n",
      "1/1 - 0s - loss: 2.2235 - val_loss: 2.2435 - 82ms/epoch - 82ms/step\n",
      "Epoch 230/2000\n",
      "1/1 - 0s - loss: 2.2217 - val_loss: 2.2419 - 80ms/epoch - 80ms/step\n",
      "Epoch 231/2000\n",
      "1/1 - 0s - loss: 2.2203 - val_loss: 2.2406 - 80ms/epoch - 80ms/step\n",
      "Epoch 232/2000\n",
      "1/1 - 0s - loss: 2.2187 - val_loss: 2.2385 - 80ms/epoch - 80ms/step\n",
      "Epoch 233/2000\n",
      "1/1 - 0s - loss: 2.2168 - val_loss: 2.2367 - 88ms/epoch - 88ms/step\n",
      "Epoch 234/2000\n",
      "1/1 - 0s - loss: 2.2148 - val_loss: 2.2348 - 90ms/epoch - 90ms/step\n",
      "Epoch 235/2000\n",
      "1/1 - 0s - loss: 2.2129 - val_loss: 2.2332 - 94ms/epoch - 94ms/step\n",
      "Epoch 236/2000\n",
      "1/1 - 0s - loss: 2.2112 - val_loss: 2.2319 - 90ms/epoch - 90ms/step\n",
      "Epoch 237/2000\n",
      "1/1 - 0s - loss: 2.2097 - val_loss: 2.2303 - 88ms/epoch - 88ms/step\n",
      "Epoch 238/2000\n",
      "1/1 - 0s - loss: 2.2082 - val_loss: 2.2290 - 93ms/epoch - 93ms/step\n",
      "Epoch 239/2000\n",
      "1/1 - 0s - loss: 2.2067 - val_loss: 2.2272 - 92ms/epoch - 92ms/step\n",
      "Epoch 240/2000\n",
      "1/1 - 0s - loss: 2.2051 - val_loss: 2.2258 - 92ms/epoch - 92ms/step\n",
      "Epoch 241/2000\n",
      "1/1 - 0s - loss: 2.2034 - val_loss: 2.2239 - 129ms/epoch - 129ms/step\n",
      "Epoch 242/2000\n",
      "1/1 - 0s - loss: 2.2016 - val_loss: 2.2224 - 89ms/epoch - 89ms/step\n",
      "Epoch 243/2000\n",
      "1/1 - 0s - loss: 2.1999 - val_loss: 2.2206 - 94ms/epoch - 94ms/step\n",
      "Epoch 244/2000\n",
      "1/1 - 0s - loss: 2.1981 - val_loss: 2.2191 - 102ms/epoch - 102ms/step\n",
      "Epoch 245/2000\n",
      "1/1 - 0s - loss: 2.1965 - val_loss: 2.2176 - 88ms/epoch - 88ms/step\n",
      "Epoch 246/2000\n",
      "1/1 - 0s - loss: 2.1949 - val_loss: 2.2161 - 84ms/epoch - 84ms/step\n",
      "Epoch 247/2000\n",
      "1/1 - 0s - loss: 2.1933 - val_loss: 2.2146 - 89ms/epoch - 89ms/step\n",
      "Epoch 248/2000\n",
      "1/1 - 0s - loss: 2.1918 - val_loss: 2.2131 - 83ms/epoch - 83ms/step\n",
      "Epoch 249/2000\n",
      "1/1 - 0s - loss: 2.1902 - val_loss: 2.2117 - 82ms/epoch - 82ms/step\n",
      "Epoch 250/2000\n",
      "1/1 - 0s - loss: 2.1887 - val_loss: 2.2103 - 86ms/epoch - 86ms/step\n",
      "Epoch 251/2000\n",
      "1/1 - 0s - loss: 2.1873 - val_loss: 2.2092 - 80ms/epoch - 80ms/step\n",
      "Epoch 252/2000\n",
      "1/1 - 0s - loss: 2.1859 - val_loss: 2.2080 - 87ms/epoch - 87ms/step\n",
      "Epoch 253/2000\n",
      "1/1 - 0s - loss: 2.1848 - val_loss: 2.2080 - 90ms/epoch - 90ms/step\n",
      "Epoch 254/2000\n",
      "1/1 - 0s - loss: 2.1844 - val_loss: 2.2090 - 83ms/epoch - 83ms/step\n",
      "Epoch 255/2000\n",
      "1/1 - 0s - loss: 2.1858 - val_loss: 2.2147 - 82ms/epoch - 82ms/step\n",
      "Epoch 256/2000\n",
      "1/1 - 0s - loss: 2.1904 - val_loss: 2.2190 - 88ms/epoch - 88ms/step\n",
      "Epoch 257/2000\n",
      "1/1 - 0s - loss: 2.1956 - val_loss: 2.2145 - 80ms/epoch - 80ms/step\n",
      "Epoch 258/2000\n",
      "1/1 - 0s - loss: 2.1900 - val_loss: 2.2011 - 88ms/epoch - 88ms/step\n",
      "Epoch 259/2000\n",
      "1/1 - 0s - loss: 2.1771 - val_loss: 2.2074 - 78ms/epoch - 78ms/step\n",
      "Epoch 260/2000\n",
      "1/1 - 0s - loss: 2.1833 - val_loss: 2.2026 - 81ms/epoch - 81ms/step\n",
      "Epoch 261/2000\n",
      "1/1 - 0s - loss: 2.1781 - val_loss: 2.1996 - 88ms/epoch - 88ms/step\n",
      "Epoch 262/2000\n",
      "1/1 - 0s - loss: 2.1748 - val_loss: 2.2014 - 90ms/epoch - 90ms/step\n",
      "Epoch 263/2000\n",
      "1/1 - 0s - loss: 2.1769 - val_loss: 2.1947 - 98ms/epoch - 98ms/step\n",
      "Epoch 264/2000\n",
      "1/1 - 0s - loss: 2.1700 - val_loss: 2.1983 - 79ms/epoch - 79ms/step\n",
      "Epoch 265/2000\n",
      "1/1 - 0s - loss: 2.1728 - val_loss: 2.1931 - 79ms/epoch - 79ms/step\n",
      "Epoch 266/2000\n",
      "1/1 - 0s - loss: 2.1681 - val_loss: 2.1927 - 79ms/epoch - 79ms/step\n",
      "Epoch 267/2000\n",
      "1/1 - 0s - loss: 2.1679 - val_loss: 2.1920 - 80ms/epoch - 80ms/step\n",
      "Epoch 268/2000\n",
      "1/1 - 0s - loss: 2.1668 - val_loss: 2.1887 - 81ms/epoch - 81ms/step\n",
      "Epoch 269/2000\n",
      "1/1 - 0s - loss: 2.1636 - val_loss: 2.1889 - 74ms/epoch - 74ms/step\n",
      "Epoch 270/2000\n",
      "1/1 - 0s - loss: 2.1640 - val_loss: 2.1857 - 82ms/epoch - 82ms/step\n",
      "Epoch 271/2000\n",
      "1/1 - 0s - loss: 2.1607 - val_loss: 2.1863 - 74ms/epoch - 74ms/step\n",
      "Epoch 272/2000\n",
      "1/1 - 0s - loss: 2.1609 - val_loss: 2.1837 - 80ms/epoch - 80ms/step\n",
      "Epoch 273/2000\n",
      "1/1 - 0s - loss: 2.1583 - val_loss: 2.1830 - 83ms/epoch - 83ms/step\n",
      "Epoch 274/2000\n",
      "1/1 - 0s - loss: 2.1578 - val_loss: 2.1815 - 100ms/epoch - 100ms/step\n",
      "Epoch 275/2000\n",
      "1/1 - 0s - loss: 2.1559 - val_loss: 2.1807 - 83ms/epoch - 83ms/step\n",
      "Epoch 276/2000\n",
      "1/1 - 0s - loss: 2.1549 - val_loss: 2.1790 - 79ms/epoch - 79ms/step\n",
      "Epoch 277/2000\n",
      "1/1 - 0s - loss: 2.1535 - val_loss: 2.1776 - 81ms/epoch - 81ms/step\n",
      "Epoch 278/2000\n",
      "1/1 - 0s - loss: 2.1522 - val_loss: 2.1767 - 84ms/epoch - 84ms/step\n",
      "Epoch 279/2000\n",
      "1/1 - 0s - loss: 2.1509 - val_loss: 2.1755 - 80ms/epoch - 80ms/step\n",
      "Epoch 280/2000\n",
      "1/1 - 0s - loss: 2.1497 - val_loss: 2.1740 - 80ms/epoch - 80ms/step\n",
      "Epoch 281/2000\n",
      "1/1 - 0s - loss: 2.1484 - val_loss: 2.1730 - 79ms/epoch - 79ms/step\n",
      "Epoch 282/2000\n",
      "1/1 - 0s - loss: 2.1473 - val_loss: 2.1718 - 81ms/epoch - 81ms/step\n",
      "Epoch 283/2000\n",
      "1/1 - 0s - loss: 2.1459 - val_loss: 2.1708 - 81ms/epoch - 81ms/step\n",
      "Epoch 284/2000\n",
      "1/1 - 0s - loss: 2.1449 - val_loss: 2.1694 - 84ms/epoch - 84ms/step\n",
      "Epoch 285/2000\n",
      "1/1 - 0s - loss: 2.1435 - val_loss: 2.1685 - 79ms/epoch - 79ms/step\n",
      "Epoch 286/2000\n",
      "1/1 - 0s - loss: 2.1425 - val_loss: 2.1674 - 83ms/epoch - 83ms/step\n",
      "Epoch 287/2000\n",
      "1/1 - 0s - loss: 2.1413 - val_loss: 2.1662 - 82ms/epoch - 82ms/step\n",
      "Epoch 288/2000\n",
      "1/1 - 0s - loss: 2.1400 - val_loss: 2.1651 - 84ms/epoch - 84ms/step\n",
      "Epoch 289/2000\n",
      "1/1 - 0s - loss: 2.1390 - val_loss: 2.1639 - 81ms/epoch - 81ms/step\n",
      "Epoch 290/2000\n",
      "1/1 - 0s - loss: 2.1377 - val_loss: 2.1630 - 81ms/epoch - 81ms/step\n",
      "Epoch 291/2000\n",
      "1/1 - 0s - loss: 2.1368 - val_loss: 2.1616 - 89ms/epoch - 89ms/step\n",
      "Epoch 292/2000\n",
      "1/1 - 0s - loss: 2.1355 - val_loss: 2.1608 - 82ms/epoch - 82ms/step\n",
      "Epoch 293/2000\n",
      "1/1 - 0s - loss: 2.1346 - val_loss: 2.1596 - 81ms/epoch - 81ms/step\n",
      "Epoch 294/2000\n",
      "1/1 - 0s - loss: 2.1333 - val_loss: 2.1587 - 83ms/epoch - 83ms/step\n",
      "Epoch 295/2000\n",
      "1/1 - 0s - loss: 2.1324 - val_loss: 2.1576 - 82ms/epoch - 82ms/step\n",
      "Epoch 296/2000\n",
      "1/1 - 0s - loss: 2.1312 - val_loss: 2.1566 - 84ms/epoch - 84ms/step\n",
      "Epoch 297/2000\n",
      "1/1 - 0s - loss: 2.1302 - val_loss: 2.1556 - 82ms/epoch - 82ms/step\n",
      "Epoch 298/2000\n",
      "1/1 - 0s - loss: 2.1292 - val_loss: 2.1546 - 82ms/epoch - 82ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 299/2000\n",
      "1/1 - 0s - loss: 2.1281 - val_loss: 2.1536 - 82ms/epoch - 82ms/step\n",
      "Epoch 300/2000\n",
      "1/1 - 0s - loss: 2.1271 - val_loss: 2.1525 - 91ms/epoch - 91ms/step\n",
      "Epoch 301/2000\n",
      "1/1 - 0s - loss: 2.1260 - val_loss: 2.1517 - 80ms/epoch - 80ms/step\n",
      "Epoch 302/2000\n",
      "1/1 - 0s - loss: 2.1250 - val_loss: 2.1507 - 81ms/epoch - 81ms/step\n",
      "Epoch 303/2000\n",
      "1/1 - 0s - loss: 2.1240 - val_loss: 2.1497 - 83ms/epoch - 83ms/step\n",
      "Epoch 304/2000\n",
      "1/1 - 0s - loss: 2.1230 - val_loss: 2.1488 - 80ms/epoch - 80ms/step\n",
      "Epoch 305/2000\n",
      "1/1 - 0s - loss: 2.1220 - val_loss: 2.1479 - 82ms/epoch - 82ms/step\n",
      "Epoch 306/2000\n",
      "1/1 - 0s - loss: 2.1210 - val_loss: 2.1469 - 82ms/epoch - 82ms/step\n",
      "Epoch 307/2000\n",
      "1/1 - 0s - loss: 2.1200 - val_loss: 2.1460 - 79ms/epoch - 79ms/step\n",
      "Epoch 308/2000\n",
      "1/1 - 0s - loss: 2.1191 - val_loss: 2.1451 - 81ms/epoch - 81ms/step\n",
      "Epoch 309/2000\n",
      "1/1 - 0s - loss: 2.1181 - val_loss: 2.1442 - 81ms/epoch - 81ms/step\n",
      "Epoch 310/2000\n",
      "1/1 - 0s - loss: 2.1171 - val_loss: 2.1432 - 81ms/epoch - 81ms/step\n",
      "Epoch 311/2000\n",
      "1/1 - 0s - loss: 2.1162 - val_loss: 2.1423 - 91ms/epoch - 91ms/step\n",
      "Epoch 312/2000\n",
      "1/1 - 0s - loss: 2.1152 - val_loss: 2.1415 - 103ms/epoch - 103ms/step\n",
      "Epoch 313/2000\n",
      "1/1 - 0s - loss: 2.1143 - val_loss: 2.1405 - 86ms/epoch - 86ms/step\n",
      "Epoch 314/2000\n",
      "1/1 - 0s - loss: 2.1133 - val_loss: 2.1396 - 91ms/epoch - 91ms/step\n",
      "Epoch 315/2000\n",
      "1/1 - 0s - loss: 2.1124 - val_loss: 2.1388 - 88ms/epoch - 88ms/step\n",
      "Epoch 316/2000\n",
      "1/1 - 0s - loss: 2.1115 - val_loss: 2.1380 - 86ms/epoch - 86ms/step\n",
      "Epoch 317/2000\n",
      "1/1 - 0s - loss: 2.1106 - val_loss: 2.1370 - 116ms/epoch - 116ms/step\n",
      "Epoch 318/2000\n",
      "1/1 - 0s - loss: 2.1097 - val_loss: 2.1362 - 91ms/epoch - 91ms/step\n",
      "Epoch 319/2000\n",
      "1/1 - 0s - loss: 2.1087 - val_loss: 2.1354 - 96ms/epoch - 96ms/step\n",
      "Epoch 320/2000\n",
      "1/1 - 0s - loss: 2.1079 - val_loss: 2.1345 - 99ms/epoch - 99ms/step\n",
      "Epoch 321/2000\n",
      "1/1 - 0s - loss: 2.1070 - val_loss: 2.1337 - 88ms/epoch - 88ms/step\n",
      "Epoch 322/2000\n",
      "1/1 - 0s - loss: 2.1061 - val_loss: 2.1329 - 87ms/epoch - 87ms/step\n",
      "Epoch 323/2000\n",
      "1/1 - 0s - loss: 2.1052 - val_loss: 2.1321 - 88ms/epoch - 88ms/step\n",
      "Epoch 324/2000\n",
      "1/1 - 0s - loss: 2.1043 - val_loss: 2.1313 - 89ms/epoch - 89ms/step\n",
      "Epoch 325/2000\n",
      "1/1 - 0s - loss: 2.1035 - val_loss: 2.1305 - 92ms/epoch - 92ms/step\n",
      "Epoch 326/2000\n",
      "1/1 - 0s - loss: 2.1026 - val_loss: 2.1297 - 86ms/epoch - 86ms/step\n",
      "Epoch 327/2000\n",
      "1/1 - 0s - loss: 2.1018 - val_loss: 2.1288 - 90ms/epoch - 90ms/step\n",
      "Epoch 328/2000\n",
      "1/1 - 0s - loss: 2.1009 - val_loss: 2.1280 - 91ms/epoch - 91ms/step\n",
      "Epoch 329/2000\n",
      "1/1 - 0s - loss: 2.1001 - val_loss: 2.1273 - 87ms/epoch - 87ms/step\n",
      "Epoch 330/2000\n",
      "1/1 - 0s - loss: 2.0992 - val_loss: 2.1265 - 102ms/epoch - 102ms/step\n",
      "Epoch 331/2000\n",
      "1/1 - 0s - loss: 2.0984 - val_loss: 2.1257 - 134ms/epoch - 134ms/step\n",
      "Epoch 332/2000\n",
      "1/1 - 0s - loss: 2.0976 - val_loss: 2.1250 - 103ms/epoch - 103ms/step\n",
      "Epoch 333/2000\n",
      "1/1 - 0s - loss: 2.0968 - val_loss: 2.1243 - 88ms/epoch - 88ms/step\n",
      "Epoch 334/2000\n",
      "1/1 - 0s - loss: 2.0959 - val_loss: 2.1235 - 95ms/epoch - 95ms/step\n",
      "Epoch 335/2000\n",
      "1/1 - 0s - loss: 2.0951 - val_loss: 2.1227 - 97ms/epoch - 97ms/step\n",
      "Epoch 336/2000\n",
      "1/1 - 0s - loss: 2.0943 - val_loss: 2.1220 - 90ms/epoch - 90ms/step\n",
      "Epoch 337/2000\n",
      "1/1 - 0s - loss: 2.0935 - val_loss: 2.1213 - 88ms/epoch - 88ms/step\n",
      "Epoch 338/2000\n",
      "1/1 - 0s - loss: 2.0927 - val_loss: 2.1205 - 92ms/epoch - 92ms/step\n",
      "Epoch 339/2000\n",
      "1/1 - 0s - loss: 2.0920 - val_loss: 2.1199 - 86ms/epoch - 86ms/step\n",
      "Epoch 340/2000\n",
      "1/1 - 0s - loss: 2.0912 - val_loss: 2.1191 - 95ms/epoch - 95ms/step\n",
      "Epoch 341/2000\n",
      "1/1 - 0s - loss: 2.0904 - val_loss: 2.1184 - 91ms/epoch - 91ms/step\n",
      "Epoch 342/2000\n",
      "1/1 - 0s - loss: 2.0896 - val_loss: 2.1177 - 88ms/epoch - 88ms/step\n",
      "Epoch 343/2000\n",
      "1/1 - 0s - loss: 2.0889 - val_loss: 2.1170 - 92ms/epoch - 92ms/step\n",
      "Epoch 344/2000\n",
      "1/1 - 0s - loss: 2.0881 - val_loss: 2.1163 - 80ms/epoch - 80ms/step\n",
      "Epoch 345/2000\n",
      "1/1 - 0s - loss: 2.0874 - val_loss: 2.1156 - 81ms/epoch - 81ms/step\n",
      "Epoch 346/2000\n",
      "1/1 - 0s - loss: 2.0866 - val_loss: 2.1149 - 81ms/epoch - 81ms/step\n",
      "Epoch 347/2000\n",
      "1/1 - 0s - loss: 2.0859 - val_loss: 2.1143 - 79ms/epoch - 79ms/step\n",
      "Epoch 348/2000\n",
      "1/1 - 0s - loss: 2.0851 - val_loss: 2.1136 - 81ms/epoch - 81ms/step\n",
      "Epoch 349/2000\n",
      "1/1 - 0s - loss: 2.0844 - val_loss: 2.1130 - 78ms/epoch - 78ms/step\n",
      "Epoch 350/2000\n",
      "1/1 - 0s - loss: 2.0837 - val_loss: 2.1124 - 80ms/epoch - 80ms/step\n",
      "Epoch 351/2000\n",
      "1/1 - 0s - loss: 2.0830 - val_loss: 2.1117 - 79ms/epoch - 79ms/step\n",
      "Epoch 352/2000\n",
      "1/1 - 0s - loss: 2.0823 - val_loss: 2.1111 - 79ms/epoch - 79ms/step\n",
      "Epoch 353/2000\n",
      "1/1 - 0s - loss: 2.0815 - val_loss: 2.1105 - 95ms/epoch - 95ms/step\n",
      "Epoch 354/2000\n",
      "1/1 - 0s - loss: 2.0808 - val_loss: 2.1098 - 81ms/epoch - 81ms/step\n",
      "Epoch 355/2000\n",
      "1/1 - 0s - loss: 2.0801 - val_loss: 2.1092 - 80ms/epoch - 80ms/step\n",
      "Epoch 356/2000\n",
      "1/1 - 0s - loss: 2.0794 - val_loss: 2.1086 - 85ms/epoch - 85ms/step\n",
      "Epoch 357/2000\n",
      "1/1 - 0s - loss: 2.0787 - val_loss: 2.1079 - 81ms/epoch - 81ms/step\n",
      "Epoch 358/2000\n",
      "1/1 - 0s - loss: 2.0781 - val_loss: 2.1073 - 81ms/epoch - 81ms/step\n",
      "Epoch 359/2000\n",
      "1/1 - 0s - loss: 2.0774 - val_loss: 2.1067 - 80ms/epoch - 80ms/step\n",
      "Epoch 360/2000\n",
      "1/1 - 0s - loss: 2.0767 - val_loss: 2.1061 - 79ms/epoch - 79ms/step\n",
      "Epoch 361/2000\n",
      "1/1 - 0s - loss: 2.0760 - val_loss: 2.1055 - 86ms/epoch - 86ms/step\n",
      "Epoch 362/2000\n",
      "1/1 - 0s - loss: 2.0754 - val_loss: 2.1050 - 80ms/epoch - 80ms/step\n",
      "Epoch 363/2000\n",
      "1/1 - 0s - loss: 2.0747 - val_loss: 2.1043 - 79ms/epoch - 79ms/step\n",
      "Epoch 364/2000\n",
      "1/1 - 0s - loss: 2.0740 - val_loss: 2.1038 - 80ms/epoch - 80ms/step\n",
      "Epoch 365/2000\n",
      "1/1 - 0s - loss: 2.0734 - val_loss: 2.1032 - 80ms/epoch - 80ms/step\n",
      "Epoch 366/2000\n",
      "1/1 - 0s - loss: 2.0727 - val_loss: 2.1026 - 79ms/epoch - 79ms/step\n",
      "Epoch 367/2000\n",
      "1/1 - 0s - loss: 2.0721 - val_loss: 2.1021 - 81ms/epoch - 81ms/step\n",
      "Epoch 368/2000\n",
      "1/1 - 0s - loss: 2.0714 - val_loss: 2.1015 - 79ms/epoch - 79ms/step\n",
      "Epoch 369/2000\n",
      "1/1 - 0s - loss: 2.0708 - val_loss: 2.1010 - 80ms/epoch - 80ms/step\n",
      "Epoch 370/2000\n",
      "1/1 - 0s - loss: 2.0702 - val_loss: 2.1004 - 80ms/epoch - 80ms/step\n",
      "Epoch 371/2000\n",
      "1/1 - 0s - loss: 2.0696 - val_loss: 2.0999 - 81ms/epoch - 81ms/step\n",
      "Epoch 372/2000\n",
      "1/1 - 0s - loss: 2.0689 - val_loss: 2.0994 - 80ms/epoch - 80ms/step\n",
      "Epoch 373/2000\n",
      "1/1 - 0s - loss: 2.0683 - val_loss: 2.0988 - 80ms/epoch - 80ms/step\n",
      "Epoch 374/2000\n",
      "1/1 - 0s - loss: 2.0677 - val_loss: 2.0983 - 82ms/epoch - 82ms/step\n",
      "Epoch 375/2000\n",
      "1/1 - 0s - loss: 2.0671 - val_loss: 2.0978 - 80ms/epoch - 80ms/step\n",
      "Epoch 376/2000\n",
      "1/1 - 0s - loss: 2.0665 - val_loss: 2.0973 - 81ms/epoch - 81ms/step\n",
      "Epoch 377/2000\n",
      "1/1 - 0s - loss: 2.0659 - val_loss: 2.0968 - 93ms/epoch - 93ms/step\n",
      "Epoch 378/2000\n",
      "1/1 - 0s - loss: 2.0653 - val_loss: 2.0963 - 118ms/epoch - 118ms/step\n",
      "Epoch 379/2000\n",
      "1/1 - 0s - loss: 2.0647 - val_loss: 2.0957 - 90ms/epoch - 90ms/step\n",
      "Epoch 380/2000\n",
      "1/1 - 0s - loss: 2.0641 - val_loss: 2.0953 - 91ms/epoch - 91ms/step\n",
      "Epoch 381/2000\n",
      "1/1 - 0s - loss: 2.0635 - val_loss: 2.0948 - 88ms/epoch - 88ms/step\n",
      "Epoch 382/2000\n",
      "1/1 - 0s - loss: 2.0629 - val_loss: 2.0943 - 91ms/epoch - 91ms/step\n",
      "Epoch 383/2000\n",
      "1/1 - 0s - loss: 2.0623 - val_loss: 2.0938 - 88ms/epoch - 88ms/step\n",
      "Epoch 384/2000\n",
      "1/1 - 0s - loss: 2.0618 - val_loss: 2.0933 - 89ms/epoch - 89ms/step\n",
      "Epoch 385/2000\n",
      "1/1 - 0s - loss: 2.0612 - val_loss: 2.0928 - 90ms/epoch - 90ms/step\n",
      "Epoch 386/2000\n",
      "1/1 - 0s - loss: 2.0606 - val_loss: 2.0923 - 87ms/epoch - 87ms/step\n",
      "Epoch 387/2000\n",
      "1/1 - 0s - loss: 2.0601 - val_loss: 2.0918 - 87ms/epoch - 87ms/step\n",
      "Epoch 388/2000\n",
      "1/1 - 0s - loss: 2.0595 - val_loss: 2.0914 - 92ms/epoch - 92ms/step\n",
      "Epoch 389/2000\n",
      "1/1 - 0s - loss: 2.0590 - val_loss: 2.0909 - 87ms/epoch - 87ms/step\n",
      "Epoch 390/2000\n",
      "1/1 - 0s - loss: 2.0584 - val_loss: 2.0904 - 99ms/epoch - 99ms/step\n",
      "Epoch 391/2000\n",
      "1/1 - 0s - loss: 2.0579 - val_loss: 2.0900 - 92ms/epoch - 92ms/step\n",
      "Epoch 392/2000\n",
      "1/1 - 0s - loss: 2.0573 - val_loss: 2.0896 - 92ms/epoch - 92ms/step\n",
      "Epoch 393/2000\n",
      "1/1 - 0s - loss: 2.0568 - val_loss: 2.0891 - 90ms/epoch - 90ms/step\n",
      "Epoch 394/2000\n",
      "1/1 - 0s - loss: 2.0563 - val_loss: 2.0886 - 89ms/epoch - 89ms/step\n",
      "Epoch 395/2000\n",
      "1/1 - 0s - loss: 2.0557 - val_loss: 2.0882 - 98ms/epoch - 98ms/step\n",
      "Epoch 396/2000\n",
      "1/1 - 0s - loss: 2.0552 - val_loss: 2.0877 - 98ms/epoch - 98ms/step\n",
      "Epoch 397/2000\n",
      "1/1 - 0s - loss: 2.0547 - val_loss: 2.0873 - 86ms/epoch - 86ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 398/2000\n",
      "1/1 - 0s - loss: 2.0542 - val_loss: 2.0869 - 85ms/epoch - 85ms/step\n",
      "Epoch 399/2000\n",
      "1/1 - 0s - loss: 2.0537 - val_loss: 2.0864 - 86ms/epoch - 86ms/step\n",
      "Epoch 400/2000\n",
      "1/1 - 0s - loss: 2.0532 - val_loss: 2.0859 - 85ms/epoch - 85ms/step\n",
      "Epoch 401/2000\n",
      "1/1 - 0s - loss: 2.0526 - val_loss: 2.0855 - 89ms/epoch - 89ms/step\n",
      "Epoch 402/2000\n",
      "1/1 - 0s - loss: 2.0521 - val_loss: 2.0851 - 88ms/epoch - 88ms/step\n",
      "Epoch 403/2000\n",
      "1/1 - 0s - loss: 2.0516 - val_loss: 2.0847 - 87ms/epoch - 87ms/step\n",
      "Epoch 404/2000\n",
      "1/1 - 0s - loss: 2.0512 - val_loss: 2.0842 - 99ms/epoch - 99ms/step\n",
      "Epoch 405/2000\n",
      "1/1 - 0s - loss: 2.0507 - val_loss: 2.0838 - 111ms/epoch - 111ms/step\n",
      "Epoch 406/2000\n",
      "1/1 - 0s - loss: 2.0502 - val_loss: 2.0834 - 88ms/epoch - 88ms/step\n",
      "Epoch 407/2000\n",
      "1/1 - 0s - loss: 2.0497 - val_loss: 2.0830 - 112ms/epoch - 112ms/step\n",
      "Epoch 408/2000\n",
      "1/1 - 0s - loss: 2.0492 - val_loss: 2.0826 - 88ms/epoch - 88ms/step\n",
      "Epoch 409/2000\n",
      "1/1 - 0s - loss: 2.0487 - val_loss: 2.0822 - 88ms/epoch - 88ms/step\n",
      "Epoch 410/2000\n",
      "1/1 - 0s - loss: 2.0483 - val_loss: 2.0818 - 100ms/epoch - 100ms/step\n",
      "Epoch 411/2000\n",
      "1/1 - 0s - loss: 2.0478 - val_loss: 2.0815 - 88ms/epoch - 88ms/step\n",
      "Epoch 412/2000\n",
      "1/1 - 0s - loss: 2.0473 - val_loss: 2.0811 - 88ms/epoch - 88ms/step\n",
      "Epoch 413/2000\n",
      "1/1 - 0s - loss: 2.0469 - val_loss: 2.0807 - 128ms/epoch - 128ms/step\n",
      "Epoch 414/2000\n",
      "1/1 - 0s - loss: 2.0464 - val_loss: 2.0803 - 87ms/epoch - 87ms/step\n",
      "Epoch 415/2000\n",
      "1/1 - 0s - loss: 2.0459 - val_loss: 2.0800 - 103ms/epoch - 103ms/step\n",
      "Epoch 416/2000\n",
      "1/1 - 0s - loss: 2.0455 - val_loss: 2.0796 - 88ms/epoch - 88ms/step\n",
      "Epoch 417/2000\n",
      "1/1 - 0s - loss: 2.0450 - val_loss: 2.0792 - 89ms/epoch - 89ms/step\n",
      "Epoch 418/2000\n",
      "1/1 - 0s - loss: 2.0446 - val_loss: 2.0789 - 90ms/epoch - 90ms/step\n",
      "Epoch 419/2000\n",
      "1/1 - 0s - loss: 2.0441 - val_loss: 2.0785 - 88ms/epoch - 88ms/step\n",
      "Epoch 420/2000\n",
      "1/1 - 0s - loss: 2.0437 - val_loss: 2.0781 - 89ms/epoch - 89ms/step\n",
      "Epoch 421/2000\n",
      "1/1 - 0s - loss: 2.0433 - val_loss: 2.0778 - 87ms/epoch - 87ms/step\n",
      "Epoch 422/2000\n",
      "1/1 - 0s - loss: 2.0428 - val_loss: 2.0775 - 88ms/epoch - 88ms/step\n",
      "Epoch 423/2000\n",
      "1/1 - 0s - loss: 2.0424 - val_loss: 2.0771 - 96ms/epoch - 96ms/step\n",
      "Epoch 424/2000\n",
      "1/1 - 0s - loss: 2.0420 - val_loss: 2.0768 - 104ms/epoch - 104ms/step\n",
      "Epoch 425/2000\n",
      "1/1 - 0s - loss: 2.0415 - val_loss: 2.0764 - 106ms/epoch - 106ms/step\n",
      "Epoch 426/2000\n",
      "1/1 - 0s - loss: 2.0411 - val_loss: 2.0762 - 86ms/epoch - 86ms/step\n",
      "Epoch 427/2000\n",
      "1/1 - 0s - loss: 2.0407 - val_loss: 2.0758 - 86ms/epoch - 86ms/step\n",
      "Epoch 428/2000\n",
      "1/1 - 0s - loss: 2.0403 - val_loss: 2.0755 - 94ms/epoch - 94ms/step\n",
      "Epoch 429/2000\n",
      "1/1 - 0s - loss: 2.0399 - val_loss: 2.0752 - 89ms/epoch - 89ms/step\n",
      "Epoch 430/2000\n",
      "1/1 - 0s - loss: 2.0395 - val_loss: 2.0748 - 89ms/epoch - 89ms/step\n",
      "Epoch 431/2000\n",
      "1/1 - 0s - loss: 2.0390 - val_loss: 2.0745 - 108ms/epoch - 108ms/step\n",
      "Epoch 432/2000\n",
      "1/1 - 0s - loss: 2.0386 - val_loss: 2.0742 - 102ms/epoch - 102ms/step\n",
      "Epoch 433/2000\n",
      "1/1 - 0s - loss: 2.0382 - val_loss: 2.0738 - 98ms/epoch - 98ms/step\n",
      "Epoch 434/2000\n",
      "1/1 - 0s - loss: 2.0378 - val_loss: 2.0736 - 136ms/epoch - 136ms/step\n",
      "Epoch 435/2000\n",
      "1/1 - 0s - loss: 2.0374 - val_loss: 2.0732 - 97ms/epoch - 97ms/step\n",
      "Epoch 436/2000\n",
      "1/1 - 0s - loss: 2.0371 - val_loss: 2.0729 - 99ms/epoch - 99ms/step\n",
      "Epoch 437/2000\n",
      "1/1 - 0s - loss: 2.0367 - val_loss: 2.0726 - 96ms/epoch - 96ms/step\n",
      "Epoch 438/2000\n",
      "1/1 - 0s - loss: 2.0363 - val_loss: 2.0723 - 96ms/epoch - 96ms/step\n",
      "Epoch 439/2000\n",
      "1/1 - 0s - loss: 2.0359 - val_loss: 2.0721 - 103ms/epoch - 103ms/step\n",
      "Epoch 440/2000\n",
      "1/1 - 0s - loss: 2.0355 - val_loss: 2.0717 - 97ms/epoch - 97ms/step\n",
      "Epoch 441/2000\n",
      "1/1 - 0s - loss: 2.0351 - val_loss: 2.0715 - 102ms/epoch - 102ms/step\n",
      "Epoch 442/2000\n",
      "1/1 - 0s - loss: 2.0348 - val_loss: 2.0712 - 102ms/epoch - 102ms/step\n",
      "Epoch 443/2000\n",
      "1/1 - 0s - loss: 2.0344 - val_loss: 2.0708 - 98ms/epoch - 98ms/step\n",
      "Epoch 444/2000\n",
      "1/1 - 0s - loss: 2.0340 - val_loss: 2.0706 - 99ms/epoch - 99ms/step\n",
      "Epoch 445/2000\n",
      "1/1 - 0s - loss: 2.0336 - val_loss: 2.0703 - 108ms/epoch - 108ms/step\n",
      "Epoch 446/2000\n",
      "1/1 - 0s - loss: 2.0333 - val_loss: 2.0700 - 143ms/epoch - 143ms/step\n",
      "Epoch 447/2000\n",
      "1/1 - 0s - loss: 2.0329 - val_loss: 2.0698 - 100ms/epoch - 100ms/step\n",
      "Epoch 448/2000\n",
      "1/1 - 0s - loss: 2.0326 - val_loss: 2.0694 - 99ms/epoch - 99ms/step\n",
      "Epoch 449/2000\n",
      "1/1 - 0s - loss: 2.0322 - val_loss: 2.0692 - 104ms/epoch - 104ms/step\n",
      "Epoch 450/2000\n",
      "1/1 - 0s - loss: 2.0319 - val_loss: 2.0690 - 97ms/epoch - 97ms/step\n",
      "Epoch 451/2000\n",
      "1/1 - 0s - loss: 2.0315 - val_loss: 2.0686 - 107ms/epoch - 107ms/step\n",
      "Epoch 452/2000\n",
      "1/1 - 0s - loss: 2.0312 - val_loss: 2.0684 - 103ms/epoch - 103ms/step\n",
      "Epoch 453/2000\n",
      "1/1 - 0s - loss: 2.0308 - val_loss: 2.0681 - 114ms/epoch - 114ms/step\n",
      "Epoch 454/2000\n",
      "1/1 - 0s - loss: 2.0305 - val_loss: 2.0677 - 131ms/epoch - 131ms/step\n",
      "Epoch 455/2000\n",
      "1/1 - 0s - loss: 2.0301 - val_loss: 2.0676 - 116ms/epoch - 116ms/step\n",
      "Epoch 456/2000\n",
      "1/1 - 0s - loss: 2.0298 - val_loss: 2.0674 - 107ms/epoch - 107ms/step\n",
      "Epoch 457/2000\n",
      "1/1 - 0s - loss: 2.0294 - val_loss: 2.0670 - 129ms/epoch - 129ms/step\n",
      "Epoch 458/2000\n",
      "1/1 - 0s - loss: 2.0291 - val_loss: 2.0669 - 123ms/epoch - 123ms/step\n",
      "Epoch 459/2000\n",
      "1/1 - 0s - loss: 2.0288 - val_loss: 2.0666 - 138ms/epoch - 138ms/step\n",
      "Epoch 460/2000\n",
      "1/1 - 0s - loss: 2.0284 - val_loss: 2.0663 - 97ms/epoch - 97ms/step\n",
      "Epoch 461/2000\n",
      "1/1 - 0s - loss: 2.0281 - val_loss: 2.0661 - 101ms/epoch - 101ms/step\n",
      "Epoch 462/2000\n",
      "1/1 - 0s - loss: 2.0278 - val_loss: 2.0657 - 110ms/epoch - 110ms/step\n",
      "Epoch 463/2000\n",
      "1/1 - 0s - loss: 2.0275 - val_loss: 2.0656 - 141ms/epoch - 141ms/step\n",
      "Epoch 464/2000\n",
      "1/1 - 0s - loss: 2.0271 - val_loss: 2.0653 - 97ms/epoch - 97ms/step\n",
      "Epoch 465/2000\n",
      "1/1 - 0s - loss: 2.0268 - val_loss: 2.0652 - 135ms/epoch - 135ms/step\n",
      "Epoch 466/2000\n",
      "1/1 - 0s - loss: 2.0265 - val_loss: 2.0650 - 120ms/epoch - 120ms/step\n",
      "Epoch 467/2000\n",
      "1/1 - 0s - loss: 2.0262 - val_loss: 2.0647 - 110ms/epoch - 110ms/step\n",
      "Epoch 468/2000\n",
      "1/1 - 0s - loss: 2.0259 - val_loss: 2.0645 - 114ms/epoch - 114ms/step\n",
      "Epoch 469/2000\n",
      "1/1 - 0s - loss: 2.0255 - val_loss: 2.0643 - 117ms/epoch - 117ms/step\n",
      "Epoch 470/2000\n",
      "1/1 - 0s - loss: 2.0252 - val_loss: 2.0639 - 101ms/epoch - 101ms/step\n",
      "Epoch 471/2000\n",
      "1/1 - 0s - loss: 2.0249 - val_loss: 2.0639 - 101ms/epoch - 101ms/step\n",
      "Epoch 472/2000\n",
      "1/1 - 0s - loss: 2.0246 - val_loss: 2.0636 - 99ms/epoch - 99ms/step\n",
      "Epoch 473/2000\n",
      "1/1 - 0s - loss: 2.0243 - val_loss: 2.0635 - 98ms/epoch - 98ms/step\n",
      "Epoch 474/2000\n",
      "1/1 - 0s - loss: 2.0240 - val_loss: 2.0632 - 100ms/epoch - 100ms/step\n",
      "Epoch 475/2000\n",
      "1/1 - 0s - loss: 2.0237 - val_loss: 2.0631 - 137ms/epoch - 137ms/step\n",
      "Epoch 476/2000\n",
      "1/1 - 0s - loss: 2.0234 - val_loss: 2.0628 - 105ms/epoch - 105ms/step\n",
      "Epoch 477/2000\n",
      "1/1 - 0s - loss: 2.0231 - val_loss: 2.0626 - 97ms/epoch - 97ms/step\n",
      "Epoch 478/2000\n",
      "1/1 - 0s - loss: 2.0228 - val_loss: 2.0624 - 100ms/epoch - 100ms/step\n",
      "Epoch 479/2000\n",
      "1/1 - 0s - loss: 2.0225 - val_loss: 2.0621 - 133ms/epoch - 133ms/step\n",
      "Epoch 480/2000\n",
      "1/1 - 0s - loss: 2.0222 - val_loss: 2.0620 - 96ms/epoch - 96ms/step\n",
      "Epoch 481/2000\n",
      "1/1 - 0s - loss: 2.0219 - val_loss: 2.0617 - 104ms/epoch - 104ms/step\n",
      "Epoch 482/2000\n",
      "1/1 - 0s - loss: 2.0217 - val_loss: 2.0616 - 156ms/epoch - 156ms/step\n",
      "Epoch 483/2000\n",
      "1/1 - 0s - loss: 2.0214 - val_loss: 2.0613 - 100ms/epoch - 100ms/step\n",
      "Epoch 484/2000\n",
      "1/1 - 0s - loss: 2.0211 - val_loss: 2.0611 - 100ms/epoch - 100ms/step\n",
      "Epoch 485/2000\n",
      "1/1 - 0s - loss: 2.0208 - val_loss: 2.0611 - 137ms/epoch - 137ms/step\n",
      "Epoch 486/2000\n",
      "1/1 - 0s - loss: 2.0205 - val_loss: 2.0607 - 97ms/epoch - 97ms/step\n",
      "Epoch 487/2000\n",
      "1/1 - 0s - loss: 2.0203 - val_loss: 2.0606 - 139ms/epoch - 139ms/step\n",
      "Epoch 488/2000\n",
      "1/1 - 0s - loss: 2.0199 - val_loss: 2.0605 - 136ms/epoch - 136ms/step\n",
      "Epoch 489/2000\n",
      "1/1 - 0s - loss: 2.0197 - val_loss: 2.0602 - 107ms/epoch - 107ms/step\n",
      "Epoch 490/2000\n",
      "1/1 - 0s - loss: 2.0194 - val_loss: 2.0602 - 111ms/epoch - 111ms/step\n",
      "Epoch 491/2000\n",
      "1/1 - 0s - loss: 2.0191 - val_loss: 2.0599 - 135ms/epoch - 135ms/step\n",
      "Epoch 492/2000\n",
      "1/1 - 0s - loss: 2.0188 - val_loss: 2.0598 - 105ms/epoch - 105ms/step\n",
      "Epoch 493/2000\n",
      "1/1 - 0s - loss: 2.0186 - val_loss: 2.0596 - 97ms/epoch - 97ms/step\n",
      "Epoch 494/2000\n",
      "1/1 - 0s - loss: 2.0183 - val_loss: 2.0594 - 100ms/epoch - 100ms/step\n",
      "Epoch 495/2000\n",
      "1/1 - 0s - loss: 2.0180 - val_loss: 2.0592 - 110ms/epoch - 110ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 496/2000\n",
      "1/1 - 0s - loss: 2.0178 - val_loss: 2.0591 - 97ms/epoch - 97ms/step\n",
      "Epoch 497/2000\n",
      "1/1 - 0s - loss: 2.0175 - val_loss: 2.0591 - 95ms/epoch - 95ms/step\n",
      "Epoch 498/2000\n",
      "1/1 - 0s - loss: 2.0173 - val_loss: 2.0588 - 96ms/epoch - 96ms/step\n",
      "Epoch 499/2000\n",
      "1/1 - 0s - loss: 2.0170 - val_loss: 2.0587 - 96ms/epoch - 96ms/step\n",
      "Epoch 500/2000\n",
      "1/1 - 0s - loss: 2.0167 - val_loss: 2.0585 - 97ms/epoch - 97ms/step\n",
      "Epoch 501/2000\n",
      "1/1 - 0s - loss: 2.0165 - val_loss: 2.0583 - 111ms/epoch - 111ms/step\n",
      "Epoch 502/2000\n",
      "1/1 - 0s - loss: 2.0162 - val_loss: 2.0581 - 98ms/epoch - 98ms/step\n",
      "Epoch 503/2000\n",
      "1/1 - 0s - loss: 2.0160 - val_loss: 2.0580 - 135ms/epoch - 135ms/step\n",
      "Epoch 504/2000\n",
      "1/1 - 0s - loss: 2.0157 - val_loss: 2.0578 - 103ms/epoch - 103ms/step\n",
      "Epoch 505/2000\n",
      "1/1 - 0s - loss: 2.0155 - val_loss: 2.0577 - 99ms/epoch - 99ms/step\n",
      "Epoch 506/2000\n",
      "1/1 - 0s - loss: 2.0152 - val_loss: 2.0576 - 98ms/epoch - 98ms/step\n",
      "Epoch 507/2000\n",
      "1/1 - 0s - loss: 2.0150 - val_loss: 2.0573 - 102ms/epoch - 102ms/step\n",
      "Epoch 508/2000\n",
      "1/1 - 0s - loss: 2.0147 - val_loss: 2.0573 - 102ms/epoch - 102ms/step\n",
      "Epoch 509/2000\n",
      "1/1 - 0s - loss: 2.0145 - val_loss: 2.0570 - 104ms/epoch - 104ms/step\n",
      "Epoch 510/2000\n",
      "1/1 - 0s - loss: 2.0142 - val_loss: 2.0569 - 99ms/epoch - 99ms/step\n",
      "Epoch 511/2000\n",
      "1/1 - 0s - loss: 2.0140 - val_loss: 2.0568 - 103ms/epoch - 103ms/step\n",
      "Epoch 512/2000\n",
      "1/1 - 0s - loss: 2.0138 - val_loss: 2.0567 - 97ms/epoch - 97ms/step\n",
      "Epoch 513/2000\n",
      "1/1 - 0s - loss: 2.0135 - val_loss: 2.0564 - 96ms/epoch - 96ms/step\n",
      "Epoch 514/2000\n",
      "1/1 - 0s - loss: 2.0133 - val_loss: 2.0564 - 130ms/epoch - 130ms/step\n",
      "Epoch 515/2000\n",
      "1/1 - 0s - loss: 2.0131 - val_loss: 2.0563 - 98ms/epoch - 98ms/step\n",
      "Epoch 516/2000\n",
      "1/1 - 0s - loss: 2.0128 - val_loss: 2.0561 - 93ms/epoch - 93ms/step\n",
      "Epoch 517/2000\n",
      "1/1 - 0s - loss: 2.0126 - val_loss: 2.0561 - 96ms/epoch - 96ms/step\n",
      "Epoch 518/2000\n",
      "1/1 - 0s - loss: 2.0124 - val_loss: 2.0558 - 104ms/epoch - 104ms/step\n",
      "Epoch 519/2000\n",
      "1/1 - 0s - loss: 2.0122 - val_loss: 2.0557 - 99ms/epoch - 99ms/step\n",
      "Epoch 520/2000\n",
      "1/1 - 0s - loss: 2.0119 - val_loss: 2.0558 - 91ms/epoch - 91ms/step\n",
      "Epoch 521/2000\n",
      "1/1 - 0s - loss: 2.0117 - val_loss: 2.0554 - 99ms/epoch - 99ms/step\n",
      "Epoch 522/2000\n",
      "1/1 - 0s - loss: 2.0115 - val_loss: 2.0555 - 92ms/epoch - 92ms/step\n",
      "Epoch 523/2000\n",
      "1/1 - 0s - loss: 2.0112 - val_loss: 2.0554 - 93ms/epoch - 93ms/step\n",
      "Epoch 524/2000\n",
      "1/1 - 0s - loss: 2.0110 - val_loss: 2.0552 - 137ms/epoch - 137ms/step\n",
      "Epoch 525/2000\n",
      "1/1 - 0s - loss: 2.0108 - val_loss: 2.0550 - 107ms/epoch - 107ms/step\n",
      "Epoch 526/2000\n",
      "1/1 - 0s - loss: 2.0106 - val_loss: 2.0551 - 93ms/epoch - 93ms/step\n",
      "Epoch 527/2000\n",
      "1/1 - 0s - loss: 2.0104 - val_loss: 2.0549 - 102ms/epoch - 102ms/step\n",
      "Epoch 528/2000\n",
      "1/1 - 0s - loss: 2.0101 - val_loss: 2.0547 - 105ms/epoch - 105ms/step\n",
      "Epoch 529/2000\n",
      "1/1 - 0s - loss: 2.0099 - val_loss: 2.0547 - 131ms/epoch - 131ms/step\n",
      "Epoch 530/2000\n",
      "1/1 - 0s - loss: 2.0097 - val_loss: 2.0546 - 107ms/epoch - 107ms/step\n",
      "Epoch 531/2000\n",
      "1/1 - 0s - loss: 2.0095 - val_loss: 2.0545 - 101ms/epoch - 101ms/step\n",
      "Epoch 532/2000\n",
      "1/1 - 0s - loss: 2.0093 - val_loss: 2.0543 - 100ms/epoch - 100ms/step\n",
      "Epoch 533/2000\n",
      "1/1 - 0s - loss: 2.0091 - val_loss: 2.0543 - 99ms/epoch - 99ms/step\n",
      "Epoch 534/2000\n",
      "1/1 - 0s - loss: 2.0089 - val_loss: 2.0541 - 139ms/epoch - 139ms/step\n",
      "Epoch 535/2000\n",
      "1/1 - 0s - loss: 2.0087 - val_loss: 2.0539 - 96ms/epoch - 96ms/step\n",
      "Epoch 536/2000\n",
      "1/1 - 0s - loss: 2.0085 - val_loss: 2.0541 - 128ms/epoch - 128ms/step\n",
      "Epoch 537/2000\n",
      "1/1 - 0s - loss: 2.0083 - val_loss: 2.0538 - 109ms/epoch - 109ms/step\n",
      "Epoch 538/2000\n",
      "1/1 - 0s - loss: 2.0081 - val_loss: 2.0535 - 98ms/epoch - 98ms/step\n",
      "Epoch 539/2000\n",
      "1/1 - 0s - loss: 2.0079 - val_loss: 2.0539 - 91ms/epoch - 91ms/step\n",
      "Epoch 540/2000\n",
      "1/1 - 0s - loss: 2.0077 - val_loss: 2.0535 - 97ms/epoch - 97ms/step\n",
      "Epoch 541/2000\n",
      "1/1 - 0s - loss: 2.0074 - val_loss: 2.0533 - 102ms/epoch - 102ms/step\n",
      "Epoch 542/2000\n",
      "1/1 - 0s - loss: 2.0072 - val_loss: 2.0536 - 148ms/epoch - 148ms/step\n",
      "Epoch 543/2000\n",
      "1/1 - 0s - loss: 2.0071 - val_loss: 2.0532 - 97ms/epoch - 97ms/step\n",
      "Epoch 544/2000\n",
      "1/1 - 0s - loss: 2.0068 - val_loss: 2.0532 - 137ms/epoch - 137ms/step\n",
      "Epoch 545/2000\n",
      "1/1 - 0s - loss: 2.0067 - val_loss: 2.0532 - 129ms/epoch - 129ms/step\n",
      "Epoch 546/2000\n",
      "1/1 - 0s - loss: 2.0065 - val_loss: 2.0530 - 97ms/epoch - 97ms/step\n",
      "Epoch 547/2000\n",
      "1/1 - 0s - loss: 2.0063 - val_loss: 2.0529 - 98ms/epoch - 98ms/step\n",
      "Epoch 548/2000\n",
      "1/1 - 0s - loss: 2.0061 - val_loss: 2.0529 - 96ms/epoch - 96ms/step\n",
      "Epoch 549/2000\n",
      "1/1 - 0s - loss: 2.0059 - val_loss: 2.0527 - 96ms/epoch - 96ms/step\n",
      "Epoch 550/2000\n",
      "1/1 - 0s - loss: 2.0057 - val_loss: 2.0525 - 139ms/epoch - 139ms/step\n",
      "Epoch 551/2000\n",
      "1/1 - 0s - loss: 2.0055 - val_loss: 2.0529 - 93ms/epoch - 93ms/step\n",
      "Epoch 552/2000\n",
      "1/1 - 0s - loss: 2.0053 - val_loss: 2.0525 - 96ms/epoch - 96ms/step\n",
      "Epoch 553/2000\n",
      "1/1 - 0s - loss: 2.0051 - val_loss: 2.0524 - 146ms/epoch - 146ms/step\n",
      "Epoch 554/2000\n",
      "1/1 - 0s - loss: 2.0049 - val_loss: 2.0525 - 132ms/epoch - 132ms/step\n",
      "Epoch 555/2000\n",
      "1/1 - 0s - loss: 2.0048 - val_loss: 2.0522 - 99ms/epoch - 99ms/step\n",
      "Epoch 556/2000\n",
      "1/1 - 0s - loss: 2.0046 - val_loss: 2.0521 - 137ms/epoch - 137ms/step\n",
      "Epoch 557/2000\n",
      "1/1 - 0s - loss: 2.0044 - val_loss: 2.0520 - 97ms/epoch - 97ms/step\n",
      "Epoch 558/2000\n",
      "1/1 - 0s - loss: 2.0042 - val_loss: 2.0521 - 99ms/epoch - 99ms/step\n",
      "Epoch 559/2000\n",
      "1/1 - 0s - loss: 2.0040 - val_loss: 2.0518 - 97ms/epoch - 97ms/step\n",
      "Epoch 560/2000\n",
      "1/1 - 0s - loss: 2.0038 - val_loss: 2.0521 - 94ms/epoch - 94ms/step\n",
      "Epoch 561/2000\n",
      "1/1 - 0s - loss: 2.0036 - val_loss: 2.0519 - 90ms/epoch - 90ms/step\n",
      "Epoch 562/2000\n",
      "1/1 - 0s - loss: 2.0035 - val_loss: 2.0516 - 100ms/epoch - 100ms/step\n",
      "Epoch 563/2000\n",
      "1/1 - 0s - loss: 2.0033 - val_loss: 2.0517 - 137ms/epoch - 137ms/step\n",
      "Epoch 564/2000\n",
      "1/1 - 0s - loss: 2.0031 - val_loss: 2.0514 - 98ms/epoch - 98ms/step\n",
      "Epoch 565/2000\n",
      "1/1 - 0s - loss: 2.0029 - val_loss: 2.0514 - 97ms/epoch - 97ms/step\n",
      "Epoch 566/2000\n",
      "1/1 - 0s - loss: 2.0028 - val_loss: 2.0512 - 138ms/epoch - 138ms/step\n",
      "Epoch 567/2000\n",
      "1/1 - 0s - loss: 2.0026 - val_loss: 2.0513 - 129ms/epoch - 129ms/step\n",
      "Epoch 568/2000\n",
      "1/1 - 0s - loss: 2.0025 - val_loss: 2.0511 - 139ms/epoch - 139ms/step\n",
      "Epoch 569/2000\n",
      "1/1 - 0s - loss: 2.0023 - val_loss: 2.0513 - 93ms/epoch - 93ms/step\n",
      "Epoch 570/2000\n",
      "1/1 - 0s - loss: 2.0021 - val_loss: 2.0510 - 99ms/epoch - 99ms/step\n",
      "Epoch 571/2000\n",
      "1/1 - 0s - loss: 2.0019 - val_loss: 2.0510 - 136ms/epoch - 136ms/step\n",
      "Epoch 572/2000\n",
      "1/1 - 0s - loss: 2.0018 - val_loss: 2.0512 - 91ms/epoch - 91ms/step\n",
      "Epoch 573/2000\n",
      "1/1 - 0s - loss: 2.0016 - val_loss: 2.0509 - 98ms/epoch - 98ms/step\n",
      "Epoch 574/2000\n",
      "1/1 - 0s - loss: 2.0014 - val_loss: 2.0511 - 91ms/epoch - 91ms/step\n",
      "Epoch 575/2000\n",
      "1/1 - 0s - loss: 2.0012 - val_loss: 2.0509 - 97ms/epoch - 97ms/step\n",
      "Epoch 576/2000\n",
      "1/1 - 0s - loss: 2.0011 - val_loss: 2.0507 - 97ms/epoch - 97ms/step\n",
      "Epoch 577/2000\n",
      "1/1 - 0s - loss: 2.0009 - val_loss: 2.0505 - 99ms/epoch - 99ms/step\n",
      "Epoch 578/2000\n",
      "1/1 - 0s - loss: 2.0007 - val_loss: 2.0506 - 94ms/epoch - 94ms/step\n",
      "Epoch 579/2000\n",
      "1/1 - 0s - loss: 2.0006 - val_loss: 2.0507 - 100ms/epoch - 100ms/step\n",
      "Epoch 580/2000\n",
      "1/1 - 0s - loss: 2.0004 - val_loss: 2.0502 - 101ms/epoch - 101ms/step\n",
      "Epoch 581/2000\n",
      "1/1 - 0s - loss: 2.0003 - val_loss: 2.0509 - 95ms/epoch - 95ms/step\n",
      "Epoch 582/2000\n",
      "1/1 - 0s - loss: 2.0001 - val_loss: 2.0503 - 130ms/epoch - 130ms/step\n",
      "Epoch 583/2000\n",
      "1/1 - 0s - loss: 1.9999 - val_loss: 2.0504 - 92ms/epoch - 92ms/step\n",
      "Epoch 584/2000\n",
      "1/1 - 0s - loss: 1.9998 - val_loss: 2.0503 - 92ms/epoch - 92ms/step\n",
      "Epoch 585/2000\n",
      "1/1 - 0s - loss: 1.9996 - val_loss: 2.0502 - 110ms/epoch - 110ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 10\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "position_embeddings = PositionEmbedding(sequence_length=len(x_masked_train[0]))(word_embeddings)\n",
    "encoder_output = word_embeddings + position_embeddings\n",
    "\n",
    "for i in range(5):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "aa33aaff",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fe4d2103",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.0, 0.918796)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0005b730",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 4s - loss: 4.0062 - val_loss: 3.5385 - 4s/epoch - 4s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.5311 - val_loss: 3.3683 - 375ms/epoch - 375ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.3743 - val_loss: 3.2384 - 325ms/epoch - 325ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.2523 - val_loss: 3.1908 - 308ms/epoch - 308ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.2024 - val_loss: 3.1443 - 325ms/epoch - 325ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.1461 - val_loss: 3.0952 - 311ms/epoch - 311ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.0865 - val_loss: 3.0677 - 399ms/epoch - 399ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.0521 - val_loss: 3.0483 - 327ms/epoch - 327ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.0308 - val_loss: 3.0379 - 320ms/epoch - 320ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 3.0223 - val_loss: 3.0383 - 372ms/epoch - 372ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 3.0258 - val_loss: 3.0377 - 401ms/epoch - 401ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 1s - loss: 3.0276 - val_loss: 3.0303 - 541ms/epoch - 541ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 1s - loss: 3.0217 - val_loss: 3.0218 - 567ms/epoch - 567ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 1s - loss: 3.0138 - val_loss: 3.0159 - 547ms/epoch - 547ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 1s - loss: 3.0074 - val_loss: 3.0120 - 563ms/epoch - 563ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 1s - loss: 3.0025 - val_loss: 3.0089 - 650ms/epoch - 650ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 1s - loss: 2.9983 - val_loss: 3.0067 - 551ms/epoch - 551ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 2.9952 - val_loss: 3.0056 - 448ms/epoch - 448ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 1s - loss: 2.9936 - val_loss: 3.0053 - 574ms/epoch - 574ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 1s - loss: 2.9932 - val_loss: 3.0055 - 554ms/epoch - 554ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 1s - loss: 2.9936 - val_loss: 3.0057 - 592ms/epoch - 592ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 1s - loss: 2.9943 - val_loss: 3.0055 - 549ms/epoch - 549ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 1s - loss: 2.9944 - val_loss: 3.0044 - 555ms/epoch - 555ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 1s - loss: 2.9935 - val_loss: 3.0028 - 532ms/epoch - 532ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 1s - loss: 2.9916 - val_loss: 3.0013 - 608ms/epoch - 608ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 1s - loss: 2.9894 - val_loss: 3.0004 - 527ms/epoch - 527ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 1s - loss: 2.9875 - val_loss: 3.0003 - 563ms/epoch - 563ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 1s - loss: 2.9862 - val_loss: 3.0006 - 569ms/epoch - 569ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 1s - loss: 2.9854 - val_loss: 3.0010 - 538ms/epoch - 538ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 1s - loss: 2.9850 - val_loss: 3.0010 - 670ms/epoch - 670ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 1s - loss: 2.9847 - val_loss: 3.0006 - 606ms/epoch - 606ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.9844 - val_loss: 2.9998 - 483ms/epoch - 483ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 1s - loss: 2.9840 - val_loss: 2.9986 - 616ms/epoch - 616ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 1s - loss: 2.9834 - val_loss: 2.9970 - 516ms/epoch - 516ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 1s - loss: 2.9824 - val_loss: 2.9953 - 554ms/epoch - 554ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 1s - loss: 2.9810 - val_loss: 2.9937 - 553ms/epoch - 553ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 1s - loss: 2.9794 - val_loss: 2.9925 - 544ms/epoch - 544ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 1s - loss: 2.9780 - val_loss: 2.9917 - 615ms/epoch - 615ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 1s - loss: 2.9766 - val_loss: 2.9910 - 535ms/epoch - 535ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 1s - loss: 2.9753 - val_loss: 2.9902 - 524ms/epoch - 524ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 1s - loss: 2.9739 - val_loss: 2.9893 - 565ms/epoch - 565ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 1s - loss: 2.9723 - val_loss: 2.9880 - 540ms/epoch - 540ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 1s - loss: 2.9705 - val_loss: 2.9861 - 622ms/epoch - 622ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 1s - loss: 2.9681 - val_loss: 2.9836 - 528ms/epoch - 528ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 1s - loss: 2.9651 - val_loss: 2.9805 - 562ms/epoch - 562ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 1s - loss: 2.9614 - val_loss: 2.9770 - 606ms/epoch - 606ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 1s - loss: 2.9570 - val_loss: 2.9732 - 563ms/epoch - 563ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 1s - loss: 2.9518 - val_loss: 2.9685 - 598ms/epoch - 598ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 1s - loss: 2.9455 - val_loss: 2.9625 - 548ms/epoch - 548ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.9372 - val_loss: 2.9545 - 477ms/epoch - 477ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.9265 - val_loss: 2.9436 - 472ms/epoch - 472ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 1s - loss: 2.9122 - val_loss: 2.9276 - 549ms/epoch - 549ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 1s - loss: 2.8922 - val_loss: 2.9036 - 532ms/epoch - 532ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.8632 - val_loss: 2.8661 - 497ms/epoch - 497ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 1s - loss: 2.8203 - val_loss: 2.8111 - 528ms/epoch - 528ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 1s - loss: 2.7583 - val_loss: 2.7333 - 549ms/epoch - 549ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 1s - loss: 2.6720 - val_loss: 2.6312 - 511ms/epoch - 511ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 1s - loss: 2.5616 - val_loss: 2.5235 - 501ms/epoch - 501ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 1s - loss: 2.4481 - val_loss: 2.4162 - 633ms/epoch - 633ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.3367 - val_loss: 2.3243 - 490ms/epoch - 490ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.2481 - val_loss: 2.2440 - 340ms/epoch - 340ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.1729 - val_loss: 2.1858 - 349ms/epoch - 349ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.1205 - val_loss: 2.1442 - 318ms/epoch - 318ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.0837 - val_loss: 2.1158 - 341ms/epoch - 341ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.0605 - val_loss: 2.0933 - 333ms/epoch - 333ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.0412 - val_loss: 2.0793 - 359ms/epoch - 359ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.0292 - val_loss: 2.0666 - 350ms/epoch - 350ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.0182 - val_loss: 2.0537 - 330ms/epoch - 330ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.0080 - val_loss: 2.0452 - 322ms/epoch - 322ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.0015 - val_loss: 2.0405 - 325ms/epoch - 325ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 1.9962 - val_loss: 2.0381 - 356ms/epoch - 356ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 1.9921 - val_loss: 2.0353 - 342ms/epoch - 342ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 1.9887 - val_loss: 2.0328 - 344ms/epoch - 344ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 1.9860 - val_loss: 2.0310 - 389ms/epoch - 389ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 1.9832 - val_loss: 2.0298 - 342ms/epoch - 342ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 1.9806 - val_loss: 2.0289 - 367ms/epoch - 367ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 1.9783 - val_loss: 2.0287 - 348ms/epoch - 348ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 1.9762 - val_loss: 2.0294 - 344ms/epoch - 344ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 1.9739 - val_loss: 2.0309 - 339ms/epoch - 339ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 1.9717 - val_loss: 2.0331 - 334ms/epoch - 334ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 1.9697 - val_loss: 2.0350 - 328ms/epoch - 328ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 1.9676 - val_loss: 2.0364 - 393ms/epoch - 393ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 100\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "position_embeddings = PositionEmbedding(sequence_length=len(x_masked_train[0]))(word_embeddings)\n",
    "encoder_output = word_embeddings + position_embeddings\n",
    "\n",
    "for i in range(5):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "beb07cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "dbdd4a69",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.0, 0.985382)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54f31bac",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
