{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d1686d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "56c107e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "504a342e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_module(query, key, value, embed_dim, num_head, i):\n",
    "    \n",
    "    # Multi headed self-attention\n",
    "    attention_output = layers.MultiHeadAttention(\n",
    "        num_heads=num_head,\n",
    "        key_dim=embed_dim // num_head,\n",
    "        name=\"encoder_{}/multiheadattention\".format(i)\n",
    "    )(query, key, value, use_causal_mask=True)\n",
    "    \n",
    "    # Add & Normalize\n",
    "    attention_output = layers.Add()([query, attention_output])  # Skip Connection\n",
    "    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)\n",
    "    \n",
    "    # Feedforward network\n",
    "    ff_net = keras.models.Sequential([\n",
    "        layers.Dense(2 * embed_dim, activation='relu', name=\"encoder_{}/ffn_dense_1\".format(i)),\n",
    "        layers.Dense(embed_dim, name=\"encoder_{}/ffn_dense_2\".format(i)),\n",
    "    ])\n",
    "\n",
    "    # Apply Feedforward network\n",
    "    ffn_output = ff_net(attention_output)\n",
    "\n",
    "    # Add & Normalize\n",
    "    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection\n",
    "    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)\n",
    "    \n",
    "    return ffn_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fef2622a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sinusoidal_embeddings(sequence_length, embedding_dim):\n",
    "    position_enc = np.array([\n",
    "        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]\n",
    "        if pos != 0 else np.zeros(embedding_dim)\n",
    "        for pos in range(sequence_length)\n",
    "    ])\n",
    "    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i\n",
    "    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1\n",
    "    return tf.cast(position_enc, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a158c53a",
   "metadata": {},
   "outputs": [],
   "source": [
    "pairs_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d8a1bc28",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('word_0', 'word_1', 'word_2'),\n",
       " ('word_0', 'word_1', 'word_3'),\n",
       " ('word_0', 'word_1', 'word_4'),\n",
       " ('word_0', 'word_1', 'word_5'),\n",
       " ('word_0', 'word_1', 'word_6'),\n",
       " ('word_0', 'word_1', 'word_7'),\n",
       " ('word_0', 'word_1', 'word_8'),\n",
       " ('word_0', 'word_1', 'word_9'),\n",
       " ('word_0', 'word_1', 'word_10'),\n",
       " ('word_0', 'word_1', 'word_11'),\n",
       " ('word_0', 'word_1', 'word_12'),\n",
       " ('word_0', 'word_1', 'word_13'),\n",
       " ('word_0', 'word_1', 'word_14'),\n",
       " ('word_0', 'word_1', 'word_15'),\n",
       " ('word_0', 'word_1', 'word_16'),\n",
       " ('word_0', 'word_1', 'word_17'),\n",
       " ('word_0', 'word_1', 'word_18'),\n",
       " ('word_0', 'word_1', 'word_19'),\n",
       " ('word_0', 'word_2', 'word_1'),\n",
       " ('word_0', 'word_2', 'word_3'),\n",
       " ('word_0', 'word_2', 'word_4'),\n",
       " ('word_0', 'word_2', 'word_5'),\n",
       " ('word_0', 'word_2', 'word_6'),\n",
       " ('word_0', 'word_2', 'word_7'),\n",
       " ('word_0', 'word_2', 'word_8'),\n",
       " ('word_0', 'word_2', 'word_9'),\n",
       " ('word_0', 'word_2', 'word_10'),\n",
       " ('word_0', 'word_2', 'word_11'),\n",
       " ('word_0', 'word_2', 'word_12'),\n",
       " ('word_0', 'word_2', 'word_13'),\n",
       " ('word_0', 'word_2', 'word_14'),\n",
       " ('word_0', 'word_2', 'word_15'),\n",
       " ('word_0', 'word_2', 'word_16'),\n",
       " ('word_0', 'word_2', 'word_17'),\n",
       " ('word_0', 'word_2', 'word_18'),\n",
       " ('word_0', 'word_2', 'word_19'),\n",
       " ('word_0', 'word_3', 'word_1'),\n",
       " ('word_0', 'word_3', 'word_2'),\n",
       " ('word_0', 'word_3', 'word_4'),\n",
       " ('word_0', 'word_3', 'word_5'),\n",
       " ('word_0', 'word_3', 'word_6'),\n",
       " ('word_0', 'word_3', 'word_7'),\n",
       " ('word_0', 'word_3', 'word_8'),\n",
       " ('word_0', 'word_3', 'word_9'),\n",
       " ('word_0', 'word_3', 'word_10'),\n",
       " ('word_0', 'word_3', 'word_11'),\n",
       " ('word_0', 'word_3', 'word_12'),\n",
       " ('word_0', 'word_3', 'word_13'),\n",
       " ('word_0', 'word_3', 'word_14'),\n",
       " ('word_0', 'word_3', 'word_15'),\n",
       " ('word_0', 'word_3', 'word_16'),\n",
       " ('word_0', 'word_3', 'word_17'),\n",
       " ('word_0', 'word_3', 'word_18'),\n",
       " ('word_0', 'word_3', 'word_19'),\n",
       " ('word_0', 'word_4', 'word_1'),\n",
       " ('word_0', 'word_4', 'word_2'),\n",
       " ('word_0', 'word_4', 'word_3'),\n",
       " ('word_0', 'word_4', 'word_5'),\n",
       " ('word_0', 'word_4', 'word_6'),\n",
       " ('word_0', 'word_4', 'word_7'),\n",
       " ('word_0', 'word_4', 'word_8'),\n",
       " ('word_0', 'word_4', 'word_9'),\n",
       " ('word_0', 'word_4', 'word_10'),\n",
       " ('word_0', 'word_4', 'word_11'),\n",
       " ('word_0', 'word_4', 'word_12'),\n",
       " ('word_0', 'word_4', 'word_13'),\n",
       " ('word_0', 'word_4', 'word_14'),\n",
       " ('word_0', 'word_4', 'word_15'),\n",
       " ('word_0', 'word_4', 'word_16'),\n",
       " ('word_0', 'word_4', 'word_17'),\n",
       " ('word_0', 'word_4', 'word_18'),\n",
       " ('word_0', 'word_4', 'word_19'),\n",
       " ('word_0', 'word_5', 'word_1'),\n",
       " ('word_0', 'word_5', 'word_2'),\n",
       " ('word_0', 'word_5', 'word_3'),\n",
       " ('word_0', 'word_5', 'word_4'),\n",
       " ('word_0', 'word_5', 'word_6'),\n",
       " ('word_0', 'word_5', 'word_7'),\n",
       " ('word_0', 'word_5', 'word_8'),\n",
       " ('word_0', 'word_5', 'word_9'),\n",
       " ('word_0', 'word_5', 'word_10'),\n",
       " ('word_0', 'word_5', 'word_11'),\n",
       " ('word_0', 'word_5', 'word_12'),\n",
       " ('word_0', 'word_5', 'word_13'),\n",
       " ('word_0', 'word_5', 'word_14'),\n",
       " ('word_0', 'word_5', 'word_15'),\n",
       " ('word_0', 'word_5', 'word_16'),\n",
       " ('word_0', 'word_5', 'word_17'),\n",
       " ('word_0', 'word_5', 'word_18'),\n",
       " ('word_0', 'word_5', 'word_19'),\n",
       " ('word_0', 'word_6', 'word_1'),\n",
       " ('word_0', 'word_6', 'word_2'),\n",
       " ('word_0', 'word_6', 'word_3'),\n",
       " ('word_0', 'word_6', 'word_4'),\n",
       " ('word_0', 'word_6', 'word_5'),\n",
       " ('word_0', 'word_6', 'word_7'),\n",
       " ('word_0', 'word_6', 'word_8'),\n",
       " ('word_0', 'word_6', 'word_9'),\n",
       " ('word_0', 'word_6', 'word_10'),\n",
       " ('word_0', 'word_6', 'word_11'),\n",
       " ('word_0', 'word_6', 'word_12'),\n",
       " ('word_0', 'word_6', 'word_13'),\n",
       " ('word_0', 'word_6', 'word_14'),\n",
       " ('word_0', 'word_6', 'word_15'),\n",
       " ('word_0', 'word_6', 'word_16'),\n",
       " ('word_0', 'word_6', 'word_17'),\n",
       " ('word_0', 'word_6', 'word_18'),\n",
       " ('word_0', 'word_6', 'word_19'),\n",
       " ('word_0', 'word_7', 'word_1'),\n",
       " ('word_0', 'word_7', 'word_2'),\n",
       " ('word_0', 'word_7', 'word_3'),\n",
       " ('word_0', 'word_7', 'word_4'),\n",
       " ('word_0', 'word_7', 'word_5'),\n",
       " ('word_0', 'word_7', 'word_6'),\n",
       " ('word_0', 'word_7', 'word_8'),\n",
       " ('word_0', 'word_7', 'word_9'),\n",
       " ('word_0', 'word_7', 'word_10'),\n",
       " ('word_0', 'word_7', 'word_11'),\n",
       " ('word_0', 'word_7', 'word_12'),\n",
       " ('word_0', 'word_7', 'word_13'),\n",
       " ('word_0', 'word_7', 'word_14'),\n",
       " ('word_0', 'word_7', 'word_15'),\n",
       " ('word_0', 'word_7', 'word_16'),\n",
       " ('word_0', 'word_7', 'word_17'),\n",
       " ('word_0', 'word_7', 'word_18'),\n",
       " ('word_0', 'word_7', 'word_19'),\n",
       " ('word_0', 'word_8', 'word_1'),\n",
       " ('word_0', 'word_8', 'word_2'),\n",
       " ('word_0', 'word_8', 'word_3'),\n",
       " ('word_0', 'word_8', 'word_4'),\n",
       " ('word_0', 'word_8', 'word_5'),\n",
       " ('word_0', 'word_8', 'word_6'),\n",
       " ('word_0', 'word_8', 'word_7'),\n",
       " ('word_0', 'word_8', 'word_9'),\n",
       " ('word_0', 'word_8', 'word_10'),\n",
       " ('word_0', 'word_8', 'word_11'),\n",
       " ('word_0', 'word_8', 'word_12'),\n",
       " ('word_0', 'word_8', 'word_13'),\n",
       " ('word_0', 'word_8', 'word_14'),\n",
       " ('word_0', 'word_8', 'word_15'),\n",
       " ('word_0', 'word_8', 'word_16'),\n",
       " ('word_0', 'word_8', 'word_17'),\n",
       " ('word_0', 'word_8', 'word_18'),\n",
       " ('word_0', 'word_8', 'word_19'),\n",
       " ('word_0', 'word_9', 'word_1'),\n",
       " ('word_0', 'word_9', 'word_2'),\n",
       " ('word_0', 'word_9', 'word_3'),\n",
       " ('word_0', 'word_9', 'word_4'),\n",
       " ('word_0', 'word_9', 'word_5'),\n",
       " ('word_0', 'word_9', 'word_6'),\n",
       " ('word_0', 'word_9', 'word_7'),\n",
       " ('word_0', 'word_9', 'word_8'),\n",
       " ('word_0', 'word_9', 'word_10'),\n",
       " ('word_0', 'word_9', 'word_11'),\n",
       " ('word_0', 'word_9', 'word_12'),\n",
       " ('word_0', 'word_9', 'word_13'),\n",
       " ('word_0', 'word_9', 'word_14'),\n",
       " ('word_0', 'word_9', 'word_15'),\n",
       " ('word_0', 'word_9', 'word_16'),\n",
       " ('word_0', 'word_9', 'word_17'),\n",
       " ('word_0', 'word_9', 'word_18'),\n",
       " ('word_0', 'word_9', 'word_19'),\n",
       " ('word_0', 'word_10', 'word_1'),\n",
       " ('word_0', 'word_10', 'word_2'),\n",
       " ('word_0', 'word_10', 'word_3'),\n",
       " ('word_0', 'word_10', 'word_4'),\n",
       " ('word_0', 'word_10', 'word_5'),\n",
       " ('word_0', 'word_10', 'word_6'),\n",
       " ('word_0', 'word_10', 'word_7'),\n",
       " ('word_0', 'word_10', 'word_8'),\n",
       " ('word_0', 'word_10', 'word_9'),\n",
       " ('word_0', 'word_10', 'word_11'),\n",
       " ('word_0', 'word_10', 'word_12'),\n",
       " ('word_0', 'word_10', 'word_13'),\n",
       " ('word_0', 'word_10', 'word_14'),\n",
       " ('word_0', 'word_10', 'word_15'),\n",
       " ('word_0', 'word_10', 'word_16'),\n",
       " ('word_0', 'word_10', 'word_17'),\n",
       " ('word_0', 'word_10', 'word_18'),\n",
       " ('word_0', 'word_10', 'word_19'),\n",
       " ('word_0', 'word_11', 'word_1'),\n",
       " ('word_0', 'word_11', 'word_2'),\n",
       " ('word_0', 'word_11', 'word_3'),\n",
       " ('word_0', 'word_11', 'word_4'),\n",
       " ('word_0', 'word_11', 'word_5'),\n",
       " ('word_0', 'word_11', 'word_6'),\n",
       " ('word_0', 'word_11', 'word_7'),\n",
       " ('word_0', 'word_11', 'word_8'),\n",
       " ('word_0', 'word_11', 'word_9'),\n",
       " ('word_0', 'word_11', 'word_10'),\n",
       " ('word_0', 'word_11', 'word_12'),\n",
       " ('word_0', 'word_11', 'word_13'),\n",
       " ('word_0', 'word_11', 'word_14'),\n",
       " ('word_0', 'word_11', 'word_15'),\n",
       " ('word_0', 'word_11', 'word_16'),\n",
       " ('word_0', 'word_11', 'word_17'),\n",
       " ('word_0', 'word_11', 'word_18'),\n",
       " ('word_0', 'word_11', 'word_19'),\n",
       " ('word_0', 'word_12', 'word_1'),\n",
       " ('word_0', 'word_12', 'word_2'),\n",
       " ('word_0', 'word_12', 'word_3'),\n",
       " ('word_0', 'word_12', 'word_4'),\n",
       " ('word_0', 'word_12', 'word_5'),\n",
       " ('word_0', 'word_12', 'word_6'),\n",
       " ('word_0', 'word_12', 'word_7'),\n",
       " ('word_0', 'word_12', 'word_8'),\n",
       " ('word_0', 'word_12', 'word_9'),\n",
       " ('word_0', 'word_12', 'word_10'),\n",
       " ('word_0', 'word_12', 'word_11'),\n",
       " ('word_0', 'word_12', 'word_13'),\n",
       " ('word_0', 'word_12', 'word_14'),\n",
       " ('word_0', 'word_12', 'word_15'),\n",
       " ('word_0', 'word_12', 'word_16'),\n",
       " ('word_0', 'word_12', 'word_17'),\n",
       " ('word_0', 'word_12', 'word_18'),\n",
       " ('word_0', 'word_12', 'word_19'),\n",
       " ('word_0', 'word_13', 'word_1'),\n",
       " ('word_0', 'word_13', 'word_2'),\n",
       " ('word_0', 'word_13', 'word_3'),\n",
       " ('word_0', 'word_13', 'word_4'),\n",
       " ('word_0', 'word_13', 'word_5'),\n",
       " ('word_0', 'word_13', 'word_6'),\n",
       " ('word_0', 'word_13', 'word_7'),\n",
       " ('word_0', 'word_13', 'word_8'),\n",
       " ('word_0', 'word_13', 'word_9'),\n",
       " ('word_0', 'word_13', 'word_10'),\n",
       " ('word_0', 'word_13', 'word_11'),\n",
       " ('word_0', 'word_13', 'word_12'),\n",
       " ('word_0', 'word_13', 'word_14'),\n",
       " ('word_0', 'word_13', 'word_15'),\n",
       " ('word_0', 'word_13', 'word_16'),\n",
       " ('word_0', 'word_13', 'word_17'),\n",
       " ('word_0', 'word_13', 'word_18'),\n",
       " ('word_0', 'word_13', 'word_19'),\n",
       " ('word_0', 'word_14', 'word_1'),\n",
       " ('word_0', 'word_14', 'word_2'),\n",
       " ('word_0', 'word_14', 'word_3'),\n",
       " ('word_0', 'word_14', 'word_4'),\n",
       " ('word_0', 'word_14', 'word_5'),\n",
       " ('word_0', 'word_14', 'word_6'),\n",
       " ('word_0', 'word_14', 'word_7'),\n",
       " ('word_0', 'word_14', 'word_8'),\n",
       " ('word_0', 'word_14', 'word_9'),\n",
       " ('word_0', 'word_14', 'word_10'),\n",
       " ('word_0', 'word_14', 'word_11'),\n",
       " ('word_0', 'word_14', 'word_12'),\n",
       " ('word_0', 'word_14', 'word_13'),\n",
       " ('word_0', 'word_14', 'word_15'),\n",
       " ('word_0', 'word_14', 'word_16'),\n",
       " ('word_0', 'word_14', 'word_17'),\n",
       " ('word_0', 'word_14', 'word_18'),\n",
       " ('word_0', 'word_14', 'word_19'),\n",
       " ('word_0', 'word_15', 'word_1'),\n",
       " ('word_0', 'word_15', 'word_2'),\n",
       " ('word_0', 'word_15', 'word_3'),\n",
       " ('word_0', 'word_15', 'word_4'),\n",
       " ('word_0', 'word_15', 'word_5'),\n",
       " ('word_0', 'word_15', 'word_6'),\n",
       " ('word_0', 'word_15', 'word_7'),\n",
       " ('word_0', 'word_15', 'word_8'),\n",
       " ('word_0', 'word_15', 'word_9'),\n",
       " ('word_0', 'word_15', 'word_10'),\n",
       " ('word_0', 'word_15', 'word_11'),\n",
       " ('word_0', 'word_15', 'word_12'),\n",
       " ('word_0', 'word_15', 'word_13'),\n",
       " ('word_0', 'word_15', 'word_14'),\n",
       " ('word_0', 'word_15', 'word_16'),\n",
       " ('word_0', 'word_15', 'word_17'),\n",
       " ('word_0', 'word_15', 'word_18'),\n",
       " ('word_0', 'word_15', 'word_19'),\n",
       " ('word_0', 'word_16', 'word_1'),\n",
       " ('word_0', 'word_16', 'word_2'),\n",
       " ('word_0', 'word_16', 'word_3'),\n",
       " ('word_0', 'word_16', 'word_4'),\n",
       " ('word_0', 'word_16', 'word_5'),\n",
       " ('word_0', 'word_16', 'word_6'),\n",
       " ('word_0', 'word_16', 'word_7'),\n",
       " ('word_0', 'word_16', 'word_8'),\n",
       " ('word_0', 'word_16', 'word_9'),\n",
       " ('word_0', 'word_16', 'word_10'),\n",
       " ('word_0', 'word_16', 'word_11'),\n",
       " ('word_0', 'word_16', 'word_12'),\n",
       " ('word_0', 'word_16', 'word_13'),\n",
       " ('word_0', 'word_16', 'word_14'),\n",
       " ('word_0', 'word_16', 'word_15'),\n",
       " ('word_0', 'word_16', 'word_17'),\n",
       " ('word_0', 'word_16', 'word_18'),\n",
       " ('word_0', 'word_16', 'word_19'),\n",
       " ('word_0', 'word_17', 'word_1'),\n",
       " ('word_0', 'word_17', 'word_2'),\n",
       " ('word_0', 'word_17', 'word_3'),\n",
       " ('word_0', 'word_17', 'word_4'),\n",
       " ('word_0', 'word_17', 'word_5'),\n",
       " ('word_0', 'word_17', 'word_6'),\n",
       " ('word_0', 'word_17', 'word_7'),\n",
       " ('word_0', 'word_17', 'word_8'),\n",
       " ('word_0', 'word_17', 'word_9'),\n",
       " ('word_0', 'word_17', 'word_10'),\n",
       " ('word_0', 'word_17', 'word_11'),\n",
       " ('word_0', 'word_17', 'word_12'),\n",
       " ('word_0', 'word_17', 'word_13'),\n",
       " ('word_0', 'word_17', 'word_14'),\n",
       " ('word_0', 'word_17', 'word_15'),\n",
       " ('word_0', 'word_17', 'word_16'),\n",
       " ('word_0', 'word_17', 'word_18'),\n",
       " ('word_0', 'word_17', 'word_19'),\n",
       " ('word_0', 'word_18', 'word_1'),\n",
       " ('word_0', 'word_18', 'word_2'),\n",
       " ('word_0', 'word_18', 'word_3'),\n",
       " ('word_0', 'word_18', 'word_4'),\n",
       " ('word_0', 'word_18', 'word_5'),\n",
       " ('word_0', 'word_18', 'word_6'),\n",
       " ('word_0', 'word_18', 'word_7'),\n",
       " ('word_0', 'word_18', 'word_8'),\n",
       " ('word_0', 'word_18', 'word_9'),\n",
       " ('word_0', 'word_18', 'word_10'),\n",
       " ('word_0', 'word_18', 'word_11'),\n",
       " ('word_0', 'word_18', 'word_12'),\n",
       " ('word_0', 'word_18', 'word_13'),\n",
       " ('word_0', 'word_18', 'word_14'),\n",
       " ('word_0', 'word_18', 'word_15'),\n",
       " ('word_0', 'word_18', 'word_16'),\n",
       " ('word_0', 'word_18', 'word_17'),\n",
       " ('word_0', 'word_18', 'word_19'),\n",
       " ('word_0', 'word_19', 'word_1'),\n",
       " ('word_0', 'word_19', 'word_2'),\n",
       " ('word_0', 'word_19', 'word_3'),\n",
       " ('word_0', 'word_19', 'word_4'),\n",
       " ('word_0', 'word_19', 'word_5'),\n",
       " ('word_0', 'word_19', 'word_6'),\n",
       " ('word_0', 'word_19', 'word_7'),\n",
       " ('word_0', 'word_19', 'word_8'),\n",
       " ('word_0', 'word_19', 'word_9'),\n",
       " ('word_0', 'word_19', 'word_10'),\n",
       " ('word_0', 'word_19', 'word_11'),\n",
       " ('word_0', 'word_19', 'word_12'),\n",
       " ('word_0', 'word_19', 'word_13'),\n",
       " ('word_0', 'word_19', 'word_14'),\n",
       " ('word_0', 'word_19', 'word_15'),\n",
       " ('word_0', 'word_19', 'word_16'),\n",
       " ('word_0', 'word_19', 'word_17'),\n",
       " ('word_0', 'word_19', 'word_18'),\n",
       " ('word_1', 'word_0', 'word_2'),\n",
       " ('word_1', 'word_0', 'word_3'),\n",
       " ('word_1', 'word_0', 'word_4'),\n",
       " ('word_1', 'word_0', 'word_5'),\n",
       " ('word_1', 'word_0', 'word_6'),\n",
       " ('word_1', 'word_0', 'word_7'),\n",
       " ('word_1', 'word_0', 'word_8'),\n",
       " ('word_1', 'word_0', 'word_9'),\n",
       " ('word_1', 'word_0', 'word_10'),\n",
       " ('word_1', 'word_0', 'word_11'),\n",
       " ('word_1', 'word_0', 'word_12'),\n",
       " ('word_1', 'word_0', 'word_13'),\n",
       " ('word_1', 'word_0', 'word_14'),\n",
       " ('word_1', 'word_0', 'word_15'),\n",
       " ('word_1', 'word_0', 'word_16'),\n",
       " ('word_1', 'word_0', 'word_17'),\n",
       " ('word_1', 'word_0', 'word_18'),\n",
       " ('word_1', 'word_0', 'word_19'),\n",
       " ('word_1', 'word_2', 'word_0'),\n",
       " ('word_1', 'word_2', 'word_3'),\n",
       " ('word_1', 'word_2', 'word_4'),\n",
       " ('word_1', 'word_2', 'word_5'),\n",
       " ('word_1', 'word_2', 'word_6'),\n",
       " ('word_1', 'word_2', 'word_7'),\n",
       " ('word_1', 'word_2', 'word_8'),\n",
       " ('word_1', 'word_2', 'word_9'),\n",
       " ('word_1', 'word_2', 'word_10'),\n",
       " ('word_1', 'word_2', 'word_11'),\n",
       " ('word_1', 'word_2', 'word_12'),\n",
       " ('word_1', 'word_2', 'word_13'),\n",
       " ('word_1', 'word_2', 'word_14'),\n",
       " ('word_1', 'word_2', 'word_15'),\n",
       " ('word_1', 'word_2', 'word_16'),\n",
       " ('word_1', 'word_2', 'word_17'),\n",
       " ('word_1', 'word_2', 'word_18'),\n",
       " ('word_1', 'word_2', 'word_19'),\n",
       " ('word_1', 'word_3', 'word_0'),\n",
       " ('word_1', 'word_3', 'word_2'),\n",
       " ('word_1', 'word_3', 'word_4'),\n",
       " ('word_1', 'word_3', 'word_5'),\n",
       " ('word_1', 'word_3', 'word_6'),\n",
       " ('word_1', 'word_3', 'word_7'),\n",
       " ('word_1', 'word_3', 'word_8'),\n",
       " ('word_1', 'word_3', 'word_9'),\n",
       " ('word_1', 'word_3', 'word_10'),\n",
       " ('word_1', 'word_3', 'word_11'),\n",
       " ('word_1', 'word_3', 'word_12'),\n",
       " ('word_1', 'word_3', 'word_13'),\n",
       " ('word_1', 'word_3', 'word_14'),\n",
       " ('word_1', 'word_3', 'word_15'),\n",
       " ('word_1', 'word_3', 'word_16'),\n",
       " ('word_1', 'word_3', 'word_17'),\n",
       " ('word_1', 'word_3', 'word_18'),\n",
       " ('word_1', 'word_3', 'word_19'),\n",
       " ('word_1', 'word_4', 'word_0'),\n",
       " ('word_1', 'word_4', 'word_2'),\n",
       " ('word_1', 'word_4', 'word_3'),\n",
       " ('word_1', 'word_4', 'word_5'),\n",
       " ('word_1', 'word_4', 'word_6'),\n",
       " ('word_1', 'word_4', 'word_7'),\n",
       " ('word_1', 'word_4', 'word_8'),\n",
       " ('word_1', 'word_4', 'word_9'),\n",
       " ('word_1', 'word_4', 'word_10'),\n",
       " ('word_1', 'word_4', 'word_11'),\n",
       " ('word_1', 'word_4', 'word_12'),\n",
       " ('word_1', 'word_4', 'word_13'),\n",
       " ('word_1', 'word_4', 'word_14'),\n",
       " ('word_1', 'word_4', 'word_15'),\n",
       " ('word_1', 'word_4', 'word_16'),\n",
       " ('word_1', 'word_4', 'word_17'),\n",
       " ('word_1', 'word_4', 'word_18'),\n",
       " ('word_1', 'word_4', 'word_19'),\n",
       " ('word_1', 'word_5', 'word_0'),\n",
       " ('word_1', 'word_5', 'word_2'),\n",
       " ('word_1', 'word_5', 'word_3'),\n",
       " ('word_1', 'word_5', 'word_4'),\n",
       " ('word_1', 'word_5', 'word_6'),\n",
       " ('word_1', 'word_5', 'word_7'),\n",
       " ('word_1', 'word_5', 'word_8'),\n",
       " ('word_1', 'word_5', 'word_9'),\n",
       " ('word_1', 'word_5', 'word_10'),\n",
       " ('word_1', 'word_5', 'word_11'),\n",
       " ('word_1', 'word_5', 'word_12'),\n",
       " ('word_1', 'word_5', 'word_13'),\n",
       " ('word_1', 'word_5', 'word_14'),\n",
       " ('word_1', 'word_5', 'word_15'),\n",
       " ('word_1', 'word_5', 'word_16'),\n",
       " ('word_1', 'word_5', 'word_17'),\n",
       " ('word_1', 'word_5', 'word_18'),\n",
       " ('word_1', 'word_5', 'word_19'),\n",
       " ('word_1', 'word_6', 'word_0'),\n",
       " ('word_1', 'word_6', 'word_2'),\n",
       " ('word_1', 'word_6', 'word_3'),\n",
       " ('word_1', 'word_6', 'word_4'),\n",
       " ('word_1', 'word_6', 'word_5'),\n",
       " ('word_1', 'word_6', 'word_7'),\n",
       " ('word_1', 'word_6', 'word_8'),\n",
       " ('word_1', 'word_6', 'word_9'),\n",
       " ('word_1', 'word_6', 'word_10'),\n",
       " ('word_1', 'word_6', 'word_11'),\n",
       " ('word_1', 'word_6', 'word_12'),\n",
       " ('word_1', 'word_6', 'word_13'),\n",
       " ('word_1', 'word_6', 'word_14'),\n",
       " ('word_1', 'word_6', 'word_15'),\n",
       " ('word_1', 'word_6', 'word_16'),\n",
       " ('word_1', 'word_6', 'word_17'),\n",
       " ('word_1', 'word_6', 'word_18'),\n",
       " ('word_1', 'word_6', 'word_19'),\n",
       " ('word_1', 'word_7', 'word_0'),\n",
       " ('word_1', 'word_7', 'word_2'),\n",
       " ('word_1', 'word_7', 'word_3'),\n",
       " ('word_1', 'word_7', 'word_4'),\n",
       " ('word_1', 'word_7', 'word_5'),\n",
       " ('word_1', 'word_7', 'word_6'),\n",
       " ('word_1', 'word_7', 'word_8'),\n",
       " ('word_1', 'word_7', 'word_9'),\n",
       " ('word_1', 'word_7', 'word_10'),\n",
       " ('word_1', 'word_7', 'word_11'),\n",
       " ('word_1', 'word_7', 'word_12'),\n",
       " ('word_1', 'word_7', 'word_13'),\n",
       " ('word_1', 'word_7', 'word_14'),\n",
       " ('word_1', 'word_7', 'word_15'),\n",
       " ('word_1', 'word_7', 'word_16'),\n",
       " ('word_1', 'word_7', 'word_17'),\n",
       " ('word_1', 'word_7', 'word_18'),\n",
       " ('word_1', 'word_7', 'word_19'),\n",
       " ('word_1', 'word_8', 'word_0'),\n",
       " ('word_1', 'word_8', 'word_2'),\n",
       " ('word_1', 'word_8', 'word_3'),\n",
       " ('word_1', 'word_8', 'word_4'),\n",
       " ('word_1', 'word_8', 'word_5'),\n",
       " ('word_1', 'word_8', 'word_6'),\n",
       " ('word_1', 'word_8', 'word_7'),\n",
       " ('word_1', 'word_8', 'word_9'),\n",
       " ('word_1', 'word_8', 'word_10'),\n",
       " ('word_1', 'word_8', 'word_11'),\n",
       " ('word_1', 'word_8', 'word_12'),\n",
       " ('word_1', 'word_8', 'word_13'),\n",
       " ('word_1', 'word_8', 'word_14'),\n",
       " ('word_1', 'word_8', 'word_15'),\n",
       " ('word_1', 'word_8', 'word_16'),\n",
       " ('word_1', 'word_8', 'word_17'),\n",
       " ('word_1', 'word_8', 'word_18'),\n",
       " ('word_1', 'word_8', 'word_19'),\n",
       " ('word_1', 'word_9', 'word_0'),\n",
       " ('word_1', 'word_9', 'word_2'),\n",
       " ('word_1', 'word_9', 'word_3'),\n",
       " ('word_1', 'word_9', 'word_4'),\n",
       " ('word_1', 'word_9', 'word_5'),\n",
       " ('word_1', 'word_9', 'word_6'),\n",
       " ('word_1', 'word_9', 'word_7'),\n",
       " ('word_1', 'word_9', 'word_8'),\n",
       " ('word_1', 'word_9', 'word_10'),\n",
       " ('word_1', 'word_9', 'word_11'),\n",
       " ('word_1', 'word_9', 'word_12'),\n",
       " ('word_1', 'word_9', 'word_13'),\n",
       " ('word_1', 'word_9', 'word_14'),\n",
       " ('word_1', 'word_9', 'word_15'),\n",
       " ('word_1', 'word_9', 'word_16'),\n",
       " ('word_1', 'word_9', 'word_17'),\n",
       " ('word_1', 'word_9', 'word_18'),\n",
       " ('word_1', 'word_9', 'word_19'),\n",
       " ('word_1', 'word_10', 'word_0'),\n",
       " ('word_1', 'word_10', 'word_2'),\n",
       " ('word_1', 'word_10', 'word_3'),\n",
       " ('word_1', 'word_10', 'word_4'),\n",
       " ('word_1', 'word_10', 'word_5'),\n",
       " ('word_1', 'word_10', 'word_6'),\n",
       " ('word_1', 'word_10', 'word_7'),\n",
       " ('word_1', 'word_10', 'word_8'),\n",
       " ('word_1', 'word_10', 'word_9'),\n",
       " ('word_1', 'word_10', 'word_11'),\n",
       " ('word_1', 'word_10', 'word_12'),\n",
       " ('word_1', 'word_10', 'word_13'),\n",
       " ('word_1', 'word_10', 'word_14'),\n",
       " ('word_1', 'word_10', 'word_15'),\n",
       " ('word_1', 'word_10', 'word_16'),\n",
       " ('word_1', 'word_10', 'word_17'),\n",
       " ('word_1', 'word_10', 'word_18'),\n",
       " ('word_1', 'word_10', 'word_19'),\n",
       " ('word_1', 'word_11', 'word_0'),\n",
       " ('word_1', 'word_11', 'word_2'),\n",
       " ('word_1', 'word_11', 'word_3'),\n",
       " ('word_1', 'word_11', 'word_4'),\n",
       " ('word_1', 'word_11', 'word_5'),\n",
       " ('word_1', 'word_11', 'word_6'),\n",
       " ('word_1', 'word_11', 'word_7'),\n",
       " ('word_1', 'word_11', 'word_8'),\n",
       " ('word_1', 'word_11', 'word_9'),\n",
       " ('word_1', 'word_11', 'word_10'),\n",
       " ('word_1', 'word_11', 'word_12'),\n",
       " ('word_1', 'word_11', 'word_13'),\n",
       " ('word_1', 'word_11', 'word_14'),\n",
       " ('word_1', 'word_11', 'word_15'),\n",
       " ('word_1', 'word_11', 'word_16'),\n",
       " ('word_1', 'word_11', 'word_17'),\n",
       " ('word_1', 'word_11', 'word_18'),\n",
       " ('word_1', 'word_11', 'word_19'),\n",
       " ('word_1', 'word_12', 'word_0'),\n",
       " ('word_1', 'word_12', 'word_2'),\n",
       " ('word_1', 'word_12', 'word_3'),\n",
       " ('word_1', 'word_12', 'word_4'),\n",
       " ('word_1', 'word_12', 'word_5'),\n",
       " ('word_1', 'word_12', 'word_6'),\n",
       " ('word_1', 'word_12', 'word_7'),\n",
       " ('word_1', 'word_12', 'word_8'),\n",
       " ('word_1', 'word_12', 'word_9'),\n",
       " ('word_1', 'word_12', 'word_10'),\n",
       " ('word_1', 'word_12', 'word_11'),\n",
       " ('word_1', 'word_12', 'word_13'),\n",
       " ('word_1', 'word_12', 'word_14'),\n",
       " ('word_1', 'word_12', 'word_15'),\n",
       " ('word_1', 'word_12', 'word_16'),\n",
       " ('word_1', 'word_12', 'word_17'),\n",
       " ('word_1', 'word_12', 'word_18'),\n",
       " ('word_1', 'word_12', 'word_19'),\n",
       " ('word_1', 'word_13', 'word_0'),\n",
       " ('word_1', 'word_13', 'word_2'),\n",
       " ('word_1', 'word_13', 'word_3'),\n",
       " ('word_1', 'word_13', 'word_4'),\n",
       " ('word_1', 'word_13', 'word_5'),\n",
       " ('word_1', 'word_13', 'word_6'),\n",
       " ('word_1', 'word_13', 'word_7'),\n",
       " ('word_1', 'word_13', 'word_8'),\n",
       " ('word_1', 'word_13', 'word_9'),\n",
       " ('word_1', 'word_13', 'word_10'),\n",
       " ('word_1', 'word_13', 'word_11'),\n",
       " ('word_1', 'word_13', 'word_12'),\n",
       " ('word_1', 'word_13', 'word_14'),\n",
       " ('word_1', 'word_13', 'word_15'),\n",
       " ('word_1', 'word_13', 'word_16'),\n",
       " ('word_1', 'word_13', 'word_17'),\n",
       " ('word_1', 'word_13', 'word_18'),\n",
       " ('word_1', 'word_13', 'word_19'),\n",
       " ('word_1', 'word_14', 'word_0'),\n",
       " ('word_1', 'word_14', 'word_2'),\n",
       " ('word_1', 'word_14', 'word_3'),\n",
       " ('word_1', 'word_14', 'word_4'),\n",
       " ('word_1', 'word_14', 'word_5'),\n",
       " ('word_1', 'word_14', 'word_6'),\n",
       " ('word_1', 'word_14', 'word_7'),\n",
       " ('word_1', 'word_14', 'word_8'),\n",
       " ('word_1', 'word_14', 'word_9'),\n",
       " ('word_1', 'word_14', 'word_10'),\n",
       " ('word_1', 'word_14', 'word_11'),\n",
       " ('word_1', 'word_14', 'word_12'),\n",
       " ('word_1', 'word_14', 'word_13'),\n",
       " ('word_1', 'word_14', 'word_15'),\n",
       " ('word_1', 'word_14', 'word_16'),\n",
       " ('word_1', 'word_14', 'word_17'),\n",
       " ('word_1', 'word_14', 'word_18'),\n",
       " ('word_1', 'word_14', 'word_19'),\n",
       " ('word_1', 'word_15', 'word_0'),\n",
       " ('word_1', 'word_15', 'word_2'),\n",
       " ('word_1', 'word_15', 'word_3'),\n",
       " ('word_1', 'word_15', 'word_4'),\n",
       " ('word_1', 'word_15', 'word_5'),\n",
       " ('word_1', 'word_15', 'word_6'),\n",
       " ('word_1', 'word_15', 'word_7'),\n",
       " ('word_1', 'word_15', 'word_8'),\n",
       " ('word_1', 'word_15', 'word_9'),\n",
       " ('word_1', 'word_15', 'word_10'),\n",
       " ('word_1', 'word_15', 'word_11'),\n",
       " ('word_1', 'word_15', 'word_12'),\n",
       " ('word_1', 'word_15', 'word_13'),\n",
       " ('word_1', 'word_15', 'word_14'),\n",
       " ('word_1', 'word_15', 'word_16'),\n",
       " ('word_1', 'word_15', 'word_17'),\n",
       " ('word_1', 'word_15', 'word_18'),\n",
       " ('word_1', 'word_15', 'word_19'),\n",
       " ('word_1', 'word_16', 'word_0'),\n",
       " ('word_1', 'word_16', 'word_2'),\n",
       " ('word_1', 'word_16', 'word_3'),\n",
       " ('word_1', 'word_16', 'word_4'),\n",
       " ('word_1', 'word_16', 'word_5'),\n",
       " ('word_1', 'word_16', 'word_6'),\n",
       " ('word_1', 'word_16', 'word_7'),\n",
       " ('word_1', 'word_16', 'word_8'),\n",
       " ('word_1', 'word_16', 'word_9'),\n",
       " ('word_1', 'word_16', 'word_10'),\n",
       " ('word_1', 'word_16', 'word_11'),\n",
       " ('word_1', 'word_16', 'word_12'),\n",
       " ('word_1', 'word_16', 'word_13'),\n",
       " ('word_1', 'word_16', 'word_14'),\n",
       " ('word_1', 'word_16', 'word_15'),\n",
       " ('word_1', 'word_16', 'word_17'),\n",
       " ('word_1', 'word_16', 'word_18'),\n",
       " ('word_1', 'word_16', 'word_19'),\n",
       " ('word_1', 'word_17', 'word_0'),\n",
       " ('word_1', 'word_17', 'word_2'),\n",
       " ('word_1', 'word_17', 'word_3'),\n",
       " ('word_1', 'word_17', 'word_4'),\n",
       " ('word_1', 'word_17', 'word_5'),\n",
       " ('word_1', 'word_17', 'word_6'),\n",
       " ('word_1', 'word_17', 'word_7'),\n",
       " ('word_1', 'word_17', 'word_8'),\n",
       " ('word_1', 'word_17', 'word_9'),\n",
       " ('word_1', 'word_17', 'word_10'),\n",
       " ('word_1', 'word_17', 'word_11'),\n",
       " ('word_1', 'word_17', 'word_12'),\n",
       " ('word_1', 'word_17', 'word_13'),\n",
       " ('word_1', 'word_17', 'word_14'),\n",
       " ('word_1', 'word_17', 'word_15'),\n",
       " ('word_1', 'word_17', 'word_16'),\n",
       " ('word_1', 'word_17', 'word_18'),\n",
       " ('word_1', 'word_17', 'word_19'),\n",
       " ('word_1', 'word_18', 'word_0'),\n",
       " ('word_1', 'word_18', 'word_2'),\n",
       " ('word_1', 'word_18', 'word_3'),\n",
       " ('word_1', 'word_18', 'word_4'),\n",
       " ('word_1', 'word_18', 'word_5'),\n",
       " ('word_1', 'word_18', 'word_6'),\n",
       " ('word_1', 'word_18', 'word_7'),\n",
       " ('word_1', 'word_18', 'word_8'),\n",
       " ('word_1', 'word_18', 'word_9'),\n",
       " ('word_1', 'word_18', 'word_10'),\n",
       " ('word_1', 'word_18', 'word_11'),\n",
       " ('word_1', 'word_18', 'word_12'),\n",
       " ('word_1', 'word_18', 'word_13'),\n",
       " ('word_1', 'word_18', 'word_14'),\n",
       " ('word_1', 'word_18', 'word_15'),\n",
       " ('word_1', 'word_18', 'word_16'),\n",
       " ('word_1', 'word_18', 'word_17'),\n",
       " ('word_1', 'word_18', 'word_19'),\n",
       " ('word_1', 'word_19', 'word_0'),\n",
       " ('word_1', 'word_19', 'word_2'),\n",
       " ('word_1', 'word_19', 'word_3'),\n",
       " ('word_1', 'word_19', 'word_4'),\n",
       " ('word_1', 'word_19', 'word_5'),\n",
       " ('word_1', 'word_19', 'word_6'),\n",
       " ('word_1', 'word_19', 'word_7'),\n",
       " ('word_1', 'word_19', 'word_8'),\n",
       " ('word_1', 'word_19', 'word_9'),\n",
       " ('word_1', 'word_19', 'word_10'),\n",
       " ('word_1', 'word_19', 'word_11'),\n",
       " ('word_1', 'word_19', 'word_12'),\n",
       " ('word_1', 'word_19', 'word_13'),\n",
       " ('word_1', 'word_19', 'word_14'),\n",
       " ('word_1', 'word_19', 'word_15'),\n",
       " ('word_1', 'word_19', 'word_16'),\n",
       " ('word_1', 'word_19', 'word_17'),\n",
       " ('word_1', 'word_19', 'word_18'),\n",
       " ('word_2', 'word_0', 'word_1'),\n",
       " ('word_2', 'word_0', 'word_3'),\n",
       " ('word_2', 'word_0', 'word_4'),\n",
       " ('word_2', 'word_0', 'word_5'),\n",
       " ('word_2', 'word_0', 'word_6'),\n",
       " ('word_2', 'word_0', 'word_7'),\n",
       " ('word_2', 'word_0', 'word_8'),\n",
       " ('word_2', 'word_0', 'word_9'),\n",
       " ('word_2', 'word_0', 'word_10'),\n",
       " ('word_2', 'word_0', 'word_11'),\n",
       " ('word_2', 'word_0', 'word_12'),\n",
       " ('word_2', 'word_0', 'word_13'),\n",
       " ('word_2', 'word_0', 'word_14'),\n",
       " ('word_2', 'word_0', 'word_15'),\n",
       " ('word_2', 'word_0', 'word_16'),\n",
       " ('word_2', 'word_0', 'word_17'),\n",
       " ('word_2', 'word_0', 'word_18'),\n",
       " ('word_2', 'word_0', 'word_19'),\n",
       " ('word_2', 'word_1', 'word_0'),\n",
       " ('word_2', 'word_1', 'word_3'),\n",
       " ('word_2', 'word_1', 'word_4'),\n",
       " ('word_2', 'word_1', 'word_5'),\n",
       " ('word_2', 'word_1', 'word_6'),\n",
       " ('word_2', 'word_1', 'word_7'),\n",
       " ('word_2', 'word_1', 'word_8'),\n",
       " ('word_2', 'word_1', 'word_9'),\n",
       " ('word_2', 'word_1', 'word_10'),\n",
       " ('word_2', 'word_1', 'word_11'),\n",
       " ('word_2', 'word_1', 'word_12'),\n",
       " ('word_2', 'word_1', 'word_13'),\n",
       " ('word_2', 'word_1', 'word_14'),\n",
       " ('word_2', 'word_1', 'word_15'),\n",
       " ('word_2', 'word_1', 'word_16'),\n",
       " ('word_2', 'word_1', 'word_17'),\n",
       " ('word_2', 'word_1', 'word_18'),\n",
       " ('word_2', 'word_1', 'word_19'),\n",
       " ('word_2', 'word_3', 'word_0'),\n",
       " ('word_2', 'word_3', 'word_1'),\n",
       " ('word_2', 'word_3', 'word_4'),\n",
       " ('word_2', 'word_3', 'word_5'),\n",
       " ('word_2', 'word_3', 'word_6'),\n",
       " ('word_2', 'word_3', 'word_7'),\n",
       " ('word_2', 'word_3', 'word_8'),\n",
       " ('word_2', 'word_3', 'word_9'),\n",
       " ('word_2', 'word_3', 'word_10'),\n",
       " ('word_2', 'word_3', 'word_11'),\n",
       " ('word_2', 'word_3', 'word_12'),\n",
       " ('word_2', 'word_3', 'word_13'),\n",
       " ('word_2', 'word_3', 'word_14'),\n",
       " ('word_2', 'word_3', 'word_15'),\n",
       " ('word_2', 'word_3', 'word_16'),\n",
       " ('word_2', 'word_3', 'word_17'),\n",
       " ('word_2', 'word_3', 'word_18'),\n",
       " ('word_2', 'word_3', 'word_19'),\n",
       " ('word_2', 'word_4', 'word_0'),\n",
       " ('word_2', 'word_4', 'word_1'),\n",
       " ('word_2', 'word_4', 'word_3'),\n",
       " ('word_2', 'word_4', 'word_5'),\n",
       " ('word_2', 'word_4', 'word_6'),\n",
       " ('word_2', 'word_4', 'word_7'),\n",
       " ('word_2', 'word_4', 'word_8'),\n",
       " ('word_2', 'word_4', 'word_9'),\n",
       " ('word_2', 'word_4', 'word_10'),\n",
       " ('word_2', 'word_4', 'word_11'),\n",
       " ('word_2', 'word_4', 'word_12'),\n",
       " ('word_2', 'word_4', 'word_13'),\n",
       " ('word_2', 'word_4', 'word_14'),\n",
       " ('word_2', 'word_4', 'word_15'),\n",
       " ('word_2', 'word_4', 'word_16'),\n",
       " ('word_2', 'word_4', 'word_17'),\n",
       " ('word_2', 'word_4', 'word_18'),\n",
       " ('word_2', 'word_4', 'word_19'),\n",
       " ('word_2', 'word_5', 'word_0'),\n",
       " ('word_2', 'word_5', 'word_1'),\n",
       " ('word_2', 'word_5', 'word_3'),\n",
       " ('word_2', 'word_5', 'word_4'),\n",
       " ('word_2', 'word_5', 'word_6'),\n",
       " ('word_2', 'word_5', 'word_7'),\n",
       " ('word_2', 'word_5', 'word_8'),\n",
       " ('word_2', 'word_5', 'word_9'),\n",
       " ('word_2', 'word_5', 'word_10'),\n",
       " ('word_2', 'word_5', 'word_11'),\n",
       " ('word_2', 'word_5', 'word_12'),\n",
       " ('word_2', 'word_5', 'word_13'),\n",
       " ('word_2', 'word_5', 'word_14'),\n",
       " ('word_2', 'word_5', 'word_15'),\n",
       " ('word_2', 'word_5', 'word_16'),\n",
       " ('word_2', 'word_5', 'word_17'),\n",
       " ('word_2', 'word_5', 'word_18'),\n",
       " ('word_2', 'word_5', 'word_19'),\n",
       " ('word_2', 'word_6', 'word_0'),\n",
       " ('word_2', 'word_6', 'word_1'),\n",
       " ('word_2', 'word_6', 'word_3'),\n",
       " ('word_2', 'word_6', 'word_4'),\n",
       " ('word_2', 'word_6', 'word_5'),\n",
       " ('word_2', 'word_6', 'word_7'),\n",
       " ('word_2', 'word_6', 'word_8'),\n",
       " ('word_2', 'word_6', 'word_9'),\n",
       " ('word_2', 'word_6', 'word_10'),\n",
       " ('word_2', 'word_6', 'word_11'),\n",
       " ('word_2', 'word_6', 'word_12'),\n",
       " ('word_2', 'word_6', 'word_13'),\n",
       " ('word_2', 'word_6', 'word_14'),\n",
       " ('word_2', 'word_6', 'word_15'),\n",
       " ('word_2', 'word_6', 'word_16'),\n",
       " ('word_2', 'word_6', 'word_17'),\n",
       " ('word_2', 'word_6', 'word_18'),\n",
       " ('word_2', 'word_6', 'word_19'),\n",
       " ('word_2', 'word_7', 'word_0'),\n",
       " ('word_2', 'word_7', 'word_1'),\n",
       " ('word_2', 'word_7', 'word_3'),\n",
       " ('word_2', 'word_7', 'word_4'),\n",
       " ('word_2', 'word_7', 'word_5'),\n",
       " ('word_2', 'word_7', 'word_6'),\n",
       " ('word_2', 'word_7', 'word_8'),\n",
       " ('word_2', 'word_7', 'word_9'),\n",
       " ('word_2', 'word_7', 'word_10'),\n",
       " ('word_2', 'word_7', 'word_11'),\n",
       " ('word_2', 'word_7', 'word_12'),\n",
       " ('word_2', 'word_7', 'word_13'),\n",
       " ('word_2', 'word_7', 'word_14'),\n",
       " ('word_2', 'word_7', 'word_15'),\n",
       " ('word_2', 'word_7', 'word_16'),\n",
       " ('word_2', 'word_7', 'word_17'),\n",
       " ('word_2', 'word_7', 'word_18'),\n",
       " ('word_2', 'word_7', 'word_19'),\n",
       " ('word_2', 'word_8', 'word_0'),\n",
       " ('word_2', 'word_8', 'word_1'),\n",
       " ('word_2', 'word_8', 'word_3'),\n",
       " ('word_2', 'word_8', 'word_4'),\n",
       " ('word_2', 'word_8', 'word_5'),\n",
       " ('word_2', 'word_8', 'word_6'),\n",
       " ('word_2', 'word_8', 'word_7'),\n",
       " ('word_2', 'word_8', 'word_9'),\n",
       " ('word_2', 'word_8', 'word_10'),\n",
       " ('word_2', 'word_8', 'word_11'),\n",
       " ('word_2', 'word_8', 'word_12'),\n",
       " ('word_2', 'word_8', 'word_13'),\n",
       " ('word_2', 'word_8', 'word_14'),\n",
       " ('word_2', 'word_8', 'word_15'),\n",
       " ('word_2', 'word_8', 'word_16'),\n",
       " ('word_2', 'word_8', 'word_17'),\n",
       " ('word_2', 'word_8', 'word_18'),\n",
       " ('word_2', 'word_8', 'word_19'),\n",
       " ('word_2', 'word_9', 'word_0'),\n",
       " ('word_2', 'word_9', 'word_1'),\n",
       " ('word_2', 'word_9', 'word_3'),\n",
       " ('word_2', 'word_9', 'word_4'),\n",
       " ('word_2', 'word_9', 'word_5'),\n",
       " ('word_2', 'word_9', 'word_6'),\n",
       " ('word_2', 'word_9', 'word_7'),\n",
       " ('word_2', 'word_9', 'word_8'),\n",
       " ('word_2', 'word_9', 'word_10'),\n",
       " ('word_2', 'word_9', 'word_11'),\n",
       " ('word_2', 'word_9', 'word_12'),\n",
       " ('word_2', 'word_9', 'word_13'),\n",
       " ('word_2', 'word_9', 'word_14'),\n",
       " ('word_2', 'word_9', 'word_15'),\n",
       " ('word_2', 'word_9', 'word_16'),\n",
       " ('word_2', 'word_9', 'word_17'),\n",
       " ('word_2', 'word_9', 'word_18'),\n",
       " ('word_2', 'word_9', 'word_19'),\n",
       " ('word_2', 'word_10', 'word_0'),\n",
       " ('word_2', 'word_10', 'word_1'),\n",
       " ('word_2', 'word_10', 'word_3'),\n",
       " ('word_2', 'word_10', 'word_4'),\n",
       " ('word_2', 'word_10', 'word_5'),\n",
       " ('word_2', 'word_10', 'word_6'),\n",
       " ('word_2', 'word_10', 'word_7'),\n",
       " ('word_2', 'word_10', 'word_8'),\n",
       " ('word_2', 'word_10', 'word_9'),\n",
       " ('word_2', 'word_10', 'word_11'),\n",
       " ('word_2', 'word_10', 'word_12'),\n",
       " ('word_2', 'word_10', 'word_13'),\n",
       " ('word_2', 'word_10', 'word_14'),\n",
       " ('word_2', 'word_10', 'word_15'),\n",
       " ('word_2', 'word_10', 'word_16'),\n",
       " ('word_2', 'word_10', 'word_17'),\n",
       " ('word_2', 'word_10', 'word_18'),\n",
       " ('word_2', 'word_10', 'word_19'),\n",
       " ('word_2', 'word_11', 'word_0'),\n",
       " ('word_2', 'word_11', 'word_1'),\n",
       " ('word_2', 'word_11', 'word_3'),\n",
       " ('word_2', 'word_11', 'word_4'),\n",
       " ('word_2', 'word_11', 'word_5'),\n",
       " ('word_2', 'word_11', 'word_6'),\n",
       " ('word_2', 'word_11', 'word_7'),\n",
       " ('word_2', 'word_11', 'word_8'),\n",
       " ('word_2', 'word_11', 'word_9'),\n",
       " ('word_2', 'word_11', 'word_10'),\n",
       " ('word_2', 'word_11', 'word_12'),\n",
       " ('word_2', 'word_11', 'word_13'),\n",
       " ('word_2', 'word_11', 'word_14'),\n",
       " ('word_2', 'word_11', 'word_15'),\n",
       " ('word_2', 'word_11', 'word_16'),\n",
       " ('word_2', 'word_11', 'word_17'),\n",
       " ('word_2', 'word_11', 'word_18'),\n",
       " ('word_2', 'word_11', 'word_19'),\n",
       " ('word_2', 'word_12', 'word_0'),\n",
       " ('word_2', 'word_12', 'word_1'),\n",
       " ('word_2', 'word_12', 'word_3'),\n",
       " ('word_2', 'word_12', 'word_4'),\n",
       " ('word_2', 'word_12', 'word_5'),\n",
       " ('word_2', 'word_12', 'word_6'),\n",
       " ('word_2', 'word_12', 'word_7'),\n",
       " ('word_2', 'word_12', 'word_8'),\n",
       " ('word_2', 'word_12', 'word_9'),\n",
       " ('word_2', 'word_12', 'word_10'),\n",
       " ('word_2', 'word_12', 'word_11'),\n",
       " ('word_2', 'word_12', 'word_13'),\n",
       " ('word_2', 'word_12', 'word_14'),\n",
       " ('word_2', 'word_12', 'word_15'),\n",
       " ('word_2', 'word_12', 'word_16'),\n",
       " ('word_2', 'word_12', 'word_17'),\n",
       " ('word_2', 'word_12', 'word_18'),\n",
       " ('word_2', 'word_12', 'word_19'),\n",
       " ('word_2', 'word_13', 'word_0'),\n",
       " ('word_2', 'word_13', 'word_1'),\n",
       " ('word_2', 'word_13', 'word_3'),\n",
       " ('word_2', 'word_13', 'word_4'),\n",
       " ('word_2', 'word_13', 'word_5'),\n",
       " ('word_2', 'word_13', 'word_6'),\n",
       " ('word_2', 'word_13', 'word_7'),\n",
       " ('word_2', 'word_13', 'word_8'),\n",
       " ('word_2', 'word_13', 'word_9'),\n",
       " ('word_2', 'word_13', 'word_10'),\n",
       " ('word_2', 'word_13', 'word_11'),\n",
       " ('word_2', 'word_13', 'word_12'),\n",
       " ('word_2', 'word_13', 'word_14'),\n",
       " ('word_2', 'word_13', 'word_15'),\n",
       " ('word_2', 'word_13', 'word_16'),\n",
       " ('word_2', 'word_13', 'word_17'),\n",
       " ('word_2', 'word_13', 'word_18'),\n",
       " ('word_2', 'word_13', 'word_19'),\n",
       " ('word_2', 'word_14', 'word_0'),\n",
       " ('word_2', 'word_14', 'word_1'),\n",
       " ('word_2', 'word_14', 'word_3'),\n",
       " ('word_2', 'word_14', 'word_4'),\n",
       " ('word_2', 'word_14', 'word_5'),\n",
       " ('word_2', 'word_14', 'word_6'),\n",
       " ('word_2', 'word_14', 'word_7'),\n",
       " ('word_2', 'word_14', 'word_8'),\n",
       " ('word_2', 'word_14', 'word_9'),\n",
       " ('word_2', 'word_14', 'word_10'),\n",
       " ('word_2', 'word_14', 'word_11'),\n",
       " ('word_2', 'word_14', 'word_12'),\n",
       " ('word_2', 'word_14', 'word_13'),\n",
       " ('word_2', 'word_14', 'word_15'),\n",
       " ('word_2', 'word_14', 'word_16'),\n",
       " ('word_2', 'word_14', 'word_17'),\n",
       " ('word_2', 'word_14', 'word_18'),\n",
       " ('word_2', 'word_14', 'word_19'),\n",
       " ('word_2', 'word_15', 'word_0'),\n",
       " ('word_2', 'word_15', 'word_1'),\n",
       " ('word_2', 'word_15', 'word_3'),\n",
       " ('word_2', 'word_15', 'word_4'),\n",
       " ('word_2', 'word_15', 'word_5'),\n",
       " ('word_2', 'word_15', 'word_6'),\n",
       " ('word_2', 'word_15', 'word_7'),\n",
       " ('word_2', 'word_15', 'word_8'),\n",
       " ('word_2', 'word_15', 'word_9'),\n",
       " ('word_2', 'word_15', 'word_10'),\n",
       " ('word_2', 'word_15', 'word_11'),\n",
       " ('word_2', 'word_15', 'word_12'),\n",
       " ('word_2', 'word_15', 'word_13'),\n",
       " ('word_2', 'word_15', 'word_14'),\n",
       " ('word_2', 'word_15', 'word_16'),\n",
       " ('word_2', 'word_15', 'word_17'),\n",
       " ('word_2', 'word_15', 'word_18'),\n",
       " ('word_2', 'word_15', 'word_19'),\n",
       " ('word_2', 'word_16', 'word_0'),\n",
       " ('word_2', 'word_16', 'word_1'),\n",
       " ('word_2', 'word_16', 'word_3'),\n",
       " ('word_2', 'word_16', 'word_4'),\n",
       " ('word_2', 'word_16', 'word_5'),\n",
       " ('word_2', 'word_16', 'word_6'),\n",
       " ('word_2', 'word_16', 'word_7'),\n",
       " ('word_2', 'word_16', 'word_8'),\n",
       " ('word_2', 'word_16', 'word_9'),\n",
       " ('word_2', 'word_16', 'word_10'),\n",
       " ('word_2', 'word_16', 'word_11'),\n",
       " ('word_2', 'word_16', 'word_12'),\n",
       " ('word_2', 'word_16', 'word_13'),\n",
       " ('word_2', 'word_16', 'word_14'),\n",
       " ('word_2', 'word_16', 'word_15'),\n",
       " ('word_2', 'word_16', 'word_17'),\n",
       " ('word_2', 'word_16', 'word_18'),\n",
       " ('word_2', 'word_16', 'word_19'),\n",
       " ('word_2', 'word_17', 'word_0'),\n",
       " ('word_2', 'word_17', 'word_1'),\n",
       " ('word_2', 'word_17', 'word_3'),\n",
       " ('word_2', 'word_17', 'word_4'),\n",
       " ('word_2', 'word_17', 'word_5'),\n",
       " ('word_2', 'word_17', 'word_6'),\n",
       " ('word_2', 'word_17', 'word_7'),\n",
       " ('word_2', 'word_17', 'word_8'),\n",
       " ('word_2', 'word_17', 'word_9'),\n",
       " ('word_2', 'word_17', 'word_10'),\n",
       " ('word_2', 'word_17', 'word_11'),\n",
       " ('word_2', 'word_17', 'word_12'),\n",
       " ('word_2', 'word_17', 'word_13'),\n",
       " ('word_2', 'word_17', 'word_14'),\n",
       " ('word_2', 'word_17', 'word_15'),\n",
       " ('word_2', 'word_17', 'word_16'),\n",
       " ('word_2', 'word_17', 'word_18'),\n",
       " ('word_2', 'word_17', 'word_19'),\n",
       " ('word_2', 'word_18', 'word_0'),\n",
       " ('word_2', 'word_18', 'word_1'),\n",
       " ('word_2', 'word_18', 'word_3'),\n",
       " ('word_2', 'word_18', 'word_4'),\n",
       " ('word_2', 'word_18', 'word_5'),\n",
       " ('word_2', 'word_18', 'word_6'),\n",
       " ('word_2', 'word_18', 'word_7'),\n",
       " ('word_2', 'word_18', 'word_8'),\n",
       " ('word_2', 'word_18', 'word_9'),\n",
       " ('word_2', 'word_18', 'word_10'),\n",
       " ...]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pairs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3e5bfd69",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 20 # vocab_size\n",
    "\n",
    "vocabs = ['word_' + str(i) for i in range(N)]\n",
    "\n",
    "vocab_map = {}\n",
    "for i in range(len(vocabs)):\n",
    "    vocab_map[vocabs[i]] = i\n",
    "    \n",
    "pairs = []\n",
    "\n",
    "for i in vocabs:\n",
    "    for j in vocabs:\n",
    "        for k in vocabs:\n",
    "            if i != j and i != k and j != k:\n",
    "                pairs.append((i,j,k))\n",
    "            \n",
    "#indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])\n",
    "\n",
    "# pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]\n",
    "# pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]\n",
    "\n",
    "pairs_train = [x for x in pairs if int(x[0].split('_')[-1]) <= 9]\n",
    "pairs_test = [x for x in pairs if int(x[0].split('_')[-1]) >= 10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d43093fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences_train = []\n",
    "sentences_number_train = []\n",
    "sentences_test = []\n",
    "sentences_number_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    sentences_train.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    sentences_test.append([pair[0], pair[1], pair[2], pair[0]])\n",
    "    sentences_number_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = []\n",
    "y_masked_labels_train = []\n",
    "x_masked_test = []\n",
    "y_masked_labels_test = []\n",
    "\n",
    "for pair in pairs_train:\n",
    "    x_masked_train.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_train.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "for pair in pairs_test:\n",
    "    x_masked_test.append([vocab_map[pair[0]], vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    y_masked_labels_test.append([vocab_map[pair[1]], vocab_map[pair[2]], vocab_map[pair[0]]])\n",
    "    \n",
    "x_masked_train = np.array(x_masked_train)\n",
    "y_masked_labels_train = np.array(y_masked_labels_train)\n",
    "x_masked_test = np.array(x_masked_test)\n",
    "y_masked_labels_test = np.array(y_masked_labels_test)\n",
    "\n",
    "perm = np.random.permutation(len(x_masked_train))\n",
    "x_masked_train = x_masked_train[perm]\n",
    "y_masked_labels_train = y_masked_labels_train[perm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "13b40f89",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 1s - loss: 3.3022 - val_loss: 3.2990 - 1s/epoch - 1s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.2886 - val_loss: 3.2856 - 41ms/epoch - 41ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.2754 - val_loss: 3.2725 - 41ms/epoch - 41ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 3.2625 - val_loss: 3.2598 - 43ms/epoch - 43ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 3.2500 - val_loss: 3.2474 - 41ms/epoch - 41ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 3.2378 - val_loss: 3.2355 - 41ms/epoch - 41ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 3.2260 - val_loss: 3.2239 - 42ms/epoch - 42ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 3.2146 - val_loss: 3.2126 - 42ms/epoch - 42ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 3.2035 - val_loss: 3.2015 - 42ms/epoch - 42ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 3.1926 - val_loss: 3.1907 - 41ms/epoch - 41ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 3.1820 - val_loss: 3.1802 - 41ms/epoch - 41ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 3.1717 - val_loss: 3.1701 - 43ms/epoch - 43ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 3.1617 - val_loss: 3.1602 - 41ms/epoch - 41ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 3.1520 - val_loss: 3.1506 - 41ms/epoch - 41ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 3.1426 - val_loss: 3.1413 - 41ms/epoch - 41ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 3.1335 - val_loss: 3.1324 - 43ms/epoch - 43ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 3.1246 - val_loss: 3.1236 - 41ms/epoch - 41ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 3.1160 - val_loss: 3.1152 - 41ms/epoch - 41ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 3.1076 - val_loss: 3.1070 - 42ms/epoch - 42ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 3.0995 - val_loss: 3.0990 - 43ms/epoch - 43ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 3.0917 - val_loss: 3.0914 - 41ms/epoch - 41ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 3.0841 - val_loss: 3.0841 - 40ms/epoch - 40ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 3.0768 - val_loss: 3.0770 - 45ms/epoch - 45ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 3.0698 - val_loss: 3.0702 - 43ms/epoch - 43ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 3.0630 - val_loss: 3.0636 - 41ms/epoch - 41ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 3.0564 - val_loss: 3.0573 - 40ms/epoch - 40ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 3.0502 - val_loss: 3.0513 - 41ms/epoch - 41ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 3.0441 - val_loss: 3.0455 - 43ms/epoch - 43ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 3.0383 - val_loss: 3.0400 - 41ms/epoch - 41ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 3.0327 - val_loss: 3.0346 - 40ms/epoch - 40ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 3.0273 - val_loss: 3.0295 - 41ms/epoch - 41ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 3.0221 - val_loss: 3.0245 - 43ms/epoch - 43ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 3.0171 - val_loss: 3.0198 - 42ms/epoch - 42ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 3.0123 - val_loss: 3.0152 - 40ms/epoch - 40ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 3.0077 - val_loss: 3.0108 - 41ms/epoch - 41ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 3.0032 - val_loss: 3.0066 - 43ms/epoch - 43ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.9989 - val_loss: 3.0025 - 42ms/epoch - 42ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.9947 - val_loss: 2.9986 - 40ms/epoch - 40ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.9907 - val_loss: 2.9948 - 41ms/epoch - 41ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.9869 - val_loss: 2.9912 - 42ms/epoch - 42ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.9832 - val_loss: 2.9877 - 43ms/epoch - 43ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.9796 - val_loss: 2.9843 - 41ms/epoch - 41ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 2.9762 - val_loss: 2.9811 - 40ms/epoch - 40ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 2.9729 - val_loss: 2.9780 - 42ms/epoch - 42ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 2.9697 - val_loss: 2.9751 - 43ms/epoch - 43ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 2.9667 - val_loss: 2.9722 - 41ms/epoch - 41ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 2.9638 - val_loss: 2.9695 - 40ms/epoch - 40ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 2.9610 - val_loss: 2.9669 - 42ms/epoch - 42ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 2.9583 - val_loss: 2.9644 - 43ms/epoch - 43ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 2.9557 - val_loss: 2.9620 - 44ms/epoch - 44ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 2.9532 - val_loss: 2.9597 - 40ms/epoch - 40ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 2.9508 - val_loss: 2.9575 - 41ms/epoch - 41ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 2.9485 - val_loss: 2.9554 - 43ms/epoch - 43ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 2.9463 - val_loss: 2.9534 - 42ms/epoch - 42ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 2.9441 - val_loss: 2.9515 - 39ms/epoch - 39ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 2.9421 - val_loss: 2.9496 - 41ms/epoch - 41ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 2.9401 - val_loss: 2.9478 - 43ms/epoch - 43ms/step\n",
      "Epoch 58/2000\n",
      "1/1 - 0s - loss: 2.9381 - val_loss: 2.9461 - 43ms/epoch - 43ms/step\n",
      "Epoch 59/2000\n",
      "1/1 - 0s - loss: 2.9363 - val_loss: 2.9444 - 40ms/epoch - 40ms/step\n",
      "Epoch 60/2000\n",
      "1/1 - 0s - loss: 2.9344 - val_loss: 2.9428 - 41ms/epoch - 41ms/step\n",
      "Epoch 61/2000\n",
      "1/1 - 0s - loss: 2.9327 - val_loss: 2.9412 - 43ms/epoch - 43ms/step\n",
      "Epoch 62/2000\n",
      "1/1 - 0s - loss: 2.9310 - val_loss: 2.9397 - 43ms/epoch - 43ms/step\n",
      "Epoch 63/2000\n",
      "1/1 - 0s - loss: 2.9293 - val_loss: 2.9382 - 40ms/epoch - 40ms/step\n",
      "Epoch 64/2000\n",
      "1/1 - 0s - loss: 2.9277 - val_loss: 2.9368 - 40ms/epoch - 40ms/step\n",
      "Epoch 65/2000\n",
      "1/1 - 0s - loss: 2.9261 - val_loss: 2.9354 - 43ms/epoch - 43ms/step\n",
      "Epoch 66/2000\n",
      "1/1 - 0s - loss: 2.9245 - val_loss: 2.9340 - 43ms/epoch - 43ms/step\n",
      "Epoch 67/2000\n",
      "1/1 - 0s - loss: 2.9229 - val_loss: 2.9327 - 41ms/epoch - 41ms/step\n",
      "Epoch 68/2000\n",
      "1/1 - 0s - loss: 2.9214 - val_loss: 2.9313 - 41ms/epoch - 41ms/step\n",
      "Epoch 69/2000\n",
      "1/1 - 0s - loss: 2.9199 - val_loss: 2.9300 - 43ms/epoch - 43ms/step\n",
      "Epoch 70/2000\n",
      "1/1 - 0s - loss: 2.9185 - val_loss: 2.9287 - 43ms/epoch - 43ms/step\n",
      "Epoch 71/2000\n",
      "1/1 - 0s - loss: 2.9170 - val_loss: 2.9275 - 41ms/epoch - 41ms/step\n",
      "Epoch 72/2000\n",
      "1/1 - 0s - loss: 2.9155 - val_loss: 2.9262 - 40ms/epoch - 40ms/step\n",
      "Epoch 73/2000\n",
      "1/1 - 0s - loss: 2.9141 - val_loss: 2.9249 - 42ms/epoch - 42ms/step\n",
      "Epoch 74/2000\n",
      "1/1 - 0s - loss: 2.9126 - val_loss: 2.9237 - 44ms/epoch - 44ms/step\n",
      "Epoch 75/2000\n",
      "1/1 - 0s - loss: 2.9112 - val_loss: 2.9225 - 42ms/epoch - 42ms/step\n",
      "Epoch 76/2000\n",
      "1/1 - 0s - loss: 2.9098 - val_loss: 2.9212 - 40ms/epoch - 40ms/step\n",
      "Epoch 77/2000\n",
      "1/1 - 0s - loss: 2.9084 - val_loss: 2.9200 - 41ms/epoch - 41ms/step\n",
      "Epoch 78/2000\n",
      "1/1 - 0s - loss: 2.9069 - val_loss: 2.9187 - 44ms/epoch - 44ms/step\n",
      "Epoch 79/2000\n",
      "1/1 - 0s - loss: 2.9055 - val_loss: 2.9174 - 43ms/epoch - 43ms/step\n",
      "Epoch 80/2000\n",
      "1/1 - 0s - loss: 2.9040 - val_loss: 2.9162 - 40ms/epoch - 40ms/step\n",
      "Epoch 81/2000\n",
      "1/1 - 0s - loss: 2.9026 - val_loss: 2.9149 - 41ms/epoch - 41ms/step\n",
      "Epoch 82/2000\n",
      "1/1 - 0s - loss: 2.9011 - val_loss: 2.9136 - 43ms/epoch - 43ms/step\n",
      "Epoch 83/2000\n",
      "1/1 - 0s - loss: 2.8996 - val_loss: 2.9123 - 43ms/epoch - 43ms/step\n",
      "Epoch 84/2000\n",
      "1/1 - 0s - loss: 2.8981 - val_loss: 2.9109 - 40ms/epoch - 40ms/step\n",
      "Epoch 85/2000\n",
      "1/1 - 0s - loss: 2.8966 - val_loss: 2.9096 - 40ms/epoch - 40ms/step\n",
      "Epoch 86/2000\n",
      "1/1 - 0s - loss: 2.8951 - val_loss: 2.9082 - 42ms/epoch - 42ms/step\n",
      "Epoch 87/2000\n",
      "1/1 - 0s - loss: 2.8936 - val_loss: 2.9068 - 43ms/epoch - 43ms/step\n",
      "Epoch 88/2000\n",
      "1/1 - 0s - loss: 2.8920 - val_loss: 2.9054 - 41ms/epoch - 41ms/step\n",
      "Epoch 89/2000\n",
      "1/1 - 0s - loss: 2.8904 - val_loss: 2.9040 - 40ms/epoch - 40ms/step\n",
      "Epoch 90/2000\n",
      "1/1 - 0s - loss: 2.8888 - val_loss: 2.9025 - 43ms/epoch - 43ms/step\n",
      "Epoch 91/2000\n",
      "1/1 - 0s - loss: 2.8871 - val_loss: 2.9010 - 43ms/epoch - 43ms/step\n",
      "Epoch 92/2000\n",
      "1/1 - 0s - loss: 2.8855 - val_loss: 2.8995 - 42ms/epoch - 42ms/step\n",
      "Epoch 93/2000\n",
      "1/1 - 0s - loss: 2.8838 - val_loss: 2.8979 - 39ms/epoch - 39ms/step\n",
      "Epoch 94/2000\n",
      "1/1 - 0s - loss: 2.8820 - val_loss: 2.8963 - 42ms/epoch - 42ms/step\n",
      "Epoch 95/2000\n",
      "1/1 - 0s - loss: 2.8802 - val_loss: 2.8946 - 42ms/epoch - 42ms/step\n",
      "Epoch 96/2000\n",
      "1/1 - 0s - loss: 2.8784 - val_loss: 2.8929 - 43ms/epoch - 43ms/step\n",
      "Epoch 97/2000\n",
      "1/1 - 0s - loss: 2.8765 - val_loss: 2.8912 - 39ms/epoch - 39ms/step\n",
      "Epoch 98/2000\n",
      "1/1 - 0s - loss: 2.8746 - val_loss: 2.8894 - 40ms/epoch - 40ms/step\n",
      "Epoch 99/2000\n",
      "1/1 - 0s - loss: 2.8726 - val_loss: 2.8875 - 43ms/epoch - 43ms/step\n",
      "Epoch 100/2000\n",
      "1/1 - 0s - loss: 2.8706 - val_loss: 2.8856 - 43ms/epoch - 43ms/step\n",
      "Epoch 101/2000\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 - 0s - loss: 2.8685 - val_loss: 2.8836 - 40ms/epoch - 40ms/step\n",
      "Epoch 102/2000\n",
      "1/1 - 0s - loss: 2.8664 - val_loss: 2.8816 - 40ms/epoch - 40ms/step\n",
      "Epoch 103/2000\n",
      "1/1 - 0s - loss: 2.8642 - val_loss: 2.8795 - 43ms/epoch - 43ms/step\n",
      "Epoch 104/2000\n",
      "1/1 - 0s - loss: 2.8619 - val_loss: 2.8772 - 43ms/epoch - 43ms/step\n",
      "Epoch 105/2000\n",
      "1/1 - 0s - loss: 2.8595 - val_loss: 2.8750 - 40ms/epoch - 40ms/step\n",
      "Epoch 106/2000\n",
      "1/1 - 0s - loss: 2.8571 - val_loss: 2.8726 - 39ms/epoch - 39ms/step\n",
      "Epoch 107/2000\n",
      "1/1 - 0s - loss: 2.8545 - val_loss: 2.8701 - 43ms/epoch - 43ms/step\n",
      "Epoch 108/2000\n",
      "1/1 - 0s - loss: 2.8519 - val_loss: 2.8675 - 43ms/epoch - 43ms/step\n",
      "Epoch 109/2000\n",
      "1/1 - 0s - loss: 2.8491 - val_loss: 2.8648 - 42ms/epoch - 42ms/step\n",
      "Epoch 110/2000\n",
      "1/1 - 0s - loss: 2.8462 - val_loss: 2.8620 - 40ms/epoch - 40ms/step\n",
      "Epoch 111/2000\n",
      "1/1 - 0s - loss: 2.8432 - val_loss: 2.8590 - 42ms/epoch - 42ms/step\n",
      "Epoch 112/2000\n",
      "1/1 - 0s - loss: 2.8401 - val_loss: 2.8559 - 43ms/epoch - 43ms/step\n",
      "Epoch 113/2000\n",
      "1/1 - 0s - loss: 2.8368 - val_loss: 2.8526 - 42ms/epoch - 42ms/step\n",
      "Epoch 114/2000\n",
      "1/1 - 0s - loss: 2.8333 - val_loss: 2.8492 - 40ms/epoch - 40ms/step\n",
      "Epoch 115/2000\n",
      "1/1 - 0s - loss: 2.8298 - val_loss: 2.8457 - 41ms/epoch - 41ms/step\n",
      "Epoch 116/2000\n",
      "1/1 - 0s - loss: 2.8260 - val_loss: 2.8419 - 42ms/epoch - 42ms/step\n",
      "Epoch 117/2000\n",
      "1/1 - 0s - loss: 2.8222 - val_loss: 2.8380 - 42ms/epoch - 42ms/step\n",
      "Epoch 118/2000\n",
      "1/1 - 0s - loss: 2.8181 - val_loss: 2.8339 - 40ms/epoch - 40ms/step\n",
      "Epoch 119/2000\n",
      "1/1 - 0s - loss: 2.8139 - val_loss: 2.8296 - 40ms/epoch - 40ms/step\n",
      "Epoch 120/2000\n",
      "1/1 - 0s - loss: 2.8094 - val_loss: 2.8251 - 43ms/epoch - 43ms/step\n",
      "Epoch 121/2000\n",
      "1/1 - 0s - loss: 2.8048 - val_loss: 2.8205 - 43ms/epoch - 43ms/step\n",
      "Epoch 122/2000\n",
      "1/1 - 0s - loss: 2.8000 - val_loss: 2.8156 - 40ms/epoch - 40ms/step\n",
      "Epoch 123/2000\n",
      "1/1 - 0s - loss: 2.7951 - val_loss: 2.8106 - 40ms/epoch - 40ms/step\n",
      "Epoch 124/2000\n",
      "1/1 - 0s - loss: 2.7899 - val_loss: 2.8053 - 42ms/epoch - 42ms/step\n",
      "Epoch 125/2000\n",
      "1/1 - 0s - loss: 2.7846 - val_loss: 2.8000 - 42ms/epoch - 42ms/step\n",
      "Epoch 126/2000\n",
      "1/1 - 0s - loss: 2.7791 - val_loss: 2.7945 - 43ms/epoch - 43ms/step\n",
      "Epoch 127/2000\n",
      "1/1 - 0s - loss: 2.7735 - val_loss: 2.7888 - 39ms/epoch - 39ms/step\n",
      "Epoch 128/2000\n",
      "1/1 - 0s - loss: 2.7677 - val_loss: 2.7832 - 42ms/epoch - 42ms/step\n",
      "Epoch 129/2000\n",
      "1/1 - 0s - loss: 2.7619 - val_loss: 2.7775 - 46ms/epoch - 46ms/step\n",
      "Epoch 130/2000\n",
      "1/1 - 0s - loss: 2.7561 - val_loss: 2.7718 - 43ms/epoch - 43ms/step\n",
      "Epoch 131/2000\n",
      "1/1 - 0s - loss: 2.7502 - val_loss: 2.7661 - 39ms/epoch - 39ms/step\n",
      "Epoch 132/2000\n",
      "1/1 - 0s - loss: 2.7443 - val_loss: 2.7604 - 42ms/epoch - 42ms/step\n",
      "Epoch 133/2000\n",
      "1/1 - 0s - loss: 2.7385 - val_loss: 2.7547 - 43ms/epoch - 43ms/step\n",
      "Epoch 134/2000\n",
      "1/1 - 0s - loss: 2.7326 - val_loss: 2.7489 - 42ms/epoch - 42ms/step\n",
      "Epoch 135/2000\n",
      "1/1 - 0s - loss: 2.7267 - val_loss: 2.7431 - 39ms/epoch - 39ms/step\n",
      "Epoch 136/2000\n",
      "1/1 - 0s - loss: 2.7209 - val_loss: 2.7371 - 42ms/epoch - 42ms/step\n",
      "Epoch 137/2000\n",
      "1/1 - 0s - loss: 2.7149 - val_loss: 2.7310 - 43ms/epoch - 43ms/step\n",
      "Epoch 138/2000\n",
      "1/1 - 0s - loss: 2.7088 - val_loss: 2.7248 - 42ms/epoch - 42ms/step\n",
      "Epoch 139/2000\n",
      "1/1 - 0s - loss: 2.7027 - val_loss: 2.7184 - 40ms/epoch - 40ms/step\n",
      "Epoch 140/2000\n",
      "1/1 - 0s - loss: 2.6964 - val_loss: 2.7118 - 41ms/epoch - 41ms/step\n",
      "Epoch 141/2000\n",
      "1/1 - 0s - loss: 2.6900 - val_loss: 2.7051 - 43ms/epoch - 43ms/step\n",
      "Epoch 142/2000\n",
      "1/1 - 0s - loss: 2.6835 - val_loss: 2.6981 - 43ms/epoch - 43ms/step\n",
      "Epoch 143/2000\n",
      "1/1 - 0s - loss: 2.6767 - val_loss: 2.6910 - 40ms/epoch - 40ms/step\n",
      "Epoch 144/2000\n",
      "1/1 - 0s - loss: 2.6698 - val_loss: 2.6837 - 41ms/epoch - 41ms/step\n",
      "Epoch 145/2000\n",
      "1/1 - 0s - loss: 2.6627 - val_loss: 2.6763 - 42ms/epoch - 42ms/step\n",
      "Epoch 146/2000\n",
      "1/1 - 0s - loss: 2.6555 - val_loss: 2.6687 - 43ms/epoch - 43ms/step\n",
      "Epoch 147/2000\n",
      "1/1 - 0s - loss: 2.6482 - val_loss: 2.6610 - 40ms/epoch - 40ms/step\n",
      "Epoch 148/2000\n",
      "1/1 - 0s - loss: 2.6407 - val_loss: 2.6531 - 40ms/epoch - 40ms/step\n",
      "Epoch 149/2000\n",
      "1/1 - 0s - loss: 2.6332 - val_loss: 2.6451 - 42ms/epoch - 42ms/step\n",
      "Epoch 150/2000\n",
      "1/1 - 0s - loss: 2.6256 - val_loss: 2.6371 - 42ms/epoch - 42ms/step\n",
      "Epoch 151/2000\n",
      "1/1 - 0s - loss: 2.6180 - val_loss: 2.6291 - 41ms/epoch - 41ms/step\n",
      "Epoch 152/2000\n",
      "1/1 - 0s - loss: 2.6104 - val_loss: 2.6213 - 40ms/epoch - 40ms/step\n",
      "Epoch 153/2000\n",
      "1/1 - 0s - loss: 2.6030 - val_loss: 2.6137 - 42ms/epoch - 42ms/step\n",
      "Epoch 154/2000\n",
      "1/1 - 0s - loss: 2.5958 - val_loss: 2.6063 - 43ms/epoch - 43ms/step\n",
      "Epoch 155/2000\n",
      "1/1 - 0s - loss: 2.5888 - val_loss: 2.5992 - 41ms/epoch - 41ms/step\n",
      "Epoch 156/2000\n",
      "1/1 - 0s - loss: 2.5820 - val_loss: 2.5924 - 40ms/epoch - 40ms/step\n",
      "Epoch 157/2000\n",
      "1/1 - 0s - loss: 2.5755 - val_loss: 2.5858 - 41ms/epoch - 41ms/step\n",
      "Epoch 158/2000\n",
      "1/1 - 0s - loss: 2.5692 - val_loss: 2.5795 - 43ms/epoch - 43ms/step\n",
      "Epoch 159/2000\n",
      "1/1 - 0s - loss: 2.5632 - val_loss: 2.5734 - 42ms/epoch - 42ms/step\n",
      "Epoch 160/2000\n",
      "1/1 - 0s - loss: 2.5574 - val_loss: 2.5675 - 40ms/epoch - 40ms/step\n",
      "Epoch 161/2000\n",
      "1/1 - 0s - loss: 2.5519 - val_loss: 2.5618 - 41ms/epoch - 41ms/step\n",
      "Epoch 162/2000\n",
      "1/1 - 0s - loss: 2.5465 - val_loss: 2.5564 - 43ms/epoch - 43ms/step\n",
      "Epoch 163/2000\n",
      "1/1 - 0s - loss: 2.5412 - val_loss: 2.5510 - 41ms/epoch - 41ms/step\n",
      "Epoch 164/2000\n",
      "1/1 - 0s - loss: 2.5362 - val_loss: 2.5459 - 40ms/epoch - 40ms/step\n",
      "Epoch 165/2000\n",
      "1/1 - 0s - loss: 2.5313 - val_loss: 2.5409 - 41ms/epoch - 41ms/step\n",
      "Epoch 166/2000\n",
      "1/1 - 0s - loss: 2.5264 - val_loss: 2.5360 - 42ms/epoch - 42ms/step\n",
      "Epoch 167/2000\n",
      "1/1 - 0s - loss: 2.5217 - val_loss: 2.5311 - 42ms/epoch - 42ms/step\n",
      "Epoch 168/2000\n",
      "1/1 - 0s - loss: 2.5171 - val_loss: 2.5264 - 40ms/epoch - 40ms/step\n",
      "Epoch 169/2000\n",
      "1/1 - 0s - loss: 2.5125 - val_loss: 2.5217 - 41ms/epoch - 41ms/step\n",
      "Epoch 170/2000\n",
      "1/1 - 0s - loss: 2.5080 - val_loss: 2.5171 - 42ms/epoch - 42ms/step\n",
      "Epoch 171/2000\n",
      "1/1 - 0s - loss: 2.5036 - val_loss: 2.5125 - 42ms/epoch - 42ms/step\n",
      "Epoch 172/2000\n",
      "1/1 - 0s - loss: 2.4993 - val_loss: 2.5079 - 40ms/epoch - 40ms/step\n",
      "Epoch 173/2000\n",
      "1/1 - 0s - loss: 2.4949 - val_loss: 2.5034 - 41ms/epoch - 41ms/step\n",
      "Epoch 174/2000\n",
      "1/1 - 0s - loss: 2.4907 - val_loss: 2.4990 - 41ms/epoch - 41ms/step\n",
      "Epoch 175/2000\n",
      "1/1 - 0s - loss: 2.4865 - val_loss: 2.4946 - 43ms/epoch - 43ms/step\n",
      "Epoch 176/2000\n",
      "1/1 - 0s - loss: 2.4823 - val_loss: 2.4902 - 41ms/epoch - 41ms/step\n",
      "Epoch 177/2000\n",
      "1/1 - 0s - loss: 2.4781 - val_loss: 2.4859 - 40ms/epoch - 40ms/step\n",
      "Epoch 178/2000\n",
      "1/1 - 0s - loss: 2.4740 - val_loss: 2.4816 - 41ms/epoch - 41ms/step\n",
      "Epoch 179/2000\n",
      "1/1 - 0s - loss: 2.4698 - val_loss: 2.4773 - 42ms/epoch - 42ms/step\n",
      "Epoch 180/2000\n",
      "1/1 - 0s - loss: 2.4657 - val_loss: 2.4730 - 41ms/epoch - 41ms/step\n",
      "Epoch 181/2000\n",
      "1/1 - 0s - loss: 2.4616 - val_loss: 2.4686 - 41ms/epoch - 41ms/step\n",
      "Epoch 182/2000\n",
      "1/1 - 0s - loss: 2.4574 - val_loss: 2.4643 - 41ms/epoch - 41ms/step\n",
      "Epoch 183/2000\n",
      "1/1 - 0s - loss: 2.4533 - val_loss: 2.4600 - 43ms/epoch - 43ms/step\n",
      "Epoch 184/2000\n",
      "1/1 - 0s - loss: 2.4491 - val_loss: 2.4557 - 41ms/epoch - 41ms/step\n",
      "Epoch 185/2000\n",
      "1/1 - 0s - loss: 2.4449 - val_loss: 2.4513 - 41ms/epoch - 41ms/step\n",
      "Epoch 186/2000\n",
      "1/1 - 0s - loss: 2.4407 - val_loss: 2.4470 - 41ms/epoch - 41ms/step\n",
      "Epoch 187/2000\n",
      "1/1 - 0s - loss: 2.4365 - val_loss: 2.4426 - 43ms/epoch - 43ms/step\n",
      "Epoch 188/2000\n",
      "1/1 - 0s - loss: 2.4323 - val_loss: 2.4382 - 41ms/epoch - 41ms/step\n",
      "Epoch 189/2000\n",
      "1/1 - 0s - loss: 2.4281 - val_loss: 2.4338 - 40ms/epoch - 40ms/step\n",
      "Epoch 190/2000\n",
      "1/1 - 0s - loss: 2.4238 - val_loss: 2.4294 - 41ms/epoch - 41ms/step\n",
      "Epoch 191/2000\n",
      "1/1 - 0s - loss: 2.4195 - val_loss: 2.4250 - 42ms/epoch - 42ms/step\n",
      "Epoch 192/2000\n",
      "1/1 - 0s - loss: 2.4152 - val_loss: 2.4206 - 41ms/epoch - 41ms/step\n",
      "Epoch 193/2000\n",
      "1/1 - 0s - loss: 2.4109 - val_loss: 2.4162 - 41ms/epoch - 41ms/step\n",
      "Epoch 194/2000\n",
      "1/1 - 0s - loss: 2.4066 - val_loss: 2.4118 - 41ms/epoch - 41ms/step\n",
      "Epoch 195/2000\n",
      "1/1 - 0s - loss: 2.4023 - val_loss: 2.4074 - 42ms/epoch - 42ms/step\n",
      "Epoch 196/2000\n",
      "1/1 - 0s - loss: 2.3980 - val_loss: 2.4030 - 41ms/epoch - 41ms/step\n",
      "Epoch 197/2000\n",
      "1/1 - 0s - loss: 2.3937 - val_loss: 2.3985 - 41ms/epoch - 41ms/step\n",
      "Epoch 198/2000\n",
      "1/1 - 0s - loss: 2.3893 - val_loss: 2.3941 - 41ms/epoch - 41ms/step\n",
      "Epoch 199/2000\n",
      "1/1 - 0s - loss: 2.3850 - val_loss: 2.3897 - 42ms/epoch - 42ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200/2000\n",
      "1/1 - 0s - loss: 2.3807 - val_loss: 2.3852 - 42ms/epoch - 42ms/step\n",
      "Epoch 201/2000\n",
      "1/1 - 0s - loss: 2.3764 - val_loss: 2.3808 - 41ms/epoch - 41ms/step\n",
      "Epoch 202/2000\n",
      "1/1 - 0s - loss: 2.3720 - val_loss: 2.3763 - 41ms/epoch - 41ms/step\n",
      "Epoch 203/2000\n",
      "1/1 - 0s - loss: 2.3677 - val_loss: 2.3719 - 42ms/epoch - 42ms/step\n",
      "Epoch 204/2000\n",
      "1/1 - 0s - loss: 2.3634 - val_loss: 2.3675 - 42ms/epoch - 42ms/step\n",
      "Epoch 205/2000\n",
      "1/1 - 0s - loss: 2.3590 - val_loss: 2.3631 - 41ms/epoch - 41ms/step\n",
      "Epoch 206/2000\n",
      "1/1 - 0s - loss: 2.3547 - val_loss: 2.3587 - 42ms/epoch - 42ms/step\n",
      "Epoch 207/2000\n",
      "1/1 - 0s - loss: 2.3504 - val_loss: 2.3543 - 41ms/epoch - 41ms/step\n",
      "Epoch 208/2000\n",
      "1/1 - 0s - loss: 2.3462 - val_loss: 2.3499 - 42ms/epoch - 42ms/step\n",
      "Epoch 209/2000\n",
      "1/1 - 0s - loss: 2.3419 - val_loss: 2.3456 - 41ms/epoch - 41ms/step\n",
      "Epoch 210/2000\n",
      "1/1 - 0s - loss: 2.3377 - val_loss: 2.3413 - 42ms/epoch - 42ms/step\n",
      "Epoch 211/2000\n",
      "1/1 - 0s - loss: 2.3335 - val_loss: 2.3370 - 41ms/epoch - 41ms/step\n",
      "Epoch 212/2000\n",
      "1/1 - 0s - loss: 2.3294 - val_loss: 2.3328 - 43ms/epoch - 43ms/step\n",
      "Epoch 213/2000\n",
      "1/1 - 0s - loss: 2.3253 - val_loss: 2.3286 - 41ms/epoch - 41ms/step\n",
      "Epoch 214/2000\n",
      "1/1 - 0s - loss: 2.3212 - val_loss: 2.3245 - 41ms/epoch - 41ms/step\n",
      "Epoch 215/2000\n",
      "1/1 - 0s - loss: 2.3172 - val_loss: 2.3204 - 41ms/epoch - 41ms/step\n",
      "Epoch 216/2000\n",
      "1/1 - 0s - loss: 2.3132 - val_loss: 2.3163 - 42ms/epoch - 42ms/step\n",
      "Epoch 217/2000\n",
      "1/1 - 0s - loss: 2.3093 - val_loss: 2.3123 - 41ms/epoch - 41ms/step\n",
      "Epoch 218/2000\n",
      "1/1 - 0s - loss: 2.3054 - val_loss: 2.3084 - 41ms/epoch - 41ms/step\n",
      "Epoch 219/2000\n",
      "1/1 - 0s - loss: 2.3016 - val_loss: 2.3045 - 41ms/epoch - 41ms/step\n",
      "Epoch 220/2000\n",
      "1/1 - 0s - loss: 2.2978 - val_loss: 2.3006 - 42ms/epoch - 42ms/step\n",
      "Epoch 221/2000\n",
      "1/1 - 0s - loss: 2.2941 - val_loss: 2.2968 - 41ms/epoch - 41ms/step\n",
      "Epoch 222/2000\n",
      "1/1 - 0s - loss: 2.2904 - val_loss: 2.2931 - 41ms/epoch - 41ms/step\n",
      "Epoch 223/2000\n",
      "1/1 - 0s - loss: 2.2868 - val_loss: 2.2895 - 42ms/epoch - 42ms/step\n",
      "Epoch 224/2000\n",
      "1/1 - 0s - loss: 2.2832 - val_loss: 2.2858 - 41ms/epoch - 41ms/step\n",
      "Epoch 225/2000\n",
      "1/1 - 0s - loss: 2.2797 - val_loss: 2.2822 - 41ms/epoch - 41ms/step\n",
      "Epoch 226/2000\n",
      "1/1 - 0s - loss: 2.2762 - val_loss: 2.2787 - 42ms/epoch - 42ms/step\n",
      "Epoch 227/2000\n",
      "1/1 - 0s - loss: 2.2727 - val_loss: 2.2752 - 41ms/epoch - 41ms/step\n",
      "Epoch 228/2000\n",
      "1/1 - 0s - loss: 2.2693 - val_loss: 2.2717 - 41ms/epoch - 41ms/step\n",
      "Epoch 229/2000\n",
      "1/1 - 0s - loss: 2.2660 - val_loss: 2.2684 - 41ms/epoch - 41ms/step\n",
      "Epoch 230/2000\n",
      "1/1 - 0s - loss: 2.2626 - val_loss: 2.2650 - 41ms/epoch - 41ms/step\n",
      "Epoch 231/2000\n",
      "1/1 - 0s - loss: 2.2594 - val_loss: 2.2617 - 41ms/epoch - 41ms/step\n",
      "Epoch 232/2000\n",
      "1/1 - 0s - loss: 2.2561 - val_loss: 2.2585 - 41ms/epoch - 41ms/step\n",
      "Epoch 233/2000\n",
      "1/1 - 0s - loss: 2.2530 - val_loss: 2.2553 - 41ms/epoch - 41ms/step\n",
      "Epoch 234/2000\n",
      "1/1 - 0s - loss: 2.2498 - val_loss: 2.2521 - 42ms/epoch - 42ms/step\n",
      "Epoch 235/2000\n",
      "1/1 - 0s - loss: 2.2467 - val_loss: 2.2490 - 42ms/epoch - 42ms/step\n",
      "Epoch 236/2000\n",
      "1/1 - 0s - loss: 2.2437 - val_loss: 2.2460 - 44ms/epoch - 44ms/step\n",
      "Epoch 237/2000\n",
      "1/1 - 0s - loss: 2.2407 - val_loss: 2.2429 - 42ms/epoch - 42ms/step\n",
      "Epoch 238/2000\n",
      "1/1 - 0s - loss: 2.2377 - val_loss: 2.2400 - 41ms/epoch - 41ms/step\n",
      "Epoch 239/2000\n",
      "1/1 - 0s - loss: 2.2348 - val_loss: 2.2370 - 42ms/epoch - 42ms/step\n",
      "Epoch 240/2000\n",
      "1/1 - 0s - loss: 2.2319 - val_loss: 2.2341 - 41ms/epoch - 41ms/step\n",
      "Epoch 241/2000\n",
      "1/1 - 0s - loss: 2.2290 - val_loss: 2.2313 - 41ms/epoch - 41ms/step\n",
      "Epoch 242/2000\n",
      "1/1 - 0s - loss: 2.2262 - val_loss: 2.2284 - 41ms/epoch - 41ms/step\n",
      "Epoch 243/2000\n",
      "1/1 - 0s - loss: 2.2234 - val_loss: 2.2256 - 43ms/epoch - 43ms/step\n",
      "Epoch 244/2000\n",
      "1/1 - 0s - loss: 2.2207 - val_loss: 2.2229 - 41ms/epoch - 41ms/step\n",
      "Epoch 245/2000\n",
      "1/1 - 0s - loss: 2.2179 - val_loss: 2.2201 - 41ms/epoch - 41ms/step\n",
      "Epoch 246/2000\n",
      "1/1 - 0s - loss: 2.2152 - val_loss: 2.2174 - 41ms/epoch - 41ms/step\n",
      "Epoch 247/2000\n",
      "1/1 - 0s - loss: 2.2126 - val_loss: 2.2148 - 43ms/epoch - 43ms/step\n",
      "Epoch 248/2000\n",
      "1/1 - 0s - loss: 2.2099 - val_loss: 2.2122 - 41ms/epoch - 41ms/step\n",
      "Epoch 249/2000\n",
      "1/1 - 0s - loss: 2.2074 - val_loss: 2.2096 - 41ms/epoch - 41ms/step\n",
      "Epoch 250/2000\n",
      "1/1 - 0s - loss: 2.2048 - val_loss: 2.2071 - 41ms/epoch - 41ms/step\n",
      "Epoch 251/2000\n",
      "1/1 - 0s - loss: 2.2023 - val_loss: 2.2046 - 43ms/epoch - 43ms/step\n",
      "Epoch 252/2000\n",
      "1/1 - 0s - loss: 2.1998 - val_loss: 2.2021 - 41ms/epoch - 41ms/step\n",
      "Epoch 253/2000\n",
      "1/1 - 0s - loss: 2.1974 - val_loss: 2.1997 - 40ms/epoch - 40ms/step\n",
      "Epoch 254/2000\n",
      "1/1 - 0s - loss: 2.1950 - val_loss: 2.1973 - 41ms/epoch - 41ms/step\n",
      "Epoch 255/2000\n",
      "1/1 - 0s - loss: 2.1926 - val_loss: 2.1950 - 43ms/epoch - 43ms/step\n",
      "Epoch 256/2000\n",
      "1/1 - 0s - loss: 2.1902 - val_loss: 2.1926 - 41ms/epoch - 41ms/step\n",
      "Epoch 257/2000\n",
      "1/1 - 0s - loss: 2.1879 - val_loss: 2.1904 - 40ms/epoch - 40ms/step\n",
      "Epoch 258/2000\n",
      "1/1 - 0s - loss: 2.1856 - val_loss: 2.1881 - 40ms/epoch - 40ms/step\n",
      "Epoch 259/2000\n",
      "1/1 - 0s - loss: 2.1834 - val_loss: 2.1859 - 43ms/epoch - 43ms/step\n",
      "Epoch 260/2000\n",
      "1/1 - 0s - loss: 2.1812 - val_loss: 2.1837 - 41ms/epoch - 41ms/step\n",
      "Epoch 261/2000\n",
      "1/1 - 0s - loss: 2.1790 - val_loss: 2.1815 - 41ms/epoch - 41ms/step\n",
      "Epoch 262/2000\n",
      "1/1 - 0s - loss: 2.1768 - val_loss: 2.1794 - 41ms/epoch - 41ms/step\n",
      "Epoch 263/2000\n",
      "1/1 - 0s - loss: 2.1747 - val_loss: 2.1773 - 42ms/epoch - 42ms/step\n",
      "Epoch 264/2000\n",
      "1/1 - 0s - loss: 2.1725 - val_loss: 2.1752 - 43ms/epoch - 43ms/step\n",
      "Epoch 265/2000\n",
      "1/1 - 0s - loss: 2.1704 - val_loss: 2.1731 - 41ms/epoch - 41ms/step\n",
      "Epoch 266/2000\n",
      "1/1 - 0s - loss: 2.1684 - val_loss: 2.1711 - 40ms/epoch - 40ms/step\n",
      "Epoch 267/2000\n",
      "1/1 - 0s - loss: 2.1663 - val_loss: 2.1690 - 42ms/epoch - 42ms/step\n",
      "Epoch 268/2000\n",
      "1/1 - 0s - loss: 2.1643 - val_loss: 2.1671 - 42ms/epoch - 42ms/step\n",
      "Epoch 269/2000\n",
      "1/1 - 0s - loss: 2.1623 - val_loss: 2.1651 - 41ms/epoch - 41ms/step\n",
      "Epoch 270/2000\n",
      "1/1 - 0s - loss: 2.1604 - val_loss: 2.1632 - 40ms/epoch - 40ms/step\n",
      "Epoch 271/2000\n",
      "1/1 - 0s - loss: 2.1584 - val_loss: 2.1613 - 41ms/epoch - 41ms/step\n",
      "Epoch 272/2000\n",
      "1/1 - 0s - loss: 2.1565 - val_loss: 2.1594 - 43ms/epoch - 43ms/step\n",
      "Epoch 273/2000\n",
      "1/1 - 0s - loss: 2.1546 - val_loss: 2.1575 - 41ms/epoch - 41ms/step\n",
      "Epoch 274/2000\n",
      "1/1 - 0s - loss: 2.1528 - val_loss: 2.1557 - 40ms/epoch - 40ms/step\n",
      "Epoch 275/2000\n",
      "1/1 - 0s - loss: 2.1509 - val_loss: 2.1539 - 41ms/epoch - 41ms/step\n",
      "Epoch 276/2000\n",
      "1/1 - 0s - loss: 2.1491 - val_loss: 2.1521 - 43ms/epoch - 43ms/step\n",
      "Epoch 277/2000\n",
      "1/1 - 0s - loss: 2.1473 - val_loss: 2.1503 - 42ms/epoch - 42ms/step\n",
      "Epoch 278/2000\n",
      "1/1 - 0s - loss: 2.1455 - val_loss: 2.1486 - 40ms/epoch - 40ms/step\n",
      "Epoch 279/2000\n",
      "1/1 - 0s - loss: 2.1438 - val_loss: 2.1469 - 41ms/epoch - 41ms/step\n",
      "Epoch 280/2000\n",
      "1/1 - 0s - loss: 2.1421 - val_loss: 2.1453 - 43ms/epoch - 43ms/step\n",
      "Epoch 281/2000\n",
      "1/1 - 0s - loss: 2.1404 - val_loss: 2.1436 - 42ms/epoch - 42ms/step\n",
      "Epoch 282/2000\n",
      "1/1 - 0s - loss: 2.1387 - val_loss: 2.1420 - 40ms/epoch - 40ms/step\n",
      "Epoch 283/2000\n",
      "1/1 - 0s - loss: 2.1370 - val_loss: 2.1403 - 41ms/epoch - 41ms/step\n",
      "Epoch 284/2000\n",
      "1/1 - 0s - loss: 2.1354 - val_loss: 2.1387 - 42ms/epoch - 42ms/step\n",
      "Epoch 285/2000\n",
      "1/1 - 0s - loss: 2.1338 - val_loss: 2.1371 - 43ms/epoch - 43ms/step\n",
      "Epoch 286/2000\n",
      "1/1 - 0s - loss: 2.1322 - val_loss: 2.1356 - 40ms/epoch - 40ms/step\n",
      "Epoch 287/2000\n",
      "1/1 - 0s - loss: 2.1306 - val_loss: 2.1341 - 40ms/epoch - 40ms/step\n",
      "Epoch 288/2000\n",
      "1/1 - 0s - loss: 2.1291 - val_loss: 2.1326 - 42ms/epoch - 42ms/step\n",
      "Epoch 289/2000\n",
      "1/1 - 0s - loss: 2.1276 - val_loss: 2.1311 - 42ms/epoch - 42ms/step\n",
      "Epoch 290/2000\n",
      "1/1 - 0s - loss: 2.1261 - val_loss: 2.1296 - 41ms/epoch - 41ms/step\n",
      "Epoch 291/2000\n",
      "1/1 - 0s - loss: 2.1246 - val_loss: 2.1282 - 40ms/epoch - 40ms/step\n",
      "Epoch 292/2000\n",
      "1/1 - 0s - loss: 2.1232 - val_loss: 2.1267 - 42ms/epoch - 42ms/step\n",
      "Epoch 293/2000\n",
      "1/1 - 0s - loss: 2.1217 - val_loss: 2.1253 - 43ms/epoch - 43ms/step\n",
      "Epoch 294/2000\n",
      "1/1 - 0s - loss: 2.1203 - val_loss: 2.1239 - 42ms/epoch - 42ms/step\n",
      "Epoch 295/2000\n",
      "1/1 - 0s - loss: 2.1189 - val_loss: 2.1225 - 39ms/epoch - 39ms/step\n",
      "Epoch 296/2000\n",
      "1/1 - 0s - loss: 2.1175 - val_loss: 2.1212 - 42ms/epoch - 42ms/step\n",
      "Epoch 297/2000\n",
      "1/1 - 0s - loss: 2.1162 - val_loss: 2.1199 - 43ms/epoch - 43ms/step\n",
      "Epoch 298/2000\n",
      "1/1 - 0s - loss: 2.1148 - val_loss: 2.1186 - 42ms/epoch - 42ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 299/2000\n",
      "1/1 - 0s - loss: 2.1135 - val_loss: 2.1173 - 40ms/epoch - 40ms/step\n",
      "Epoch 300/2000\n",
      "1/1 - 0s - loss: 2.1123 - val_loss: 2.1161 - 41ms/epoch - 41ms/step\n",
      "Epoch 301/2000\n",
      "1/1 - 0s - loss: 2.1110 - val_loss: 2.1148 - 43ms/epoch - 43ms/step\n",
      "Epoch 302/2000\n",
      "1/1 - 0s - loss: 2.1097 - val_loss: 2.1136 - 43ms/epoch - 43ms/step\n",
      "Epoch 303/2000\n",
      "1/1 - 0s - loss: 2.1085 - val_loss: 2.1123 - 40ms/epoch - 40ms/step\n",
      "Epoch 304/2000\n",
      "1/1 - 0s - loss: 2.1072 - val_loss: 2.1111 - 40ms/epoch - 40ms/step\n",
      "Epoch 305/2000\n",
      "1/1 - 0s - loss: 2.1060 - val_loss: 2.1099 - 43ms/epoch - 43ms/step\n",
      "Epoch 306/2000\n",
      "1/1 - 0s - loss: 2.1048 - val_loss: 2.1087 - 42ms/epoch - 42ms/step\n",
      "Epoch 307/2000\n",
      "1/1 - 0s - loss: 2.1036 - val_loss: 2.1076 - 40ms/epoch - 40ms/step\n",
      "Epoch 308/2000\n",
      "1/1 - 0s - loss: 2.1025 - val_loss: 2.1064 - 39ms/epoch - 39ms/step\n",
      "Epoch 309/2000\n",
      "1/1 - 0s - loss: 2.1014 - val_loss: 2.1053 - 43ms/epoch - 43ms/step\n",
      "Epoch 310/2000\n",
      "1/1 - 0s - loss: 2.1002 - val_loss: 2.1042 - 43ms/epoch - 43ms/step\n",
      "Epoch 311/2000\n",
      "1/1 - 0s - loss: 2.0991 - val_loss: 2.1031 - 41ms/epoch - 41ms/step\n",
      "Epoch 312/2000\n",
      "1/1 - 0s - loss: 2.0981 - val_loss: 2.1020 - 39ms/epoch - 39ms/step\n",
      "Epoch 313/2000\n",
      "1/1 - 0s - loss: 2.0970 - val_loss: 2.1010 - 41ms/epoch - 41ms/step\n",
      "Epoch 314/2000\n",
      "1/1 - 0s - loss: 2.0959 - val_loss: 2.0999 - 43ms/epoch - 43ms/step\n",
      "Epoch 315/2000\n",
      "1/1 - 0s - loss: 2.0948 - val_loss: 2.0988 - 42ms/epoch - 42ms/step\n",
      "Epoch 316/2000\n",
      "1/1 - 0s - loss: 2.0938 - val_loss: 2.0978 - 39ms/epoch - 39ms/step\n",
      "Epoch 317/2000\n",
      "1/1 - 0s - loss: 2.0927 - val_loss: 2.0968 - 41ms/epoch - 41ms/step\n",
      "Epoch 318/2000\n",
      "1/1 - 0s - loss: 2.0917 - val_loss: 2.0958 - 43ms/epoch - 43ms/step\n",
      "Epoch 319/2000\n",
      "1/1 - 0s - loss: 2.0907 - val_loss: 2.0948 - 43ms/epoch - 43ms/step\n",
      "Epoch 320/2000\n",
      "1/1 - 0s - loss: 2.0898 - val_loss: 2.0939 - 39ms/epoch - 39ms/step\n",
      "Epoch 321/2000\n",
      "1/1 - 0s - loss: 2.0888 - val_loss: 2.0929 - 39ms/epoch - 39ms/step\n",
      "Epoch 322/2000\n",
      "1/1 - 0s - loss: 2.0878 - val_loss: 2.0920 - 44ms/epoch - 44ms/step\n",
      "Epoch 323/2000\n",
      "1/1 - 0s - loss: 2.0869 - val_loss: 2.0911 - 43ms/epoch - 43ms/step\n",
      "Epoch 324/2000\n",
      "1/1 - 0s - loss: 2.0860 - val_loss: 2.0902 - 40ms/epoch - 40ms/step\n",
      "Epoch 325/2000\n",
      "1/1 - 0s - loss: 2.0851 - val_loss: 2.0893 - 39ms/epoch - 39ms/step\n",
      "Epoch 326/2000\n",
      "1/1 - 0s - loss: 2.0842 - val_loss: 2.0883 - 43ms/epoch - 43ms/step\n",
      "Epoch 327/2000\n",
      "1/1 - 0s - loss: 2.0833 - val_loss: 2.0875 - 42ms/epoch - 42ms/step\n",
      "Epoch 328/2000\n",
      "1/1 - 0s - loss: 2.0824 - val_loss: 2.0866 - 42ms/epoch - 42ms/step\n",
      "Epoch 329/2000\n",
      "1/1 - 0s - loss: 2.0815 - val_loss: 2.0857 - 40ms/epoch - 40ms/step\n",
      "Epoch 330/2000\n",
      "1/1 - 0s - loss: 2.0806 - val_loss: 2.0849 - 41ms/epoch - 41ms/step\n",
      "Epoch 331/2000\n",
      "1/1 - 0s - loss: 2.0798 - val_loss: 2.0840 - 44ms/epoch - 44ms/step\n",
      "Epoch 332/2000\n",
      "1/1 - 0s - loss: 2.0789 - val_loss: 2.0832 - 42ms/epoch - 42ms/step\n",
      "Epoch 333/2000\n",
      "1/1 - 0s - loss: 2.0781 - val_loss: 2.0824 - 40ms/epoch - 40ms/step\n",
      "Epoch 334/2000\n",
      "1/1 - 0s - loss: 2.0773 - val_loss: 2.0816 - 41ms/epoch - 41ms/step\n",
      "Epoch 335/2000\n",
      "1/1 - 0s - loss: 2.0765 - val_loss: 2.0809 - 42ms/epoch - 42ms/step\n",
      "Epoch 336/2000\n",
      "1/1 - 0s - loss: 2.0758 - val_loss: 2.0801 - 43ms/epoch - 43ms/step\n",
      "Epoch 337/2000\n",
      "1/1 - 0s - loss: 2.0750 - val_loss: 2.0793 - 40ms/epoch - 40ms/step\n",
      "Epoch 338/2000\n",
      "1/1 - 0s - loss: 2.0742 - val_loss: 2.0785 - 39ms/epoch - 39ms/step\n",
      "Epoch 339/2000\n",
      "1/1 - 0s - loss: 2.0734 - val_loss: 2.0777 - 43ms/epoch - 43ms/step\n",
      "Epoch 340/2000\n",
      "1/1 - 0s - loss: 2.0726 - val_loss: 2.0770 - 43ms/epoch - 43ms/step\n",
      "Epoch 341/2000\n",
      "1/1 - 0s - loss: 2.0719 - val_loss: 2.0763 - 41ms/epoch - 41ms/step\n",
      "Epoch 342/2000\n",
      "1/1 - 0s - loss: 2.0711 - val_loss: 2.0755 - 39ms/epoch - 39ms/step\n",
      "Epoch 343/2000\n",
      "1/1 - 0s - loss: 2.0704 - val_loss: 2.0748 - 43ms/epoch - 43ms/step\n",
      "Epoch 344/2000\n",
      "1/1 - 0s - loss: 2.0697 - val_loss: 2.0741 - 43ms/epoch - 43ms/step\n",
      "Epoch 345/2000\n",
      "1/1 - 0s - loss: 2.0690 - val_loss: 2.0735 - 42ms/epoch - 42ms/step\n",
      "Epoch 346/2000\n",
      "1/1 - 0s - loss: 2.0683 - val_loss: 2.0728 - 39ms/epoch - 39ms/step\n",
      "Epoch 347/2000\n",
      "1/1 - 0s - loss: 2.0677 - val_loss: 2.0721 - 43ms/epoch - 43ms/step\n",
      "Epoch 348/2000\n",
      "1/1 - 0s - loss: 2.0670 - val_loss: 2.0714 - 43ms/epoch - 43ms/step\n",
      "Epoch 349/2000\n",
      "1/1 - 0s - loss: 2.0663 - val_loss: 2.0707 - 42ms/epoch - 42ms/step\n",
      "Epoch 350/2000\n",
      "1/1 - 0s - loss: 2.0656 - val_loss: 2.0700 - 40ms/epoch - 40ms/step\n",
      "Epoch 351/2000\n",
      "1/1 - 0s - loss: 2.0649 - val_loss: 2.0694 - 41ms/epoch - 41ms/step\n",
      "Epoch 352/2000\n",
      "1/1 - 0s - loss: 2.0643 - val_loss: 2.0688 - 43ms/epoch - 43ms/step\n",
      "Epoch 353/2000\n",
      "1/1 - 0s - loss: 2.0637 - val_loss: 2.0681 - 43ms/epoch - 43ms/step\n",
      "Epoch 354/2000\n",
      "1/1 - 0s - loss: 2.0631 - val_loss: 2.0675 - 40ms/epoch - 40ms/step\n",
      "Epoch 355/2000\n",
      "1/1 - 0s - loss: 2.0625 - val_loss: 2.0669 - 40ms/epoch - 40ms/step\n",
      "Epoch 356/2000\n",
      "1/1 - 0s - loss: 2.0619 - val_loss: 2.0663 - 43ms/epoch - 43ms/step\n",
      "Epoch 357/2000\n",
      "1/1 - 0s - loss: 2.0613 - val_loss: 2.0657 - 43ms/epoch - 43ms/step\n",
      "Epoch 358/2000\n",
      "1/1 - 0s - loss: 2.0607 - val_loss: 2.0651 - 41ms/epoch - 41ms/step\n",
      "Epoch 359/2000\n",
      "1/1 - 0s - loss: 2.0601 - val_loss: 2.0645 - 40ms/epoch - 40ms/step\n",
      "Epoch 360/2000\n",
      "1/1 - 0s - loss: 2.0595 - val_loss: 2.0639 - 42ms/epoch - 42ms/step\n",
      "Epoch 361/2000\n",
      "1/1 - 0s - loss: 2.0589 - val_loss: 2.0634 - 42ms/epoch - 42ms/step\n",
      "Epoch 362/2000\n",
      "1/1 - 0s - loss: 2.0583 - val_loss: 2.0628 - 44ms/epoch - 44ms/step\n",
      "Epoch 363/2000\n",
      "1/1 - 0s - loss: 2.0578 - val_loss: 2.0623 - 41ms/epoch - 41ms/step\n",
      "Epoch 364/2000\n",
      "1/1 - 0s - loss: 2.0572 - val_loss: 2.0617 - 42ms/epoch - 42ms/step\n",
      "Epoch 365/2000\n",
      "1/1 - 0s - loss: 2.0567 - val_loss: 2.0612 - 43ms/epoch - 43ms/step\n",
      "Epoch 366/2000\n",
      "1/1 - 0s - loss: 2.0562 - val_loss: 2.0607 - 41ms/epoch - 41ms/step\n",
      "Epoch 367/2000\n",
      "1/1 - 0s - loss: 2.0557 - val_loss: 2.0602 - 39ms/epoch - 39ms/step\n",
      "Epoch 368/2000\n",
      "1/1 - 0s - loss: 2.0552 - val_loss: 2.0597 - 41ms/epoch - 41ms/step\n",
      "Epoch 369/2000\n",
      "1/1 - 0s - loss: 2.0547 - val_loss: 2.0592 - 42ms/epoch - 42ms/step\n",
      "Epoch 370/2000\n",
      "1/1 - 0s - loss: 2.0542 - val_loss: 2.0586 - 42ms/epoch - 42ms/step\n",
      "Epoch 371/2000\n",
      "1/1 - 0s - loss: 2.0536 - val_loss: 2.0581 - 39ms/epoch - 39ms/step\n",
      "Epoch 372/2000\n",
      "1/1 - 0s - loss: 2.0531 - val_loss: 2.0576 - 41ms/epoch - 41ms/step\n",
      "Epoch 373/2000\n",
      "1/1 - 0s - loss: 2.0526 - val_loss: 2.0571 - 42ms/epoch - 42ms/step\n",
      "Epoch 374/2000\n",
      "1/1 - 0s - loss: 2.0521 - val_loss: 2.0567 - 42ms/epoch - 42ms/step\n",
      "Epoch 375/2000\n",
      "1/1 - 0s - loss: 2.0517 - val_loss: 2.0562 - 41ms/epoch - 41ms/step\n",
      "Epoch 376/2000\n",
      "1/1 - 0s - loss: 2.0512 - val_loss: 2.0558 - 40ms/epoch - 40ms/step\n",
      "Epoch 377/2000\n",
      "1/1 - 0s - loss: 2.0508 - val_loss: 2.0553 - 43ms/epoch - 43ms/step\n",
      "Epoch 378/2000\n",
      "1/1 - 0s - loss: 2.0503 - val_loss: 2.0548 - 42ms/epoch - 42ms/step\n",
      "Epoch 379/2000\n",
      "1/1 - 0s - loss: 2.0499 - val_loss: 2.0544 - 40ms/epoch - 40ms/step\n",
      "Epoch 380/2000\n",
      "1/1 - 0s - loss: 2.0494 - val_loss: 2.0539 - 40ms/epoch - 40ms/step\n",
      "Epoch 381/2000\n",
      "1/1 - 0s - loss: 2.0490 - val_loss: 2.0535 - 42ms/epoch - 42ms/step\n",
      "Epoch 382/2000\n",
      "1/1 - 0s - loss: 2.0485 - val_loss: 2.0530 - 42ms/epoch - 42ms/step\n",
      "Epoch 383/2000\n",
      "1/1 - 0s - loss: 2.0481 - val_loss: 2.0526 - 41ms/epoch - 41ms/step\n",
      "Epoch 384/2000\n",
      "1/1 - 0s - loss: 2.0476 - val_loss: 2.0522 - 40ms/epoch - 40ms/step\n",
      "Epoch 385/2000\n",
      "1/1 - 0s - loss: 2.0472 - val_loss: 2.0517 - 42ms/epoch - 42ms/step\n",
      "Epoch 386/2000\n",
      "1/1 - 0s - loss: 2.0468 - val_loss: 2.0513 - 43ms/epoch - 43ms/step\n",
      "Epoch 387/2000\n",
      "1/1 - 0s - loss: 2.0464 - val_loss: 2.0509 - 41ms/epoch - 41ms/step\n",
      "Epoch 388/2000\n",
      "1/1 - 0s - loss: 2.0460 - val_loss: 2.0505 - 40ms/epoch - 40ms/step\n",
      "Epoch 389/2000\n",
      "1/1 - 0s - loss: 2.0456 - val_loss: 2.0501 - 41ms/epoch - 41ms/step\n",
      "Epoch 390/2000\n",
      "1/1 - 0s - loss: 2.0453 - val_loss: 2.0497 - 43ms/epoch - 43ms/step\n",
      "Epoch 391/2000\n",
      "1/1 - 0s - loss: 2.0449 - val_loss: 2.0493 - 41ms/epoch - 41ms/step\n",
      "Epoch 392/2000\n",
      "1/1 - 0s - loss: 2.0445 - val_loss: 2.0489 - 40ms/epoch - 40ms/step\n",
      "Epoch 393/2000\n",
      "1/1 - 0s - loss: 2.0441 - val_loss: 2.0485 - 41ms/epoch - 41ms/step\n",
      "Epoch 394/2000\n",
      "1/1 - 0s - loss: 2.0437 - val_loss: 2.0481 - 43ms/epoch - 43ms/step\n",
      "Epoch 395/2000\n",
      "1/1 - 0s - loss: 2.0433 - val_loss: 2.0478 - 41ms/epoch - 41ms/step\n",
      "Epoch 396/2000\n",
      "1/1 - 0s - loss: 2.0430 - val_loss: 2.0474 - 40ms/epoch - 40ms/step\n",
      "Epoch 397/2000\n",
      "1/1 - 0s - loss: 2.0426 - val_loss: 2.0470 - 40ms/epoch - 40ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 398/2000\n",
      "1/1 - 0s - loss: 2.0422 - val_loss: 2.0467 - 42ms/epoch - 42ms/step\n",
      "Epoch 399/2000\n",
      "1/1 - 0s - loss: 2.0419 - val_loss: 2.0464 - 41ms/epoch - 41ms/step\n",
      "Epoch 400/2000\n",
      "1/1 - 0s - loss: 2.0415 - val_loss: 2.0460 - 41ms/epoch - 41ms/step\n",
      "Epoch 401/2000\n",
      "1/1 - 0s - loss: 2.0412 - val_loss: 2.0457 - 42ms/epoch - 42ms/step\n",
      "Epoch 402/2000\n",
      "1/1 - 0s - loss: 2.0408 - val_loss: 2.0454 - 42ms/epoch - 42ms/step\n",
      "Epoch 403/2000\n",
      "1/1 - 0s - loss: 2.0405 - val_loss: 2.0450 - 43ms/epoch - 43ms/step\n",
      "Epoch 404/2000\n",
      "1/1 - 0s - loss: 2.0402 - val_loss: 2.0447 - 40ms/epoch - 40ms/step\n",
      "Epoch 405/2000\n",
      "1/1 - 0s - loss: 2.0398 - val_loss: 2.0443 - 41ms/epoch - 41ms/step\n",
      "Epoch 406/2000\n",
      "1/1 - 0s - loss: 2.0395 - val_loss: 2.0440 - 42ms/epoch - 42ms/step\n",
      "Epoch 407/2000\n",
      "1/1 - 0s - loss: 2.0392 - val_loss: 2.0437 - 42ms/epoch - 42ms/step\n",
      "Epoch 408/2000\n",
      "1/1 - 0s - loss: 2.0389 - val_loss: 2.0434 - 41ms/epoch - 41ms/step\n",
      "Epoch 409/2000\n",
      "1/1 - 0s - loss: 2.0385 - val_loss: 2.0430 - 41ms/epoch - 41ms/step\n",
      "Epoch 410/2000\n",
      "1/1 - 0s - loss: 2.0382 - val_loss: 2.0427 - 42ms/epoch - 42ms/step\n",
      "Epoch 411/2000\n",
      "1/1 - 0s - loss: 2.0379 - val_loss: 2.0424 - 42ms/epoch - 42ms/step\n",
      "Epoch 412/2000\n",
      "1/1 - 0s - loss: 2.0376 - val_loss: 2.0421 - 41ms/epoch - 41ms/step\n",
      "Epoch 413/2000\n",
      "1/1 - 0s - loss: 2.0373 - val_loss: 2.0418 - 41ms/epoch - 41ms/step\n",
      "Epoch 414/2000\n",
      "1/1 - 0s - loss: 2.0370 - val_loss: 2.0416 - 42ms/epoch - 42ms/step\n",
      "Epoch 415/2000\n",
      "1/1 - 0s - loss: 2.0368 - val_loss: 2.0413 - 42ms/epoch - 42ms/step\n",
      "Epoch 416/2000\n",
      "1/1 - 0s - loss: 2.0365 - val_loss: 2.0412 - 41ms/epoch - 41ms/step\n",
      "Epoch 417/2000\n",
      "1/1 - 0s - loss: 2.0363 - val_loss: 2.0410 - 41ms/epoch - 41ms/step\n",
      "Epoch 418/2000\n",
      "1/1 - 0s - loss: 2.0363 - val_loss: 2.0409 - 41ms/epoch - 41ms/step\n",
      "Epoch 419/2000\n",
      "1/1 - 0s - loss: 2.0361 - val_loss: 2.0404 - 42ms/epoch - 42ms/step\n",
      "Epoch 420/2000\n",
      "1/1 - 0s - loss: 2.0357 - val_loss: 2.0399 - 40ms/epoch - 40ms/step\n",
      "Epoch 421/2000\n",
      "1/1 - 0s - loss: 2.0351 - val_loss: 2.0396 - 41ms/epoch - 41ms/step\n",
      "Epoch 422/2000\n",
      "1/1 - 0s - loss: 2.0348 - val_loss: 2.0395 - 41ms/epoch - 41ms/step\n",
      "Epoch 423/2000\n",
      "1/1 - 0s - loss: 2.0347 - val_loss: 2.0392 - 42ms/epoch - 42ms/step\n",
      "Epoch 424/2000\n",
      "1/1 - 0s - loss: 2.0344 - val_loss: 2.0387 - 40ms/epoch - 40ms/step\n",
      "Epoch 425/2000\n",
      "1/1 - 0s - loss: 2.0340 - val_loss: 2.0385 - 41ms/epoch - 41ms/step\n",
      "Epoch 426/2000\n",
      "1/1 - 0s - loss: 2.0337 - val_loss: 2.0384 - 41ms/epoch - 41ms/step\n",
      "Epoch 427/2000\n",
      "1/1 - 0s - loss: 2.0336 - val_loss: 2.0380 - 43ms/epoch - 43ms/step\n",
      "Epoch 428/2000\n",
      "1/1 - 0s - loss: 2.0332 - val_loss: 2.0376 - 41ms/epoch - 41ms/step\n",
      "Epoch 429/2000\n",
      "1/1 - 0s - loss: 2.0329 - val_loss: 2.0375 - 41ms/epoch - 41ms/step\n",
      "Epoch 430/2000\n",
      "1/1 - 0s - loss: 2.0327 - val_loss: 2.0373 - 41ms/epoch - 41ms/step\n",
      "Epoch 431/2000\n",
      "1/1 - 0s - loss: 2.0325 - val_loss: 2.0369 - 42ms/epoch - 42ms/step\n",
      "Epoch 432/2000\n",
      "1/1 - 0s - loss: 2.0321 - val_loss: 2.0367 - 41ms/epoch - 41ms/step\n",
      "Epoch 433/2000\n",
      "1/1 - 0s - loss: 2.0319 - val_loss: 2.0365 - 41ms/epoch - 41ms/step\n",
      "Epoch 434/2000\n",
      "1/1 - 0s - loss: 2.0317 - val_loss: 2.0362 - 41ms/epoch - 41ms/step\n",
      "Epoch 435/2000\n",
      "1/1 - 0s - loss: 2.0314 - val_loss: 2.0359 - 42ms/epoch - 42ms/step\n",
      "Epoch 436/2000\n",
      "1/1 - 0s - loss: 2.0312 - val_loss: 2.0357 - 41ms/epoch - 41ms/step\n",
      "Epoch 437/2000\n",
      "1/1 - 0s - loss: 2.0310 - val_loss: 2.0355 - 42ms/epoch - 42ms/step\n",
      "Epoch 438/2000\n",
      "1/1 - 0s - loss: 2.0307 - val_loss: 2.0352 - 42ms/epoch - 42ms/step\n",
      "Epoch 439/2000\n",
      "1/1 - 0s - loss: 2.0305 - val_loss: 2.0350 - 42ms/epoch - 42ms/step\n",
      "Epoch 440/2000\n",
      "1/1 - 0s - loss: 2.0303 - val_loss: 2.0348 - 42ms/epoch - 42ms/step\n",
      "Epoch 441/2000\n",
      "1/1 - 0s - loss: 2.0300 - val_loss: 2.0345 - 41ms/epoch - 41ms/step\n",
      "Epoch 442/2000\n",
      "1/1 - 0s - loss: 2.0298 - val_loss: 2.0343 - 41ms/epoch - 41ms/step\n",
      "Epoch 443/2000\n",
      "1/1 - 0s - loss: 2.0296 - val_loss: 2.0341 - 42ms/epoch - 42ms/step\n",
      "Epoch 444/2000\n",
      "1/1 - 0s - loss: 2.0294 - val_loss: 2.0339 - 41ms/epoch - 41ms/step\n",
      "Epoch 445/2000\n",
      "1/1 - 0s - loss: 2.0291 - val_loss: 2.0337 - 42ms/epoch - 42ms/step\n",
      "Epoch 446/2000\n",
      "1/1 - 0s - loss: 2.0289 - val_loss: 2.0335 - 41ms/epoch - 41ms/step\n",
      "Epoch 447/2000\n",
      "1/1 - 0s - loss: 2.0287 - val_loss: 2.0332 - 41ms/epoch - 41ms/step\n",
      "Epoch 448/2000\n",
      "1/1 - 0s - loss: 2.0285 - val_loss: 2.0330 - 41ms/epoch - 41ms/step\n",
      "Epoch 449/2000\n",
      "1/1 - 0s - loss: 2.0283 - val_loss: 2.0328 - 41ms/epoch - 41ms/step\n",
      "Epoch 450/2000\n",
      "1/1 - 0s - loss: 2.0281 - val_loss: 2.0326 - 42ms/epoch - 42ms/step\n",
      "Epoch 451/2000\n",
      "1/1 - 0s - loss: 2.0279 - val_loss: 2.0324 - 42ms/epoch - 42ms/step\n",
      "Epoch 452/2000\n",
      "1/1 - 0s - loss: 2.0277 - val_loss: 2.0322 - 41ms/epoch - 41ms/step\n",
      "Epoch 453/2000\n",
      "1/1 - 0s - loss: 2.0274 - val_loss: 2.0320 - 41ms/epoch - 41ms/step\n",
      "Epoch 454/2000\n",
      "1/1 - 0s - loss: 2.0272 - val_loss: 2.0318 - 42ms/epoch - 42ms/step\n",
      "Epoch 455/2000\n",
      "1/1 - 0s - loss: 2.0270 - val_loss: 2.0316 - 42ms/epoch - 42ms/step\n",
      "Epoch 456/2000\n",
      "1/1 - 0s - loss: 2.0268 - val_loss: 2.0314 - 44ms/epoch - 44ms/step\n",
      "Epoch 457/2000\n",
      "1/1 - 0s - loss: 2.0266 - val_loss: 2.0312 - 41ms/epoch - 41ms/step\n",
      "Epoch 458/2000\n",
      "1/1 - 0s - loss: 2.0265 - val_loss: 2.0310 - 42ms/epoch - 42ms/step\n",
      "Epoch 459/2000\n",
      "1/1 - 0s - loss: 2.0263 - val_loss: 2.0308 - 41ms/epoch - 41ms/step\n",
      "Epoch 460/2000\n",
      "1/1 - 0s - loss: 2.0261 - val_loss: 2.0306 - 42ms/epoch - 42ms/step\n",
      "Epoch 461/2000\n",
      "1/1 - 0s - loss: 2.0259 - val_loss: 2.0305 - 41ms/epoch - 41ms/step\n",
      "Epoch 462/2000\n",
      "1/1 - 0s - loss: 2.0257 - val_loss: 2.0303 - 42ms/epoch - 42ms/step\n",
      "Epoch 463/2000\n",
      "1/1 - 0s - loss: 2.0255 - val_loss: 2.0301 - 41ms/epoch - 41ms/step\n",
      "Epoch 464/2000\n",
      "1/1 - 0s - loss: 2.0253 - val_loss: 2.0299 - 41ms/epoch - 41ms/step\n",
      "Epoch 465/2000\n",
      "1/1 - 0s - loss: 2.0251 - val_loss: 2.0297 - 41ms/epoch - 41ms/step\n",
      "Epoch 466/2000\n",
      "1/1 - 0s - loss: 2.0250 - val_loss: 2.0295 - 43ms/epoch - 43ms/step\n",
      "Epoch 467/2000\n",
      "1/1 - 0s - loss: 2.0248 - val_loss: 2.0294 - 42ms/epoch - 42ms/step\n",
      "Epoch 468/2000\n",
      "1/1 - 0s - loss: 2.0246 - val_loss: 2.0292 - 41ms/epoch - 41ms/step\n",
      "Epoch 469/2000\n",
      "1/1 - 0s - loss: 2.0244 - val_loss: 2.0290 - 41ms/epoch - 41ms/step\n",
      "Epoch 470/2000\n",
      "1/1 - 0s - loss: 2.0243 - val_loss: 2.0289 - 43ms/epoch - 43ms/step\n",
      "Epoch 471/2000\n",
      "1/1 - 0s - loss: 2.0241 - val_loss: 2.0287 - 42ms/epoch - 42ms/step\n",
      "Epoch 472/2000\n",
      "1/1 - 0s - loss: 2.0239 - val_loss: 2.0285 - 41ms/epoch - 41ms/step\n",
      "Epoch 473/2000\n",
      "1/1 - 0s - loss: 2.0237 - val_loss: 2.0283 - 41ms/epoch - 41ms/step\n",
      "Epoch 474/2000\n",
      "1/1 - 0s - loss: 2.0236 - val_loss: 2.0282 - 42ms/epoch - 42ms/step\n",
      "Epoch 475/2000\n",
      "1/1 - 0s - loss: 2.0234 - val_loss: 2.0280 - 41ms/epoch - 41ms/step\n",
      "Epoch 476/2000\n",
      "1/1 - 0s - loss: 2.0232 - val_loss: 2.0278 - 40ms/epoch - 40ms/step\n",
      "Epoch 477/2000\n",
      "1/1 - 0s - loss: 2.0231 - val_loss: 2.0277 - 41ms/epoch - 41ms/step\n",
      "Epoch 478/2000\n",
      "1/1 - 0s - loss: 2.0229 - val_loss: 2.0275 - 42ms/epoch - 42ms/step\n",
      "Epoch 479/2000\n",
      "1/1 - 0s - loss: 2.0228 - val_loss: 2.0274 - 42ms/epoch - 42ms/step\n",
      "Epoch 480/2000\n",
      "1/1 - 0s - loss: 2.0226 - val_loss: 2.0272 - 41ms/epoch - 41ms/step\n",
      "Epoch 481/2000\n",
      "1/1 - 0s - loss: 2.0224 - val_loss: 2.0271 - 41ms/epoch - 41ms/step\n",
      "Epoch 482/2000\n",
      "1/1 - 0s - loss: 2.0223 - val_loss: 2.0269 - 43ms/epoch - 43ms/step\n",
      "Epoch 483/2000\n",
      "1/1 - 0s - loss: 2.0221 - val_loss: 2.0268 - 42ms/epoch - 42ms/step\n",
      "Epoch 484/2000\n",
      "1/1 - 0s - loss: 2.0220 - val_loss: 2.0266 - 41ms/epoch - 41ms/step\n",
      "Epoch 485/2000\n",
      "1/1 - 0s - loss: 2.0218 - val_loss: 2.0265 - 41ms/epoch - 41ms/step\n",
      "Epoch 486/2000\n",
      "1/1 - 0s - loss: 2.0217 - val_loss: 2.0263 - 43ms/epoch - 43ms/step\n",
      "Epoch 487/2000\n",
      "1/1 - 0s - loss: 2.0215 - val_loss: 2.0262 - 41ms/epoch - 41ms/step\n",
      "Epoch 488/2000\n",
      "1/1 - 0s - loss: 2.0214 - val_loss: 2.0260 - 40ms/epoch - 40ms/step\n",
      "Epoch 489/2000\n",
      "1/1 - 0s - loss: 2.0212 - val_loss: 2.0259 - 41ms/epoch - 41ms/step\n",
      "Epoch 490/2000\n",
      "1/1 - 0s - loss: 2.0211 - val_loss: 2.0257 - 42ms/epoch - 42ms/step\n",
      "Epoch 491/2000\n",
      "1/1 - 0s - loss: 2.0209 - val_loss: 2.0256 - 43ms/epoch - 43ms/step\n",
      "Epoch 492/2000\n",
      "1/1 - 0s - loss: 2.0208 - val_loss: 2.0254 - 41ms/epoch - 41ms/step\n",
      "Epoch 493/2000\n",
      "1/1 - 0s - loss: 2.0206 - val_loss: 2.0253 - 40ms/epoch - 40ms/step\n",
      "Epoch 494/2000\n",
      "1/1 - 0s - loss: 2.0205 - val_loss: 2.0252 - 42ms/epoch - 42ms/step\n",
      "Epoch 495/2000\n",
      "1/1 - 0s - loss: 2.0204 - val_loss: 2.0250 - 42ms/epoch - 42ms/step\n",
      "Epoch 496/2000\n",
      "1/1 - 0s - loss: 2.0202 - val_loss: 2.0249 - 41ms/epoch - 41ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 497/2000\n",
      "1/1 - 0s - loss: 2.0201 - val_loss: 2.0247 - 40ms/epoch - 40ms/step\n",
      "Epoch 498/2000\n",
      "1/1 - 0s - loss: 2.0199 - val_loss: 2.0246 - 41ms/epoch - 41ms/step\n",
      "Epoch 499/2000\n",
      "1/1 - 0s - loss: 2.0198 - val_loss: 2.0245 - 43ms/epoch - 43ms/step\n",
      "Epoch 500/2000\n",
      "1/1 - 0s - loss: 2.0197 - val_loss: 2.0243 - 41ms/epoch - 41ms/step\n",
      "Epoch 501/2000\n",
      "1/1 - 0s - loss: 2.0195 - val_loss: 2.0242 - 40ms/epoch - 40ms/step\n",
      "Epoch 502/2000\n",
      "1/1 - 0s - loss: 2.0194 - val_loss: 2.0241 - 41ms/epoch - 41ms/step\n",
      "Epoch 503/2000\n",
      "1/1 - 0s - loss: 2.0193 - val_loss: 2.0239 - 43ms/epoch - 43ms/step\n",
      "Epoch 504/2000\n",
      "1/1 - 0s - loss: 2.0191 - val_loss: 2.0238 - 41ms/epoch - 41ms/step\n",
      "Epoch 505/2000\n",
      "1/1 - 0s - loss: 2.0190 - val_loss: 2.0237 - 40ms/epoch - 40ms/step\n",
      "Epoch 506/2000\n",
      "1/1 - 0s - loss: 2.0189 - val_loss: 2.0236 - 41ms/epoch - 41ms/step\n",
      "Epoch 507/2000\n",
      "1/1 - 0s - loss: 2.0188 - val_loss: 2.0234 - 43ms/epoch - 43ms/step\n",
      "Epoch 508/2000\n",
      "1/1 - 0s - loss: 2.0186 - val_loss: 2.0233 - 42ms/epoch - 42ms/step\n",
      "Epoch 509/2000\n",
      "1/1 - 0s - loss: 2.0185 - val_loss: 2.0232 - 40ms/epoch - 40ms/step\n",
      "Epoch 510/2000\n",
      "1/1 - 0s - loss: 2.0184 - val_loss: 2.0231 - 41ms/epoch - 41ms/step\n",
      "Epoch 511/2000\n",
      "1/1 - 0s - loss: 2.0182 - val_loss: 2.0229 - 42ms/epoch - 42ms/step\n",
      "Epoch 512/2000\n",
      "1/1 - 0s - loss: 2.0181 - val_loss: 2.0228 - 43ms/epoch - 43ms/step\n",
      "Epoch 513/2000\n",
      "1/1 - 0s - loss: 2.0180 - val_loss: 2.0227 - 40ms/epoch - 40ms/step\n",
      "Epoch 514/2000\n",
      "1/1 - 0s - loss: 2.0179 - val_loss: 2.0226 - 44ms/epoch - 44ms/step\n",
      "Epoch 515/2000\n",
      "1/1 - 0s - loss: 2.0178 - val_loss: 2.0225 - 43ms/epoch - 43ms/step\n",
      "Epoch 516/2000\n",
      "1/1 - 0s - loss: 2.0176 - val_loss: 2.0224 - 43ms/epoch - 43ms/step\n",
      "Epoch 517/2000\n",
      "1/1 - 0s - loss: 2.0175 - val_loss: 2.0222 - 41ms/epoch - 41ms/step\n",
      "Epoch 518/2000\n",
      "1/1 - 0s - loss: 2.0174 - val_loss: 2.0221 - 39ms/epoch - 39ms/step\n",
      "Epoch 519/2000\n",
      "1/1 - 0s - loss: 2.0173 - val_loss: 2.0220 - 43ms/epoch - 43ms/step\n",
      "Epoch 520/2000\n",
      "1/1 - 0s - loss: 2.0172 - val_loss: 2.0219 - 43ms/epoch - 43ms/step\n",
      "Epoch 521/2000\n",
      "1/1 - 0s - loss: 2.0171 - val_loss: 2.0218 - 41ms/epoch - 41ms/step\n",
      "Epoch 522/2000\n",
      "1/1 - 0s - loss: 2.0169 - val_loss: 2.0217 - 39ms/epoch - 39ms/step\n",
      "Epoch 523/2000\n",
      "1/1 - 0s - loss: 2.0168 - val_loss: 2.0216 - 43ms/epoch - 43ms/step\n",
      "Epoch 524/2000\n",
      "1/1 - 0s - loss: 2.0167 - val_loss: 2.0215 - 43ms/epoch - 43ms/step\n",
      "Epoch 525/2000\n",
      "1/1 - 0s - loss: 2.0166 - val_loss: 2.0213 - 42ms/epoch - 42ms/step\n",
      "Epoch 526/2000\n",
      "1/1 - 0s - loss: 2.0165 - val_loss: 2.0212 - 40ms/epoch - 40ms/step\n",
      "Epoch 527/2000\n",
      "1/1 - 0s - loss: 2.0164 - val_loss: 2.0211 - 41ms/epoch - 41ms/step\n",
      "Epoch 528/2000\n",
      "1/1 - 0s - loss: 2.0163 - val_loss: 2.0210 - 43ms/epoch - 43ms/step\n",
      "Epoch 529/2000\n",
      "1/1 - 0s - loss: 2.0162 - val_loss: 2.0209 - 43ms/epoch - 43ms/step\n",
      "Epoch 530/2000\n",
      "1/1 - 0s - loss: 2.0161 - val_loss: 2.0208 - 39ms/epoch - 39ms/step\n",
      "Epoch 531/2000\n",
      "1/1 - 0s - loss: 2.0159 - val_loss: 2.0207 - 40ms/epoch - 40ms/step\n",
      "Epoch 532/2000\n",
      "1/1 - 0s - loss: 2.0158 - val_loss: 2.0206 - 43ms/epoch - 43ms/step\n",
      "Epoch 533/2000\n",
      "1/1 - 0s - loss: 2.0157 - val_loss: 2.0205 - 43ms/epoch - 43ms/step\n",
      "Epoch 534/2000\n",
      "1/1 - 0s - loss: 2.0156 - val_loss: 2.0204 - 40ms/epoch - 40ms/step\n",
      "Epoch 535/2000\n",
      "1/1 - 0s - loss: 2.0155 - val_loss: 2.0203 - 39ms/epoch - 39ms/step\n",
      "Epoch 536/2000\n",
      "1/1 - 0s - loss: 2.0154 - val_loss: 2.0202 - 42ms/epoch - 42ms/step\n",
      "Epoch 537/2000\n",
      "1/1 - 0s - loss: 2.0153 - val_loss: 2.0201 - 43ms/epoch - 43ms/step\n",
      "Epoch 538/2000\n",
      "1/1 - 0s - loss: 2.0152 - val_loss: 2.0201 - 41ms/epoch - 41ms/step\n",
      "Epoch 539/2000\n",
      "1/1 - 0s - loss: 2.0152 - val_loss: 2.0200 - 39ms/epoch - 39ms/step\n",
      "Epoch 540/2000\n",
      "1/1 - 0s - loss: 2.0151 - val_loss: 2.0200 - 43ms/epoch - 43ms/step\n",
      "Epoch 541/2000\n",
      "1/1 - 0s - loss: 2.0151 - val_loss: 2.0199 - 42ms/epoch - 42ms/step\n",
      "Epoch 542/2000\n",
      "1/1 - 0s - loss: 2.0150 - val_loss: 2.0199 - 40ms/epoch - 40ms/step\n",
      "Epoch 543/2000\n",
      "1/1 - 0s - loss: 2.0150 - val_loss: 2.0198 - 40ms/epoch - 40ms/step\n",
      "Epoch 544/2000\n",
      "1/1 - 0s - loss: 2.0149 - val_loss: 2.0196 - 41ms/epoch - 41ms/step\n",
      "Epoch 545/2000\n",
      "1/1 - 0s - loss: 2.0147 - val_loss: 2.0194 - 43ms/epoch - 43ms/step\n",
      "Epoch 546/2000\n",
      "1/1 - 0s - loss: 2.0145 - val_loss: 2.0192 - 43ms/epoch - 43ms/step\n",
      "Epoch 547/2000\n",
      "1/1 - 0s - loss: 2.0143 - val_loss: 2.0192 - 40ms/epoch - 40ms/step\n",
      "Epoch 548/2000\n",
      "1/1 - 0s - loss: 2.0143 - val_loss: 2.0192 - 40ms/epoch - 40ms/step\n",
      "Epoch 549/2000\n",
      "1/1 - 0s - loss: 2.0143 - val_loss: 2.0191 - 43ms/epoch - 43ms/step\n",
      "Epoch 550/2000\n",
      "1/1 - 0s - loss: 2.0142 - val_loss: 2.0189 - 43ms/epoch - 43ms/step\n",
      "Epoch 551/2000\n",
      "1/1 - 0s - loss: 2.0140 - val_loss: 2.0188 - 40ms/epoch - 40ms/step\n",
      "Epoch 552/2000\n",
      "1/1 - 0s - loss: 2.0139 - val_loss: 2.0188 - 39ms/epoch - 39ms/step\n",
      "Epoch 553/2000\n",
      "1/1 - 0s - loss: 2.0138 - val_loss: 2.0187 - 42ms/epoch - 42ms/step\n",
      "Epoch 554/2000\n",
      "1/1 - 0s - loss: 2.0138 - val_loss: 2.0186 - 44ms/epoch - 44ms/step\n",
      "Epoch 555/2000\n",
      "1/1 - 0s - loss: 2.0136 - val_loss: 2.0185 - 42ms/epoch - 42ms/step\n",
      "Epoch 556/2000\n",
      "1/1 - 0s - loss: 2.0135 - val_loss: 2.0184 - 39ms/epoch - 39ms/step\n",
      "Epoch 557/2000\n",
      "1/1 - 0s - loss: 2.0134 - val_loss: 2.0183 - 42ms/epoch - 42ms/step\n",
      "Epoch 558/2000\n",
      "1/1 - 0s - loss: 2.0133 - val_loss: 2.0182 - 43ms/epoch - 43ms/step\n",
      "Epoch 559/2000\n",
      "1/1 - 0s - loss: 2.0133 - val_loss: 2.0181 - 42ms/epoch - 42ms/step\n",
      "Epoch 560/2000\n",
      "1/1 - 0s - loss: 2.0132 - val_loss: 2.0180 - 39ms/epoch - 39ms/step\n",
      "Epoch 561/2000\n",
      "1/1 - 0s - loss: 2.0130 - val_loss: 2.0179 - 41ms/epoch - 41ms/step\n",
      "Epoch 562/2000\n",
      "1/1 - 0s - loss: 2.0130 - val_loss: 2.0179 - 41ms/epoch - 41ms/step\n",
      "Epoch 563/2000\n",
      "1/1 - 0s - loss: 2.0129 - val_loss: 2.0178 - 43ms/epoch - 43ms/step\n",
      "Epoch 564/2000\n",
      "1/1 - 0s - loss: 2.0128 - val_loss: 2.0177 - 40ms/epoch - 40ms/step\n",
      "Epoch 565/2000\n",
      "1/1 - 0s - loss: 2.0127 - val_loss: 2.0176 - 39ms/epoch - 39ms/step\n",
      "Epoch 566/2000\n",
      "1/1 - 0s - loss: 2.0126 - val_loss: 2.0175 - 43ms/epoch - 43ms/step\n",
      "Epoch 567/2000\n",
      "1/1 - 0s - loss: 2.0125 - val_loss: 2.0175 - 43ms/epoch - 43ms/step\n",
      "Epoch 568/2000\n",
      "1/1 - 0s - loss: 2.0125 - val_loss: 2.0174 - 40ms/epoch - 40ms/step\n",
      "Epoch 569/2000\n",
      "1/1 - 0s - loss: 2.0124 - val_loss: 2.0173 - 39ms/epoch - 39ms/step\n",
      "Epoch 570/2000\n",
      "1/1 - 0s - loss: 2.0123 - val_loss: 2.0172 - 42ms/epoch - 42ms/step\n",
      "Epoch 571/2000\n",
      "1/1 - 0s - loss: 2.0122 - val_loss: 2.0171 - 57ms/epoch - 57ms/step\n",
      "Epoch 572/2000\n",
      "1/1 - 0s - loss: 2.0121 - val_loss: 2.0171 - 40ms/epoch - 40ms/step\n",
      "Epoch 573/2000\n",
      "1/1 - 0s - loss: 2.0120 - val_loss: 2.0170 - 40ms/epoch - 40ms/step\n",
      "Epoch 574/2000\n",
      "1/1 - 0s - loss: 2.0120 - val_loss: 2.0169 - 43ms/epoch - 43ms/step\n",
      "Epoch 575/2000\n",
      "1/1 - 0s - loss: 2.0119 - val_loss: 2.0168 - 43ms/epoch - 43ms/step\n",
      "Epoch 576/2000\n",
      "1/1 - 0s - loss: 2.0118 - val_loss: 2.0167 - 42ms/epoch - 42ms/step\n",
      "Epoch 577/2000\n",
      "1/1 - 0s - loss: 2.0117 - val_loss: 2.0167 - 39ms/epoch - 39ms/step\n",
      "Epoch 578/2000\n",
      "1/1 - 0s - loss: 2.0116 - val_loss: 2.0166 - 42ms/epoch - 42ms/step\n",
      "Epoch 579/2000\n",
      "1/1 - 0s - loss: 2.0116 - val_loss: 2.0165 - 43ms/epoch - 43ms/step\n",
      "Epoch 580/2000\n",
      "1/1 - 0s - loss: 2.0115 - val_loss: 2.0164 - 42ms/epoch - 42ms/step\n",
      "Epoch 581/2000\n",
      "1/1 - 0s - loss: 2.0114 - val_loss: 2.0164 - 39ms/epoch - 39ms/step\n",
      "Epoch 582/2000\n",
      "1/1 - 0s - loss: 2.0113 - val_loss: 2.0163 - 41ms/epoch - 41ms/step\n",
      "Epoch 583/2000\n",
      "1/1 - 0s - loss: 2.0112 - val_loss: 2.0162 - 44ms/epoch - 44ms/step\n",
      "Epoch 584/2000\n",
      "1/1 - 0s - loss: 2.0112 - val_loss: 2.0162 - 44ms/epoch - 44ms/step\n",
      "Epoch 585/2000\n",
      "1/1 - 0s - loss: 2.0111 - val_loss: 2.0161 - 40ms/epoch - 40ms/step\n",
      "Epoch 586/2000\n",
      "1/1 - 0s - loss: 2.0110 - val_loss: 2.0160 - 40ms/epoch - 40ms/step\n",
      "Epoch 587/2000\n",
      "1/1 - 0s - loss: 2.0109 - val_loss: 2.0159 - 42ms/epoch - 42ms/step\n",
      "Epoch 588/2000\n",
      "1/1 - 0s - loss: 2.0109 - val_loss: 2.0159 - 43ms/epoch - 43ms/step\n",
      "Epoch 589/2000\n",
      "1/1 - 0s - loss: 2.0108 - val_loss: 2.0158 - 40ms/epoch - 40ms/step\n",
      "Epoch 590/2000\n",
      "1/1 - 0s - loss: 2.0107 - val_loss: 2.0157 - 41ms/epoch - 41ms/step\n",
      "Epoch 591/2000\n",
      "1/1 - 0s - loss: 2.0107 - val_loss: 2.0157 - 42ms/epoch - 42ms/step\n",
      "Epoch 592/2000\n",
      "1/1 - 0s - loss: 2.0106 - val_loss: 2.0156 - 43ms/epoch - 43ms/step\n",
      "Epoch 593/2000\n",
      "1/1 - 0s - loss: 2.0105 - val_loss: 2.0155 - 41ms/epoch - 41ms/step\n",
      "Epoch 594/2000\n",
      "1/1 - 0s - loss: 2.0104 - val_loss: 2.0155 - 39ms/epoch - 39ms/step\n",
      "Epoch 595/2000\n",
      "1/1 - 0s - loss: 2.0104 - val_loss: 2.0154 - 42ms/epoch - 42ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 596/2000\n",
      "1/1 - 0s - loss: 2.0103 - val_loss: 2.0153 - 43ms/epoch - 43ms/step\n",
      "Epoch 597/2000\n",
      "1/1 - 0s - loss: 2.0102 - val_loss: 2.0153 - 42ms/epoch - 42ms/step\n",
      "Epoch 598/2000\n",
      "1/1 - 0s - loss: 2.0102 - val_loss: 2.0152 - 40ms/epoch - 40ms/step\n",
      "Epoch 599/2000\n",
      "1/1 - 0s - loss: 2.0101 - val_loss: 2.0151 - 42ms/epoch - 42ms/step\n",
      "Epoch 600/2000\n",
      "1/1 - 0s - loss: 2.0100 - val_loss: 2.0151 - 43ms/epoch - 43ms/step\n",
      "Epoch 601/2000\n",
      "1/1 - 0s - loss: 2.0099 - val_loss: 2.0150 - 41ms/epoch - 41ms/step\n",
      "Epoch 602/2000\n",
      "1/1 - 0s - loss: 2.0099 - val_loss: 2.0150 - 40ms/epoch - 40ms/step\n",
      "Epoch 603/2000\n",
      "1/1 - 0s - loss: 2.0098 - val_loss: 2.0149 - 40ms/epoch - 40ms/step\n",
      "Epoch 604/2000\n",
      "1/1 - 0s - loss: 2.0097 - val_loss: 2.0148 - 43ms/epoch - 43ms/step\n",
      "Epoch 605/2000\n",
      "1/1 - 0s - loss: 2.0097 - val_loss: 2.0148 - 42ms/epoch - 42ms/step\n",
      "Epoch 606/2000\n",
      "1/1 - 0s - loss: 2.0096 - val_loss: 2.0147 - 40ms/epoch - 40ms/step\n",
      "Epoch 607/2000\n",
      "1/1 - 0s - loss: 2.0095 - val_loss: 2.0146 - 41ms/epoch - 41ms/step\n",
      "Epoch 608/2000\n",
      "1/1 - 0s - loss: 2.0095 - val_loss: 2.0146 - 43ms/epoch - 43ms/step\n",
      "Epoch 609/2000\n",
      "1/1 - 0s - loss: 2.0094 - val_loss: 2.0145 - 41ms/epoch - 41ms/step\n",
      "Epoch 610/2000\n",
      "1/1 - 0s - loss: 2.0093 - val_loss: 2.0145 - 41ms/epoch - 41ms/step\n",
      "Epoch 611/2000\n",
      "1/1 - 0s - loss: 2.0093 - val_loss: 2.0144 - 41ms/epoch - 41ms/step\n",
      "Epoch 612/2000\n",
      "1/1 - 0s - loss: 2.0092 - val_loss: 2.0143 - 42ms/epoch - 42ms/step\n",
      "Epoch 613/2000\n",
      "1/1 - 0s - loss: 2.0091 - val_loss: 2.0143 - 42ms/epoch - 42ms/step\n",
      "Epoch 614/2000\n",
      "1/1 - 0s - loss: 2.0091 - val_loss: 2.0142 - 41ms/epoch - 41ms/step\n",
      "Epoch 615/2000\n",
      "1/1 - 0s - loss: 2.0090 - val_loss: 2.0142 - 41ms/epoch - 41ms/step\n",
      "Epoch 616/2000\n",
      "1/1 - 0s - loss: 2.0090 - val_loss: 2.0141 - 42ms/epoch - 42ms/step\n",
      "Epoch 617/2000\n",
      "1/1 - 0s - loss: 2.0089 - val_loss: 2.0141 - 43ms/epoch - 43ms/step\n",
      "Epoch 618/2000\n",
      "1/1 - 0s - loss: 2.0088 - val_loss: 2.0140 - 40ms/epoch - 40ms/step\n",
      "Epoch 619/2000\n",
      "1/1 - 0s - loss: 2.0088 - val_loss: 2.0140 - 40ms/epoch - 40ms/step\n",
      "Epoch 620/2000\n",
      "1/1 - 0s - loss: 2.0087 - val_loss: 2.0139 - 41ms/epoch - 41ms/step\n",
      "Epoch 621/2000\n",
      "1/1 - 0s - loss: 2.0086 - val_loss: 2.0138 - 42ms/epoch - 42ms/step\n",
      "Epoch 622/2000\n",
      "1/1 - 0s - loss: 2.0086 - val_loss: 2.0138 - 41ms/epoch - 41ms/step\n",
      "Epoch 623/2000\n",
      "1/1 - 0s - loss: 2.0085 - val_loss: 2.0137 - 40ms/epoch - 40ms/step\n",
      "Epoch 624/2000\n",
      "1/1 - 0s - loss: 2.0085 - val_loss: 2.0137 - 42ms/epoch - 42ms/step\n",
      "Epoch 625/2000\n",
      "1/1 - 0s - loss: 2.0084 - val_loss: 2.0136 - 42ms/epoch - 42ms/step\n",
      "Epoch 626/2000\n",
      "1/1 - 0s - loss: 2.0083 - val_loss: 2.0136 - 42ms/epoch - 42ms/step\n",
      "Epoch 627/2000\n",
      "1/1 - 0s - loss: 2.0083 - val_loss: 2.0135 - 40ms/epoch - 40ms/step\n",
      "Epoch 628/2000\n",
      "1/1 - 0s - loss: 2.0082 - val_loss: 2.0135 - 41ms/epoch - 41ms/step\n",
      "Epoch 629/2000\n",
      "1/1 - 0s - loss: 2.0082 - val_loss: 2.0134 - 42ms/epoch - 42ms/step\n",
      "Epoch 630/2000\n",
      "1/1 - 0s - loss: 2.0081 - val_loss: 2.0134 - 41ms/epoch - 41ms/step\n",
      "Epoch 631/2000\n",
      "1/1 - 0s - loss: 2.0081 - val_loss: 2.0133 - 41ms/epoch - 41ms/step\n",
      "Epoch 632/2000\n",
      "1/1 - 0s - loss: 2.0080 - val_loss: 2.0133 - 41ms/epoch - 41ms/step\n",
      "Epoch 633/2000\n",
      "1/1 - 0s - loss: 2.0080 - val_loss: 2.0133 - 41ms/epoch - 41ms/step\n",
      "Epoch 634/2000\n",
      "1/1 - 0s - loss: 2.0080 - val_loss: 2.0132 - 41ms/epoch - 41ms/step\n",
      "Epoch 635/2000\n",
      "1/1 - 0s - loss: 2.0079 - val_loss: 2.0132 - 41ms/epoch - 41ms/step\n",
      "Epoch 636/2000\n",
      "1/1 - 0s - loss: 2.0079 - val_loss: 2.0131 - 41ms/epoch - 41ms/step\n",
      "Epoch 637/2000\n",
      "1/1 - 0s - loss: 2.0078 - val_loss: 2.0130 - 43ms/epoch - 43ms/step\n",
      "Epoch 638/2000\n",
      "1/1 - 0s - loss: 2.0077 - val_loss: 2.0130 - 42ms/epoch - 42ms/step\n",
      "Epoch 639/2000\n",
      "1/1 - 0s - loss: 2.0076 - val_loss: 2.0130 - 41ms/epoch - 41ms/step\n",
      "Epoch 640/2000\n",
      "1/1 - 0s - loss: 2.0076 - val_loss: 2.0130 - 39ms/epoch - 39ms/step\n",
      "Epoch 641/2000\n",
      "1/1 - 0s - loss: 2.0076 - val_loss: 2.0130 - 43ms/epoch - 43ms/step\n",
      "Epoch 642/2000\n",
      "1/1 - 0s - loss: 2.0076 - val_loss: 2.0129 - 41ms/epoch - 41ms/step\n",
      "Epoch 643/2000\n",
      "1/1 - 0s - loss: 2.0075 - val_loss: 2.0129 - 41ms/epoch - 41ms/step\n",
      "Epoch 644/2000\n",
      "1/1 - 0s - loss: 2.0075 - val_loss: 2.0128 - 41ms/epoch - 41ms/step\n",
      "Epoch 645/2000\n",
      "1/1 - 0s - loss: 2.0074 - val_loss: 2.0127 - 41ms/epoch - 41ms/step\n",
      "Epoch 646/2000\n",
      "1/1 - 0s - loss: 2.0073 - val_loss: 2.0126 - 42ms/epoch - 42ms/step\n",
      "Epoch 647/2000\n",
      "1/1 - 0s - loss: 2.0072 - val_loss: 2.0125 - 41ms/epoch - 41ms/step\n",
      "Epoch 648/2000\n",
      "1/1 - 0s - loss: 2.0071 - val_loss: 2.0125 - 41ms/epoch - 41ms/step\n",
      "Epoch 649/2000\n",
      "1/1 - 0s - loss: 2.0070 - val_loss: 2.0125 - 42ms/epoch - 42ms/step\n",
      "Epoch 650/2000\n",
      "1/1 - 0s - loss: 2.0070 - val_loss: 2.0125 - 42ms/epoch - 42ms/step\n",
      "Epoch 651/2000\n",
      "1/1 - 0s - loss: 2.0070 - val_loss: 2.0124 - 42ms/epoch - 42ms/step\n",
      "Epoch 652/2000\n",
      "1/1 - 0s - loss: 2.0070 - val_loss: 2.0123 - 41ms/epoch - 41ms/step\n",
      "Epoch 653/2000\n",
      "1/1 - 0s - loss: 2.0069 - val_loss: 2.0123 - 41ms/epoch - 41ms/step\n",
      "Epoch 654/2000\n",
      "1/1 - 0s - loss: 2.0068 - val_loss: 2.0122 - 43ms/epoch - 43ms/step\n",
      "Epoch 655/2000\n",
      "1/1 - 0s - loss: 2.0067 - val_loss: 2.0122 - 41ms/epoch - 41ms/step\n",
      "Epoch 656/2000\n",
      "1/1 - 0s - loss: 2.0067 - val_loss: 2.0121 - 41ms/epoch - 41ms/step\n",
      "Epoch 657/2000\n",
      "1/1 - 0s - loss: 2.0066 - val_loss: 2.0121 - 41ms/epoch - 41ms/step\n",
      "Epoch 658/2000\n",
      "1/1 - 0s - loss: 2.0066 - val_loss: 2.0121 - 42ms/epoch - 42ms/step\n",
      "Epoch 659/2000\n",
      "1/1 - 0s - loss: 2.0066 - val_loss: 2.0120 - 41ms/epoch - 41ms/step\n",
      "Epoch 660/2000\n",
      "1/1 - 0s - loss: 2.0065 - val_loss: 2.0119 - 41ms/epoch - 41ms/step\n",
      "Epoch 661/2000\n",
      "1/1 - 0s - loss: 2.0064 - val_loss: 2.0119 - 41ms/epoch - 41ms/step\n",
      "Epoch 662/2000\n",
      "1/1 - 0s - loss: 2.0064 - val_loss: 2.0118 - 43ms/epoch - 43ms/step\n",
      "Epoch 663/2000\n",
      "1/1 - 0s - loss: 2.0063 - val_loss: 2.0118 - 41ms/epoch - 41ms/step\n",
      "Epoch 664/2000\n",
      "1/1 - 0s - loss: 2.0063 - val_loss: 2.0118 - 41ms/epoch - 41ms/step\n",
      "Epoch 665/2000\n",
      "1/1 - 0s - loss: 2.0062 - val_loss: 2.0117 - 41ms/epoch - 41ms/step\n",
      "Epoch 666/2000\n",
      "1/1 - 0s - loss: 2.0062 - val_loss: 2.0117 - 43ms/epoch - 43ms/step\n",
      "Epoch 667/2000\n",
      "1/1 - 0s - loss: 2.0061 - val_loss: 2.0116 - 41ms/epoch - 41ms/step\n",
      "Epoch 668/2000\n",
      "1/1 - 0s - loss: 2.0061 - val_loss: 2.0116 - 41ms/epoch - 41ms/step\n",
      "Epoch 669/2000\n",
      "1/1 - 0s - loss: 2.0060 - val_loss: 2.0115 - 41ms/epoch - 41ms/step\n",
      "Epoch 670/2000\n",
      "1/1 - 0s - loss: 2.0060 - val_loss: 2.0115 - 42ms/epoch - 42ms/step\n",
      "Epoch 671/2000\n",
      "1/1 - 0s - loss: 2.0059 - val_loss: 2.0115 - 41ms/epoch - 41ms/step\n",
      "Epoch 672/2000\n",
      "1/1 - 0s - loss: 2.0059 - val_loss: 2.0114 - 42ms/epoch - 42ms/step\n",
      "Epoch 673/2000\n",
      "1/1 - 0s - loss: 2.0058 - val_loss: 2.0114 - 41ms/epoch - 41ms/step\n",
      "Epoch 674/2000\n",
      "1/1 - 0s - loss: 2.0058 - val_loss: 2.0114 - 41ms/epoch - 41ms/step\n",
      "Epoch 675/2000\n",
      "1/1 - 0s - loss: 2.0057 - val_loss: 2.0113 - 41ms/epoch - 41ms/step\n",
      "Epoch 676/2000\n",
      "1/1 - 0s - loss: 2.0057 - val_loss: 2.0113 - 41ms/epoch - 41ms/step\n",
      "Epoch 677/2000\n",
      "1/1 - 0s - loss: 2.0056 - val_loss: 2.0112 - 41ms/epoch - 41ms/step\n",
      "Epoch 678/2000\n",
      "1/1 - 0s - loss: 2.0056 - val_loss: 2.0112 - 42ms/epoch - 42ms/step\n",
      "Epoch 679/2000\n",
      "1/1 - 0s - loss: 2.0055 - val_loss: 2.0112 - 41ms/epoch - 41ms/step\n",
      "Epoch 680/2000\n",
      "1/1 - 0s - loss: 2.0055 - val_loss: 2.0111 - 42ms/epoch - 42ms/step\n",
      "Epoch 681/2000\n",
      "1/1 - 0s - loss: 2.0054 - val_loss: 2.0111 - 41ms/epoch - 41ms/step\n",
      "Epoch 682/2000\n",
      "1/1 - 0s - loss: 2.0054 - val_loss: 2.0110 - 41ms/epoch - 41ms/step\n",
      "Epoch 683/2000\n",
      "1/1 - 0s - loss: 2.0054 - val_loss: 2.0110 - 41ms/epoch - 41ms/step\n",
      "Epoch 684/2000\n",
      "1/1 - 0s - loss: 2.0053 - val_loss: 2.0110 - 41ms/epoch - 41ms/step\n",
      "Epoch 685/2000\n",
      "1/1 - 0s - loss: 2.0053 - val_loss: 2.0109 - 42ms/epoch - 42ms/step\n",
      "Epoch 686/2000\n",
      "1/1 - 0s - loss: 2.0052 - val_loss: 2.0109 - 42ms/epoch - 42ms/step\n",
      "Epoch 687/2000\n",
      "1/1 - 0s - loss: 2.0052 - val_loss: 2.0109 - 41ms/epoch - 41ms/step\n",
      "Epoch 688/2000\n",
      "1/1 - 0s - loss: 2.0051 - val_loss: 2.0108 - 41ms/epoch - 41ms/step\n",
      "Epoch 689/2000\n",
      "1/1 - 0s - loss: 2.0051 - val_loss: 2.0108 - 42ms/epoch - 42ms/step\n",
      "Epoch 690/2000\n",
      "1/1 - 0s - loss: 2.0050 - val_loss: 2.0107 - 41ms/epoch - 41ms/step\n",
      "Epoch 691/2000\n",
      "1/1 - 0s - loss: 2.0050 - val_loss: 2.0107 - 41ms/epoch - 41ms/step\n",
      "Epoch 692/2000\n",
      "1/1 - 0s - loss: 2.0049 - val_loss: 2.0107 - 41ms/epoch - 41ms/step\n",
      "Epoch 693/2000\n",
      "1/1 - 0s - loss: 2.0049 - val_loss: 2.0106 - 42ms/epoch - 42ms/step\n",
      "Epoch 694/2000\n",
      "1/1 - 0s - loss: 2.0049 - val_loss: 2.0106 - 41ms/epoch - 41ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 695/2000\n",
      "1/1 - 0s - loss: 2.0048 - val_loss: 2.0106 - 41ms/epoch - 41ms/step\n",
      "Epoch 696/2000\n",
      "1/1 - 0s - loss: 2.0048 - val_loss: 2.0105 - 41ms/epoch - 41ms/step\n",
      "Epoch 697/2000\n",
      "1/1 - 0s - loss: 2.0047 - val_loss: 2.0105 - 43ms/epoch - 43ms/step\n",
      "Epoch 698/2000\n",
      "1/1 - 0s - loss: 2.0047 - val_loss: 2.0105 - 41ms/epoch - 41ms/step\n",
      "Epoch 699/2000\n",
      "1/1 - 0s - loss: 2.0046 - val_loss: 2.0104 - 41ms/epoch - 41ms/step\n",
      "Epoch 700/2000\n",
      "1/1 - 0s - loss: 2.0046 - val_loss: 2.0104 - 41ms/epoch - 41ms/step\n",
      "Epoch 701/2000\n",
      "1/1 - 0s - loss: 2.0046 - val_loss: 2.0104 - 43ms/epoch - 43ms/step\n",
      "Epoch 702/2000\n",
      "1/1 - 0s - loss: 2.0045 - val_loss: 2.0103 - 41ms/epoch - 41ms/step\n",
      "Epoch 703/2000\n",
      "1/1 - 0s - loss: 2.0045 - val_loss: 2.0103 - 41ms/epoch - 41ms/step\n",
      "Epoch 704/2000\n",
      "1/1 - 0s - loss: 2.0044 - val_loss: 2.0103 - 41ms/epoch - 41ms/step\n",
      "Epoch 705/2000\n",
      "1/1 - 0s - loss: 2.0044 - val_loss: 2.0103 - 43ms/epoch - 43ms/step\n",
      "Epoch 706/2000\n",
      "1/1 - 0s - loss: 2.0044 - val_loss: 2.0102 - 41ms/epoch - 41ms/step\n",
      "Epoch 707/2000\n",
      "1/1 - 0s - loss: 2.0043 - val_loss: 2.0102 - 40ms/epoch - 40ms/step\n",
      "Epoch 708/2000\n",
      "1/1 - 0s - loss: 2.0043 - val_loss: 2.0102 - 41ms/epoch - 41ms/step\n",
      "Epoch 709/2000\n",
      "1/1 - 0s - loss: 2.0043 - val_loss: 2.0102 - 43ms/epoch - 43ms/step\n",
      "Epoch 710/2000\n",
      "1/1 - 0s - loss: 2.0043 - val_loss: 2.0102 - 40ms/epoch - 40ms/step\n",
      "Epoch 711/2000\n",
      "1/1 - 0s - loss: 2.0043 - val_loss: 2.0102 - 38ms/epoch - 38ms/step\n",
      "Epoch 712/2000\n",
      "1/1 - 0s - loss: 2.0043 - val_loss: 2.0103 - 39ms/epoch - 39ms/step\n",
      "Epoch 713/2000\n",
      "1/1 - 0s - loss: 2.0043 - val_loss: 2.0103 - 40ms/epoch - 40ms/step\n",
      "Epoch 714/2000\n",
      "1/1 - 0s - loss: 2.0043 - val_loss: 2.0103 - 58ms/epoch - 58ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 10\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "position_embeddings = PositionEmbedding(sequence_length=len(x_masked_train[0]))(word_embeddings)\n",
    "encoder_output = word_embeddings + position_embeddings\n",
    "\n",
    "for i in range(1):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "aa33aaff",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "fe4d2103",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0, 0.002540875)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "0005b730",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2000\n",
      "1/1 - 4s - loss: 3.6708 - val_loss: 3.1908 - 4s/epoch - 4s/step\n",
      "Epoch 2/2000\n",
      "1/1 - 0s - loss: 3.1555 - val_loss: 3.0720 - 230ms/epoch - 230ms/step\n",
      "Epoch 3/2000\n",
      "1/1 - 0s - loss: 3.0556 - val_loss: 2.9671 - 234ms/epoch - 234ms/step\n",
      "Epoch 4/2000\n",
      "1/1 - 0s - loss: 2.9613 - val_loss: 2.9044 - 228ms/epoch - 228ms/step\n",
      "Epoch 5/2000\n",
      "1/1 - 0s - loss: 2.8997 - val_loss: 2.8759 - 230ms/epoch - 230ms/step\n",
      "Epoch 6/2000\n",
      "1/1 - 0s - loss: 2.8680 - val_loss: 2.8542 - 227ms/epoch - 227ms/step\n",
      "Epoch 7/2000\n",
      "1/1 - 0s - loss: 2.8421 - val_loss: 2.8374 - 224ms/epoch - 224ms/step\n",
      "Epoch 8/2000\n",
      "1/1 - 0s - loss: 2.8214 - val_loss: 2.8259 - 222ms/epoch - 222ms/step\n",
      "Epoch 9/2000\n",
      "1/1 - 0s - loss: 2.8076 - val_loss: 2.8151 - 227ms/epoch - 227ms/step\n",
      "Epoch 10/2000\n",
      "1/1 - 0s - loss: 2.7961 - val_loss: 2.8058 - 225ms/epoch - 225ms/step\n",
      "Epoch 11/2000\n",
      "1/1 - 0s - loss: 2.7876 - val_loss: 2.7992 - 229ms/epoch - 229ms/step\n",
      "Epoch 12/2000\n",
      "1/1 - 0s - loss: 2.7819 - val_loss: 2.7937 - 227ms/epoch - 227ms/step\n",
      "Epoch 13/2000\n",
      "1/1 - 0s - loss: 2.7767 - val_loss: 2.7892 - 223ms/epoch - 223ms/step\n",
      "Epoch 14/2000\n",
      "1/1 - 0s - loss: 2.7717 - val_loss: 2.7843 - 229ms/epoch - 229ms/step\n",
      "Epoch 15/2000\n",
      "1/1 - 0s - loss: 2.7660 - val_loss: 2.7773 - 228ms/epoch - 228ms/step\n",
      "Epoch 16/2000\n",
      "1/1 - 0s - loss: 2.7582 - val_loss: 2.7680 - 234ms/epoch - 234ms/step\n",
      "Epoch 17/2000\n",
      "1/1 - 0s - loss: 2.7486 - val_loss: 2.7577 - 229ms/epoch - 229ms/step\n",
      "Epoch 18/2000\n",
      "1/1 - 0s - loss: 2.7387 - val_loss: 2.7486 - 226ms/epoch - 226ms/step\n",
      "Epoch 19/2000\n",
      "1/1 - 0s - loss: 2.7306 - val_loss: 2.7409 - 225ms/epoch - 225ms/step\n",
      "Epoch 20/2000\n",
      "1/1 - 0s - loss: 2.7238 - val_loss: 2.7317 - 225ms/epoch - 225ms/step\n",
      "Epoch 21/2000\n",
      "1/1 - 0s - loss: 2.7150 - val_loss: 2.7191 - 229ms/epoch - 229ms/step\n",
      "Epoch 22/2000\n",
      "1/1 - 0s - loss: 2.7019 - val_loss: 2.7039 - 233ms/epoch - 233ms/step\n",
      "Epoch 23/2000\n",
      "1/1 - 0s - loss: 2.6855 - val_loss: 2.6852 - 227ms/epoch - 227ms/step\n",
      "Epoch 24/2000\n",
      "1/1 - 0s - loss: 2.6652 - val_loss: 2.6595 - 228ms/epoch - 228ms/step\n",
      "Epoch 25/2000\n",
      "1/1 - 0s - loss: 2.6380 - val_loss: 2.6244 - 226ms/epoch - 226ms/step\n",
      "Epoch 26/2000\n",
      "1/1 - 0s - loss: 2.6015 - val_loss: 2.5788 - 227ms/epoch - 227ms/step\n",
      "Epoch 27/2000\n",
      "1/1 - 0s - loss: 2.5539 - val_loss: 2.5162 - 227ms/epoch - 227ms/step\n",
      "Epoch 28/2000\n",
      "1/1 - 0s - loss: 2.4888 - val_loss: 2.4348 - 235ms/epoch - 235ms/step\n",
      "Epoch 29/2000\n",
      "1/1 - 0s - loss: 2.4049 - val_loss: 2.3402 - 226ms/epoch - 226ms/step\n",
      "Epoch 30/2000\n",
      "1/1 - 0s - loss: 2.3091 - val_loss: 2.2446 - 226ms/epoch - 226ms/step\n",
      "Epoch 31/2000\n",
      "1/1 - 0s - loss: 2.2137 - val_loss: 2.1667 - 227ms/epoch - 227ms/step\n",
      "Epoch 32/2000\n",
      "1/1 - 0s - loss: 2.1367 - val_loss: 2.1173 - 228ms/epoch - 228ms/step\n",
      "Epoch 33/2000\n",
      "1/1 - 0s - loss: 2.0881 - val_loss: 2.0932 - 224ms/epoch - 224ms/step\n",
      "Epoch 34/2000\n",
      "1/1 - 0s - loss: 2.0639 - val_loss: 2.0787 - 227ms/epoch - 227ms/step\n",
      "Epoch 35/2000\n",
      "1/1 - 0s - loss: 2.0482 - val_loss: 2.0671 - 232ms/epoch - 232ms/step\n",
      "Epoch 36/2000\n",
      "1/1 - 0s - loss: 2.0357 - val_loss: 2.0574 - 223ms/epoch - 223ms/step\n",
      "Epoch 37/2000\n",
      "1/1 - 0s - loss: 2.0253 - val_loss: 2.0497 - 234ms/epoch - 234ms/step\n",
      "Epoch 38/2000\n",
      "1/1 - 0s - loss: 2.0175 - val_loss: 2.0442 - 221ms/epoch - 221ms/step\n",
      "Epoch 39/2000\n",
      "1/1 - 0s - loss: 2.0120 - val_loss: 2.0406 - 220ms/epoch - 220ms/step\n",
      "Epoch 40/2000\n",
      "1/1 - 0s - loss: 2.0084 - val_loss: 2.0377 - 227ms/epoch - 227ms/step\n",
      "Epoch 41/2000\n",
      "1/1 - 0s - loss: 2.0057 - val_loss: 2.0343 - 228ms/epoch - 228ms/step\n",
      "Epoch 42/2000\n",
      "1/1 - 0s - loss: 2.0027 - val_loss: 2.0304 - 223ms/epoch - 223ms/step\n",
      "Epoch 43/2000\n",
      "1/1 - 0s - loss: 1.9991 - val_loss: 2.0268 - 228ms/epoch - 228ms/step\n",
      "Epoch 44/2000\n",
      "1/1 - 0s - loss: 1.9956 - val_loss: 2.0244 - 234ms/epoch - 234ms/step\n",
      "Epoch 45/2000\n",
      "1/1 - 0s - loss: 1.9931 - val_loss: 2.0231 - 224ms/epoch - 224ms/step\n",
      "Epoch 46/2000\n",
      "1/1 - 0s - loss: 1.9915 - val_loss: 2.0224 - 233ms/epoch - 233ms/step\n",
      "Epoch 47/2000\n",
      "1/1 - 0s - loss: 1.9904 - val_loss: 2.0215 - 222ms/epoch - 222ms/step\n",
      "Epoch 48/2000\n",
      "1/1 - 0s - loss: 1.9892 - val_loss: 2.0202 - 220ms/epoch - 220ms/step\n",
      "Epoch 49/2000\n",
      "1/1 - 0s - loss: 1.9877 - val_loss: 2.0186 - 227ms/epoch - 227ms/step\n",
      "Epoch 50/2000\n",
      "1/1 - 0s - loss: 1.9859 - val_loss: 2.0172 - 226ms/epoch - 226ms/step\n",
      "Epoch 51/2000\n",
      "1/1 - 0s - loss: 1.9843 - val_loss: 2.0163 - 229ms/epoch - 229ms/step\n",
      "Epoch 52/2000\n",
      "1/1 - 0s - loss: 1.9830 - val_loss: 2.0159 - 229ms/epoch - 229ms/step\n",
      "Epoch 53/2000\n",
      "1/1 - 0s - loss: 1.9821 - val_loss: 2.0160 - 221ms/epoch - 221ms/step\n",
      "Epoch 54/2000\n",
      "1/1 - 0s - loss: 1.9812 - val_loss: 2.0162 - 211ms/epoch - 211ms/step\n",
      "Epoch 55/2000\n",
      "1/1 - 0s - loss: 1.9803 - val_loss: 2.0164 - 214ms/epoch - 214ms/step\n",
      "Epoch 56/2000\n",
      "1/1 - 0s - loss: 1.9793 - val_loss: 2.0168 - 211ms/epoch - 211ms/step\n",
      "Epoch 57/2000\n",
      "1/1 - 0s - loss: 1.9781 - val_loss: 2.0172 - 238ms/epoch - 238ms/step\n"
     ]
    }
   ],
   "source": [
    "embed_dim = 100\n",
    "num_head = 2\n",
    "\n",
    "callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "word_embeddings = layers.Embedding(N, embed_dim, name=\"word_embedding\")(inputs)\n",
    "position_embeddings = PositionEmbedding(sequence_length=len(x_masked_train[0]))(word_embeddings)\n",
    "encoder_output = word_embeddings + position_embeddings\n",
    "\n",
    "for i in range(1):\n",
    "    encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "mlm_output = layers.Dense(N, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "adam = Adam()\n",
    "mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=2000, batch_size=5000, \n",
    "                        verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "beb07cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = []\n",
    "prob = []\n",
    "x_test_subset = x_masked_test[np.random.choice(x_masked_test.shape[0], size=1000, replace=False)]\n",
    "\n",
    "for sentence_number in x_test_subset:\n",
    "    temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "        (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "    temp = temp[:,-1,:]\n",
    "    acc.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "    prob.append(temp[0][sentence_number[-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "dbdd4a69",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0, 0.0052961786)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.mean(acc), np.mean(prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "54f31bac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[10,  0,  1, 10],\n",
       "       [10,  0,  2, 10],\n",
       "       [10,  0,  3, 10],\n",
       "       ...,\n",
       "       [19, 18, 15, 19],\n",
       "       [19, 18, 16, 19],\n",
       "       [19, 18, 17, 19]])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_masked_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "458ae205",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
