{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "539ebd7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "###This code is about develop a RNN model for the target CSTR process using heterogeneous transfer learning\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import preprocessing\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import mean_absolute_percentage_error\n",
    "from keras.models import Sequential\n",
    "from keras.layers import LSTM, Dense, SimpleRNN, Input, Activation, Dropout\n",
    "from keras import backend as K\n",
    "from tensorflow.keras.optimizers import Adam,SGD\n",
    "import tensorflow as tf\n",
    "from keras.models import Model\n",
    "from keras.models import load_model\n",
    "import time "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "069e2e34",
   "metadata": {},
   "source": [
    "# Generate Target dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3a4a2d8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# specifying constant parameters\n",
    "#input varioable: Q, C_A0\n",
    "#state variable: T, C_A\n",
    "\n",
    "#parameters\n",
    "T01 = 350\n",
    "CA0 = 5\n",
    "CB0 = 4\n",
    "F1 = 48.0133\n",
    "\n",
    "R = 8.314\n",
    "V = 60\n",
    "Fo1 = 43.34\n",
    "\n",
    "k0 = 1.528e+6\n",
    "E = 71160\n",
    "delta_H = -1.06e+5\n",
    "\n",
    "Pl1 = 639.15\n",
    "Cp = 2.5\n",
    "Cp_modified = 1.55\n",
    "\n",
    "\n",
    "#steady states\n",
    "T1_s = 310.5 #516508  # the steady state for state variable T\n",
    "CA1_s = 5.35319 #46083 # the steady state for state variable C_A\n",
    "CB1_s = 4.24528 #46083 # the steady state for state variable C_B\n",
    "Q1_s = -3279600  # the steady state for input variable Q\n",
    "\n",
    "\n",
    "t_final = 0.05  #the control period\n",
    "t_step = 0.01   # the step to use first-principle to calculate the state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d8e4402d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10,)\n",
      "shape of x_deviation is (1000, 3)\n"
     ]
    }
   ],
   "source": [
    "#The data size for new-cstr should be consensus with the basic-one\n",
    "# generating inputs and initial states for CSTR, all expressed in deviation form\n",
    "\n",
    "#real value\n",
    "Q1_list = np.linspace(-4320000, -2880000, 10, endpoint=True)\n",
    "#the initial states are in real value\n",
    "T1_initial = np.linspace(-10, 10, 10, endpoint=True) + T1_s\n",
    "CA1_initial = np.linspace(-0.2, 0.2, 10, endpoint=True) + CA1_s\n",
    "CB1_initial = np.linspace(-0.2, 0.2, 10, endpoint=True) + CB1_s\n",
    "\n",
    "print(CA1_initial.shape)\n",
    "#control variable: Q1, Q2\n",
    "#state variable: CA1,CA2,CB1,CB2,T1,T2\n",
    "\n",
    "# sieve out initial states that lie outside of stability region\n",
    "\n",
    "CA1_start = list()\n",
    "CB1_start = list()\n",
    "T1_start = list()\n",
    "\n",
    "for T1 in T1_initial:   \n",
    "    for CA1 in CA1_initial:\n",
    "        for CB1 in CB1_initial:\n",
    "            CA1_start.append(CA1)\n",
    "            CB1_start.append(CB1)\n",
    "            T1_start.append(T1)\n",
    "           \n",
    "\n",
    "# convert to np.arrays\n",
    "CA1_start = np.array([CA1_start])\n",
    "CB1_start = np.array([CB1_start])\n",
    "T1_start = np.array([T1_start])\n",
    "\n",
    "\n",
    "x_deviation = np.concatenate((CA1_start.T, CB1_start.T,T1_start.T), axis=1)  # every row is a pair of initial states within stability region\n",
    "print(\"shape of x_deviation is {}\".format(x_deviation.shape))\n",
    "# print(x_deviation.shape)  # the initial state is in"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "16075e8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def CSTR_simu(Q1, t_final, t_step, CA1_initial, CB1_initial,T1_initial):\n",
    "    \"\"\"\n",
    "        simulating CSTR using forward Euler method\n",
    "    \"\"\"\n",
    "    CA1_list = list()  # evolution of CA over time\n",
    "    CB1_list = list()  # evolution of CA over time\n",
    "    T1_list = list()  # evolution of T over time\n",
    "        \n",
    "    CA1 = CA1_initial #+ CA1_s  # the real state.the derivation plus the steady state\n",
    "    CB1 = CB1_initial #+ CB1_s\n",
    "    T1 = T1_initial #+ T1_s\n",
    "\n",
    "    \n",
    "    for i in range(int(t_final / t_step)):\n",
    "        par_reaction1 = k0 * np.exp(-E/(R*T1))\n",
    "\n",
    "        dCA1 = (F1*CA0 - Fo1 *CA1)/V - par_reaction1 * CA1 * CB1\n",
    "        dCB1 = (F1*CB0 - Fo1 *CB1)/V - par_reaction1 * CA1 * CB1\n",
    "        dT1 = (F1*T01 - Fo1 *T1)/V - delta_H/(Pl1 * Cp_modified) * par_reaction1 * CA1 * CB1 + Q1/(Pl1 * Cp_modified * V)\n",
    "\n",
    "        CA1 += dCA1 * t_step\n",
    "        CB1 += dCB1 * t_step\n",
    "        T1 += dT1 * t_step\n",
    "        \n",
    "        \n",
    "        #if i%5 ==0:\n",
    "        CA1_list.append(CA1)  # in real value form\n",
    "        CB1_list.append(CB1) \n",
    "        T1_list.append(T1)  \n",
    "    \n",
    "    return CA1_list, CB1_list, T1_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b400bf04",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get X and y data for training and testing\n",
    "\n",
    "CA1_output = list()\n",
    "CB1_output = list()\n",
    "T1_output = list()\n",
    "\n",
    "\n",
    "Q1_input = list()   #input variable for \n",
    "CA1_input = list()\n",
    "CB1_input = list()\n",
    "T1_input = list()\n",
    "\n",
    "\n",
    "for u1 in Q1_list:\n",
    "    Q1 = u1\n",
    "    for CA1_initial, CB1_initial,T1_initial in x_deviation:\n",
    "        Q1_input.append(Q1)\n",
    "        CA1_input.append(CA1_initial)\n",
    "        CB1_input.append(CB1_initial)\n",
    "        T1_input.append(T1_initial)\n",
    "       \n",
    "        CA1_list, CB1_list, T1_list= CSTR_simu(Q1, t_final, t_step, CA1_initial, CB1_initial,T1_initial)\n",
    "\n",
    "        CA1_output.append(CA1_list)\n",
    "        CB1_output.append(CB1_list)\n",
    "        T1_output.append(T1_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d18fd7d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RNN_input shape is (10000, 1, 4)\n",
      "RNN_input shape is (10000, 5, 4)\n"
     ]
    }
   ],
   "source": [
    "# collate input for RNN\n",
    "\n",
    "Q1_input = np.array(Q1_input)\n",
    "Q1_input = Q1_input.reshape(-1,1,1)\n",
    "\n",
    "CA1_input = np.array(CA1_input)\n",
    "CA1_input = CA1_input.reshape(-1,1,1)\n",
    "\n",
    "CB1_input = np.array(CB1_input)\n",
    "CB1_input = CB1_input.reshape(-1,1,1)\n",
    "\n",
    "T1_input = np.array(T1_input)\n",
    "T1_input = T1_input.reshape(-1,1,1)\n",
    "\n",
    "\n",
    "RNN_input = np.concatenate((Q1_input, T1_input, CA1_input, CB1_input), axis=2)   #the value for input variable and the initial value for state variable \n",
    "\n",
    "\"\"\"\n",
    "    the input to RNN is in the shape [number of samples x timestep x variables], and the input variables are same for every\n",
    "    time step, not sure if my treatment here is correct\n",
    "\"\"\"\n",
    "print(\"RNN_input shape is {}\".format(RNN_input.shape))\n",
    "RNN_input = RNN_input.repeat(5, axis=1)  # to keep consensus with the shape for RNN_output, since the output variable is collected 100(0.01/1e-4) times for each RNN_input\n",
    "print(\"RNN_input shape is {}\".format(RNN_input.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cdf611b1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RNN_output shape is (10000, 5, 3)\n"
     ]
    }
   ],
   "source": [
    "# collate output for RNN\n",
    "\n",
    "CA1_output = np.array(CA1_output)\n",
    "CA1_output = CA1_output.reshape(-1, 5, 1)\n",
    "\n",
    "\n",
    "CB1_output = np.array(CB1_output)\n",
    "CB1_output = CB1_output.reshape(-1, 5, 1)\n",
    "\n",
    "\n",
    "T1_output = np.array(T1_output)\n",
    "T1_output = T1_output.reshape(-1, 5, 1)\n",
    "\n",
    "\n",
    "RNN_output = np.concatenate((T1_output, CA1_output, CB1_output), axis=2)\n",
    "print(\"RNN_output shape is {}\".format(RNN_output.shape))  # output shape: number of samples x timestep x variables"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa5fda0a",
   "metadata": {},
   "source": [
    "# normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6b8005e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(500, 5, 4) (9500, 5, 4) (500, 5, 3) (9500, 5, 3)\n"
     ]
    }
   ],
   "source": [
    "# split source dataset into train and test sets\n",
    "X_train, X_test, y_train, y_test = train_test_split(RNN_input, RNN_output, test_size=0.95, random_state=123)\n",
    "\n",
    "# define scalers for both X and y base on training data only\n",
    "scaler_X = preprocessing.StandardScaler().fit(X_train.reshape(-1, 4))\n",
    "scaler_y = preprocessing.StandardScaler().fit(y_train.reshape(-1, 3))\n",
    "\n",
    "#source data\n",
    "X_train = scaler_X.transform(X_train.reshape(-1, 4)).reshape(-1,5,4)\n",
    "X_test = scaler_X.transform(X_test.reshape(-1, 4)).reshape(-1,5,4)\n",
    "y_train = scaler_y.transform(y_train.reshape(-1,3)).reshape(-1,5,3)\n",
    "y_test = scaler_y.transform(y_test.reshape(-1,3)).reshape(-1,5,3)\n",
    "\n",
    "print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1f81109",
   "metadata": {},
   "source": [
    "# Develop the benchmark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "25925acb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/500\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-06-30 15:42:50.435678: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n",
      "2025-06-30 15:42:50.435717: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (wulab2-System-Product-Name): /proc/driver/nvidia/version does not exist\n",
      "2025-06-30 15:42:50.436156: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2/2 - 0s - loss: 1.4347 - mse: 1.4347 - val_loss: 1.4496 - val_mse: 1.4496 - 498ms/epoch - 249ms/step\n",
      "Epoch 2/500\n",
      "2/2 - 0s - loss: 1.3697 - mse: 1.3697 - val_loss: 1.3858 - val_mse: 1.3858 - 15ms/epoch - 7ms/step\n",
      "Epoch 3/500\n",
      "2/2 - 0s - loss: 1.3076 - mse: 1.3076 - val_loss: 1.3239 - val_mse: 1.3239 - 15ms/epoch - 8ms/step\n",
      "Epoch 4/500\n",
      "2/2 - 0s - loss: 1.2460 - mse: 1.2460 - val_loss: 1.2638 - val_mse: 1.2638 - 18ms/epoch - 9ms/step\n",
      "Epoch 5/500\n",
      "2/2 - 0s - loss: 1.1874 - mse: 1.1874 - val_loss: 1.2055 - val_mse: 1.2055 - 16ms/epoch - 8ms/step\n",
      "Epoch 6/500\n",
      "2/2 - 0s - loss: 1.1301 - mse: 1.1301 - val_loss: 1.1490 - val_mse: 1.1490 - 18ms/epoch - 9ms/step\n",
      "Epoch 7/500\n",
      "2/2 - 0s - loss: 1.0751 - mse: 1.0751 - val_loss: 1.0942 - val_mse: 1.0942 - 16ms/epoch - 8ms/step\n",
      "Epoch 8/500\n",
      "2/2 - 0s - loss: 1.0215 - mse: 1.0215 - val_loss: 1.0413 - val_mse: 1.0413 - 21ms/epoch - 10ms/step\n",
      "Epoch 9/500\n",
      "2/2 - 0s - loss: 0.9701 - mse: 0.9701 - val_loss: 0.9903 - val_mse: 0.9903 - 19ms/epoch - 9ms/step\n",
      "Epoch 10/500\n",
      "2/2 - 0s - loss: 0.9211 - mse: 0.9211 - val_loss: 0.9412 - val_mse: 0.9412 - 19ms/epoch - 9ms/step\n",
      "Epoch 11/500\n",
      "2/2 - 0s - loss: 0.8730 - mse: 0.8730 - val_loss: 0.8940 - val_mse: 0.8940 - 19ms/epoch - 9ms/step\n",
      "Epoch 12/500\n",
      "2/2 - 0s - loss: 0.8282 - mse: 0.8282 - val_loss: 0.8485 - val_mse: 0.8485 - 16ms/epoch - 8ms/step\n",
      "Epoch 13/500\n",
      "2/2 - 0s - loss: 0.7847 - mse: 0.7847 - val_loss: 0.8051 - val_mse: 0.8051 - 18ms/epoch - 9ms/step\n",
      "Epoch 14/500\n",
      "2/2 - 0s - loss: 0.7437 - mse: 0.7437 - val_loss: 0.7637 - val_mse: 0.7637 - 15ms/epoch - 7ms/step\n",
      "Epoch 15/500\n",
      "2/2 - 0s - loss: 0.7045 - mse: 0.7045 - val_loss: 0.7242 - val_mse: 0.7242 - 17ms/epoch - 9ms/step\n",
      "Epoch 16/500\n",
      "2/2 - 0s - loss: 0.6668 - mse: 0.6668 - val_loss: 0.6866 - val_mse: 0.6866 - 22ms/epoch - 11ms/step\n",
      "Epoch 17/500\n",
      "2/2 - 0s - loss: 0.6313 - mse: 0.6313 - val_loss: 0.6509 - val_mse: 0.6509 - 18ms/epoch - 9ms/step\n",
      "Epoch 18/500\n",
      "2/2 - 0s - loss: 0.5973 - mse: 0.5973 - val_loss: 0.6169 - val_mse: 0.6169 - 19ms/epoch - 9ms/step\n",
      "Epoch 19/500\n",
      "2/2 - 0s - loss: 0.5653 - mse: 0.5653 - val_loss: 0.5845 - val_mse: 0.5845 - 15ms/epoch - 8ms/step\n",
      "Epoch 20/500\n",
      "2/2 - 0s - loss: 0.5358 - mse: 0.5358 - val_loss: 0.5538 - val_mse: 0.5538 - 18ms/epoch - 9ms/step\n",
      "Epoch 21/500\n",
      "2/2 - 0s - loss: 0.5071 - mse: 0.5071 - val_loss: 0.5248 - val_mse: 0.5248 - 15ms/epoch - 7ms/step\n",
      "Epoch 22/500\n",
      "2/2 - 0s - loss: 0.4802 - mse: 0.4802 - val_loss: 0.4975 - val_mse: 0.4975 - 19ms/epoch - 9ms/step\n",
      "Epoch 23/500\n",
      "2/2 - 0s - loss: 0.4547 - mse: 0.4547 - val_loss: 0.4718 - val_mse: 0.4718 - 22ms/epoch - 11ms/step\n",
      "Epoch 24/500\n",
      "2/2 - 0s - loss: 0.4310 - mse: 0.4310 - val_loss: 0.4476 - val_mse: 0.4476 - 18ms/epoch - 9ms/step\n",
      "Epoch 25/500\n",
      "2/2 - 0s - loss: 0.4084 - mse: 0.4084 - val_loss: 0.4249 - val_mse: 0.4249 - 19ms/epoch - 10ms/step\n",
      "Epoch 26/500\n",
      "2/2 - 0s - loss: 0.3876 - mse: 0.3876 - val_loss: 0.4035 - val_mse: 0.4035 - 21ms/epoch - 11ms/step\n",
      "Epoch 27/500\n",
      "2/2 - 0s - loss: 0.3677 - mse: 0.3677 - val_loss: 0.3835 - val_mse: 0.3835 - 19ms/epoch - 10ms/step\n",
      "Epoch 28/500\n",
      "2/2 - 0s - loss: 0.3494 - mse: 0.3494 - val_loss: 0.3647 - val_mse: 0.3647 - 18ms/epoch - 9ms/step\n",
      "Epoch 29/500\n",
      "2/2 - 0s - loss: 0.3321 - mse: 0.3321 - val_loss: 0.3471 - val_mse: 0.3471 - 20ms/epoch - 10ms/step\n",
      "Epoch 30/500\n",
      "2/2 - 0s - loss: 0.3158 - mse: 0.3158 - val_loss: 0.3306 - val_mse: 0.3306 - 25ms/epoch - 13ms/step\n",
      "Epoch 31/500\n",
      "2/2 - 0s - loss: 0.3009 - mse: 0.3009 - val_loss: 0.3152 - val_mse: 0.3152 - 17ms/epoch - 8ms/step\n",
      "Epoch 32/500\n",
      "2/2 - 0s - loss: 0.2867 - mse: 0.2867 - val_loss: 0.3010 - val_mse: 0.3010 - 18ms/epoch - 9ms/step\n",
      "Epoch 33/500\n",
      "2/2 - 0s - loss: 0.2735 - mse: 0.2735 - val_loss: 0.2876 - val_mse: 0.2876 - 15ms/epoch - 8ms/step\n",
      "Epoch 34/500\n",
      "2/2 - 0s - loss: 0.2615 - mse: 0.2615 - val_loss: 0.2751 - val_mse: 0.2751 - 17ms/epoch - 9ms/step\n",
      "Epoch 35/500\n",
      "2/2 - 0s - loss: 0.2500 - mse: 0.2500 - val_loss: 0.2635 - val_mse: 0.2635 - 15ms/epoch - 7ms/step\n",
      "Epoch 36/500\n",
      "2/2 - 0s - loss: 0.2394 - mse: 0.2394 - val_loss: 0.2526 - val_mse: 0.2526 - 15ms/epoch - 8ms/step\n",
      "Epoch 37/500\n",
      "2/2 - 0s - loss: 0.2297 - mse: 0.2297 - val_loss: 0.2425 - val_mse: 0.2425 - 15ms/epoch - 7ms/step\n",
      "Epoch 38/500\n",
      "2/2 - 0s - loss: 0.2204 - mse: 0.2204 - val_loss: 0.2331 - val_mse: 0.2331 - 16ms/epoch - 8ms/step\n",
      "Epoch 39/500\n",
      "2/2 - 0s - loss: 0.2119 - mse: 0.2119 - val_loss: 0.2243 - val_mse: 0.2243 - 16ms/epoch - 8ms/step\n",
      "Epoch 40/500\n",
      "2/2 - 0s - loss: 0.2040 - mse: 0.2040 - val_loss: 0.2161 - val_mse: 0.2161 - 17ms/epoch - 8ms/step\n",
      "Epoch 41/500\n",
      "2/2 - 0s - loss: 0.1965 - mse: 0.1965 - val_loss: 0.2085 - val_mse: 0.2085 - 17ms/epoch - 9ms/step\n",
      "Epoch 42/500\n",
      "2/2 - 0s - loss: 0.1897 - mse: 0.1897 - val_loss: 0.2014 - val_mse: 0.2014 - 15ms/epoch - 7ms/step\n",
      "Epoch 43/500\n",
      "2/2 - 0s - loss: 0.1832 - mse: 0.1832 - val_loss: 0.1948 - val_mse: 0.1948 - 16ms/epoch - 8ms/step\n",
      "Epoch 44/500\n",
      "2/2 - 0s - loss: 0.1771 - mse: 0.1771 - val_loss: 0.1885 - val_mse: 0.1885 - 14ms/epoch - 7ms/step\n",
      "Epoch 45/500\n",
      "2/2 - 0s - loss: 0.1715 - mse: 0.1715 - val_loss: 0.1826 - val_mse: 0.1826 - 16ms/epoch - 8ms/step\n",
      "Epoch 46/500\n",
      "2/2 - 0s - loss: 0.1662 - mse: 0.1662 - val_loss: 0.1771 - val_mse: 0.1771 - 17ms/epoch - 9ms/step\n",
      "Epoch 47/500\n",
      "2/2 - 0s - loss: 0.1612 - mse: 0.1612 - val_loss: 0.1719 - val_mse: 0.1719 - 16ms/epoch - 8ms/step\n",
      "Epoch 48/500\n",
      "2/2 - 0s - loss: 0.1566 - mse: 0.1566 - val_loss: 0.1670 - val_mse: 0.1670 - 17ms/epoch - 9ms/step\n",
      "Epoch 49/500\n",
      "2/2 - 0s - loss: 0.1521 - mse: 0.1521 - val_loss: 0.1623 - val_mse: 0.1623 - 15ms/epoch - 8ms/step\n",
      "Epoch 50/500\n",
      "2/2 - 0s - loss: 0.1480 - mse: 0.1480 - val_loss: 0.1579 - val_mse: 0.1579 - 18ms/epoch - 9ms/step\n",
      "Epoch 51/500\n",
      "2/2 - 0s - loss: 0.1441 - mse: 0.1441 - val_loss: 0.1537 - val_mse: 0.1537 - 21ms/epoch - 11ms/step\n",
      "Epoch 52/500\n",
      "2/2 - 0s - loss: 0.1403 - mse: 0.1403 - val_loss: 0.1498 - val_mse: 0.1498 - 16ms/epoch - 8ms/step\n",
      "Epoch 53/500\n",
      "2/2 - 0s - loss: 0.1368 - mse: 0.1368 - val_loss: 0.1461 - val_mse: 0.1461 - 18ms/epoch - 9ms/step\n",
      "Epoch 54/500\n",
      "2/2 - 0s - loss: 0.1335 - mse: 0.1335 - val_loss: 0.1425 - val_mse: 0.1425 - 18ms/epoch - 9ms/step\n",
      "Epoch 55/500\n",
      "2/2 - 0s - loss: 0.1303 - mse: 0.1303 - val_loss: 0.1391 - val_mse: 0.1391 - 19ms/epoch - 9ms/step\n",
      "Epoch 56/500\n",
      "2/2 - 0s - loss: 0.1272 - mse: 0.1272 - val_loss: 0.1359 - val_mse: 0.1359 - 22ms/epoch - 11ms/step\n",
      "Epoch 57/500\n",
      "2/2 - 0s - loss: 0.1244 - mse: 0.1244 - val_loss: 0.1328 - val_mse: 0.1328 - 19ms/epoch - 9ms/step\n",
      "Epoch 58/500\n",
      "2/2 - 0s - loss: 0.1216 - mse: 0.1216 - val_loss: 0.1299 - val_mse: 0.1299 - 19ms/epoch - 9ms/step\n",
      "Epoch 59/500\n",
      "2/2 - 0s - loss: 0.1190 - mse: 0.1190 - val_loss: 0.1271 - val_mse: 0.1271 - 18ms/epoch - 9ms/step\n",
      "Epoch 60/500\n",
      "2/2 - 0s - loss: 0.1164 - mse: 0.1164 - val_loss: 0.1244 - val_mse: 0.1244 - 24ms/epoch - 12ms/step\n",
      "Epoch 61/500\n",
      "2/2 - 0s - loss: 0.1140 - mse: 0.1140 - val_loss: 0.1218 - val_mse: 0.1218 - 18ms/epoch - 9ms/step\n",
      "Epoch 62/500\n",
      "2/2 - 0s - loss: 0.1117 - mse: 0.1117 - val_loss: 0.1193 - val_mse: 0.1193 - 17ms/epoch - 9ms/step\n",
      "Epoch 63/500\n",
      "2/2 - 0s - loss: 0.1094 - mse: 0.1094 - val_loss: 0.1169 - val_mse: 0.1169 - 14ms/epoch - 7ms/step\n",
      "Epoch 64/500\n",
      "2/2 - 0s - loss: 0.1073 - mse: 0.1073 - val_loss: 0.1146 - val_mse: 0.1146 - 17ms/epoch - 9ms/step\n",
      "Epoch 65/500\n",
      "2/2 - 0s - loss: 0.1052 - mse: 0.1052 - val_loss: 0.1124 - val_mse: 0.1124 - 12ms/epoch - 6ms/step\n",
      "Epoch 66/500\n",
      "2/2 - 0s - loss: 0.1032 - mse: 0.1032 - val_loss: 0.1102 - val_mse: 0.1102 - 15ms/epoch - 7ms/step\n",
      "Epoch 67/500\n",
      "2/2 - 0s - loss: 0.1012 - mse: 0.1012 - val_loss: 0.1082 - val_mse: 0.1082 - 14ms/epoch - 7ms/step\n",
      "Epoch 68/500\n",
      "2/2 - 0s - loss: 0.0994 - mse: 0.0994 - val_loss: 0.1061 - val_mse: 0.1061 - 17ms/epoch - 9ms/step\n",
      "Epoch 69/500\n",
      "2/2 - 0s - loss: 0.0975 - mse: 0.0975 - val_loss: 0.1042 - val_mse: 0.1042 - 17ms/epoch - 8ms/step\n",
      "Epoch 70/500\n",
      "2/2 - 0s - loss: 0.0957 - mse: 0.0957 - val_loss: 0.1023 - val_mse: 0.1023 - 16ms/epoch - 8ms/step\n",
      "Epoch 71/500\n",
      "2/2 - 0s - loss: 0.0941 - mse: 0.0941 - val_loss: 0.1005 - val_mse: 0.1005 - 16ms/epoch - 8ms/step\n",
      "Epoch 72/500\n",
      "2/2 - 0s - loss: 0.0924 - mse: 0.0924 - val_loss: 0.0987 - val_mse: 0.0987 - 18ms/epoch - 9ms/step\n",
      "Epoch 73/500\n",
      "2/2 - 0s - loss: 0.0908 - mse: 0.0908 - val_loss: 0.0970 - val_mse: 0.0970 - 15ms/epoch - 7ms/step\n",
      "Epoch 74/500\n",
      "2/2 - 0s - loss: 0.0892 - mse: 0.0892 - val_loss: 0.0953 - val_mse: 0.0953 - 16ms/epoch - 8ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 75/500\n",
      "2/2 - 0s - loss: 0.0877 - mse: 0.0877 - val_loss: 0.0937 - val_mse: 0.0937 - 14ms/epoch - 7ms/step\n",
      "Epoch 76/500\n",
      "2/2 - 0s - loss: 0.0862 - mse: 0.0862 - val_loss: 0.0921 - val_mse: 0.0921 - 18ms/epoch - 9ms/step\n",
      "Epoch 77/500\n",
      "2/2 - 0s - loss: 0.0848 - mse: 0.0848 - val_loss: 0.0906 - val_mse: 0.0906 - 18ms/epoch - 9ms/step\n",
      "Epoch 78/500\n",
      "2/2 - 0s - loss: 0.0833 - mse: 0.0833 - val_loss: 0.0891 - val_mse: 0.0891 - 22ms/epoch - 11ms/step\n",
      "Epoch 79/500\n",
      "2/2 - 0s - loss: 0.0820 - mse: 0.0820 - val_loss: 0.0876 - val_mse: 0.0876 - 18ms/epoch - 9ms/step\n",
      "Epoch 80/500\n",
      "2/2 - 0s - loss: 0.0806 - mse: 0.0806 - val_loss: 0.0862 - val_mse: 0.0862 - 17ms/epoch - 9ms/step\n",
      "Epoch 81/500\n",
      "2/2 - 0s - loss: 0.0794 - mse: 0.0794 - val_loss: 0.0848 - val_mse: 0.0848 - 16ms/epoch - 8ms/step\n",
      "Epoch 82/500\n",
      "2/2 - 0s - loss: 0.0781 - mse: 0.0781 - val_loss: 0.0835 - val_mse: 0.0835 - 16ms/epoch - 8ms/step\n",
      "Epoch 83/500\n",
      "2/2 - 0s - loss: 0.0768 - mse: 0.0768 - val_loss: 0.0821 - val_mse: 0.0821 - 16ms/epoch - 8ms/step\n",
      "Epoch 84/500\n",
      "2/2 - 0s - loss: 0.0756 - mse: 0.0756 - val_loss: 0.0809 - val_mse: 0.0809 - 17ms/epoch - 8ms/step\n",
      "Epoch 85/500\n",
      "2/2 - 0s - loss: 0.0744 - mse: 0.0744 - val_loss: 0.0796 - val_mse: 0.0796 - 16ms/epoch - 8ms/step\n",
      "Epoch 86/500\n",
      "2/2 - 0s - loss: 0.0732 - mse: 0.0732 - val_loss: 0.0784 - val_mse: 0.0784 - 17ms/epoch - 8ms/step\n",
      "Epoch 87/500\n",
      "2/2 - 0s - loss: 0.0721 - mse: 0.0721 - val_loss: 0.0772 - val_mse: 0.0772 - 18ms/epoch - 9ms/step\n",
      "Epoch 88/500\n",
      "2/2 - 0s - loss: 0.0710 - mse: 0.0710 - val_loss: 0.0760 - val_mse: 0.0760 - 18ms/epoch - 9ms/step\n",
      "Epoch 89/500\n",
      "2/2 - 0s - loss: 0.0699 - mse: 0.0699 - val_loss: 0.0749 - val_mse: 0.0749 - 20ms/epoch - 10ms/step\n",
      "Epoch 90/500\n",
      "2/2 - 0s - loss: 0.0688 - mse: 0.0688 - val_loss: 0.0738 - val_mse: 0.0738 - 19ms/epoch - 9ms/step\n",
      "Epoch 91/500\n",
      "2/2 - 0s - loss: 0.0678 - mse: 0.0678 - val_loss: 0.0727 - val_mse: 0.0727 - 17ms/epoch - 8ms/step\n",
      "Epoch 92/500\n",
      "2/2 - 0s - loss: 0.0668 - mse: 0.0668 - val_loss: 0.0716 - val_mse: 0.0716 - 16ms/epoch - 8ms/step\n",
      "Epoch 93/500\n",
      "2/2 - 0s - loss: 0.0658 - mse: 0.0658 - val_loss: 0.0706 - val_mse: 0.0706 - 16ms/epoch - 8ms/step\n",
      "Epoch 94/500\n",
      "2/2 - 0s - loss: 0.0648 - mse: 0.0648 - val_loss: 0.0696 - val_mse: 0.0696 - 16ms/epoch - 8ms/step\n",
      "Epoch 95/500\n",
      "2/2 - 0s - loss: 0.0638 - mse: 0.0638 - val_loss: 0.0686 - val_mse: 0.0686 - 16ms/epoch - 8ms/step\n",
      "Epoch 96/500\n",
      "2/2 - 0s - loss: 0.0629 - mse: 0.0629 - val_loss: 0.0676 - val_mse: 0.0676 - 16ms/epoch - 8ms/step\n",
      "Epoch 97/500\n",
      "2/2 - 0s - loss: 0.0620 - mse: 0.0620 - val_loss: 0.0667 - val_mse: 0.0667 - 16ms/epoch - 8ms/step\n",
      "Epoch 98/500\n",
      "2/2 - 0s - loss: 0.0611 - mse: 0.0611 - val_loss: 0.0658 - val_mse: 0.0658 - 17ms/epoch - 8ms/step\n",
      "Epoch 99/500\n",
      "2/2 - 0s - loss: 0.0602 - mse: 0.0602 - val_loss: 0.0649 - val_mse: 0.0649 - 16ms/epoch - 8ms/step\n",
      "Epoch 100/500\n",
      "2/2 - 0s - loss: 0.0594 - mse: 0.0594 - val_loss: 0.0640 - val_mse: 0.0640 - 16ms/epoch - 8ms/step\n",
      "Epoch 101/500\n",
      "2/2 - 0s - loss: 0.0585 - mse: 0.0585 - val_loss: 0.0631 - val_mse: 0.0631 - 13ms/epoch - 6ms/step\n",
      "Epoch 102/500\n",
      "2/2 - 0s - loss: 0.0577 - mse: 0.0577 - val_loss: 0.0623 - val_mse: 0.0623 - 15ms/epoch - 7ms/step\n",
      "Epoch 103/500\n",
      "2/2 - 0s - loss: 0.0569 - mse: 0.0569 - val_loss: 0.0614 - val_mse: 0.0614 - 14ms/epoch - 7ms/step\n",
      "Epoch 104/500\n",
      "2/2 - 0s - loss: 0.0561 - mse: 0.0561 - val_loss: 0.0606 - val_mse: 0.0606 - 17ms/epoch - 9ms/step\n",
      "Epoch 105/500\n",
      "2/2 - 0s - loss: 0.0553 - mse: 0.0553 - val_loss: 0.0598 - val_mse: 0.0598 - 17ms/epoch - 8ms/step\n",
      "Epoch 106/500\n",
      "2/2 - 0s - loss: 0.0546 - mse: 0.0546 - val_loss: 0.0591 - val_mse: 0.0591 - 17ms/epoch - 8ms/step\n",
      "Epoch 107/500\n",
      "2/2 - 0s - loss: 0.0539 - mse: 0.0539 - val_loss: 0.0583 - val_mse: 0.0583 - 16ms/epoch - 8ms/step\n",
      "Epoch 108/500\n",
      "2/2 - 0s - loss: 0.0531 - mse: 0.0531 - val_loss: 0.0575 - val_mse: 0.0575 - 16ms/epoch - 8ms/step\n",
      "Epoch 109/500\n",
      "2/2 - 0s - loss: 0.0524 - mse: 0.0524 - val_loss: 0.0568 - val_mse: 0.0568 - 17ms/epoch - 8ms/step\n",
      "Epoch 110/500\n",
      "2/2 - 0s - loss: 0.0517 - mse: 0.0517 - val_loss: 0.0561 - val_mse: 0.0561 - 15ms/epoch - 7ms/step\n",
      "Epoch 111/500\n",
      "2/2 - 0s - loss: 0.0511 - mse: 0.0511 - val_loss: 0.0553 - val_mse: 0.0553 - 14ms/epoch - 7ms/step\n",
      "Epoch 112/500\n",
      "2/2 - 0s - loss: 0.0504 - mse: 0.0504 - val_loss: 0.0546 - val_mse: 0.0546 - 15ms/epoch - 7ms/step\n",
      "Epoch 113/500\n",
      "2/2 - 0s - loss: 0.0498 - mse: 0.0498 - val_loss: 0.0539 - val_mse: 0.0539 - 17ms/epoch - 8ms/step\n",
      "Epoch 114/500\n",
      "2/2 - 0s - loss: 0.0491 - mse: 0.0491 - val_loss: 0.0533 - val_mse: 0.0533 - 17ms/epoch - 8ms/step\n",
      "Epoch 115/500\n",
      "2/2 - 0s - loss: 0.0485 - mse: 0.0485 - val_loss: 0.0526 - val_mse: 0.0526 - 16ms/epoch - 8ms/step\n",
      "Epoch 116/500\n",
      "2/2 - 0s - loss: 0.0479 - mse: 0.0479 - val_loss: 0.0520 - val_mse: 0.0520 - 16ms/epoch - 8ms/step\n",
      "Epoch 117/500\n",
      "2/2 - 0s - loss: 0.0473 - mse: 0.0473 - val_loss: 0.0513 - val_mse: 0.0513 - 17ms/epoch - 9ms/step\n",
      "Epoch 118/500\n",
      "2/2 - 0s - loss: 0.0467 - mse: 0.0467 - val_loss: 0.0507 - val_mse: 0.0507 - 15ms/epoch - 8ms/step\n",
      "Epoch 119/500\n",
      "2/2 - 0s - loss: 0.0461 - mse: 0.0461 - val_loss: 0.0501 - val_mse: 0.0501 - 16ms/epoch - 8ms/step\n",
      "Epoch 120/500\n",
      "2/2 - 0s - loss: 0.0455 - mse: 0.0455 - val_loss: 0.0495 - val_mse: 0.0495 - 15ms/epoch - 7ms/step\n",
      "Epoch 121/500\n",
      "2/2 - 0s - loss: 0.0450 - mse: 0.0450 - val_loss: 0.0490 - val_mse: 0.0490 - 16ms/epoch - 8ms/step\n",
      "Epoch 122/500\n",
      "2/2 - 0s - loss: 0.0445 - mse: 0.0445 - val_loss: 0.0484 - val_mse: 0.0484 - 16ms/epoch - 8ms/step\n",
      "Epoch 123/500\n",
      "2/2 - 0s - loss: 0.0439 - mse: 0.0439 - val_loss: 0.0478 - val_mse: 0.0478 - 16ms/epoch - 8ms/step\n",
      "Epoch 124/500\n",
      "2/2 - 0s - loss: 0.0434 - mse: 0.0434 - val_loss: 0.0473 - val_mse: 0.0473 - 16ms/epoch - 8ms/step\n",
      "Epoch 125/500\n",
      "2/2 - 0s - loss: 0.0429 - mse: 0.0429 - val_loss: 0.0468 - val_mse: 0.0468 - 16ms/epoch - 8ms/step\n",
      "Epoch 126/500\n",
      "2/2 - 0s - loss: 0.0424 - mse: 0.0424 - val_loss: 0.0463 - val_mse: 0.0463 - 17ms/epoch - 8ms/step\n",
      "Epoch 127/500\n",
      "2/2 - 0s - loss: 0.0419 - mse: 0.0419 - val_loss: 0.0457 - val_mse: 0.0457 - 18ms/epoch - 9ms/step\n",
      "Epoch 128/500\n",
      "2/2 - 0s - loss: 0.0414 - mse: 0.0414 - val_loss: 0.0452 - val_mse: 0.0452 - 15ms/epoch - 7ms/step\n",
      "Epoch 129/500\n",
      "2/2 - 0s - loss: 0.0410 - mse: 0.0410 - val_loss: 0.0447 - val_mse: 0.0447 - 16ms/epoch - 8ms/step\n",
      "Epoch 130/500\n",
      "2/2 - 0s - loss: 0.0405 - mse: 0.0405 - val_loss: 0.0443 - val_mse: 0.0443 - 14ms/epoch - 7ms/step\n",
      "Epoch 131/500\n",
      "2/2 - 0s - loss: 0.0400 - mse: 0.0400 - val_loss: 0.0438 - val_mse: 0.0438 - 17ms/epoch - 9ms/step\n",
      "Epoch 132/500\n",
      "2/2 - 0s - loss: 0.0396 - mse: 0.0396 - val_loss: 0.0433 - val_mse: 0.0433 - 15ms/epoch - 7ms/step\n",
      "Epoch 133/500\n",
      "2/2 - 0s - loss: 0.0391 - mse: 0.0391 - val_loss: 0.0429 - val_mse: 0.0429 - 15ms/epoch - 8ms/step\n",
      "Epoch 134/500\n",
      "2/2 - 0s - loss: 0.0387 - mse: 0.0387 - val_loss: 0.0424 - val_mse: 0.0424 - 14ms/epoch - 7ms/step\n",
      "Epoch 135/500\n",
      "2/2 - 0s - loss: 0.0383 - mse: 0.0383 - val_loss: 0.0420 - val_mse: 0.0420 - 16ms/epoch - 8ms/step\n",
      "Epoch 136/500\n",
      "2/2 - 0s - loss: 0.0379 - mse: 0.0379 - val_loss: 0.0416 - val_mse: 0.0416 - 17ms/epoch - 9ms/step\n",
      "Epoch 137/500\n",
      "2/2 - 0s - loss: 0.0375 - mse: 0.0375 - val_loss: 0.0412 - val_mse: 0.0412 - 14ms/epoch - 7ms/step\n",
      "Epoch 138/500\n",
      "2/2 - 0s - loss: 0.0371 - mse: 0.0371 - val_loss: 0.0407 - val_mse: 0.0407 - 15ms/epoch - 8ms/step\n",
      "Epoch 139/500\n",
      "2/2 - 0s - loss: 0.0367 - mse: 0.0367 - val_loss: 0.0403 - val_mse: 0.0403 - 15ms/epoch - 7ms/step\n",
      "Epoch 140/500\n",
      "2/2 - 0s - loss: 0.0363 - mse: 0.0363 - val_loss: 0.0399 - val_mse: 0.0399 - 16ms/epoch - 8ms/step\n",
      "Epoch 141/500\n",
      "2/2 - 0s - loss: 0.0359 - mse: 0.0359 - val_loss: 0.0395 - val_mse: 0.0395 - 17ms/epoch - 9ms/step\n",
      "Epoch 142/500\n",
      "2/2 - 0s - loss: 0.0355 - mse: 0.0355 - val_loss: 0.0391 - val_mse: 0.0391 - 15ms/epoch - 7ms/step\n",
      "Epoch 143/500\n",
      "2/2 - 0s - loss: 0.0352 - mse: 0.0352 - val_loss: 0.0387 - val_mse: 0.0387 - 16ms/epoch - 8ms/step\n",
      "Epoch 144/500\n",
      "2/2 - 0s - loss: 0.0348 - mse: 0.0348 - val_loss: 0.0383 - val_mse: 0.0383 - 15ms/epoch - 7ms/step\n",
      "Epoch 145/500\n",
      "2/2 - 0s - loss: 0.0344 - mse: 0.0344 - val_loss: 0.0379 - val_mse: 0.0379 - 17ms/epoch - 8ms/step\n",
      "Epoch 146/500\n",
      "2/2 - 0s - loss: 0.0341 - mse: 0.0341 - val_loss: 0.0376 - val_mse: 0.0376 - 15ms/epoch - 7ms/step\n",
      "Epoch 147/500\n",
      "2/2 - 0s - loss: 0.0338 - mse: 0.0338 - val_loss: 0.0372 - val_mse: 0.0372 - 15ms/epoch - 7ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 148/500\n",
      "2/2 - 0s - loss: 0.0334 - mse: 0.0334 - val_loss: 0.0368 - val_mse: 0.0368 - 15ms/epoch - 7ms/step\n",
      "Epoch 149/500\n",
      "2/2 - 0s - loss: 0.0331 - mse: 0.0331 - val_loss: 0.0364 - val_mse: 0.0364 - 17ms/epoch - 8ms/step\n",
      "Epoch 150/500\n",
      "2/2 - 0s - loss: 0.0328 - mse: 0.0328 - val_loss: 0.0361 - val_mse: 0.0361 - 17ms/epoch - 8ms/step\n",
      "Epoch 151/500\n",
      "2/2 - 0s - loss: 0.0324 - mse: 0.0324 - val_loss: 0.0358 - val_mse: 0.0358 - 16ms/epoch - 8ms/step\n",
      "Epoch 152/500\n",
      "2/2 - 0s - loss: 0.0321 - mse: 0.0321 - val_loss: 0.0354 - val_mse: 0.0354 - 16ms/epoch - 8ms/step\n",
      "Epoch 153/500\n",
      "2/2 - 0s - loss: 0.0318 - mse: 0.0318 - val_loss: 0.0351 - val_mse: 0.0351 - 16ms/epoch - 8ms/step\n",
      "Epoch 154/500\n",
      "2/2 - 0s - loss: 0.0315 - mse: 0.0315 - val_loss: 0.0348 - val_mse: 0.0348 - 16ms/epoch - 8ms/step\n",
      "Epoch 155/500\n",
      "2/2 - 0s - loss: 0.0312 - mse: 0.0312 - val_loss: 0.0345 - val_mse: 0.0345 - 31ms/epoch - 15ms/step\n",
      "Epoch 156/500\n",
      "2/2 - 0s - loss: 0.0309 - mse: 0.0309 - val_loss: 0.0341 - val_mse: 0.0341 - 15ms/epoch - 8ms/step\n",
      "Epoch 157/500\n",
      "2/2 - 0s - loss: 0.0306 - mse: 0.0306 - val_loss: 0.0338 - val_mse: 0.0338 - 15ms/epoch - 7ms/step\n",
      "Epoch 158/500\n",
      "2/2 - 0s - loss: 0.0303 - mse: 0.0303 - val_loss: 0.0335 - val_mse: 0.0335 - 17ms/epoch - 9ms/step\n",
      "Epoch 159/500\n",
      "2/2 - 0s - loss: 0.0300 - mse: 0.0300 - val_loss: 0.0332 - val_mse: 0.0332 - 17ms/epoch - 8ms/step\n",
      "Epoch 160/500\n",
      "2/2 - 0s - loss: 0.0297 - mse: 0.0297 - val_loss: 0.0329 - val_mse: 0.0329 - 17ms/epoch - 8ms/step\n",
      "Epoch 161/500\n",
      "2/2 - 0s - loss: 0.0294 - mse: 0.0294 - val_loss: 0.0326 - val_mse: 0.0326 - 16ms/epoch - 8ms/step\n",
      "Epoch 162/500\n",
      "2/2 - 0s - loss: 0.0292 - mse: 0.0292 - val_loss: 0.0323 - val_mse: 0.0323 - 16ms/epoch - 8ms/step\n",
      "Epoch 163/500\n",
      "2/2 - 0s - loss: 0.0289 - mse: 0.0289 - val_loss: 0.0320 - val_mse: 0.0320 - 16ms/epoch - 8ms/step\n",
      "Epoch 164/500\n",
      "2/2 - 0s - loss: 0.0286 - mse: 0.0286 - val_loss: 0.0317 - val_mse: 0.0317 - 17ms/epoch - 8ms/step\n",
      "Epoch 165/500\n",
      "2/2 - 0s - loss: 0.0283 - mse: 0.0283 - val_loss: 0.0314 - val_mse: 0.0314 - 18ms/epoch - 9ms/step\n",
      "Epoch 166/500\n",
      "2/2 - 0s - loss: 0.0281 - mse: 0.0281 - val_loss: 0.0312 - val_mse: 0.0312 - 14ms/epoch - 7ms/step\n",
      "Epoch 167/500\n",
      "2/2 - 0s - loss: 0.0278 - mse: 0.0278 - val_loss: 0.0309 - val_mse: 0.0309 - 15ms/epoch - 8ms/step\n",
      "Epoch 168/500\n",
      "2/2 - 0s - loss: 0.0276 - mse: 0.0276 - val_loss: 0.0306 - val_mse: 0.0306 - 15ms/epoch - 7ms/step\n",
      "Epoch 169/500\n",
      "2/2 - 0s - loss: 0.0273 - mse: 0.0273 - val_loss: 0.0304 - val_mse: 0.0304 - 18ms/epoch - 9ms/step\n",
      "Epoch 170/500\n",
      "2/2 - 0s - loss: 0.0271 - mse: 0.0271 - val_loss: 0.0301 - val_mse: 0.0301 - 15ms/epoch - 7ms/step\n",
      "Epoch 171/500\n",
      "2/2 - 0s - loss: 0.0268 - mse: 0.0268 - val_loss: 0.0298 - val_mse: 0.0298 - 15ms/epoch - 8ms/step\n",
      "Epoch 172/500\n",
      "2/2 - 0s - loss: 0.0266 - mse: 0.0266 - val_loss: 0.0296 - val_mse: 0.0296 - 15ms/epoch - 7ms/step\n",
      "Epoch 173/500\n",
      "2/2 - 0s - loss: 0.0264 - mse: 0.0264 - val_loss: 0.0293 - val_mse: 0.0293 - 17ms/epoch - 8ms/step\n",
      "Epoch 174/500\n",
      "2/2 - 0s - loss: 0.0261 - mse: 0.0261 - val_loss: 0.0291 - val_mse: 0.0291 - 18ms/epoch - 9ms/step\n",
      "Epoch 175/500\n",
      "2/2 - 0s - loss: 0.0259 - mse: 0.0259 - val_loss: 0.0288 - val_mse: 0.0288 - 18ms/epoch - 9ms/step\n",
      "Epoch 176/500\n",
      "2/2 - 0s - loss: 0.0257 - mse: 0.0257 - val_loss: 0.0286 - val_mse: 0.0286 - 21ms/epoch - 11ms/step\n",
      "Epoch 177/500\n",
      "2/2 - 0s - loss: 0.0254 - mse: 0.0254 - val_loss: 0.0283 - val_mse: 0.0283 - 18ms/epoch - 9ms/step\n",
      "Epoch 178/500\n",
      "2/2 - 0s - loss: 0.0252 - mse: 0.0252 - val_loss: 0.0281 - val_mse: 0.0281 - 17ms/epoch - 9ms/step\n",
      "Epoch 179/500\n",
      "2/2 - 0s - loss: 0.0250 - mse: 0.0250 - val_loss: 0.0279 - val_mse: 0.0279 - 17ms/epoch - 8ms/step\n",
      "Epoch 180/500\n",
      "2/2 - 0s - loss: 0.0248 - mse: 0.0248 - val_loss: 0.0276 - val_mse: 0.0276 - 18ms/epoch - 9ms/step\n",
      "Epoch 181/500\n",
      "2/2 - 0s - loss: 0.0246 - mse: 0.0246 - val_loss: 0.0274 - val_mse: 0.0274 - 18ms/epoch - 9ms/step\n",
      "Epoch 182/500\n",
      "2/2 - 0s - loss: 0.0244 - mse: 0.0244 - val_loss: 0.0272 - val_mse: 0.0272 - 21ms/epoch - 11ms/step\n",
      "Epoch 183/500\n",
      "2/2 - 0s - loss: 0.0242 - mse: 0.0242 - val_loss: 0.0269 - val_mse: 0.0269 - 18ms/epoch - 9ms/step\n",
      "Epoch 184/500\n",
      "2/2 - 0s - loss: 0.0239 - mse: 0.0239 - val_loss: 0.0267 - val_mse: 0.0267 - 17ms/epoch - 8ms/step\n",
      "Epoch 185/500\n",
      "2/2 - 0s - loss: 0.0237 - mse: 0.0237 - val_loss: 0.0265 - val_mse: 0.0265 - 16ms/epoch - 8ms/step\n",
      "Epoch 186/500\n",
      "2/2 - 0s - loss: 0.0235 - mse: 0.0235 - val_loss: 0.0263 - val_mse: 0.0263 - 16ms/epoch - 8ms/step\n",
      "Epoch 187/500\n",
      "2/2 - 0s - loss: 0.0233 - mse: 0.0233 - val_loss: 0.0261 - val_mse: 0.0261 - 16ms/epoch - 8ms/step\n",
      "Epoch 188/500\n",
      "2/2 - 0s - loss: 0.0232 - mse: 0.0232 - val_loss: 0.0259 - val_mse: 0.0259 - 16ms/epoch - 8ms/step\n",
      "Epoch 189/500\n",
      "2/2 - 0s - loss: 0.0230 - mse: 0.0230 - val_loss: 0.0257 - val_mse: 0.0257 - 16ms/epoch - 8ms/step\n",
      "Epoch 190/500\n",
      "2/2 - 0s - loss: 0.0228 - mse: 0.0228 - val_loss: 0.0255 - val_mse: 0.0255 - 17ms/epoch - 9ms/step\n",
      "Epoch 191/500\n",
      "2/2 - 0s - loss: 0.0226 - mse: 0.0226 - val_loss: 0.0253 - val_mse: 0.0253 - 15ms/epoch - 7ms/step\n",
      "Epoch 192/500\n",
      "2/2 - 0s - loss: 0.0224 - mse: 0.0224 - val_loss: 0.0251 - val_mse: 0.0251 - 15ms/epoch - 8ms/step\n",
      "Epoch 193/500\n",
      "2/2 - 0s - loss: 0.0222 - mse: 0.0222 - val_loss: 0.0249 - val_mse: 0.0249 - 15ms/epoch - 8ms/step\n",
      "Epoch 194/500\n",
      "2/2 - 0s - loss: 0.0220 - mse: 0.0220 - val_loss: 0.0247 - val_mse: 0.0247 - 17ms/epoch - 8ms/step\n",
      "Epoch 195/500\n",
      "2/2 - 0s - loss: 0.0219 - mse: 0.0219 - val_loss: 0.0245 - val_mse: 0.0245 - 14ms/epoch - 7ms/step\n",
      "Epoch 196/500\n",
      "2/2 - 0s - loss: 0.0217 - mse: 0.0217 - val_loss: 0.0243 - val_mse: 0.0243 - 14ms/epoch - 7ms/step\n",
      "Epoch 197/500\n",
      "2/2 - 0s - loss: 0.0215 - mse: 0.0215 - val_loss: 0.0242 - val_mse: 0.0242 - 15ms/epoch - 7ms/step\n",
      "Epoch 198/500\n",
      "2/2 - 0s - loss: 0.0213 - mse: 0.0213 - val_loss: 0.0240 - val_mse: 0.0240 - 17ms/epoch - 8ms/step\n",
      "Epoch 199/500\n",
      "2/2 - 0s - loss: 0.0212 - mse: 0.0212 - val_loss: 0.0238 - val_mse: 0.0238 - 17ms/epoch - 8ms/step\n",
      "Epoch 200/500\n",
      "2/2 - 0s - loss: 0.0210 - mse: 0.0210 - val_loss: 0.0236 - val_mse: 0.0236 - 16ms/epoch - 8ms/step\n",
      "Epoch 201/500\n",
      "2/2 - 0s - loss: 0.0208 - mse: 0.0208 - val_loss: 0.0234 - val_mse: 0.0234 - 16ms/epoch - 8ms/step\n",
      "Epoch 202/500\n",
      "2/2 - 0s - loss: 0.0207 - mse: 0.0207 - val_loss: 0.0233 - val_mse: 0.0233 - 17ms/epoch - 8ms/step\n",
      "Epoch 203/500\n",
      "2/2 - 0s - loss: 0.0205 - mse: 0.0205 - val_loss: 0.0231 - val_mse: 0.0231 - 18ms/epoch - 9ms/step\n",
      "Epoch 204/500\n",
      "2/2 - 0s - loss: 0.0203 - mse: 0.0203 - val_loss: 0.0229 - val_mse: 0.0229 - 15ms/epoch - 8ms/step\n",
      "Epoch 205/500\n",
      "2/2 - 0s - loss: 0.0202 - mse: 0.0202 - val_loss: 0.0227 - val_mse: 0.0227 - 16ms/epoch - 8ms/step\n",
      "Epoch 206/500\n",
      "2/2 - 0s - loss: 0.0200 - mse: 0.0200 - val_loss: 0.0226 - val_mse: 0.0226 - 14ms/epoch - 7ms/step\n",
      "Epoch 207/500\n",
      "2/2 - 0s - loss: 0.0199 - mse: 0.0199 - val_loss: 0.0224 - val_mse: 0.0224 - 17ms/epoch - 9ms/step\n",
      "Epoch 208/500\n",
      "2/2 - 0s - loss: 0.0197 - mse: 0.0197 - val_loss: 0.0222 - val_mse: 0.0222 - 15ms/epoch - 7ms/step\n",
      "Epoch 209/500\n",
      "2/2 - 0s - loss: 0.0196 - mse: 0.0196 - val_loss: 0.0221 - val_mse: 0.0221 - 16ms/epoch - 8ms/step\n",
      "Epoch 210/500\n",
      "2/2 - 0s - loss: 0.0194 - mse: 0.0194 - val_loss: 0.0219 - val_mse: 0.0219 - 14ms/epoch - 7ms/step\n",
      "Epoch 211/500\n",
      "2/2 - 0s - loss: 0.0193 - mse: 0.0193 - val_loss: 0.0217 - val_mse: 0.0217 - 17ms/epoch - 9ms/step\n",
      "Epoch 212/500\n",
      "2/2 - 0s - loss: 0.0191 - mse: 0.0191 - val_loss: 0.0216 - val_mse: 0.0216 - 14ms/epoch - 7ms/step\n",
      "Epoch 213/500\n",
      "2/2 - 0s - loss: 0.0190 - mse: 0.0190 - val_loss: 0.0214 - val_mse: 0.0214 - 15ms/epoch - 8ms/step\n",
      "Epoch 214/500\n",
      "2/2 - 0s - loss: 0.0188 - mse: 0.0188 - val_loss: 0.0213 - val_mse: 0.0213 - 14ms/epoch - 7ms/step\n",
      "Epoch 215/500\n",
      "2/2 - 0s - loss: 0.0187 - mse: 0.0187 - val_loss: 0.0211 - val_mse: 0.0211 - 16ms/epoch - 8ms/step\n",
      "Epoch 216/500\n",
      "2/2 - 0s - loss: 0.0185 - mse: 0.0185 - val_loss: 0.0210 - val_mse: 0.0210 - 16ms/epoch - 8ms/step\n",
      "Epoch 217/500\n",
      "2/2 - 0s - loss: 0.0184 - mse: 0.0184 - val_loss: 0.0208 - val_mse: 0.0208 - 18ms/epoch - 9ms/step\n",
      "Epoch 218/500\n",
      "2/2 - 0s - loss: 0.0183 - mse: 0.0183 - val_loss: 0.0207 - val_mse: 0.0207 - 15ms/epoch - 7ms/step\n",
      "Epoch 219/500\n",
      "2/2 - 0s - loss: 0.0181 - mse: 0.0181 - val_loss: 0.0205 - val_mse: 0.0205 - 15ms/epoch - 8ms/step\n",
      "Epoch 220/500\n",
      "2/2 - 0s - loss: 0.0180 - mse: 0.0180 - val_loss: 0.0204 - val_mse: 0.0204 - 15ms/epoch - 7ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 221/500\n",
      "2/2 - 0s - loss: 0.0179 - mse: 0.0179 - val_loss: 0.0202 - val_mse: 0.0202 - 17ms/epoch - 8ms/step\n",
      "Epoch 222/500\n",
      "2/2 - 0s - loss: 0.0177 - mse: 0.0177 - val_loss: 0.0201 - val_mse: 0.0201 - 16ms/epoch - 8ms/step\n",
      "Epoch 223/500\n",
      "2/2 - 0s - loss: 0.0176 - mse: 0.0176 - val_loss: 0.0199 - val_mse: 0.0199 - 16ms/epoch - 8ms/step\n",
      "Epoch 224/500\n",
      "2/2 - 0s - loss: 0.0175 - mse: 0.0175 - val_loss: 0.0198 - val_mse: 0.0198 - 16ms/epoch - 8ms/step\n",
      "Epoch 225/500\n",
      "2/2 - 0s - loss: 0.0173 - mse: 0.0173 - val_loss: 0.0197 - val_mse: 0.0197 - 16ms/epoch - 8ms/step\n",
      "Epoch 226/500\n",
      "2/2 - 0s - loss: 0.0172 - mse: 0.0172 - val_loss: 0.0195 - val_mse: 0.0195 - 17ms/epoch - 9ms/step\n",
      "Epoch 227/500\n",
      "2/2 - 0s - loss: 0.0171 - mse: 0.0171 - val_loss: 0.0194 - val_mse: 0.0194 - 15ms/epoch - 7ms/step\n",
      "Epoch 228/500\n",
      "2/2 - 0s - loss: 0.0170 - mse: 0.0170 - val_loss: 0.0193 - val_mse: 0.0193 - 15ms/epoch - 7ms/step\n",
      "Epoch 229/500\n",
      "2/2 - 0s - loss: 0.0168 - mse: 0.0168 - val_loss: 0.0191 - val_mse: 0.0191 - 15ms/epoch - 7ms/step\n",
      "Epoch 230/500\n",
      "2/2 - 0s - loss: 0.0167 - mse: 0.0167 - val_loss: 0.0190 - val_mse: 0.0190 - 16ms/epoch - 8ms/step\n",
      "Epoch 231/500\n",
      "2/2 - 0s - loss: 0.0166 - mse: 0.0166 - val_loss: 0.0189 - val_mse: 0.0189 - 17ms/epoch - 8ms/step\n",
      "Epoch 232/500\n",
      "2/2 - 0s - loss: 0.0165 - mse: 0.0165 - val_loss: 0.0187 - val_mse: 0.0187 - 16ms/epoch - 8ms/step\n",
      "Epoch 233/500\n",
      "2/2 - 0s - loss: 0.0164 - mse: 0.0164 - val_loss: 0.0186 - val_mse: 0.0186 - 16ms/epoch - 8ms/step\n",
      "Epoch 234/500\n",
      "2/2 - 0s - loss: 0.0163 - mse: 0.0163 - val_loss: 0.0185 - val_mse: 0.0185 - 16ms/epoch - 8ms/step\n",
      "Epoch 235/500\n",
      "2/2 - 0s - loss: 0.0162 - mse: 0.0162 - val_loss: 0.0184 - val_mse: 0.0184 - 16ms/epoch - 8ms/step\n",
      "Epoch 236/500\n",
      "2/2 - 0s - loss: 0.0160 - mse: 0.0160 - val_loss: 0.0183 - val_mse: 0.0183 - 16ms/epoch - 8ms/step\n",
      "Epoch 237/500\n",
      "2/2 - 0s - loss: 0.0159 - mse: 0.0159 - val_loss: 0.0181 - val_mse: 0.0181 - 16ms/epoch - 8ms/step\n",
      "Epoch 238/500\n",
      "2/2 - 0s - loss: 0.0158 - mse: 0.0158 - val_loss: 0.0180 - val_mse: 0.0180 - 16ms/epoch - 8ms/step\n",
      "Epoch 239/500\n",
      "2/2 - 0s - loss: 0.0157 - mse: 0.0157 - val_loss: 0.0179 - val_mse: 0.0179 - 16ms/epoch - 8ms/step\n",
      "Epoch 240/500\n",
      "2/2 - 0s - loss: 0.0156 - mse: 0.0156 - val_loss: 0.0178 - val_mse: 0.0178 - 17ms/epoch - 9ms/step\n",
      "Epoch 241/500\n",
      "2/2 - 0s - loss: 0.0155 - mse: 0.0155 - val_loss: 0.0177 - val_mse: 0.0177 - 19ms/epoch - 9ms/step\n",
      "Epoch 242/500\n",
      "2/2 - 0s - loss: 0.0154 - mse: 0.0154 - val_loss: 0.0176 - val_mse: 0.0176 - 22ms/epoch - 11ms/step\n",
      "Epoch 243/500\n",
      "2/2 - 0s - loss: 0.0153 - mse: 0.0153 - val_loss: 0.0174 - val_mse: 0.0174 - 18ms/epoch - 9ms/step\n",
      "Epoch 244/500\n",
      "2/2 - 0s - loss: 0.0152 - mse: 0.0152 - val_loss: 0.0173 - val_mse: 0.0173 - 17ms/epoch - 8ms/step\n",
      "Epoch 245/500\n",
      "2/2 - 0s - loss: 0.0151 - mse: 0.0151 - val_loss: 0.0172 - val_mse: 0.0172 - 16ms/epoch - 8ms/step\n",
      "Epoch 246/500\n",
      "2/2 - 0s - loss: 0.0150 - mse: 0.0150 - val_loss: 0.0171 - val_mse: 0.0171 - 16ms/epoch - 8ms/step\n",
      "Epoch 247/500\n",
      "2/2 - 0s - loss: 0.0149 - mse: 0.0149 - val_loss: 0.0170 - val_mse: 0.0170 - 16ms/epoch - 8ms/step\n",
      "Epoch 248/500\n",
      "2/2 - 0s - loss: 0.0148 - mse: 0.0148 - val_loss: 0.0169 - val_mse: 0.0169 - 16ms/epoch - 8ms/step\n",
      "Epoch 249/500\n",
      "2/2 - 0s - loss: 0.0147 - mse: 0.0147 - val_loss: 0.0168 - val_mse: 0.0168 - 17ms/epoch - 8ms/step\n",
      "Epoch 250/500\n",
      "2/2 - 0s - loss: 0.0146 - mse: 0.0146 - val_loss: 0.0167 - val_mse: 0.0167 - 15ms/epoch - 7ms/step\n",
      "Epoch 251/500\n",
      "2/2 - 0s - loss: 0.0145 - mse: 0.0145 - val_loss: 0.0166 - val_mse: 0.0166 - 15ms/epoch - 7ms/step\n",
      "Epoch 252/500\n",
      "2/2 - 0s - loss: 0.0144 - mse: 0.0144 - val_loss: 0.0165 - val_mse: 0.0165 - 15ms/epoch - 7ms/step\n",
      "Epoch 253/500\n",
      "2/2 - 0s - loss: 0.0143 - mse: 0.0143 - val_loss: 0.0164 - val_mse: 0.0164 - 17ms/epoch - 8ms/step\n",
      "Epoch 254/500\n",
      "2/2 - 0s - loss: 0.0142 - mse: 0.0142 - val_loss: 0.0163 - val_mse: 0.0163 - 17ms/epoch - 8ms/step\n",
      "Epoch 255/500\n",
      "2/2 - 0s - loss: 0.0141 - mse: 0.0141 - val_loss: 0.0162 - val_mse: 0.0162 - 16ms/epoch - 8ms/step\n",
      "Epoch 256/500\n",
      "2/2 - 0s - loss: 0.0140 - mse: 0.0140 - val_loss: 0.0161 - val_mse: 0.0161 - 16ms/epoch - 8ms/step\n",
      "Epoch 257/500\n",
      "2/2 - 0s - loss: 0.0139 - mse: 0.0139 - val_loss: 0.0160 - val_mse: 0.0160 - 16ms/epoch - 8ms/step\n",
      "Epoch 258/500\n",
      "2/2 - 0s - loss: 0.0139 - mse: 0.0139 - val_loss: 0.0159 - val_mse: 0.0159 - 16ms/epoch - 8ms/step\n",
      "Epoch 259/500\n",
      "2/2 - 0s - loss: 0.0138 - mse: 0.0138 - val_loss: 0.0158 - val_mse: 0.0158 - 16ms/epoch - 8ms/step\n",
      "Epoch 260/500\n",
      "2/2 - 0s - loss: 0.0137 - mse: 0.0137 - val_loss: 0.0157 - val_mse: 0.0157 - 16ms/epoch - 8ms/step\n",
      "Epoch 261/500\n",
      "2/2 - 0s - loss: 0.0136 - mse: 0.0136 - val_loss: 0.0156 - val_mse: 0.0156 - 18ms/epoch - 9ms/step\n",
      "Epoch 262/500\n",
      "2/2 - 0s - loss: 0.0135 - mse: 0.0135 - val_loss: 0.0155 - val_mse: 0.0155 - 15ms/epoch - 8ms/step\n",
      "Epoch 263/500\n",
      "2/2 - 0s - loss: 0.0134 - mse: 0.0134 - val_loss: 0.0154 - val_mse: 0.0154 - 16ms/epoch - 8ms/step\n",
      "Epoch 264/500\n",
      "2/2 - 0s - loss: 0.0133 - mse: 0.0133 - val_loss: 0.0153 - val_mse: 0.0153 - 14ms/epoch - 7ms/step\n",
      "Epoch 265/500\n",
      "2/2 - 0s - loss: 0.0133 - mse: 0.0133 - val_loss: 0.0152 - val_mse: 0.0152 - 17ms/epoch - 8ms/step\n",
      "Epoch 266/500\n",
      "2/2 - 0s - loss: 0.0132 - mse: 0.0132 - val_loss: 0.0151 - val_mse: 0.0151 - 14ms/epoch - 7ms/step\n",
      "Epoch 267/500\n",
      "2/2 - 0s - loss: 0.0131 - mse: 0.0131 - val_loss: 0.0150 - val_mse: 0.0150 - 15ms/epoch - 7ms/step\n",
      "Epoch 268/500\n",
      "2/2 - 0s - loss: 0.0130 - mse: 0.0130 - val_loss: 0.0149 - val_mse: 0.0149 - 15ms/epoch - 7ms/step\n",
      "Epoch 269/500\n",
      "2/2 - 0s - loss: 0.0129 - mse: 0.0129 - val_loss: 0.0148 - val_mse: 0.0148 - 16ms/epoch - 8ms/step\n",
      "Epoch 270/500\n",
      "2/2 - 0s - loss: 0.0129 - mse: 0.0129 - val_loss: 0.0148 - val_mse: 0.0148 - 16ms/epoch - 8ms/step\n",
      "Epoch 271/500\n",
      "2/2 - 0s - loss: 0.0128 - mse: 0.0128 - val_loss: 0.0147 - val_mse: 0.0147 - 17ms/epoch - 8ms/step\n",
      "Epoch 272/500\n",
      "2/2 - 0s - loss: 0.0127 - mse: 0.0127 - val_loss: 0.0146 - val_mse: 0.0146 - 16ms/epoch - 8ms/step\n",
      "Epoch 273/500\n",
      "2/2 - 0s - loss: 0.0126 - mse: 0.0126 - val_loss: 0.0145 - val_mse: 0.0145 - 17ms/epoch - 8ms/step\n",
      "Epoch 274/500\n",
      "2/2 - 0s - loss: 0.0126 - mse: 0.0126 - val_loss: 0.0144 - val_mse: 0.0144 - 18ms/epoch - 9ms/step\n",
      "Epoch 275/500\n",
      "2/2 - 0s - loss: 0.0125 - mse: 0.0125 - val_loss: 0.0143 - val_mse: 0.0143 - 15ms/epoch - 8ms/step\n",
      "Epoch 276/500\n",
      "2/2 - 0s - loss: 0.0124 - mse: 0.0124 - val_loss: 0.0143 - val_mse: 0.0143 - 16ms/epoch - 8ms/step\n",
      "Epoch 277/500\n",
      "2/2 - 0s - loss: 0.0123 - mse: 0.0123 - val_loss: 0.0142 - val_mse: 0.0142 - 14ms/epoch - 7ms/step\n",
      "Epoch 278/500\n",
      "2/2 - 0s - loss: 0.0123 - mse: 0.0123 - val_loss: 0.0141 - val_mse: 0.0141 - 18ms/epoch - 9ms/step\n",
      "Epoch 279/500\n",
      "2/2 - 0s - loss: 0.0122 - mse: 0.0122 - val_loss: 0.0140 - val_mse: 0.0140 - 18ms/epoch - 9ms/step\n",
      "Epoch 280/500\n",
      "2/2 - 0s - loss: 0.0121 - mse: 0.0121 - val_loss: 0.0140 - val_mse: 0.0140 - 21ms/epoch - 11ms/step\n",
      "Epoch 281/500\n",
      "2/2 - 0s - loss: 0.0121 - mse: 0.0121 - val_loss: 0.0139 - val_mse: 0.0139 - 19ms/epoch - 10ms/step\n",
      "Epoch 282/500\n",
      "2/2 - 0s - loss: 0.0120 - mse: 0.0120 - val_loss: 0.0138 - val_mse: 0.0138 - 17ms/epoch - 9ms/step\n",
      "Epoch 283/500\n",
      "2/2 - 0s - loss: 0.0119 - mse: 0.0119 - val_loss: 0.0137 - val_mse: 0.0137 - 17ms/epoch - 8ms/step\n",
      "Epoch 284/500\n",
      "2/2 - 0s - loss: 0.0119 - mse: 0.0119 - val_loss: 0.0137 - val_mse: 0.0137 - 18ms/epoch - 9ms/step\n",
      "Epoch 285/500\n",
      "2/2 - 0s - loss: 0.0118 - mse: 0.0118 - val_loss: 0.0136 - val_mse: 0.0136 - 15ms/epoch - 8ms/step\n",
      "Epoch 286/500\n",
      "2/2 - 0s - loss: 0.0117 - mse: 0.0117 - val_loss: 0.0135 - val_mse: 0.0135 - 17ms/epoch - 8ms/step\n",
      "Epoch 287/500\n",
      "2/2 - 0s - loss: 0.0117 - mse: 0.0117 - val_loss: 0.0134 - val_mse: 0.0134 - 14ms/epoch - 7ms/step\n",
      "Epoch 288/500\n",
      "2/2 - 0s - loss: 0.0116 - mse: 0.0116 - val_loss: 0.0134 - val_mse: 0.0134 - 18ms/epoch - 9ms/step\n",
      "Epoch 289/500\n",
      "2/2 - 0s - loss: 0.0115 - mse: 0.0115 - val_loss: 0.0133 - val_mse: 0.0133 - 21ms/epoch - 10ms/step\n",
      "Epoch 290/500\n",
      "2/2 - 0s - loss: 0.0115 - mse: 0.0115 - val_loss: 0.0132 - val_mse: 0.0132 - 19ms/epoch - 10ms/step\n",
      "Epoch 291/500\n",
      "2/2 - 0s - loss: 0.0114 - mse: 0.0114 - val_loss: 0.0132 - val_mse: 0.0132 - 17ms/epoch - 9ms/step\n",
      "Epoch 292/500\n",
      "2/2 - 0s - loss: 0.0113 - mse: 0.0113 - val_loss: 0.0131 - val_mse: 0.0131 - 17ms/epoch - 8ms/step\n",
      "Epoch 293/500\n",
      "2/2 - 0s - loss: 0.0113 - mse: 0.0113 - val_loss: 0.0130 - val_mse: 0.0130 - 18ms/epoch - 9ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 294/500\n",
      "2/2 - 0s - loss: 0.0112 - mse: 0.0112 - val_loss: 0.0130 - val_mse: 0.0130 - 15ms/epoch - 8ms/step\n",
      "Epoch 295/500\n",
      "2/2 - 0s - loss: 0.0112 - mse: 0.0112 - val_loss: 0.0129 - val_mse: 0.0129 - 17ms/epoch - 9ms/step\n",
      "Epoch 296/500\n",
      "2/2 - 0s - loss: 0.0111 - mse: 0.0111 - val_loss: 0.0128 - val_mse: 0.0128 - 15ms/epoch - 8ms/step\n",
      "Epoch 297/500\n",
      "2/2 - 0s - loss: 0.0110 - mse: 0.0110 - val_loss: 0.0128 - val_mse: 0.0128 - 17ms/epoch - 9ms/step\n",
      "Epoch 298/500\n",
      "2/2 - 0s - loss: 0.0110 - mse: 0.0110 - val_loss: 0.0127 - val_mse: 0.0127 - 20ms/epoch - 10ms/step\n",
      "Epoch 299/500\n",
      "2/2 - 0s - loss: 0.0109 - mse: 0.0109 - val_loss: 0.0126 - val_mse: 0.0126 - 19ms/epoch - 10ms/step\n",
      "Epoch 300/500\n",
      "2/2 - 0s - loss: 0.0109 - mse: 0.0109 - val_loss: 0.0126 - val_mse: 0.0126 - 17ms/epoch - 8ms/step\n",
      "Epoch 301/500\n",
      "2/2 - 0s - loss: 0.0108 - mse: 0.0108 - val_loss: 0.0125 - val_mse: 0.0125 - 16ms/epoch - 8ms/step\n",
      "Epoch 302/500\n",
      "2/2 - 0s - loss: 0.0107 - mse: 0.0107 - val_loss: 0.0124 - val_mse: 0.0124 - 16ms/epoch - 8ms/step\n",
      "Epoch 303/500\n",
      "2/2 - 0s - loss: 0.0107 - mse: 0.0107 - val_loss: 0.0124 - val_mse: 0.0124 - 16ms/epoch - 8ms/step\n",
      "Epoch 304/500\n",
      "2/2 - 0s - loss: 0.0106 - mse: 0.0106 - val_loss: 0.0123 - val_mse: 0.0123 - 17ms/epoch - 9ms/step\n",
      "Epoch 305/500\n",
      "2/2 - 0s - loss: 0.0106 - mse: 0.0106 - val_loss: 0.0122 - val_mse: 0.0122 - 18ms/epoch - 9ms/step\n",
      "Epoch 306/500\n",
      "2/2 - 0s - loss: 0.0105 - mse: 0.0105 - val_loss: 0.0122 - val_mse: 0.0122 - 13ms/epoch - 6ms/step\n",
      "Epoch 307/500\n",
      "2/2 - 0s - loss: 0.0105 - mse: 0.0105 - val_loss: 0.0121 - val_mse: 0.0121 - 14ms/epoch - 7ms/step\n",
      "Epoch 308/500\n",
      "2/2 - 0s - loss: 0.0104 - mse: 0.0104 - val_loss: 0.0121 - val_mse: 0.0121 - 14ms/epoch - 7ms/step\n",
      "Epoch 309/500\n",
      "2/2 - 0s - loss: 0.0104 - mse: 0.0104 - val_loss: 0.0120 - val_mse: 0.0120 - 15ms/epoch - 7ms/step\n",
      "Epoch 310/500\n",
      "2/2 - 0s - loss: 0.0103 - mse: 0.0103 - val_loss: 0.0119 - val_mse: 0.0119 - 14ms/epoch - 7ms/step\n",
      "Epoch 311/500\n",
      "2/2 - 0s - loss: 0.0103 - mse: 0.0103 - val_loss: 0.0119 - val_mse: 0.0119 - 16ms/epoch - 8ms/step\n",
      "Epoch 312/500\n",
      "2/2 - 0s - loss: 0.0102 - mse: 0.0102 - val_loss: 0.0118 - val_mse: 0.0118 - 24ms/epoch - 12ms/step\n",
      "Epoch 313/500\n",
      "2/2 - 0s - loss: 0.0102 - mse: 0.0102 - val_loss: 0.0118 - val_mse: 0.0118 - 18ms/epoch - 9ms/step\n",
      "Epoch 314/500\n",
      "2/2 - 0s - loss: 0.0101 - mse: 0.0101 - val_loss: 0.0117 - val_mse: 0.0117 - 14ms/epoch - 7ms/step\n",
      "Epoch 315/500\n",
      "2/2 - 0s - loss: 0.0101 - mse: 0.0101 - val_loss: 0.0116 - val_mse: 0.0116 - 14ms/epoch - 7ms/step\n",
      "Epoch 316/500\n",
      "2/2 - 0s - loss: 0.0100 - mse: 0.0100 - val_loss: 0.0116 - val_mse: 0.0116 - 16ms/epoch - 8ms/step\n",
      "Epoch 317/500\n",
      "2/2 - 0s - loss: 0.0100 - mse: 0.0100 - val_loss: 0.0115 - val_mse: 0.0115 - 16ms/epoch - 8ms/step\n",
      "Epoch 318/500\n",
      "2/2 - 0s - loss: 0.0099 - mse: 0.0099 - val_loss: 0.0115 - val_mse: 0.0115 - 15ms/epoch - 8ms/step\n",
      "Epoch 319/500\n",
      "2/2 - 0s - loss: 0.0099 - mse: 0.0099 - val_loss: 0.0114 - val_mse: 0.0114 - 16ms/epoch - 8ms/step\n",
      "Epoch 320/500\n",
      "2/2 - 0s - loss: 0.0098 - mse: 0.0098 - val_loss: 0.0114 - val_mse: 0.0114 - 16ms/epoch - 8ms/step\n",
      "Epoch 321/500\n",
      "2/2 - 0s - loss: 0.0098 - mse: 0.0098 - val_loss: 0.0113 - val_mse: 0.0113 - 15ms/epoch - 8ms/step\n",
      "Epoch 322/500\n",
      "2/2 - 0s - loss: 0.0097 - mse: 0.0097 - val_loss: 0.0113 - val_mse: 0.0113 - 15ms/epoch - 7ms/step\n",
      "Epoch 323/500\n",
      "2/2 - 0s - loss: 0.0097 - mse: 0.0097 - val_loss: 0.0112 - val_mse: 0.0112 - 15ms/epoch - 8ms/step\n",
      "Epoch 324/500\n",
      "2/2 - 0s - loss: 0.0096 - mse: 0.0096 - val_loss: 0.0112 - val_mse: 0.0112 - 15ms/epoch - 8ms/step\n",
      "Epoch 325/500\n",
      "2/2 - 0s - loss: 0.0096 - mse: 0.0096 - val_loss: 0.0111 - val_mse: 0.0111 - 16ms/epoch - 8ms/step\n",
      "Epoch 326/500\n",
      "2/2 - 0s - loss: 0.0095 - mse: 0.0095 - val_loss: 0.0111 - val_mse: 0.0111 - 16ms/epoch - 8ms/step\n",
      "Epoch 327/500\n",
      "2/2 - 0s - loss: 0.0095 - mse: 0.0095 - val_loss: 0.0110 - val_mse: 0.0110 - 15ms/epoch - 8ms/step\n",
      "Epoch 328/500\n",
      "2/2 - 0s - loss: 0.0094 - mse: 0.0094 - val_loss: 0.0110 - val_mse: 0.0110 - 15ms/epoch - 8ms/step\n",
      "Epoch 329/500\n",
      "2/2 - 0s - loss: 0.0094 - mse: 0.0094 - val_loss: 0.0109 - val_mse: 0.0109 - 15ms/epoch - 8ms/step\n",
      "Epoch 330/500\n",
      "2/2 - 0s - loss: 0.0094 - mse: 0.0094 - val_loss: 0.0109 - val_mse: 0.0109 - 16ms/epoch - 8ms/step\n",
      "Epoch 331/500\n",
      "2/2 - 0s - loss: 0.0093 - mse: 0.0093 - val_loss: 0.0108 - val_mse: 0.0108 - 16ms/epoch - 8ms/step\n",
      "Epoch 332/500\n",
      "2/2 - 0s - loss: 0.0093 - mse: 0.0093 - val_loss: 0.0108 - val_mse: 0.0108 - 16ms/epoch - 8ms/step\n",
      "Epoch 333/500\n",
      "2/2 - 0s - loss: 0.0092 - mse: 0.0092 - val_loss: 0.0107 - val_mse: 0.0107 - 15ms/epoch - 7ms/step\n",
      "Epoch 334/500\n",
      "2/2 - 0s - loss: 0.0092 - mse: 0.0092 - val_loss: 0.0107 - val_mse: 0.0107 - 15ms/epoch - 7ms/step\n",
      "Epoch 335/500\n",
      "2/2 - 0s - loss: 0.0092 - mse: 0.0092 - val_loss: 0.0106 - val_mse: 0.0106 - 15ms/epoch - 8ms/step\n",
      "Epoch 336/500\n",
      "2/2 - 0s - loss: 0.0091 - mse: 0.0091 - val_loss: 0.0106 - val_mse: 0.0106 - 15ms/epoch - 8ms/step\n",
      "Epoch 337/500\n",
      "2/2 - 0s - loss: 0.0091 - mse: 0.0091 - val_loss: 0.0105 - val_mse: 0.0105 - 15ms/epoch - 8ms/step\n",
      "Epoch 338/500\n",
      "2/2 - 0s - loss: 0.0090 - mse: 0.0090 - val_loss: 0.0105 - val_mse: 0.0105 - 15ms/epoch - 8ms/step\n",
      "Epoch 339/500\n",
      "2/2 - 0s - loss: 0.0090 - mse: 0.0090 - val_loss: 0.0104 - val_mse: 0.0104 - 16ms/epoch - 8ms/step\n",
      "Epoch 340/500\n",
      "2/2 - 0s - loss: 0.0090 - mse: 0.0090 - val_loss: 0.0104 - val_mse: 0.0104 - 15ms/epoch - 7ms/step\n",
      "Epoch 341/500\n",
      "2/2 - 0s - loss: 0.0089 - mse: 0.0089 - val_loss: 0.0104 - val_mse: 0.0104 - 15ms/epoch - 8ms/step\n",
      "Epoch 342/500\n",
      "2/2 - 0s - loss: 0.0089 - mse: 0.0089 - val_loss: 0.0103 - val_mse: 0.0103 - 16ms/epoch - 8ms/step\n",
      "Epoch 343/500\n",
      "2/2 - 0s - loss: 0.0088 - mse: 0.0088 - val_loss: 0.0103 - val_mse: 0.0103 - 16ms/epoch - 8ms/step\n",
      "Epoch 344/500\n",
      "2/2 - 0s - loss: 0.0088 - mse: 0.0088 - val_loss: 0.0102 - val_mse: 0.0102 - 16ms/epoch - 8ms/step\n",
      "Epoch 345/500\n",
      "2/2 - 0s - loss: 0.0088 - mse: 0.0088 - val_loss: 0.0102 - val_mse: 0.0102 - 16ms/epoch - 8ms/step\n",
      "Epoch 346/500\n",
      "2/2 - 0s - loss: 0.0087 - mse: 0.0087 - val_loss: 0.0101 - val_mse: 0.0101 - 20ms/epoch - 10ms/step\n",
      "Epoch 347/500\n",
      "2/2 - 0s - loss: 0.0087 - mse: 0.0087 - val_loss: 0.0101 - val_mse: 0.0101 - 15ms/epoch - 8ms/step\n",
      "Epoch 348/500\n",
      "2/2 - 0s - loss: 0.0086 - mse: 0.0086 - val_loss: 0.0100 - val_mse: 0.0100 - 15ms/epoch - 7ms/step\n",
      "Epoch 349/500\n",
      "2/2 - 0s - loss: 0.0086 - mse: 0.0086 - val_loss: 0.0100 - val_mse: 0.0100 - 15ms/epoch - 8ms/step\n",
      "Epoch 350/500\n",
      "2/2 - 0s - loss: 0.0086 - mse: 0.0086 - val_loss: 0.0100 - val_mse: 0.0100 - 15ms/epoch - 8ms/step\n",
      "Epoch 351/500\n",
      "2/2 - 0s - loss: 0.0085 - mse: 0.0085 - val_loss: 0.0099 - val_mse: 0.0099 - 17ms/epoch - 8ms/step\n",
      "Epoch 352/500\n",
      "2/2 - 0s - loss: 0.0085 - mse: 0.0085 - val_loss: 0.0099 - val_mse: 0.0099 - 16ms/epoch - 8ms/step\n",
      "Epoch 353/500\n",
      "2/2 - 0s - loss: 0.0085 - mse: 0.0085 - val_loss: 0.0098 - val_mse: 0.0098 - 16ms/epoch - 8ms/step\n",
      "Epoch 354/500\n",
      "2/2 - 0s - loss: 0.0084 - mse: 0.0084 - val_loss: 0.0098 - val_mse: 0.0098 - 16ms/epoch - 8ms/step\n",
      "Epoch 355/500\n",
      "2/2 - 0s - loss: 0.0084 - mse: 0.0084 - val_loss: 0.0098 - val_mse: 0.0098 - 16ms/epoch - 8ms/step\n",
      "Epoch 356/500\n",
      "2/2 - 0s - loss: 0.0084 - mse: 0.0084 - val_loss: 0.0097 - val_mse: 0.0097 - 16ms/epoch - 8ms/step\n",
      "Epoch 357/500\n",
      "2/2 - 0s - loss: 0.0083 - mse: 0.0083 - val_loss: 0.0097 - val_mse: 0.0097 - 16ms/epoch - 8ms/step\n",
      "Epoch 358/500\n",
      "2/2 - 0s - loss: 0.0083 - mse: 0.0083 - val_loss: 0.0096 - val_mse: 0.0096 - 16ms/epoch - 8ms/step\n",
      "Epoch 359/500\n",
      "2/2 - 0s - loss: 0.0083 - mse: 0.0083 - val_loss: 0.0096 - val_mse: 0.0096 - 16ms/epoch - 8ms/step\n",
      "Epoch 360/500\n",
      "2/2 - 0s - loss: 0.0082 - mse: 0.0082 - val_loss: 0.0096 - val_mse: 0.0096 - 16ms/epoch - 8ms/step\n",
      "Epoch 361/500\n",
      "2/2 - 0s - loss: 0.0082 - mse: 0.0082 - val_loss: 0.0095 - val_mse: 0.0095 - 16ms/epoch - 8ms/step\n",
      "Epoch 362/500\n",
      "2/2 - 0s - loss: 0.0082 - mse: 0.0082 - val_loss: 0.0095 - val_mse: 0.0095 - 16ms/epoch - 8ms/step\n",
      "Epoch 363/500\n",
      "2/2 - 0s - loss: 0.0081 - mse: 0.0081 - val_loss: 0.0095 - val_mse: 0.0095 - 16ms/epoch - 8ms/step\n",
      "Epoch 364/500\n",
      "2/2 - 0s - loss: 0.0081 - mse: 0.0081 - val_loss: 0.0094 - val_mse: 0.0094 - 16ms/epoch - 8ms/step\n",
      "Epoch 365/500\n",
      "2/2 - 0s - loss: 0.0081 - mse: 0.0081 - val_loss: 0.0094 - val_mse: 0.0094 - 16ms/epoch - 8ms/step\n",
      "Epoch 366/500\n",
      "2/2 - 0s - loss: 0.0080 - mse: 0.0080 - val_loss: 0.0093 - val_mse: 0.0093 - 16ms/epoch - 8ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 367/500\n",
      "2/2 - 0s - loss: 0.0080 - mse: 0.0080 - val_loss: 0.0093 - val_mse: 0.0093 - 16ms/epoch - 8ms/step\n",
      "Epoch 368/500\n",
      "2/2 - 0s - loss: 0.0080 - mse: 0.0080 - val_loss: 0.0093 - val_mse: 0.0093 - 15ms/epoch - 7ms/step\n",
      "Epoch 369/500\n",
      "2/2 - 0s - loss: 0.0079 - mse: 0.0079 - val_loss: 0.0092 - val_mse: 0.0092 - 15ms/epoch - 7ms/step\n",
      "Epoch 370/500\n",
      "2/2 - 0s - loss: 0.0079 - mse: 0.0079 - val_loss: 0.0092 - val_mse: 0.0092 - 17ms/epoch - 8ms/step\n",
      "Epoch 371/500\n",
      "2/2 - 0s - loss: 0.0079 - mse: 0.0079 - val_loss: 0.0092 - val_mse: 0.0092 - 15ms/epoch - 8ms/step\n",
      "Epoch 372/500\n",
      "2/2 - 0s - loss: 0.0078 - mse: 0.0078 - val_loss: 0.0091 - val_mse: 0.0091 - 16ms/epoch - 8ms/step\n",
      "Epoch 373/500\n",
      "2/2 - 0s - loss: 0.0078 - mse: 0.0078 - val_loss: 0.0091 - val_mse: 0.0091 - 15ms/epoch - 8ms/step\n",
      "Epoch 374/500\n",
      "2/2 - 0s - loss: 0.0078 - mse: 0.0078 - val_loss: 0.0091 - val_mse: 0.0091 - 15ms/epoch - 8ms/step\n",
      "Epoch 375/500\n",
      "2/2 - 0s - loss: 0.0078 - mse: 0.0078 - val_loss: 0.0090 - val_mse: 0.0090 - 16ms/epoch - 8ms/step\n",
      "Epoch 376/500\n",
      "2/2 - 0s - loss: 0.0077 - mse: 0.0077 - val_loss: 0.0090 - val_mse: 0.0090 - 15ms/epoch - 8ms/step\n",
      "Epoch 377/500\n",
      "2/2 - 0s - loss: 0.0077 - mse: 0.0077 - val_loss: 0.0090 - val_mse: 0.0090 - 15ms/epoch - 7ms/step\n",
      "Epoch 378/500\n",
      "2/2 - 0s - loss: 0.0077 - mse: 0.0077 - val_loss: 0.0089 - val_mse: 0.0089 - 15ms/epoch - 8ms/step\n",
      "Epoch 379/500\n",
      "2/2 - 0s - loss: 0.0076 - mse: 0.0076 - val_loss: 0.0089 - val_mse: 0.0089 - 16ms/epoch - 8ms/step\n",
      "Epoch 380/500\n",
      "2/2 - 0s - loss: 0.0076 - mse: 0.0076 - val_loss: 0.0089 - val_mse: 0.0089 - 16ms/epoch - 8ms/step\n",
      "Epoch 381/500\n",
      "2/2 - 0s - loss: 0.0076 - mse: 0.0076 - val_loss: 0.0088 - val_mse: 0.0088 - 16ms/epoch - 8ms/step\n",
      "Epoch 382/500\n",
      "2/2 - 0s - loss: 0.0076 - mse: 0.0076 - val_loss: 0.0088 - val_mse: 0.0088 - 14ms/epoch - 7ms/step\n",
      "Epoch 383/500\n",
      "2/2 - 0s - loss: 0.0075 - mse: 0.0075 - val_loss: 0.0088 - val_mse: 0.0088 - 17ms/epoch - 8ms/step\n",
      "Epoch 384/500\n",
      "2/2 - 0s - loss: 0.0075 - mse: 0.0075 - val_loss: 0.0087 - val_mse: 0.0087 - 15ms/epoch - 8ms/step\n",
      "Epoch 385/500\n",
      "2/2 - 0s - loss: 0.0075 - mse: 0.0075 - val_loss: 0.0087 - val_mse: 0.0087 - 16ms/epoch - 8ms/step\n",
      "Epoch 386/500\n",
      "2/2 - 0s - loss: 0.0074 - mse: 0.0074 - val_loss: 0.0087 - val_mse: 0.0087 - 15ms/epoch - 7ms/step\n",
      "Epoch 387/500\n",
      "2/2 - 0s - loss: 0.0074 - mse: 0.0074 - val_loss: 0.0086 - val_mse: 0.0086 - 16ms/epoch - 8ms/step\n",
      "Epoch 388/500\n",
      "2/2 - 0s - loss: 0.0074 - mse: 0.0074 - val_loss: 0.0086 - val_mse: 0.0086 - 14ms/epoch - 7ms/step\n",
      "Epoch 389/500\n",
      "2/2 - 0s - loss: 0.0074 - mse: 0.0074 - val_loss: 0.0086 - val_mse: 0.0086 - 17ms/epoch - 8ms/step\n",
      "Epoch 390/500\n",
      "2/2 - 0s - loss: 0.0073 - mse: 0.0073 - val_loss: 0.0085 - val_mse: 0.0085 - 15ms/epoch - 7ms/step\n",
      "Epoch 391/500\n",
      "2/2 - 0s - loss: 0.0073 - mse: 0.0073 - val_loss: 0.0085 - val_mse: 0.0085 - 16ms/epoch - 8ms/step\n",
      "Epoch 392/500\n",
      "2/2 - 0s - loss: 0.0073 - mse: 0.0073 - val_loss: 0.0085 - val_mse: 0.0085 - 15ms/epoch - 7ms/step\n",
      "Epoch 393/500\n",
      "2/2 - 0s - loss: 0.0073 - mse: 0.0073 - val_loss: 0.0085 - val_mse: 0.0085 - 16ms/epoch - 8ms/step\n",
      "Epoch 394/500\n",
      "2/2 - 0s - loss: 0.0072 - mse: 0.0072 - val_loss: 0.0084 - val_mse: 0.0084 - 16ms/epoch - 8ms/step\n",
      "Epoch 395/500\n",
      "2/2 - 0s - loss: 0.0072 - mse: 0.0072 - val_loss: 0.0084 - val_mse: 0.0084 - 16ms/epoch - 8ms/step\n",
      "Epoch 396/500\n",
      "2/2 - 0s - loss: 0.0072 - mse: 0.0072 - val_loss: 0.0084 - val_mse: 0.0084 - 16ms/epoch - 8ms/step\n",
      "Epoch 397/500\n",
      "2/2 - 0s - loss: 0.0072 - mse: 0.0072 - val_loss: 0.0083 - val_mse: 0.0083 - 15ms/epoch - 8ms/step\n",
      "Epoch 398/500\n",
      "2/2 - 0s - loss: 0.0071 - mse: 0.0071 - val_loss: 0.0083 - val_mse: 0.0083 - 15ms/epoch - 8ms/step\n",
      "Epoch 399/500\n",
      "2/2 - 0s - loss: 0.0071 - mse: 0.0071 - val_loss: 0.0083 - val_mse: 0.0083 - 16ms/epoch - 8ms/step\n",
      "Epoch 400/500\n",
      "2/2 - 0s - loss: 0.0071 - mse: 0.0071 - val_loss: 0.0083 - val_mse: 0.0083 - 16ms/epoch - 8ms/step\n",
      "Epoch 401/500\n",
      "2/2 - 0s - loss: 0.0071 - mse: 0.0071 - val_loss: 0.0082 - val_mse: 0.0082 - 16ms/epoch - 8ms/step\n",
      "Epoch 402/500\n",
      "2/2 - 0s - loss: 0.0070 - mse: 0.0070 - val_loss: 0.0082 - val_mse: 0.0082 - 16ms/epoch - 8ms/step\n",
      "Epoch 403/500\n",
      "2/2 - 0s - loss: 0.0070 - mse: 0.0070 - val_loss: 0.0082 - val_mse: 0.0082 - 16ms/epoch - 8ms/step\n",
      "Epoch 404/500\n",
      "2/2 - 0s - loss: 0.0070 - mse: 0.0070 - val_loss: 0.0082 - val_mse: 0.0082 - 16ms/epoch - 8ms/step\n",
      "Epoch 405/500\n",
      "2/2 - 0s - loss: 0.0070 - mse: 0.0070 - val_loss: 0.0081 - val_mse: 0.0081 - 16ms/epoch - 8ms/step\n",
      "Epoch 406/500\n",
      "2/2 - 0s - loss: 0.0069 - mse: 0.0069 - val_loss: 0.0081 - val_mse: 0.0081 - 16ms/epoch - 8ms/step\n",
      "Epoch 407/500\n",
      "2/2 - 0s - loss: 0.0069 - mse: 0.0069 - val_loss: 0.0081 - val_mse: 0.0081 - 16ms/epoch - 8ms/step\n",
      "Epoch 408/500\n",
      "2/2 - 0s - loss: 0.0069 - mse: 0.0069 - val_loss: 0.0081 - val_mse: 0.0081 - 16ms/epoch - 8ms/step\n",
      "Epoch 409/500\n",
      "2/2 - 0s - loss: 0.0069 - mse: 0.0069 - val_loss: 0.0080 - val_mse: 0.0080 - 15ms/epoch - 8ms/step\n",
      "Epoch 410/500\n",
      "2/2 - 0s - loss: 0.0069 - mse: 0.0069 - val_loss: 0.0080 - val_mse: 0.0080 - 16ms/epoch - 8ms/step\n",
      "Epoch 411/500\n",
      "2/2 - 0s - loss: 0.0068 - mse: 0.0068 - val_loss: 0.0080 - val_mse: 0.0080 - 16ms/epoch - 8ms/step\n",
      "Epoch 412/500\n",
      "2/2 - 0s - loss: 0.0068 - mse: 0.0068 - val_loss: 0.0079 - val_mse: 0.0079 - 16ms/epoch - 8ms/step\n",
      "Epoch 413/500\n",
      "2/2 - 0s - loss: 0.0068 - mse: 0.0068 - val_loss: 0.0079 - val_mse: 0.0079 - 16ms/epoch - 8ms/step\n",
      "Epoch 414/500\n",
      "2/2 - 0s - loss: 0.0068 - mse: 0.0068 - val_loss: 0.0079 - val_mse: 0.0079 - 16ms/epoch - 8ms/step\n",
      "Epoch 415/500\n",
      "2/2 - 0s - loss: 0.0067 - mse: 0.0067 - val_loss: 0.0079 - val_mse: 0.0079 - 16ms/epoch - 8ms/step\n",
      "Epoch 416/500\n",
      "2/2 - 0s - loss: 0.0067 - mse: 0.0067 - val_loss: 0.0078 - val_mse: 0.0078 - 16ms/epoch - 8ms/step\n",
      "Epoch 417/500\n",
      "2/2 - 0s - loss: 0.0067 - mse: 0.0067 - val_loss: 0.0078 - val_mse: 0.0078 - 16ms/epoch - 8ms/step\n",
      "Epoch 418/500\n",
      "2/2 - 0s - loss: 0.0067 - mse: 0.0067 - val_loss: 0.0078 - val_mse: 0.0078 - 16ms/epoch - 8ms/step\n",
      "Epoch 419/500\n",
      "2/2 - 0s - loss: 0.0067 - mse: 0.0067 - val_loss: 0.0078 - val_mse: 0.0078 - 16ms/epoch - 8ms/step\n",
      "Epoch 420/500\n",
      "2/2 - 0s - loss: 0.0066 - mse: 0.0066 - val_loss: 0.0077 - val_mse: 0.0077 - 16ms/epoch - 8ms/step\n",
      "Epoch 421/500\n",
      "2/2 - 0s - loss: 0.0066 - mse: 0.0066 - val_loss: 0.0077 - val_mse: 0.0077 - 16ms/epoch - 8ms/step\n",
      "Epoch 422/500\n",
      "2/2 - 0s - loss: 0.0066 - mse: 0.0066 - val_loss: 0.0077 - val_mse: 0.0077 - 15ms/epoch - 8ms/step\n",
      "Epoch 423/500\n",
      "2/2 - 0s - loss: 0.0066 - mse: 0.0066 - val_loss: 0.0077 - val_mse: 0.0077 - 16ms/epoch - 8ms/step\n",
      "Epoch 424/500\n",
      "2/2 - 0s - loss: 0.0066 - mse: 0.0066 - val_loss: 0.0076 - val_mse: 0.0076 - 16ms/epoch - 8ms/step\n",
      "Epoch 425/500\n",
      "2/2 - 0s - loss: 0.0065 - mse: 0.0065 - val_loss: 0.0076 - val_mse: 0.0076 - 16ms/epoch - 8ms/step\n",
      "Epoch 426/500\n",
      "2/2 - 0s - loss: 0.0065 - mse: 0.0065 - val_loss: 0.0076 - val_mse: 0.0076 - 16ms/epoch - 8ms/step\n",
      "Epoch 427/500\n",
      "2/2 - 0s - loss: 0.0065 - mse: 0.0065 - val_loss: 0.0076 - val_mse: 0.0076 - 16ms/epoch - 8ms/step\n",
      "Epoch 428/500\n",
      "2/2 - 0s - loss: 0.0065 - mse: 0.0065 - val_loss: 0.0075 - val_mse: 0.0075 - 16ms/epoch - 8ms/step\n",
      "Epoch 429/500\n",
      "2/2 - 0s - loss: 0.0064 - mse: 0.0064 - val_loss: 0.0075 - val_mse: 0.0075 - 16ms/epoch - 8ms/step\n",
      "Epoch 430/500\n",
      "2/2 - 0s - loss: 0.0064 - mse: 0.0064 - val_loss: 0.0075 - val_mse: 0.0075 - 17ms/epoch - 8ms/step\n",
      "Epoch 431/500\n",
      "2/2 - 0s - loss: 0.0064 - mse: 0.0064 - val_loss: 0.0075 - val_mse: 0.0075 - 15ms/epoch - 8ms/step\n",
      "Epoch 432/500\n",
      "2/2 - 0s - loss: 0.0064 - mse: 0.0064 - val_loss: 0.0075 - val_mse: 0.0075 - 16ms/epoch - 8ms/step\n",
      "Epoch 433/500\n",
      "2/2 - 0s - loss: 0.0064 - mse: 0.0064 - val_loss: 0.0074 - val_mse: 0.0074 - 15ms/epoch - 8ms/step\n",
      "Epoch 434/500\n",
      "2/2 - 0s - loss: 0.0064 - mse: 0.0064 - val_loss: 0.0074 - val_mse: 0.0074 - 16ms/epoch - 8ms/step\n",
      "Epoch 435/500\n",
      "2/2 - 0s - loss: 0.0063 - mse: 0.0063 - val_loss: 0.0074 - val_mse: 0.0074 - 16ms/epoch - 8ms/step\n",
      "Epoch 436/500\n",
      "2/2 - 0s - loss: 0.0063 - mse: 0.0063 - val_loss: 0.0074 - val_mse: 0.0074 - 16ms/epoch - 8ms/step\n",
      "Epoch 437/500\n",
      "2/2 - 0s - loss: 0.0063 - mse: 0.0063 - val_loss: 0.0074 - val_mse: 0.0074 - 15ms/epoch - 8ms/step\n",
      "Epoch 438/500\n",
      "2/2 - 0s - loss: 0.0063 - mse: 0.0063 - val_loss: 0.0073 - val_mse: 0.0073 - 15ms/epoch - 8ms/step\n",
      "Epoch 439/500\n",
      "2/2 - 0s - loss: 0.0063 - mse: 0.0063 - val_loss: 0.0073 - val_mse: 0.0073 - 16ms/epoch - 8ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 440/500\n",
      "2/2 - 0s - loss: 0.0062 - mse: 0.0062 - val_loss: 0.0073 - val_mse: 0.0073 - 16ms/epoch - 8ms/step\n",
      "Epoch 441/500\n",
      "2/2 - 0s - loss: 0.0062 - mse: 0.0062 - val_loss: 0.0073 - val_mse: 0.0073 - 15ms/epoch - 8ms/step\n",
      "Epoch 442/500\n",
      "2/2 - 0s - loss: 0.0062 - mse: 0.0062 - val_loss: 0.0072 - val_mse: 0.0072 - 16ms/epoch - 8ms/step\n",
      "Epoch 443/500\n",
      "2/2 - 0s - loss: 0.0062 - mse: 0.0062 - val_loss: 0.0072 - val_mse: 0.0072 - 15ms/epoch - 8ms/step\n",
      "Epoch 444/500\n",
      "2/2 - 0s - loss: 0.0062 - mse: 0.0062 - val_loss: 0.0072 - val_mse: 0.0072 - 15ms/epoch - 8ms/step\n",
      "Epoch 445/500\n",
      "2/2 - 0s - loss: 0.0061 - mse: 0.0061 - val_loss: 0.0072 - val_mse: 0.0072 - 16ms/epoch - 8ms/step\n",
      "Epoch 446/500\n",
      "2/2 - 0s - loss: 0.0061 - mse: 0.0061 - val_loss: 0.0072 - val_mse: 0.0072 - 16ms/epoch - 8ms/step\n",
      "Epoch 447/500\n",
      "2/2 - 0s - loss: 0.0061 - mse: 0.0061 - val_loss: 0.0072 - val_mse: 0.0072 - 16ms/epoch - 8ms/step\n",
      "Epoch 448/500\n",
      "2/2 - 0s - loss: 0.0061 - mse: 0.0061 - val_loss: 0.0071 - val_mse: 0.0071 - 15ms/epoch - 7ms/step\n",
      "Epoch 449/500\n",
      "2/2 - 0s - loss: 0.0061 - mse: 0.0061 - val_loss: 0.0071 - val_mse: 0.0071 - 15ms/epoch - 8ms/step\n",
      "Epoch 450/500\n",
      "2/2 - 0s - loss: 0.0061 - mse: 0.0061 - val_loss: 0.0071 - val_mse: 0.0071 - 15ms/epoch - 7ms/step\n",
      "Epoch 451/500\n",
      "2/2 - 0s - loss: 0.0060 - mse: 0.0060 - val_loss: 0.0071 - val_mse: 0.0071 - 16ms/epoch - 8ms/step\n",
      "Epoch 452/500\n",
      "2/2 - 0s - loss: 0.0060 - mse: 0.0060 - val_loss: 0.0070 - val_mse: 0.0070 - 15ms/epoch - 8ms/step\n",
      "Epoch 453/500\n",
      "2/2 - 0s - loss: 0.0060 - mse: 0.0060 - val_loss: 0.0070 - val_mse: 0.0070 - 15ms/epoch - 7ms/step\n",
      "Epoch 454/500\n",
      "2/2 - 0s - loss: 0.0060 - mse: 0.0060 - val_loss: 0.0070 - val_mse: 0.0070 - 16ms/epoch - 8ms/step\n",
      "Epoch 455/500\n",
      "2/2 - 0s - loss: 0.0060 - mse: 0.0060 - val_loss: 0.0070 - val_mse: 0.0070 - 16ms/epoch - 8ms/step\n",
      "Epoch 456/500\n",
      "2/2 - 0s - loss: 0.0060 - mse: 0.0060 - val_loss: 0.0070 - val_mse: 0.0070 - 16ms/epoch - 8ms/step\n",
      "Epoch 457/500\n",
      "2/2 - 0s - loss: 0.0059 - mse: 0.0059 - val_loss: 0.0069 - val_mse: 0.0069 - 15ms/epoch - 8ms/step\n",
      "Epoch 458/500\n",
      "2/2 - 0s - loss: 0.0059 - mse: 0.0059 - val_loss: 0.0069 - val_mse: 0.0069 - 14ms/epoch - 7ms/step\n",
      "Epoch 459/500\n",
      "2/2 - 0s - loss: 0.0059 - mse: 0.0059 - val_loss: 0.0069 - val_mse: 0.0069 - 15ms/epoch - 8ms/step\n",
      "Epoch 460/500\n",
      "2/2 - 0s - loss: 0.0059 - mse: 0.0059 - val_loss: 0.0069 - val_mse: 0.0069 - 17ms/epoch - 8ms/step\n",
      "Epoch 461/500\n",
      "2/2 - 0s - loss: 0.0059 - mse: 0.0059 - val_loss: 0.0069 - val_mse: 0.0069 - 16ms/epoch - 8ms/step\n",
      "Epoch 462/500\n",
      "2/2 - 0s - loss: 0.0059 - mse: 0.0059 - val_loss: 0.0069 - val_mse: 0.0069 - 16ms/epoch - 8ms/step\n",
      "Epoch 463/500\n",
      "2/2 - 0s - loss: 0.0058 - mse: 0.0058 - val_loss: 0.0068 - val_mse: 0.0068 - 16ms/epoch - 8ms/step\n",
      "Epoch 464/500\n",
      "2/2 - 0s - loss: 0.0058 - mse: 0.0058 - val_loss: 0.0068 - val_mse: 0.0068 - 15ms/epoch - 8ms/step\n",
      "Epoch 465/500\n",
      "2/2 - 0s - loss: 0.0058 - mse: 0.0058 - val_loss: 0.0068 - val_mse: 0.0068 - 16ms/epoch - 8ms/step\n",
      "Epoch 466/500\n",
      "2/2 - 0s - loss: 0.0058 - mse: 0.0058 - val_loss: 0.0068 - val_mse: 0.0068 - 15ms/epoch - 7ms/step\n",
      "Epoch 467/500\n",
      "2/2 - 0s - loss: 0.0058 - mse: 0.0058 - val_loss: 0.0067 - val_mse: 0.0067 - 17ms/epoch - 8ms/step\n",
      "Epoch 468/500\n",
      "2/2 - 0s - loss: 0.0058 - mse: 0.0058 - val_loss: 0.0067 - val_mse: 0.0067 - 15ms/epoch - 8ms/step\n",
      "Epoch 469/500\n",
      "2/2 - 0s - loss: 0.0057 - mse: 0.0057 - val_loss: 0.0067 - val_mse: 0.0067 - 16ms/epoch - 8ms/step\n",
      "Epoch 470/500\n",
      "2/2 - 0s - loss: 0.0057 - mse: 0.0057 - val_loss: 0.0067 - val_mse: 0.0067 - 16ms/epoch - 8ms/step\n",
      "Epoch 471/500\n",
      "2/2 - 0s - loss: 0.0057 - mse: 0.0057 - val_loss: 0.0067 - val_mse: 0.0067 - 23ms/epoch - 11ms/step\n",
      "Epoch 472/500\n",
      "2/2 - 0s - loss: 0.0057 - mse: 0.0057 - val_loss: 0.0067 - val_mse: 0.0067 - 19ms/epoch - 9ms/step\n",
      "Epoch 473/500\n",
      "2/2 - 0s - loss: 0.0057 - mse: 0.0057 - val_loss: 0.0066 - val_mse: 0.0066 - 15ms/epoch - 7ms/step\n",
      "Epoch 474/500\n",
      "2/2 - 0s - loss: 0.0057 - mse: 0.0057 - val_loss: 0.0066 - val_mse: 0.0066 - 14ms/epoch - 7ms/step\n",
      "Epoch 475/500\n",
      "2/2 - 0s - loss: 0.0056 - mse: 0.0056 - val_loss: 0.0066 - val_mse: 0.0066 - 16ms/epoch - 8ms/step\n",
      "Epoch 476/500\n",
      "2/2 - 0s - loss: 0.0056 - mse: 0.0056 - val_loss: 0.0066 - val_mse: 0.0066 - 15ms/epoch - 8ms/step\n",
      "Epoch 477/500\n",
      "2/2 - 0s - loss: 0.0056 - mse: 0.0056 - val_loss: 0.0066 - val_mse: 0.0066 - 15ms/epoch - 8ms/step\n",
      "Epoch 478/500\n",
      "2/2 - 0s - loss: 0.0056 - mse: 0.0056 - val_loss: 0.0065 - val_mse: 0.0065 - 16ms/epoch - 8ms/step\n",
      "Epoch 479/500\n",
      "2/2 - 0s - loss: 0.0056 - mse: 0.0056 - val_loss: 0.0065 - val_mse: 0.0065 - 16ms/epoch - 8ms/step\n",
      "Epoch 480/500\n",
      "2/2 - 0s - loss: 0.0056 - mse: 0.0056 - val_loss: 0.0065 - val_mse: 0.0065 - 16ms/epoch - 8ms/step\n",
      "Epoch 481/500\n",
      "2/2 - 0s - loss: 0.0056 - mse: 0.0056 - val_loss: 0.0065 - val_mse: 0.0065 - 16ms/epoch - 8ms/step\n",
      "Epoch 482/500\n",
      "2/2 - 0s - loss: 0.0055 - mse: 0.0055 - val_loss: 0.0065 - val_mse: 0.0065 - 15ms/epoch - 8ms/step\n",
      "Epoch 483/500\n",
      "2/2 - 0s - loss: 0.0055 - mse: 0.0055 - val_loss: 0.0065 - val_mse: 0.0065 - 16ms/epoch - 8ms/step\n",
      "Epoch 484/500\n",
      "2/2 - 0s - loss: 0.0055 - mse: 0.0055 - val_loss: 0.0065 - val_mse: 0.0065 - 15ms/epoch - 8ms/step\n",
      "Epoch 485/500\n",
      "2/2 - 0s - loss: 0.0055 - mse: 0.0055 - val_loss: 0.0064 - val_mse: 0.0064 - 15ms/epoch - 8ms/step\n",
      "Epoch 486/500\n",
      "2/2 - 0s - loss: 0.0055 - mse: 0.0055 - val_loss: 0.0064 - val_mse: 0.0064 - 15ms/epoch - 8ms/step\n",
      "Epoch 487/500\n",
      "2/2 - 0s - loss: 0.0055 - mse: 0.0055 - val_loss: 0.0064 - val_mse: 0.0064 - 15ms/epoch - 7ms/step\n",
      "Epoch 488/500\n",
      "2/2 - 0s - loss: 0.0055 - mse: 0.0055 - val_loss: 0.0064 - val_mse: 0.0064 - 15ms/epoch - 7ms/step\n",
      "Epoch 489/500\n",
      "2/2 - 0s - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0064 - val_mse: 0.0064 - 15ms/epoch - 8ms/step\n",
      "Epoch 490/500\n",
      "2/2 - 0s - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0063 - val_mse: 0.0063 - 16ms/epoch - 8ms/step\n",
      "Epoch 491/500\n",
      "2/2 - 0s - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0063 - val_mse: 0.0063 - 15ms/epoch - 8ms/step\n",
      "Epoch 492/500\n",
      "2/2 - 0s - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0063 - val_mse: 0.0063 - 15ms/epoch - 8ms/step\n",
      "Epoch 493/500\n",
      "2/2 - 0s - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0063 - val_mse: 0.0063 - 15ms/epoch - 8ms/step\n",
      "Epoch 494/500\n",
      "2/2 - 0s - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0063 - val_mse: 0.0063 - 16ms/epoch - 8ms/step\n",
      "Epoch 495/500\n",
      "2/2 - 0s - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0063 - val_mse: 0.0063 - 15ms/epoch - 8ms/step\n",
      "Epoch 496/500\n",
      "2/2 - 0s - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0063 - val_mse: 0.0063 - 15ms/epoch - 8ms/step\n",
      "Epoch 497/500\n",
      "2/2 - 0s - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0062 - val_mse: 0.0062 - 15ms/epoch - 7ms/step\n",
      "Epoch 498/500\n",
      "2/2 - 0s - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0062 - val_mse: 0.0062 - 15ms/epoch - 7ms/step\n",
      "Epoch 499/500\n",
      "2/2 - 0s - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0062 - val_mse: 0.0062 - 15ms/epoch - 8ms/step\n",
      "Epoch 500/500\n",
      "2/2 - 0s - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0062 - val_mse: 0.0062 - 16ms/epoch - 8ms/step\n",
      "9.157248973846436\n"
     ]
    }
   ],
   "source": [
    "###the standard RNN model is trained as benchmark, to show the priority of transfer learning RNN models\n",
    "model = Sequential()\n",
    "model.add(SimpleRNN(16, activation='tanh', return_sequences=True))\n",
    "model.add(Dense(3, activation='linear'))\n",
    "model.compile(optimizer='adam', loss='mse', metrics=['mse'])\n",
    "\n",
    "t0 = time.time()\n",
    "history = model.fit(X_train, y_train, epochs=500, batch_size=256, validation_split=0.25, verbose=2)\n",
    "t1 = time.time()\n",
    "print(t1-t0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "81f35c9a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "38/38 [==============================] - 0s 614us/step - loss: 0.0056 - mse: 0.0056\n",
      "[0.005602560006082058, 0.005602559540420771]\n",
      "9.157248973846436\n"
     ]
    }
   ],
   "source": [
    "#use the SOURCE test data to evaluate the model\n",
    "loss_and_metrics = model.evaluate(X_test, y_test, batch_size=256)\n",
    "print(loss_and_metrics)\n",
    "\n",
    "print(t1-t0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ef7b115a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "38/38 [==============================] - 0s 571us/step - loss: 0.0056 - mse: 0.0056\n",
      "[0.005602560006082058, 0.005602559540420771]\n",
      "9.157248973846436\n"
     ]
    }
   ],
   "source": [
    "#use the TARGET test data to evaluate the model\n",
    "loss_and_metrics = model.evaluate(X_test, y_test, batch_size=256)\n",
    "print(loss_and_metrics)\n",
    "\n",
    "print(t1-t0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e569d35c",
   "metadata": {},
   "source": [
    "# construct the NN after feature transformation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d659a2af",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = load_model(\"pretrained_model.h5\")\n",
    "X_train_S = np.load('source_input_data.npy')\n",
    "y_train_S = np.load('source_output_data.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4687c4cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "71/71 [==============================] - 0s 538us/step - loss: 0.0011 - mse: 0.0011  \n",
      "2/2 [==============================] - 0s 1ms/step - loss: 2.4447 - mse: 2.4447\n",
      "73/73 [==============================] - 0s 545us/step - loss: 0.0666 - mse: 0.0666\n"
     ]
    }
   ],
   "source": [
    "#combine source and target for training\n",
    "#first 640 target, 3995 source\n",
    "RNN_input_final_Train = np.concatenate((X_train, X_train_S), axis=0) \n",
    "RNN_output_final_Train = np.concatenate((y_train, y_train_S), axis=0)\n",
    "\n",
    "NN_S = model.predict(X_train_S)\n",
    "\n",
    "#performance of the pre-trained model on train data\n",
    "loss_and_metrics = model.evaluate(X_train_S, y_train_S, batch_size=256)\n",
    "loss_and_metrics = model.evaluate(X_train, y_train, batch_size=256)\n",
    "\n",
    "loss_and_metrics = model.evaluate(RNN_input_final_Train, RNN_output_final_Train, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f2c61d08",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 2: Freeze the original model's weights\n",
    "for layer in model.layers:\n",
    "    layer.trainable = False\n",
    "\n",
    "# Step 3: Create new layers\n",
    "input_layer = Input(shape=(None, 4))  # Replace input_dim with your actual input feature dimension\n",
    "new_input_layer = Dense(8, activation='linear', use_bias=False)  # New front layer\n",
    "new_input_layer1 = Dense(8, activation='linear', use_bias=False)  # New front layer\n",
    "new_input_layer2 = Dense(4, activation='linear', use_bias=False)  # New front layer\n",
    "\n",
    "# Connect the frozen original model\n",
    "x = model(new_input_layer2(new_input_layer1(new_input_layer(input_layer))))\n",
    "\n",
    "# Add new output layer\n",
    "new_output_layer = Dense(4, activation='linear', use_bias=False)  # Replace output_dim with your desired final output size\n",
    "new_output_layer1 = Dense(4, activation='linear', use_bias=False)  # Replace output_dim with your desired final output size\n",
    "new_output_layer2 = Dense(3, activation='linear', use_bias=False)  # Replace output_dim with your desired final output size\n",
    "\n",
    "output = new_output_layer2(new_output_layer1(new_output_layer(x)))\n",
    "# Step 4: Assemble final model\n",
    "final_model = Model(inputs=input_layer, outputs=output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e9a742b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#assign parameters\n",
    "# Define A: 4x8 matrix [I_4 | 0]\n",
    "A = np.array([\n",
    "    [1, 0, 0, 0, 0, 0, 0, 0],\n",
    "    [0, 1, 0, 0, 0, 0, 0, 0],\n",
    "    [0, 0, 1, 0, 0, 0, 0, 0],\n",
    "    [0, 0, 0, 1, 0, 0, 0, 0]\n",
    "], dtype=np.float32)\n",
    "\n",
    "# Define B: 8x8 identity matrix\n",
    "B = np.array(np.eye(8), dtype=np.float32)\n",
    "\n",
    "# Define C: 8x4 matrix [I_4; 0]\n",
    "C = np.array([\n",
    "    [1, 0, 0, 0],\n",
    "    [0, 1, 0, 0],\n",
    "    [0, 0, 1, 0],\n",
    "    [0, 0, 0, 1],\n",
    "    [0, 0, 0, 0],\n",
    "    [0, 0, 0, 0],\n",
    "    [0, 0, 0, 0],\n",
    "    [0, 0, 0, 0]\n",
    "], dtype=np.float32)\n",
    "\n",
    "new_input_layer.set_weights([A])\n",
    "new_input_layer1.set_weights([B])\n",
    "new_input_layer2.set_weights([C])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a5792001",
   "metadata": {},
   "outputs": [],
   "source": [
    "# A: 3x4 matrix [I_2 | 0]\n",
    "Ao = np.array([\n",
    "    [1, 0, 0, 0],\n",
    "    [0, 1, 0, 0],\n",
    "    [0, 0, 1, 0]\n",
    "], dtype=np.float32)\n",
    "\n",
    "# B: 4x4 identity\n",
    "Bo = np.array(np.eye(4), dtype=np.float32)\n",
    "\n",
    "# C: 4x3 matrix [I_2; 0]\n",
    "Co = np.array([\n",
    "    [1, 0, 0],\n",
    "    [0, 1, 0],\n",
    "    [0, 0, 1],\n",
    "    [0, 0, 0]\n",
    "], dtype=np.float32)\n",
    "\n",
    "new_output_layer.set_weights([Ao])\n",
    "new_output_layer1.set_weights([Bo])\n",
    "new_output_layer2.set_weights([Co])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "4b514281",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Step 5: Custom loss function\n",
    "def custom_loss(y_true, y_pred):\n",
    "    Num = 500 #number of training target data\n",
    "    NN_out = final_model(RNN_input_final_Train) #[[t1,ca1,cb1],[t2,ca2,cb2]]\n",
    "    \n",
    "    #first term\n",
    "    loss11 =  tf.math.reduce_mean((NN_out[:Num,:,:]-RNN_output_final_Train[:Num,:,:])**2) # prediction on target\n",
    "    loss12 =  tf.math.reduce_mean((NN_out[Num:,:,:]-RNN_output_final_Train[Num:,:,:])**2) # prediction on source\n",
    "    \n",
    "    loss1 = abs(loss11 - loss12) \n",
    "    \n",
    "    #second term\n",
    "    \n",
    "    loss2 =  tf.math.reduce_mean((NN_out[Num:,:,:]- NN_S)**2) # prediction error for h and h^* on source\n",
    "    \n",
    "    #last term\n",
    "    max_abs_input = tf.reduce_max(tf.abs(input_layer_weights))\n",
    "    max_abs_output = tf.reduce_max(tf.abs(output_layer_weights))\n",
    "    loss3 =  max_abs_input * max_abs_output\n",
    "    \n",
    "    #weight\n",
    "    a = 0\n",
    "    b = 0.01\n",
    "    c = 0.01\n",
    "    \n",
    "    loss = 1.0*loss11+a*loss1+b*loss2+c*loss3#100*loss11+a*loss1+b*loss2+c*loss3\n",
    "    \n",
    "    return loss  # you can modify this to include regularization if needed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "91d741df",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " input_1 (InputLayer)        [(None, None, 4)]         0         \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, None, 8)           32        \n",
      "                                                                 \n",
      " dense_2 (Dense)             (None, None, 8)           64        \n",
      "                                                                 \n",
      " dense_3 (Dense)             (None, None, 4)           32        \n",
      "                                                                 \n",
      " sequential (Sequential)     (None, 5, 3)              387       \n",
      "                                                                 \n",
      " dense_4 (Dense)             (None, 5, 4)              12        \n",
      "                                                                 \n",
      " dense_5 (Dense)             (None, 5, 4)              16        \n",
      "                                                                 \n",
      " dense_6 (Dense)             (None, 5, 3)              12        \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 555\n",
      "Trainable params: 168\n",
      "Non-trainable params: 387\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "# Step 6: Compile\n",
    "#final_model.compile(optimizer='adam', loss=custom_loss, metrics=['mse'])\n",
    "final_model.compile(optimizer='adam', loss=custom_loss, metrics=['mse'])\n",
    "print(final_model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "a2d89c02",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iteration :0     Target train: tf.Tensor(2.4447496, shape=(), dtype=float32)\n",
      "2/2 - 1s - loss: 2.4386 - mse: 2.3372 - val_loss: 2.3551 - val_mse: 2.6190 - 1s/epoch - 506ms/step\n",
      "iteration :1     Target train: tf.Tensor(2.3451145, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 2.3401 - mse: 2.2394 - val_loss: 2.2624 - val_mse: 2.5273 - 58ms/epoch - 29ms/step\n",
      "iteration :2     Target train: tf.Tensor(2.2525005, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 2.2485 - mse: 2.1481 - val_loss: 2.1761 - val_mse: 2.4411 - 58ms/epoch - 29ms/step\n",
      "iteration :3     Target train: tf.Tensor(2.1661246, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 2.1630 - mse: 2.0589 - val_loss: 2.0953 - val_mse: 2.3599 - 57ms/epoch - 29ms/step\n",
      "iteration :4     Target train: tf.Tensor(2.0853372, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 2.0831 - mse: 1.9842 - val_loss: 2.0197 - val_mse: 2.2832 - 54ms/epoch - 27ms/step\n",
      "iteration :5     Target train: tf.Tensor(2.0097623, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 2.0083 - mse: 1.9046 - val_loss: 1.9491 - val_mse: 2.2105 - 62ms/epoch - 31ms/step\n",
      "iteration :6     Target train: tf.Tensor(1.9390905, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.9384 - mse: 1.8377 - val_loss: 1.8830 - val_mse: 2.1416 - 58ms/epoch - 29ms/step\n",
      "iteration :7     Target train: tf.Tensor(1.873014, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.8730 - mse: 1.7733 - val_loss: 1.8212 - val_mse: 2.0761 - 56ms/epoch - 28ms/step\n",
      "iteration :8     Target train: tf.Tensor(1.8111391, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.8118 - mse: 1.7151 - val_loss: 1.7630 - val_mse: 2.0137 - 61ms/epoch - 30ms/step\n",
      "iteration :9     Target train: tf.Tensor(1.7529361, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.7542 - mse: 1.6571 - val_loss: 1.7079 - val_mse: 1.9541 - 60ms/epoch - 30ms/step\n",
      "iteration :10     Target train: tf.Tensor(1.6978198, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.6995 - mse: 1.6045 - val_loss: 1.6554 - val_mse: 1.8970 - 59ms/epoch - 29ms/step\n",
      "iteration :11     Target train: tf.Tensor(1.6452199, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.6473 - mse: 1.5536 - val_loss: 1.6049 - val_mse: 1.8418 - 63ms/epoch - 32ms/step\n",
      "iteration :12     Target train: tf.Tensor(1.5946542, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.5971 - mse: 1.5028 - val_loss: 1.5562 - val_mse: 1.7879 - 59ms/epoch - 30ms/step\n",
      "iteration :13     Target train: tf.Tensor(1.5459481, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.5487 - mse: 1.4554 - val_loss: 1.5096 - val_mse: 1.7349 - 59ms/epoch - 29ms/step\n",
      "iteration :14     Target train: tf.Tensor(1.4992808, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.5025 - mse: 1.4150 - val_loss: 1.4652 - val_mse: 1.6825 - 97ms/epoch - 48ms/step\n",
      "iteration :15     Target train: tf.Tensor(1.4547812, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.4584 - mse: 1.3741 - val_loss: 1.4231 - val_mse: 1.6317 - 84ms/epoch - 42ms/step\n",
      "iteration :16     Target train: tf.Tensor(1.4126501, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.4168 - mse: 1.3344 - val_loss: 1.3837 - val_mse: 1.5835 - 69ms/epoch - 35ms/step\n",
      "iteration :17     Target train: tf.Tensor(1.3731744, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.3778 - mse: 1.2972 - val_loss: 1.3470 - val_mse: 1.5377 - 63ms/epoch - 32ms/step\n",
      "iteration :18     Target train: tf.Tensor(1.3364365, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.3415 - mse: 1.2655 - val_loss: 1.3131 - val_mse: 1.4938 - 58ms/epoch - 29ms/step\n",
      "iteration :19     Target train: tf.Tensor(1.3024428, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.3080 - mse: 1.2340 - val_loss: 1.2819 - val_mse: 1.4525 - 92ms/epoch - 46ms/step\n",
      "iteration :20     Target train: tf.Tensor(1.2712208, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.2773 - mse: 1.2062 - val_loss: 1.2535 - val_mse: 1.4145 - 88ms/epoch - 44ms/step\n",
      "iteration :21     Target train: tf.Tensor(1.242733, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.2493 - mse: 1.1807 - val_loss: 1.2276 - val_mse: 1.3800 - 63ms/epoch - 31ms/step\n",
      "iteration :22     Target train: tf.Tensor(1.2167405, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.2237 - mse: 1.1591 - val_loss: 1.2038 - val_mse: 1.3484 - 59ms/epoch - 29ms/step\n",
      "iteration :23     Target train: tf.Tensor(1.192851, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.2002 - mse: 1.1376 - val_loss: 1.1816 - val_mse: 1.3193 - 79ms/epoch - 40ms/step\n",
      "iteration :24     Target train: tf.Tensor(1.1706791, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.1783 - mse: 1.1182 - val_loss: 1.1610 - val_mse: 1.2920 - 99ms/epoch - 49ms/step\n",
      "iteration :25     Target train: tf.Tensor(1.1499308, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.1578 - mse: 1.0995 - val_loss: 1.1415 - val_mse: 1.2665 - 71ms/epoch - 36ms/step\n",
      "iteration :26     Target train: tf.Tensor(1.1304082, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.1385 - mse: 1.0818 - val_loss: 1.1231 - val_mse: 1.2424 - 59ms/epoch - 30ms/step\n",
      "iteration :27     Target train: tf.Tensor(1.1119835, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.1203 - mse: 1.0663 - val_loss: 1.1058 - val_mse: 1.2198 - 59ms/epoch - 29ms/step\n",
      "iteration :28     Target train: tf.Tensor(1.0945679, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.1031 - mse: 1.0508 - val_loss: 1.0893 - val_mse: 1.1986 - 68ms/epoch - 34ms/step\n",
      "iteration :29     Target train: tf.Tensor(1.0780878, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.0868 - mse: 1.0359 - val_loss: 1.0738 - val_mse: 1.1785 - 58ms/epoch - 29ms/step\n",
      "iteration :30     Target train: tf.Tensor(1.0624747, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.0714 - mse: 1.0211 - val_loss: 1.0590 - val_mse: 1.1594 - 56ms/epoch - 28ms/step\n",
      "iteration :31     Target train: tf.Tensor(1.0476592, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.0568 - mse: 1.0081 - val_loss: 1.0450 - val_mse: 1.1412 - 62ms/epoch - 31ms/step\n",
      "iteration :32     Target train: tf.Tensor(1.0335692, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.0428 - mse: 0.9959 - val_loss: 1.0316 - val_mse: 1.1238 - 59ms/epoch - 29ms/step\n",
      "iteration :33     Target train: tf.Tensor(1.0201322, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.0295 - mse: 0.9829 - val_loss: 1.0188 - val_mse: 1.1071 - 94ms/epoch - 47ms/step\n",
      "iteration :34     Target train: tf.Tensor(1.0072824, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.0168 - mse: 0.9728 - val_loss: 1.0065 - val_mse: 1.0910 - 84ms/epoch - 42ms/step\n",
      "iteration :35     Target train: tf.Tensor(0.9949656, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 1.0046 - mse: 0.9613 - val_loss: 0.9947 - val_mse: 1.0756 - 78ms/epoch - 39ms/step\n",
      "iteration :36     Target train: tf.Tensor(0.9831393, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.9929 - mse: 0.9503 - val_loss: 0.9834 - val_mse: 1.0609 - 98ms/epoch - 49ms/step\n",
      "iteration :37     Target train: tf.Tensor(0.97176933, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.9817 - mse: 0.9402 - val_loss: 0.9725 - val_mse: 1.0468 - 66ms/epoch - 33ms/step\n",
      "iteration :38     Target train: tf.Tensor(0.9608271, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.9708 - mse: 0.9306 - val_loss: 0.9620 - val_mse: 1.0333 - 95ms/epoch - 47ms/step\n",
      "iteration :39     Target train: tf.Tensor(0.9502859, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.9604 - mse: 0.9214 - val_loss: 0.9519 - val_mse: 1.0204 - 67ms/epoch - 34ms/step\n",
      "iteration :40     Target train: tf.Tensor(0.94012105, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.9504 - mse: 0.9121 - val_loss: 0.9422 - val_mse: 1.0081 - 60ms/epoch - 30ms/step\n",
      "iteration :41     Target train: tf.Tensor(0.9303085, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.9406 - mse: 0.9030 - val_loss: 0.9327 - val_mse: 0.9963 - 63ms/epoch - 31ms/step\n",
      "iteration :42     Target train: tf.Tensor(0.9208249, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.9313 - mse: 0.8944 - val_loss: 0.9236 - val_mse: 0.9850 - 60ms/epoch - 30ms/step\n",
      "iteration :43     Target train: tf.Tensor(0.91164786, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.9222 - mse: 0.8860 - val_loss: 0.9147 - val_mse: 0.9742 - 60ms/epoch - 30ms/step\n",
      "iteration :44     Target train: tf.Tensor(0.9027558, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.9134 - mse: 0.8779 - val_loss: 0.9062 - val_mse: 0.9638 - 59ms/epoch - 29ms/step\n",
      "iteration :45     Target train: tf.Tensor(0.8941281, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.9048 - mse: 0.8698 - val_loss: 0.8978 - val_mse: 0.9538 - 59ms/epoch - 29ms/step\n",
      "iteration :46     Target train: tf.Tensor(0.8857458, shape=(), dtype=float32)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2/2 - 0s - loss: 0.8965 - mse: 0.8617 - val_loss: 0.8897 - val_mse: 0.9441 - 65ms/epoch - 32ms/step\n",
      "iteration :47     Target train: tf.Tensor(0.87759125, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.8885 - mse: 0.8539 - val_loss: 0.8818 - val_mse: 0.9348 - 68ms/epoch - 34ms/step\n",
      "iteration :48     Target train: tf.Tensor(0.8696485, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.8806 - mse: 0.8467 - val_loss: 0.8741 - val_mse: 0.9257 - 58ms/epoch - 29ms/step\n",
      "iteration :49     Target train: tf.Tensor(0.8619031, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.8729 - mse: 0.8394 - val_loss: 0.8666 - val_mse: 0.9169 - 56ms/epoch - 28ms/step\n",
      "iteration :50     Target train: tf.Tensor(0.8543416, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.8654 - mse: 0.8322 - val_loss: 0.8593 - val_mse: 0.9083 - 65ms/epoch - 33ms/step\n",
      "iteration :51     Target train: tf.Tensor(0.84695196, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.8581 - mse: 0.8253 - val_loss: 0.8521 - val_mse: 0.8999 - 60ms/epoch - 30ms/step\n",
      "iteration :52     Target train: tf.Tensor(0.83972365, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.8510 - mse: 0.8184 - val_loss: 0.8450 - val_mse: 0.8918 - 62ms/epoch - 31ms/step\n",
      "iteration :53     Target train: tf.Tensor(0.83264655, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.8439 - mse: 0.8119 - val_loss: 0.8382 - val_mse: 0.8838 - 59ms/epoch - 30ms/step\n",
      "iteration :54     Target train: tf.Tensor(0.82571167, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.8371 - mse: 0.8052 - val_loss: 0.8314 - val_mse: 0.8760 - 58ms/epoch - 29ms/step\n",
      "iteration :55     Target train: tf.Tensor(0.81891066, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.8303 - mse: 0.7989 - val_loss: 0.8248 - val_mse: 0.8683 - 66ms/epoch - 33ms/step\n",
      "iteration :56     Target train: tf.Tensor(0.8122361, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.8237 - mse: 0.7926 - val_loss: 0.8183 - val_mse: 0.8609 - 65ms/epoch - 32ms/step\n",
      "iteration :57     Target train: tf.Tensor(0.8056806, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.8172 - mse: 0.7864 - val_loss: 0.8119 - val_mse: 0.8535 - 60ms/epoch - 30ms/step\n",
      "iteration :58     Target train: tf.Tensor(0.79923815, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.8109 - mse: 0.7803 - val_loss: 0.8056 - val_mse: 0.8463 - 62ms/epoch - 31ms/step\n",
      "iteration :59     Target train: tf.Tensor(0.7929026, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.8046 - mse: 0.7742 - val_loss: 0.7994 - val_mse: 0.8392 - 58ms/epoch - 29ms/step\n",
      "iteration :60     Target train: tf.Tensor(0.78666866, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7984 - mse: 0.7681 - val_loss: 0.7933 - val_mse: 0.8322 - 56ms/epoch - 28ms/step\n",
      "iteration :61     Target train: tf.Tensor(0.78053147, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7923 - mse: 0.7624 - val_loss: 0.7873 - val_mse: 0.8253 - 67ms/epoch - 34ms/step\n",
      "iteration :62     Target train: tf.Tensor(0.7744861, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7864 - mse: 0.7565 - val_loss: 0.7814 - val_mse: 0.8185 - 57ms/epoch - 29ms/step\n",
      "iteration :63     Target train: tf.Tensor(0.7685288, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7805 - mse: 0.7510 - val_loss: 0.7756 - val_mse: 0.8118 - 61ms/epoch - 31ms/step\n",
      "iteration :64     Target train: tf.Tensor(0.76265574, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7746 - mse: 0.7453 - val_loss: 0.7698 - val_mse: 0.8052 - 58ms/epoch - 29ms/step\n",
      "iteration :65     Target train: tf.Tensor(0.7568635, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7689 - mse: 0.7399 - val_loss: 0.7641 - val_mse: 0.7987 - 57ms/epoch - 29ms/step\n",
      "iteration :66     Target train: tf.Tensor(0.7511493, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7633 - mse: 0.7344 - val_loss: 0.7586 - val_mse: 0.7923 - 60ms/epoch - 30ms/step\n",
      "iteration :67     Target train: tf.Tensor(0.7455102, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7577 - mse: 0.7291 - val_loss: 0.7530 - val_mse: 0.7859 - 62ms/epoch - 31ms/step\n",
      "iteration :68     Target train: tf.Tensor(0.73994416, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7522 - mse: 0.7238 - val_loss: 0.7476 - val_mse: 0.7796 - 58ms/epoch - 29ms/step\n",
      "iteration :69     Target train: tf.Tensor(0.7344487, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7467 - mse: 0.7186 - val_loss: 0.7422 - val_mse: 0.7734 - 62ms/epoch - 31ms/step\n",
      "iteration :70     Target train: tf.Tensor(0.729022, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7414 - mse: 0.7134 - val_loss: 0.7369 - val_mse: 0.7672 - 63ms/epoch - 32ms/step\n",
      "iteration :71     Target train: tf.Tensor(0.72366214, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7361 - mse: 0.7084 - val_loss: 0.7316 - val_mse: 0.7611 - 60ms/epoch - 30ms/step\n",
      "iteration :72     Target train: tf.Tensor(0.71836764, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7308 - mse: 0.7034 - val_loss: 0.7265 - val_mse: 0.7551 - 85ms/epoch - 43ms/step\n",
      "iteration :73     Target train: tf.Tensor(0.71313673, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7256 - mse: 0.6984 - val_loss: 0.7213 - val_mse: 0.7491 - 105ms/epoch - 52ms/step\n",
      "iteration :74     Target train: tf.Tensor(0.7079681, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7205 - mse: 0.6934 - val_loss: 0.7163 - val_mse: 0.7432 - 77ms/epoch - 38ms/step\n",
      "iteration :75     Target train: tf.Tensor(0.70286036, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7155 - mse: 0.6886 - val_loss: 0.7113 - val_mse: 0.7373 - 71ms/epoch - 36ms/step\n",
      "iteration :76     Target train: tf.Tensor(0.6978123, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7105 - mse: 0.6839 - val_loss: 0.7063 - val_mse: 0.7315 - 70ms/epoch - 35ms/step\n",
      "iteration :77     Target train: tf.Tensor(0.692823, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7056 - mse: 0.6792 - val_loss: 0.7014 - val_mse: 0.7257 - 70ms/epoch - 35ms/step\n",
      "iteration :78     Target train: tf.Tensor(0.6878915, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.7007 - mse: 0.6745 - val_loss: 0.6966 - val_mse: 0.7200 - 77ms/epoch - 38ms/step\n",
      "iteration :79     Target train: tf.Tensor(0.683017, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6959 - mse: 0.6699 - val_loss: 0.6919 - val_mse: 0.7144 - 66ms/epoch - 33ms/step\n",
      "iteration :80     Target train: tf.Tensor(0.67819875, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6911 - mse: 0.6654 - val_loss: 0.6871 - val_mse: 0.7088 - 66ms/epoch - 33ms/step\n",
      "iteration :81     Target train: tf.Tensor(0.67343646, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6864 - mse: 0.6609 - val_loss: 0.6825 - val_mse: 0.7033 - 69ms/epoch - 35ms/step\n",
      "iteration :82     Target train: tf.Tensor(0.66872936, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6817 - mse: 0.6565 - val_loss: 0.6779 - val_mse: 0.6979 - 65ms/epoch - 32ms/step\n",
      "iteration :83     Target train: tf.Tensor(0.6640771, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6772 - mse: 0.6522 - val_loss: 0.6733 - val_mse: 0.6925 - 64ms/epoch - 32ms/step\n",
      "iteration :84     Target train: tf.Tensor(0.65947926, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6726 - mse: 0.6478 - val_loss: 0.6688 - val_mse: 0.6872 - 63ms/epoch - 31ms/step\n",
      "iteration :85     Target train: tf.Tensor(0.6549355, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6681 - mse: 0.6435 - val_loss: 0.6644 - val_mse: 0.6819 - 63ms/epoch - 32ms/step\n",
      "iteration :86     Target train: tf.Tensor(0.6504454, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6637 - mse: 0.6393 - val_loss: 0.6600 - val_mse: 0.6767 - 64ms/epoch - 32ms/step\n",
      "iteration :87     Target train: tf.Tensor(0.64600873, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6593 - mse: 0.6351 - val_loss: 0.6557 - val_mse: 0.6716 - 67ms/epoch - 33ms/step\n",
      "iteration :88     Target train: tf.Tensor(0.64162517, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6550 - mse: 0.6311 - val_loss: 0.6514 - val_mse: 0.6665 - 66ms/epoch - 33ms/step\n",
      "iteration :89     Target train: tf.Tensor(0.6372947, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6507 - mse: 0.6269 - val_loss: 0.6472 - val_mse: 0.6615 - 60ms/epoch - 30ms/step\n",
      "iteration :90     Target train: tf.Tensor(0.6330171, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6465 - mse: 0.6230 - val_loss: 0.6430 - val_mse: 0.6566 - 59ms/epoch - 30ms/step\n",
      "iteration :91     Target train: tf.Tensor(0.62879235, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6423 - mse: 0.6189 - val_loss: 0.6389 - val_mse: 0.6517 - 62ms/epoch - 31ms/step\n",
      "iteration :92     Target train: tf.Tensor(0.62462056, shape=(), dtype=float32)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2/2 - 0s - loss: 0.6382 - mse: 0.6150 - val_loss: 0.6348 - val_mse: 0.6469 - 73ms/epoch - 36ms/step\n",
      "iteration :93     Target train: tf.Tensor(0.6205018, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6342 - mse: 0.6112 - val_loss: 0.6308 - val_mse: 0.6421 - 62ms/epoch - 31ms/step\n",
      "iteration :94     Target train: tf.Tensor(0.6164363, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6302 - mse: 0.6073 - val_loss: 0.6268 - val_mse: 0.6374 - 62ms/epoch - 31ms/step\n",
      "iteration :95     Target train: tf.Tensor(0.61242443, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6262 - mse: 0.6036 - val_loss: 0.6229 - val_mse: 0.6328 - 61ms/epoch - 31ms/step\n",
      "iteration :96     Target train: tf.Tensor(0.6084663, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6223 - mse: 0.5999 - val_loss: 0.6191 - val_mse: 0.6282 - 58ms/epoch - 29ms/step\n",
      "iteration :97     Target train: tf.Tensor(0.60456264, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6185 - mse: 0.5960 - val_loss: 0.6153 - val_mse: 0.6236 - 60ms/epoch - 30ms/step\n",
      "iteration :98     Target train: tf.Tensor(0.6007137, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6147 - mse: 0.5925 - val_loss: 0.6116 - val_mse: 0.6192 - 61ms/epoch - 30ms/step\n",
      "iteration :99     Target train: tf.Tensor(0.59691983, shape=(), dtype=float32)\n",
      "2/2 - 0s - loss: 0.6110 - mse: 0.5890 - val_loss: 0.6079 - val_mse: 0.6148 - 60ms/epoch - 30ms/step\n",
      "12.162475109100342\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "for i in range(100):\n",
    "    Num = 500 #number of training target data\n",
    "     # Access the layers\n",
    "    input_layer_weights = final_model.layers[1].trainable_weights[0]  # weight matrix only\n",
    "    output_layer_weights = final_model.layers[-1].trainable_weights[0]\n",
    "    \n",
    "    #prediction performance on the target training set\n",
    "    NN_out = final_model(RNN_input_final_Train) \n",
    "    loss11 =  tf.math.reduce_mean((NN_out[:Num,:,:]-RNN_output_final_Train[:Num,:,:])**2) # prediction on target\n",
    "    print(\"iteration :\" + str(i) + \"     Target train: \" + str(loss11))\n",
    "    \n",
    "    #history = final_model.fit(RNN_input_final_Train, RNN_output_final_Train, epochs=1, batch_size=256, validation_split=0.25, verbose=2)\n",
    "    history = final_model.fit(X_train, y_train, epochs=1, batch_size=256, validation_split=0.25, verbose=2)\n",
    "    i += 1\n",
    "\n",
    "t1 = time.time()\n",
    "print(t1 - t0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "8f14a2c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Target train:  tf.Tensor(0.5931818, shape=(), dtype=float32)\n",
      "Target test:  tf.Tensor(0.6196591, shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "#prediction performance on the target training set\n",
    "NN_out = final_model(RNN_input_final_Train) \n",
    "loss11 =  tf.math.reduce_mean((NN_out[:Num,:,:]-RNN_output_final_Train[:Num,:,:])**2) # prediction on target\n",
    "print(\"Target train: \", loss11)\n",
    "\n",
    "#prediction performance on the target testing set\n",
    "NN_out = final_model(X_test) \n",
    "loss11 =  tf.math.reduce_mean((NN_out-y_test)**2) # prediction on target\n",
    "print(\"Target test: \", loss11)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21ca0c99",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8016b92",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "647daff1",
   "metadata": {},
   "source": [
    "# combine the parameters and refine the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "523000d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#calculate the input layers\n",
    "# Get first layer\n",
    "layer_I1 = final_model.layers[1]\n",
    "# Get weights\n",
    "weights_I1 = layer_I1.get_weights()\n",
    "\n",
    "# Get first layer\n",
    "layer_I2 = final_model.layers[2]\n",
    "# Get weights\n",
    "weights_I2 = layer_I2.get_weights()\n",
    "\n",
    "# Get first layer\n",
    "layer_I3 = final_model.layers[3]\n",
    "# Get weights\n",
    "weights_I3 = layer_I3.get_weights()\n",
    "\n",
    "# print(\"Weights (kernel_I1):\\n\", weights_I1)\n",
    "# print(\"Weights (kernel_I2):\\n\", weights_I2)\n",
    "# print(\"Weights (kernel_I3):\\n\", weights_I3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "b90b3b96",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights (kernel_P1):\n",
      " [array([[ 0.989428  , -0.1069353 ,  0.06000958,  0.        ],\n",
      "       [-0.01070091,  0.89324445,  0.05997863,  0.        ],\n",
      "       [ 0.05432884,  0.1212594 ,  0.9675572 ,  0.        ]],\n",
      "      dtype=float32)]\n",
      "Weights (kernel_P2):\n",
      " [array([[ 0.97891515, -0.10348026,  0.06621196,  0.        ],\n",
      "       [-0.02352717,  0.9038057 ,  0.06577406,  0.        ],\n",
      "       [ 0.06646628,  0.11868355,  0.9576942 ,  0.        ],\n",
      "       [ 0.        ,  0.        ,  0.        ,  1.        ]],\n",
      "      dtype=float32)]\n",
      "Weights (kernel_P3):\n",
      " [array([[ 0.9640754 , -0.09952678,  0.07131607],\n",
      "       [-0.0315524 ,  0.9140939 ,  0.06860425],\n",
      "       [ 0.07692752,  0.11654448,  0.9484599 ],\n",
      "       [ 0.        ,  0.        ,  0.        ]], dtype=float32)]\n"
     ]
    }
   ],
   "source": [
    "#calculate the output layers\n",
    "# Get first layer\n",
    "layer_P1 = final_model.layers[-3]\n",
    "# Get weights\n",
    "weights_P1 = layer_P1.get_weights()\n",
    "\n",
    "# Get first layer\n",
    "layer_P2 = final_model.layers[-2]\n",
    "# Get weights\n",
    "weights_P2 = layer_P2.get_weights()\n",
    "\n",
    "# Get first layer\n",
    "layer_P3 = final_model.layers[-1]\n",
    "# Get weights\n",
    "weights_P3 = layer_P3.get_weights()\n",
    "\n",
    "print(\"Weights (kernel_P1):\\n\", weights_P1)\n",
    "print(\"Weights (kernel_P2):\\n\", weights_P2)\n",
    "print(\"Weights (kernel_P3):\\n\", weights_P3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "e2032c68",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights (kernel_I1):\n",
      " [[ 0.8831213  -0.03359921 -0.09328191 -0.18412596  0.          0.\n",
      "   0.          0.        ]\n",
      " [-0.11737688  0.979868   -0.09221544 -0.11486997  0.          0.\n",
      "   0.          0.        ]\n",
      " [-0.11395823  0.10799164  0.8926967  -0.1919047   0.          0.\n",
      "   0.          0.        ]\n",
      " [ 0.10617935 -0.12544572  0.12247889  1.0062222   0.          0.\n",
      "   0.          0.        ]]\n",
      "Weights (kernel_I2):\n",
      " [[ 0.89393604 -0.07082861 -0.08531859 -0.08462644  0.          0.\n",
      "   0.          0.        ]\n",
      " [-0.12194752  0.96438533 -0.10509134 -0.04884171  0.          0.\n",
      "   0.          0.        ]\n",
      " [-0.10814632  0.04927953  0.89343256 -0.07299699  0.          0.\n",
      "   0.          0.        ]\n",
      " [ 0.14753413 -0.12382958  0.15035036  0.9829431   0.          0.\n",
      "   0.          0.        ]\n",
      " [ 0.          0.          0.          0.          1.          0.\n",
      "   0.          0.        ]\n",
      " [ 0.          0.          0.          0.          0.          1.\n",
      "   0.          0.        ]\n",
      " [ 0.          0.          0.          0.          0.          0.\n",
      "   1.          0.        ]\n",
      " [ 0.          0.          0.          0.          0.          0.\n",
      "   0.          1.        ]]\n",
      "Weights (kernel_I3):\n",
      " [[ 0.9093382  -0.09070838 -0.07040551  0.05632314]\n",
      " [-0.1267936   0.9442269  -0.11932627  0.05638725]\n",
      " [-0.09936161 -0.01422356  0.8971795   0.15424505]\n",
      " [ 0.1627264  -0.06541851  0.16510339  0.8813886 ]\n",
      " [ 0.          0.          0.          0.        ]\n",
      " [ 0.          0.          0.          0.        ]\n",
      " [ 0.          0.          0.          0.        ]\n",
      " [ 0.          0.          0.          0.        ]]\n",
      "Weights (kernel_P1):\n",
      " [[ 0.989428   -0.1069353   0.06000958  0.        ]\n",
      " [-0.01070091  0.89324445  0.05997863  0.        ]\n",
      " [ 0.05432884  0.1212594   0.9675572   0.        ]]\n",
      "Weights (kernel_P2):\n",
      " [[ 0.97891515 -0.10348026  0.06621196  0.        ]\n",
      " [-0.02352717  0.9038057   0.06577406  0.        ]\n",
      " [ 0.06646628  0.11868355  0.9576942   0.        ]\n",
      " [ 0.          0.          0.          1.        ]]\n",
      "Weights (kernel_P3):\n",
      " [[ 0.9640754  -0.09952678  0.07131607]\n",
      " [-0.0315524   0.9140939   0.06860425]\n",
      " [ 0.07692752  0.11654448  0.9484599 ]\n",
      " [ 0.          0.          0.        ]]\n",
      "Weights Final input P:\n",
      " [[ 0.6937392  -0.12412499 -0.2503755  -0.20673771]\n",
      " [-0.33682394  0.9424127  -0.29522252 -0.11545671]\n",
      " [-0.35777208  0.1971015   0.64244646 -0.10488582]\n",
      " [ 0.38720885 -0.3232431   0.41115704  0.90208733]]\n",
      "Weights Final output Q:\n",
      " [[ 0.9550165  -0.25895876  0.16634536]\n",
      " [-0.0433646   0.76168144  0.16352125]\n",
      " [ 0.17579125  0.29794118  0.91302884]]\n"
     ]
    }
   ],
   "source": [
    "print(\"Weights (kernel_I1):\\n\", weights_I1[0])\n",
    "print(\"Weights (kernel_I2):\\n\", weights_I2[0])\n",
    "print(\"Weights (kernel_I3):\\n\", weights_I3[0])\n",
    "\n",
    "print(\"Weights (kernel_P1):\\n\", weights_P1[0])\n",
    "print(\"Weights (kernel_P2):\\n\", weights_P2[0])\n",
    "print(\"Weights (kernel_P3):\\n\", weights_P3[0])\n",
    "\n",
    "#calculate the transfromation matrix\n",
    "P = weights_I1[0]@weights_I2[0]@weights_I3[0]\n",
    "Q = weights_P1[0]@weights_P2[0]@weights_P3[0]\n",
    "\n",
    "print(\"Weights Final input P:\\n\", P)\n",
    "print(\"Weights Final output Q:\\n\", Q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63fb1c5d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "c51eb080",
   "metadata": {},
   "source": [
    "# Refine the model and set the weight parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "409049a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 2: Freeze the original model's weights\n",
    "for layer in model.layers:\n",
    "    layer.trainable = False\n",
    "\n",
    "# Step 3: Create new layers\n",
    "input_layer = Input(shape=(None, 4))  # Replace input_dim with your actual input feature dimension\n",
    "new_input_layer = Dense(4, activation='linear', use_bias=False)  # New front layer\n",
    "\n",
    "# Connect the frozen original model\n",
    "x = model(new_input_layer(input_layer))\n",
    "\n",
    "# Add new output layer\n",
    "new_output_layer = Dense(3, activation='linear', use_bias=False)  # Replace output_dim with your desired final output size\n",
    "\n",
    "output = new_output_layer(x)\n",
    "# Step 4: Assemble final model\n",
    "final_model_PQ = Model(inputs=input_layer, outputs=output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "ea01d0d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#assign parameters\n",
    "W1 = np.array(P, dtype=np.float32)  # shape (input_dim, output_dim)\n",
    "W2 = np.array(Q, dtype=np.float32)  # shape (input_dim, output_dim)\n",
    "\n",
    "new_input_layer.set_weights([W1])\n",
    "new_output_layer.set_weights([W2])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc27c3d2",
   "metadata": {},
   "source": [
    "# fine tune the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "be147352",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model_1\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " input_2 (InputLayer)        [(None, None, 4)]         0         \n",
      "                                                                 \n",
      " dense_7 (Dense)             (None, None, 4)           16        \n",
      "                                                                 \n",
      " sequential (Sequential)     (None, 5, 3)              387       \n",
      "                                                                 \n",
      " dense_8 (Dense)             (None, 5, 3)              9         \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 412\n",
      "Trainable params: 412\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "# Step 1: ynFreeze the model's weights\n",
    "for layer in final_model.layers:\n",
    "    layer.trainable = True\n",
    "    \n",
    "# Step 2: Compile with mse as loss function\n",
    "final_model_PQ.compile(optimizer='adam', loss='mse', metrics=['mse'])\n",
    "print(final_model_PQ.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "55fbdca2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/400\n",
      "2/2 - 0s - loss: 0.5799 - mse: 0.5799 - val_loss: 0.5718 - val_mse: 0.5718 - 493ms/epoch - 247ms/step\n",
      "Epoch 2/400\n",
      "2/2 - 0s - loss: 0.5427 - mse: 0.5427 - val_loss: 0.5363 - val_mse: 0.5363 - 16ms/epoch - 8ms/step\n",
      "Epoch 3/400\n",
      "2/2 - 0s - loss: 0.5159 - mse: 0.5159 - val_loss: 0.5068 - val_mse: 0.5068 - 16ms/epoch - 8ms/step\n",
      "Epoch 4/400\n",
      "2/2 - 0s - loss: 0.4933 - mse: 0.4933 - val_loss: 0.4821 - val_mse: 0.4821 - 18ms/epoch - 9ms/step\n",
      "Epoch 5/400\n",
      "2/2 - 0s - loss: 0.4752 - mse: 0.4752 - val_loss: 0.4592 - val_mse: 0.4592 - 23ms/epoch - 11ms/step\n",
      "Epoch 6/400\n",
      "2/2 - 0s - loss: 0.4581 - mse: 0.4581 - val_loss: 0.4389 - val_mse: 0.4389 - 18ms/epoch - 9ms/step\n",
      "Epoch 7/400\n",
      "2/2 - 0s - loss: 0.4431 - mse: 0.4431 - val_loss: 0.4212 - val_mse: 0.4212 - 18ms/epoch - 9ms/step\n",
      "Epoch 8/400\n",
      "2/2 - 0s - loss: 0.4304 - mse: 0.4304 - val_loss: 0.4063 - val_mse: 0.4063 - 16ms/epoch - 8ms/step\n",
      "Epoch 9/400\n",
      "2/2 - 0s - loss: 0.4191 - mse: 0.4191 - val_loss: 0.3943 - val_mse: 0.3943 - 18ms/epoch - 9ms/step\n",
      "Epoch 10/400\n",
      "2/2 - 0s - loss: 0.4098 - mse: 0.4098 - val_loss: 0.3847 - val_mse: 0.3847 - 16ms/epoch - 8ms/step\n",
      "Epoch 11/400\n",
      "2/2 - 0s - loss: 0.4024 - mse: 0.4024 - val_loss: 0.3767 - val_mse: 0.3767 - 22ms/epoch - 11ms/step\n",
      "Epoch 12/400\n",
      "2/2 - 0s - loss: 0.3959 - mse: 0.3959 - val_loss: 0.3702 - val_mse: 0.3702 - 18ms/epoch - 9ms/step\n",
      "Epoch 13/400\n",
      "2/2 - 0s - loss: 0.3902 - mse: 0.3902 - val_loss: 0.3646 - val_mse: 0.3646 - 17ms/epoch - 9ms/step\n",
      "Epoch 14/400\n",
      "2/2 - 0s - loss: 0.3853 - mse: 0.3853 - val_loss: 0.3595 - val_mse: 0.3595 - 19ms/epoch - 9ms/step\n",
      "Epoch 15/400\n",
      "2/2 - 0s - loss: 0.3807 - mse: 0.3807 - val_loss: 0.3549 - val_mse: 0.3549 - 16ms/epoch - 8ms/step\n",
      "Epoch 16/400\n",
      "2/2 - 0s - loss: 0.3762 - mse: 0.3762 - val_loss: 0.3507 - val_mse: 0.3507 - 18ms/epoch - 9ms/step\n",
      "Epoch 17/400\n",
      "2/2 - 0s - loss: 0.3720 - mse: 0.3720 - val_loss: 0.3469 - val_mse: 0.3469 - 20ms/epoch - 10ms/step\n",
      "Epoch 18/400\n",
      "2/2 - 0s - loss: 0.3684 - mse: 0.3684 - val_loss: 0.3434 - val_mse: 0.3434 - 20ms/epoch - 10ms/step\n",
      "Epoch 19/400\n",
      "2/2 - 0s - loss: 0.3649 - mse: 0.3649 - val_loss: 0.3401 - val_mse: 0.3401 - 17ms/epoch - 9ms/step\n",
      "Epoch 20/400\n",
      "2/2 - 0s - loss: 0.3615 - mse: 0.3615 - val_loss: 0.3368 - val_mse: 0.3368 - 18ms/epoch - 9ms/step\n",
      "Epoch 21/400\n",
      "2/2 - 0s - loss: 0.3583 - mse: 0.3583 - val_loss: 0.3335 - val_mse: 0.3335 - 16ms/epoch - 8ms/step\n",
      "Epoch 22/400\n",
      "2/2 - 0s - loss: 0.3551 - mse: 0.3551 - val_loss: 0.3302 - val_mse: 0.3302 - 18ms/epoch - 9ms/step\n",
      "Epoch 23/400\n",
      "2/2 - 0s - loss: 0.3521 - mse: 0.3521 - val_loss: 0.3270 - val_mse: 0.3270 - 16ms/epoch - 8ms/step\n",
      "Epoch 24/400\n",
      "2/2 - 0s - loss: 0.3490 - mse: 0.3490 - val_loss: 0.3239 - val_mse: 0.3239 - 35ms/epoch - 17ms/step\n",
      "Epoch 25/400\n",
      "2/2 - 0s - loss: 0.3460 - mse: 0.3460 - val_loss: 0.3207 - val_mse: 0.3207 - 18ms/epoch - 9ms/step\n",
      "Epoch 26/400\n",
      "2/2 - 0s - loss: 0.3430 - mse: 0.3430 - val_loss: 0.3176 - val_mse: 0.3176 - 17ms/epoch - 8ms/step\n",
      "Epoch 27/400\n",
      "2/2 - 0s - loss: 0.3401 - mse: 0.3401 - val_loss: 0.3146 - val_mse: 0.3146 - 18ms/epoch - 9ms/step\n",
      "Epoch 28/400\n",
      "2/2 - 0s - loss: 0.3372 - mse: 0.3372 - val_loss: 0.3116 - val_mse: 0.3116 - 18ms/epoch - 9ms/step\n",
      "Epoch 29/400\n",
      "2/2 - 0s - loss: 0.3343 - mse: 0.3343 - val_loss: 0.3086 - val_mse: 0.3086 - 17ms/epoch - 9ms/step\n",
      "Epoch 30/400\n",
      "2/2 - 0s - loss: 0.3315 - mse: 0.3315 - val_loss: 0.3056 - val_mse: 0.3056 - 18ms/epoch - 9ms/step\n",
      "Epoch 31/400\n",
      "2/2 - 0s - loss: 0.3287 - mse: 0.3287 - val_loss: 0.3027 - val_mse: 0.3027 - 17ms/epoch - 8ms/step\n",
      "Epoch 32/400\n",
      "2/2 - 0s - loss: 0.3259 - mse: 0.3259 - val_loss: 0.2998 - val_mse: 0.2998 - 20ms/epoch - 10ms/step\n",
      "Epoch 33/400\n",
      "2/2 - 0s - loss: 0.3231 - mse: 0.3231 - val_loss: 0.2968 - val_mse: 0.2968 - 18ms/epoch - 9ms/step\n",
      "Epoch 34/400\n",
      "2/2 - 0s - loss: 0.3203 - mse: 0.3203 - val_loss: 0.2939 - val_mse: 0.2939 - 19ms/epoch - 9ms/step\n",
      "Epoch 35/400\n",
      "2/2 - 0s - loss: 0.3175 - mse: 0.3175 - val_loss: 0.2910 - val_mse: 0.2910 - 20ms/epoch - 10ms/step\n",
      "Epoch 36/400\n",
      "2/2 - 0s - loss: 0.3147 - mse: 0.3147 - val_loss: 0.2882 - val_mse: 0.2882 - 22ms/epoch - 11ms/step\n",
      "Epoch 37/400\n",
      "2/2 - 0s - loss: 0.3118 - mse: 0.3118 - val_loss: 0.2853 - val_mse: 0.2853 - 18ms/epoch - 9ms/step\n",
      "Epoch 38/400\n",
      "2/2 - 0s - loss: 0.3089 - mse: 0.3089 - val_loss: 0.2824 - val_mse: 0.2824 - 16ms/epoch - 8ms/step\n",
      "Epoch 39/400\n",
      "2/2 - 0s - loss: 0.3060 - mse: 0.3060 - val_loss: 0.2795 - val_mse: 0.2795 - 16ms/epoch - 8ms/step\n",
      "Epoch 40/400\n",
      "2/2 - 0s - loss: 0.3029 - mse: 0.3029 - val_loss: 0.2765 - val_mse: 0.2765 - 15ms/epoch - 7ms/step\n",
      "Epoch 41/400\n",
      "2/2 - 0s - loss: 0.2999 - mse: 0.2999 - val_loss: 0.2735 - val_mse: 0.2735 - 18ms/epoch - 9ms/step\n",
      "Epoch 42/400\n",
      "2/2 - 0s - loss: 0.2968 - mse: 0.2968 - val_loss: 0.2704 - val_mse: 0.2704 - 16ms/epoch - 8ms/step\n",
      "Epoch 43/400\n",
      "2/2 - 0s - loss: 0.2935 - mse: 0.2935 - val_loss: 0.2672 - val_mse: 0.2672 - 18ms/epoch - 9ms/step\n",
      "Epoch 44/400\n",
      "2/2 - 0s - loss: 0.2902 - mse: 0.2902 - val_loss: 0.2639 - val_mse: 0.2639 - 16ms/epoch - 8ms/step\n",
      "Epoch 45/400\n",
      "2/2 - 0s - loss: 0.2869 - mse: 0.2869 - val_loss: 0.2605 - val_mse: 0.2605 - 23ms/epoch - 12ms/step\n",
      "Epoch 46/400\n",
      "2/2 - 0s - loss: 0.2834 - mse: 0.2834 - val_loss: 0.2571 - val_mse: 0.2571 - 18ms/epoch - 9ms/step\n",
      "Epoch 47/400\n",
      "2/2 - 0s - loss: 0.2798 - mse: 0.2798 - val_loss: 0.2536 - val_mse: 0.2536 - 18ms/epoch - 9ms/step\n",
      "Epoch 48/400\n",
      "2/2 - 0s - loss: 0.2760 - mse: 0.2760 - val_loss: 0.2500 - val_mse: 0.2500 - 19ms/epoch - 10ms/step\n",
      "Epoch 49/400\n",
      "2/2 - 0s - loss: 0.2723 - mse: 0.2723 - val_loss: 0.2463 - val_mse: 0.2463 - 24ms/epoch - 12ms/step\n",
      "Epoch 50/400\n",
      "2/2 - 0s - loss: 0.2683 - mse: 0.2683 - val_loss: 0.2425 - val_mse: 0.2425 - 19ms/epoch - 10ms/step\n",
      "Epoch 51/400\n",
      "2/2 - 0s - loss: 0.2641 - mse: 0.2641 - val_loss: 0.2386 - val_mse: 0.2386 - 20ms/epoch - 10ms/step\n",
      "Epoch 52/400\n",
      "2/2 - 0s - loss: 0.2600 - mse: 0.2600 - val_loss: 0.2345 - val_mse: 0.2345 - 22ms/epoch - 11ms/step\n",
      "Epoch 53/400\n",
      "2/2 - 0s - loss: 0.2556 - mse: 0.2556 - val_loss: 0.2303 - val_mse: 0.2303 - 18ms/epoch - 9ms/step\n",
      "Epoch 54/400\n",
      "2/2 - 0s - loss: 0.2510 - mse: 0.2510 - val_loss: 0.2259 - val_mse: 0.2259 - 17ms/epoch - 9ms/step\n",
      "Epoch 55/400\n",
      "2/2 - 0s - loss: 0.2464 - mse: 0.2464 - val_loss: 0.2214 - val_mse: 0.2214 - 19ms/epoch - 10ms/step\n",
      "Epoch 56/400\n",
      "2/2 - 0s - loss: 0.2413 - mse: 0.2413 - val_loss: 0.2168 - val_mse: 0.2168 - 18ms/epoch - 9ms/step\n",
      "Epoch 57/400\n",
      "2/2 - 0s - loss: 0.2364 - mse: 0.2364 - val_loss: 0.2118 - val_mse: 0.2118 - 22ms/epoch - 11ms/step\n",
      "Epoch 58/400\n",
      "2/2 - 0s - loss: 0.2308 - mse: 0.2308 - val_loss: 0.2068 - val_mse: 0.2068 - 18ms/epoch - 9ms/step\n",
      "Epoch 59/400\n",
      "2/2 - 0s - loss: 0.2255 - mse: 0.2255 - val_loss: 0.2015 - val_mse: 0.2015 - 17ms/epoch - 8ms/step\n",
      "Epoch 60/400\n",
      "2/2 - 0s - loss: 0.2195 - mse: 0.2195 - val_loss: 0.1961 - val_mse: 0.1961 - 19ms/epoch - 9ms/step\n",
      "Epoch 61/400\n",
      "2/2 - 0s - loss: 0.2136 - mse: 0.2136 - val_loss: 0.1905 - val_mse: 0.1905 - 16ms/epoch - 8ms/step\n",
      "Epoch 62/400\n",
      "2/2 - 0s - loss: 0.2073 - mse: 0.2073 - val_loss: 0.1848 - val_mse: 0.1848 - 18ms/epoch - 9ms/step\n",
      "Epoch 63/400\n",
      "2/2 - 0s - loss: 0.2009 - mse: 0.2009 - val_loss: 0.1788 - val_mse: 0.1788 - 15ms/epoch - 8ms/step\n",
      "Epoch 64/400\n",
      "2/2 - 0s - loss: 0.1944 - mse: 0.1944 - val_loss: 0.1728 - val_mse: 0.1728 - 22ms/epoch - 11ms/step\n",
      "Epoch 65/400\n",
      "2/2 - 0s - loss: 0.1874 - mse: 0.1874 - val_loss: 0.1667 - val_mse: 0.1667 - 18ms/epoch - 9ms/step\n",
      "Epoch 66/400\n",
      "2/2 - 0s - loss: 0.1805 - mse: 0.1805 - val_loss: 0.1604 - val_mse: 0.1604 - 17ms/epoch - 8ms/step\n",
      "Epoch 67/400\n",
      "2/2 - 0s - loss: 0.1732 - mse: 0.1732 - val_loss: 0.1540 - val_mse: 0.1540 - 19ms/epoch - 9ms/step\n",
      "Epoch 68/400\n",
      "2/2 - 0s - loss: 0.1663 - mse: 0.1663 - val_loss: 0.1477 - val_mse: 0.1477 - 16ms/epoch - 8ms/step\n",
      "Epoch 69/400\n",
      "2/2 - 0s - loss: 0.1589 - mse: 0.1589 - val_loss: 0.1414 - val_mse: 0.1414 - 18ms/epoch - 9ms/step\n",
      "Epoch 70/400\n",
      "2/2 - 0s - loss: 0.1515 - mse: 0.1515 - val_loss: 0.1351 - val_mse: 0.1351 - 15ms/epoch - 8ms/step\n",
      "Epoch 71/400\n",
      "2/2 - 0s - loss: 0.1441 - mse: 0.1441 - val_loss: 0.1288 - val_mse: 0.1288 - 23ms/epoch - 11ms/step\n",
      "Epoch 72/400\n",
      "2/2 - 0s - loss: 0.1368 - mse: 0.1368 - val_loss: 0.1226 - val_mse: 0.1226 - 18ms/epoch - 9ms/step\n",
      "Epoch 73/400\n",
      "2/2 - 0s - loss: 0.1299 - mse: 0.1299 - val_loss: 0.1165 - val_mse: 0.1165 - 18ms/epoch - 9ms/step\n",
      "Epoch 74/400\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2/2 - 0s - loss: 0.1226 - mse: 0.1226 - val_loss: 0.1105 - val_mse: 0.1105 - 19ms/epoch - 9ms/step\n",
      "Epoch 75/400\n",
      "2/2 - 0s - loss: 0.1158 - mse: 0.1158 - val_loss: 0.1048 - val_mse: 0.1048 - 23ms/epoch - 12ms/step\n",
      "Epoch 76/400\n",
      "2/2 - 0s - loss: 0.1093 - mse: 0.1093 - val_loss: 0.0994 - val_mse: 0.0994 - 18ms/epoch - 9ms/step\n",
      "Epoch 77/400\n",
      "2/2 - 0s - loss: 0.1031 - mse: 0.1031 - val_loss: 0.0942 - val_mse: 0.0942 - 17ms/epoch - 8ms/step\n",
      "Epoch 78/400\n",
      "2/2 - 0s - loss: 0.0972 - mse: 0.0972 - val_loss: 0.0893 - val_mse: 0.0893 - 20ms/epoch - 10ms/step\n",
      "Epoch 79/400\n",
      "2/2 - 0s - loss: 0.0916 - mse: 0.0916 - val_loss: 0.0848 - val_mse: 0.0848 - 17ms/epoch - 8ms/step\n",
      "Epoch 80/400\n",
      "2/2 - 0s - loss: 0.0864 - mse: 0.0864 - val_loss: 0.0806 - val_mse: 0.0806 - 18ms/epoch - 9ms/step\n",
      "Epoch 81/400\n",
      "2/2 - 0s - loss: 0.0817 - mse: 0.0817 - val_loss: 0.0766 - val_mse: 0.0766 - 16ms/epoch - 8ms/step\n",
      "Epoch 82/400\n",
      "2/2 - 0s - loss: 0.0773 - mse: 0.0773 - val_loss: 0.0729 - val_mse: 0.0729 - 20ms/epoch - 10ms/step\n",
      "Epoch 83/400\n",
      "2/2 - 0s - loss: 0.0733 - mse: 0.0733 - val_loss: 0.0696 - val_mse: 0.0696 - 17ms/epoch - 9ms/step\n",
      "Epoch 84/400\n",
      "2/2 - 0s - loss: 0.0694 - mse: 0.0694 - val_loss: 0.0665 - val_mse: 0.0665 - 17ms/epoch - 8ms/step\n",
      "Epoch 85/400\n",
      "2/2 - 0s - loss: 0.0662 - mse: 0.0662 - val_loss: 0.0635 - val_mse: 0.0635 - 18ms/epoch - 9ms/step\n",
      "Epoch 86/400\n",
      "2/2 - 0s - loss: 0.0631 - mse: 0.0631 - val_loss: 0.0608 - val_mse: 0.0608 - 17ms/epoch - 8ms/step\n",
      "Epoch 87/400\n",
      "2/2 - 0s - loss: 0.0602 - mse: 0.0602 - val_loss: 0.0582 - val_mse: 0.0582 - 18ms/epoch - 9ms/step\n",
      "Epoch 88/400\n",
      "2/2 - 0s - loss: 0.0576 - mse: 0.0576 - val_loss: 0.0558 - val_mse: 0.0558 - 16ms/epoch - 8ms/step\n",
      "Epoch 89/400\n",
      "2/2 - 0s - loss: 0.0553 - mse: 0.0553 - val_loss: 0.0536 - val_mse: 0.0536 - 22ms/epoch - 11ms/step\n",
      "Epoch 90/400\n",
      "2/2 - 0s - loss: 0.0529 - mse: 0.0529 - val_loss: 0.0514 - val_mse: 0.0514 - 18ms/epoch - 9ms/step\n",
      "Epoch 91/400\n",
      "2/2 - 0s - loss: 0.0508 - mse: 0.0508 - val_loss: 0.0496 - val_mse: 0.0496 - 17ms/epoch - 9ms/step\n",
      "Epoch 92/400\n",
      "2/2 - 0s - loss: 0.0489 - mse: 0.0489 - val_loss: 0.0477 - val_mse: 0.0477 - 16ms/epoch - 8ms/step\n",
      "Epoch 93/400\n",
      "2/2 - 0s - loss: 0.0471 - mse: 0.0471 - val_loss: 0.0460 - val_mse: 0.0460 - 17ms/epoch - 9ms/step\n",
      "Epoch 94/400\n",
      "2/2 - 0s - loss: 0.0454 - mse: 0.0454 - val_loss: 0.0444 - val_mse: 0.0444 - 17ms/epoch - 8ms/step\n",
      "Epoch 95/400\n",
      "2/2 - 0s - loss: 0.0439 - mse: 0.0439 - val_loss: 0.0428 - val_mse: 0.0428 - 18ms/epoch - 9ms/step\n",
      "Epoch 96/400\n",
      "2/2 - 0s - loss: 0.0424 - mse: 0.0424 - val_loss: 0.0415 - val_mse: 0.0415 - 16ms/epoch - 8ms/step\n",
      "Epoch 97/400\n",
      "2/2 - 0s - loss: 0.0410 - mse: 0.0410 - val_loss: 0.0402 - val_mse: 0.0402 - 21ms/epoch - 11ms/step\n",
      "Epoch 98/400\n",
      "2/2 - 0s - loss: 0.0397 - mse: 0.0397 - val_loss: 0.0390 - val_mse: 0.0390 - 19ms/epoch - 10ms/step\n",
      "Epoch 99/400\n",
      "2/2 - 0s - loss: 0.0385 - mse: 0.0385 - val_loss: 0.0378 - val_mse: 0.0378 - 17ms/epoch - 9ms/step\n",
      "Epoch 100/400\n",
      "2/2 - 0s - loss: 0.0373 - mse: 0.0373 - val_loss: 0.0366 - val_mse: 0.0366 - 17ms/epoch - 9ms/step\n",
      "Epoch 101/400\n",
      "2/2 - 0s - loss: 0.0362 - mse: 0.0362 - val_loss: 0.0355 - val_mse: 0.0355 - 19ms/epoch - 10ms/step\n",
      "Epoch 102/400\n",
      "2/2 - 0s - loss: 0.0351 - mse: 0.0351 - val_loss: 0.0345 - val_mse: 0.0345 - 22ms/epoch - 11ms/step\n",
      "Epoch 103/400\n",
      "2/2 - 0s - loss: 0.0341 - mse: 0.0341 - val_loss: 0.0337 - val_mse: 0.0337 - 19ms/epoch - 9ms/step\n",
      "Epoch 104/400\n",
      "2/2 - 0s - loss: 0.0332 - mse: 0.0332 - val_loss: 0.0328 - val_mse: 0.0328 - 17ms/epoch - 8ms/step\n",
      "Epoch 105/400\n",
      "2/2 - 0s - loss: 0.0323 - mse: 0.0323 - val_loss: 0.0319 - val_mse: 0.0319 - 17ms/epoch - 8ms/step\n",
      "Epoch 106/400\n",
      "2/2 - 0s - loss: 0.0315 - mse: 0.0315 - val_loss: 0.0311 - val_mse: 0.0311 - 19ms/epoch - 9ms/step\n",
      "Epoch 107/400\n",
      "2/2 - 0s - loss: 0.0307 - mse: 0.0307 - val_loss: 0.0303 - val_mse: 0.0303 - 18ms/epoch - 9ms/step\n",
      "Epoch 108/400\n",
      "2/2 - 0s - loss: 0.0298 - mse: 0.0298 - val_loss: 0.0295 - val_mse: 0.0295 - 22ms/epoch - 11ms/step\n",
      "Epoch 109/400\n",
      "2/2 - 0s - loss: 0.0291 - mse: 0.0291 - val_loss: 0.0288 - val_mse: 0.0288 - 18ms/epoch - 9ms/step\n",
      "Epoch 110/400\n",
      "2/2 - 0s - loss: 0.0284 - mse: 0.0284 - val_loss: 0.0281 - val_mse: 0.0281 - 16ms/epoch - 8ms/step\n",
      "Epoch 111/400\n",
      "2/2 - 0s - loss: 0.0278 - mse: 0.0278 - val_loss: 0.0275 - val_mse: 0.0275 - 17ms/epoch - 8ms/step\n",
      "Epoch 112/400\n",
      "2/2 - 0s - loss: 0.0271 - mse: 0.0271 - val_loss: 0.0269 - val_mse: 0.0269 - 18ms/epoch - 9ms/step\n",
      "Epoch 113/400\n",
      "2/2 - 0s - loss: 0.0265 - mse: 0.0265 - val_loss: 0.0264 - val_mse: 0.0264 - 17ms/epoch - 8ms/step\n",
      "Epoch 114/400\n",
      "2/2 - 0s - loss: 0.0259 - mse: 0.0259 - val_loss: 0.0258 - val_mse: 0.0258 - 18ms/epoch - 9ms/step\n",
      "Epoch 115/400\n",
      "2/2 - 0s - loss: 0.0253 - mse: 0.0253 - val_loss: 0.0253 - val_mse: 0.0253 - 16ms/epoch - 8ms/step\n",
      "Epoch 116/400\n",
      "2/2 - 0s - loss: 0.0248 - mse: 0.0248 - val_loss: 0.0247 - val_mse: 0.0247 - 22ms/epoch - 11ms/step\n",
      "Epoch 117/400\n",
      "2/2 - 0s - loss: 0.0242 - mse: 0.0242 - val_loss: 0.0241 - val_mse: 0.0241 - 18ms/epoch - 9ms/step\n",
      "Epoch 118/400\n",
      "2/2 - 0s - loss: 0.0237 - mse: 0.0237 - val_loss: 0.0236 - val_mse: 0.0236 - 16ms/epoch - 8ms/step\n",
      "Epoch 119/400\n",
      "2/2 - 0s - loss: 0.0232 - mse: 0.0232 - val_loss: 0.0232 - val_mse: 0.0232 - 17ms/epoch - 9ms/step\n",
      "Epoch 120/400\n",
      "2/2 - 0s - loss: 0.0227 - mse: 0.0227 - val_loss: 0.0228 - val_mse: 0.0228 - 16ms/epoch - 8ms/step\n",
      "Epoch 121/400\n",
      "2/2 - 0s - loss: 0.0223 - mse: 0.0223 - val_loss: 0.0223 - val_mse: 0.0223 - 16ms/epoch - 8ms/step\n",
      "Epoch 122/400\n",
      "2/2 - 0s - loss: 0.0218 - mse: 0.0218 - val_loss: 0.0219 - val_mse: 0.0219 - 15ms/epoch - 8ms/step\n",
      "Epoch 123/400\n",
      "2/2 - 0s - loss: 0.0214 - mse: 0.0214 - val_loss: 0.0214 - val_mse: 0.0214 - 18ms/epoch - 9ms/step\n",
      "Epoch 124/400\n",
      "2/2 - 0s - loss: 0.0210 - mse: 0.0210 - val_loss: 0.0210 - val_mse: 0.0210 - 16ms/epoch - 8ms/step\n",
      "Epoch 125/400\n",
      "2/2 - 0s - loss: 0.0206 - mse: 0.0206 - val_loss: 0.0206 - val_mse: 0.0206 - 18ms/epoch - 9ms/step\n",
      "Epoch 126/400\n",
      "2/2 - 0s - loss: 0.0202 - mse: 0.0202 - val_loss: 0.0202 - val_mse: 0.0202 - 15ms/epoch - 8ms/step\n",
      "Epoch 127/400\n",
      "2/2 - 0s - loss: 0.0198 - mse: 0.0198 - val_loss: 0.0199 - val_mse: 0.0199 - 23ms/epoch - 11ms/step\n",
      "Epoch 128/400\n",
      "2/2 - 0s - loss: 0.0194 - mse: 0.0194 - val_loss: 0.0196 - val_mse: 0.0196 - 18ms/epoch - 9ms/step\n",
      "Epoch 129/400\n",
      "2/2 - 0s - loss: 0.0191 - mse: 0.0191 - val_loss: 0.0192 - val_mse: 0.0192 - 29ms/epoch - 14ms/step\n",
      "Epoch 130/400\n",
      "2/2 - 0s - loss: 0.0187 - mse: 0.0187 - val_loss: 0.0189 - val_mse: 0.0189 - 18ms/epoch - 9ms/step\n",
      "Epoch 131/400\n",
      "2/2 - 0s - loss: 0.0184 - mse: 0.0184 - val_loss: 0.0185 - val_mse: 0.0185 - 15ms/epoch - 8ms/step\n",
      "Epoch 132/400\n",
      "2/2 - 0s - loss: 0.0180 - mse: 0.0180 - val_loss: 0.0181 - val_mse: 0.0181 - 17ms/epoch - 9ms/step\n",
      "Epoch 133/400\n",
      "2/2 - 0s - loss: 0.0177 - mse: 0.0177 - val_loss: 0.0178 - val_mse: 0.0178 - 17ms/epoch - 8ms/step\n",
      "Epoch 134/400\n",
      "2/2 - 0s - loss: 0.0174 - mse: 0.0174 - val_loss: 0.0175 - val_mse: 0.0175 - 19ms/epoch - 9ms/step\n",
      "Epoch 135/400\n",
      "2/2 - 0s - loss: 0.0171 - mse: 0.0171 - val_loss: 0.0173 - val_mse: 0.0173 - 16ms/epoch - 8ms/step\n",
      "Epoch 136/400\n",
      "2/2 - 0s - loss: 0.0168 - mse: 0.0168 - val_loss: 0.0170 - val_mse: 0.0170 - 18ms/epoch - 9ms/step\n",
      "Epoch 137/400\n",
      "2/2 - 0s - loss: 0.0165 - mse: 0.0165 - val_loss: 0.0167 - val_mse: 0.0167 - 16ms/epoch - 8ms/step\n",
      "Epoch 138/400\n",
      "2/2 - 0s - loss: 0.0162 - mse: 0.0162 - val_loss: 0.0164 - val_mse: 0.0164 - 22ms/epoch - 11ms/step\n",
      "Epoch 139/400\n",
      "2/2 - 0s - loss: 0.0159 - mse: 0.0159 - val_loss: 0.0161 - val_mse: 0.0161 - 18ms/epoch - 9ms/step\n",
      "Epoch 140/400\n",
      "2/2 - 0s - loss: 0.0157 - mse: 0.0157 - val_loss: 0.0159 - val_mse: 0.0159 - 17ms/epoch - 9ms/step\n",
      "Epoch 141/400\n",
      "2/2 - 0s - loss: 0.0154 - mse: 0.0154 - val_loss: 0.0156 - val_mse: 0.0156 - 17ms/epoch - 9ms/step\n",
      "Epoch 142/400\n",
      "2/2 - 0s - loss: 0.0151 - mse: 0.0151 - val_loss: 0.0154 - val_mse: 0.0154 - 16ms/epoch - 8ms/step\n",
      "Epoch 143/400\n",
      "2/2 - 0s - loss: 0.0149 - mse: 0.0149 - val_loss: 0.0151 - val_mse: 0.0151 - 15ms/epoch - 8ms/step\n",
      "Epoch 144/400\n",
      "2/2 - 0s - loss: 0.0147 - mse: 0.0147 - val_loss: 0.0149 - val_mse: 0.0149 - 15ms/epoch - 8ms/step\n",
      "Epoch 145/400\n",
      "2/2 - 0s - loss: 0.0144 - mse: 0.0144 - val_loss: 0.0146 - val_mse: 0.0146 - 18ms/epoch - 9ms/step\n",
      "Epoch 146/400\n",
      "2/2 - 0s - loss: 0.0142 - mse: 0.0142 - val_loss: 0.0144 - val_mse: 0.0144 - 16ms/epoch - 8ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 147/400\n",
      "2/2 - 0s - loss: 0.0140 - mse: 0.0140 - val_loss: 0.0142 - val_mse: 0.0142 - 18ms/epoch - 9ms/step\n",
      "Epoch 148/400\n",
      "2/2 - 0s - loss: 0.0138 - mse: 0.0138 - val_loss: 0.0140 - val_mse: 0.0140 - 16ms/epoch - 8ms/step\n",
      "Epoch 149/400\n",
      "2/2 - 0s - loss: 0.0135 - mse: 0.0135 - val_loss: 0.0137 - val_mse: 0.0137 - 21ms/epoch - 11ms/step\n",
      "Epoch 150/400\n",
      "2/2 - 0s - loss: 0.0134 - mse: 0.0134 - val_loss: 0.0135 - val_mse: 0.0135 - 19ms/epoch - 9ms/step\n",
      "Epoch 151/400\n",
      "2/2 - 0s - loss: 0.0131 - mse: 0.0131 - val_loss: 0.0133 - val_mse: 0.0133 - 16ms/epoch - 8ms/step\n",
      "Epoch 152/400\n",
      "2/2 - 0s - loss: 0.0129 - mse: 0.0129 - val_loss: 0.0131 - val_mse: 0.0131 - 17ms/epoch - 8ms/step\n",
      "Epoch 153/400\n",
      "2/2 - 0s - loss: 0.0127 - mse: 0.0127 - val_loss: 0.0129 - val_mse: 0.0129 - 17ms/epoch - 9ms/step\n",
      "Epoch 154/400\n",
      "2/2 - 0s - loss: 0.0125 - mse: 0.0125 - val_loss: 0.0127 - val_mse: 0.0127 - 18ms/epoch - 9ms/step\n",
      "Epoch 155/400\n",
      "2/2 - 0s - loss: 0.0123 - mse: 0.0123 - val_loss: 0.0125 - val_mse: 0.0125 - 20ms/epoch - 10ms/step\n",
      "Epoch 156/400\n",
      "2/2 - 0s - loss: 0.0122 - mse: 0.0122 - val_loss: 0.0124 - val_mse: 0.0124 - 19ms/epoch - 10ms/step\n",
      "Epoch 157/400\n",
      "2/2 - 0s - loss: 0.0120 - mse: 0.0120 - val_loss: 0.0122 - val_mse: 0.0122 - 17ms/epoch - 9ms/step\n",
      "Epoch 158/400\n",
      "2/2 - 0s - loss: 0.0118 - mse: 0.0118 - val_loss: 0.0120 - val_mse: 0.0120 - 16ms/epoch - 8ms/step\n",
      "Epoch 159/400\n",
      "2/2 - 0s - loss: 0.0116 - mse: 0.0116 - val_loss: 0.0119 - val_mse: 0.0119 - 17ms/epoch - 9ms/step\n",
      "Epoch 160/400\n",
      "2/2 - 0s - loss: 0.0115 - mse: 0.0115 - val_loss: 0.0117 - val_mse: 0.0117 - 19ms/epoch - 9ms/step\n",
      "Epoch 161/400\n",
      "2/2 - 0s - loss: 0.0113 - mse: 0.0113 - val_loss: 0.0115 - val_mse: 0.0115 - 21ms/epoch - 10ms/step\n",
      "Epoch 162/400\n",
      "2/2 - 0s - loss: 0.0112 - mse: 0.0112 - val_loss: 0.0114 - val_mse: 0.0114 - 20ms/epoch - 10ms/step\n",
      "Epoch 163/400\n",
      "2/2 - 0s - loss: 0.0110 - mse: 0.0110 - val_loss: 0.0112 - val_mse: 0.0112 - 17ms/epoch - 8ms/step\n",
      "Epoch 164/400\n",
      "2/2 - 0s - loss: 0.0108 - mse: 0.0108 - val_loss: 0.0111 - val_mse: 0.0111 - 17ms/epoch - 8ms/step\n",
      "Epoch 165/400\n",
      "2/2 - 0s - loss: 0.0107 - mse: 0.0107 - val_loss: 0.0109 - val_mse: 0.0109 - 18ms/epoch - 9ms/step\n",
      "Epoch 166/400\n",
      "2/2 - 0s - loss: 0.0105 - mse: 0.0105 - val_loss: 0.0107 - val_mse: 0.0107 - 18ms/epoch - 9ms/step\n",
      "Epoch 167/400\n",
      "2/2 - 0s - loss: 0.0104 - mse: 0.0104 - val_loss: 0.0106 - val_mse: 0.0106 - 23ms/epoch - 12ms/step\n",
      "Epoch 168/400\n",
      "2/2 - 0s - loss: 0.0103 - mse: 0.0103 - val_loss: 0.0104 - val_mse: 0.0104 - 18ms/epoch - 9ms/step\n",
      "Epoch 169/400\n",
      "2/2 - 0s - loss: 0.0101 - mse: 0.0101 - val_loss: 0.0103 - val_mse: 0.0103 - 17ms/epoch - 8ms/step\n",
      "Epoch 170/400\n",
      "2/2 - 0s - loss: 0.0100 - mse: 0.0100 - val_loss: 0.0101 - val_mse: 0.0101 - 17ms/epoch - 9ms/step\n",
      "Epoch 171/400\n",
      "2/2 - 0s - loss: 0.0099 - mse: 0.0099 - val_loss: 0.0100 - val_mse: 0.0100 - 19ms/epoch - 9ms/step\n",
      "Epoch 172/400\n",
      "2/2 - 0s - loss: 0.0097 - mse: 0.0097 - val_loss: 0.0099 - val_mse: 0.0099 - 22ms/epoch - 11ms/step\n",
      "Epoch 173/400\n",
      "2/2 - 0s - loss: 0.0096 - mse: 0.0096 - val_loss: 0.0098 - val_mse: 0.0098 - 18ms/epoch - 9ms/step\n",
      "Epoch 174/400\n",
      "2/2 - 0s - loss: 0.0095 - mse: 0.0095 - val_loss: 0.0096 - val_mse: 0.0096 - 17ms/epoch - 9ms/step\n",
      "Epoch 175/400\n",
      "2/2 - 0s - loss: 0.0093 - mse: 0.0093 - val_loss: 0.0095 - val_mse: 0.0095 - 17ms/epoch - 8ms/step\n",
      "Epoch 176/400\n",
      "2/2 - 0s - loss: 0.0092 - mse: 0.0092 - val_loss: 0.0094 - val_mse: 0.0094 - 18ms/epoch - 9ms/step\n",
      "Epoch 177/400\n",
      "2/2 - 0s - loss: 0.0091 - mse: 0.0091 - val_loss: 0.0093 - val_mse: 0.0093 - 19ms/epoch - 9ms/step\n",
      "Epoch 178/400\n",
      "2/2 - 0s - loss: 0.0090 - mse: 0.0090 - val_loss: 0.0092 - val_mse: 0.0092 - 21ms/epoch - 10ms/step\n",
      "Epoch 179/400\n",
      "2/2 - 0s - loss: 0.0089 - mse: 0.0089 - val_loss: 0.0090 - val_mse: 0.0090 - 18ms/epoch - 9ms/step\n",
      "Epoch 180/400\n",
      "2/2 - 0s - loss: 0.0088 - mse: 0.0088 - val_loss: 0.0089 - val_mse: 0.0089 - 17ms/epoch - 8ms/step\n",
      "Epoch 181/400\n",
      "2/2 - 0s - loss: 0.0087 - mse: 0.0087 - val_loss: 0.0088 - val_mse: 0.0088 - 16ms/epoch - 8ms/step\n",
      "Epoch 182/400\n",
      "2/2 - 0s - loss: 0.0086 - mse: 0.0086 - val_loss: 0.0087 - val_mse: 0.0087 - 18ms/epoch - 9ms/step\n",
      "Epoch 183/400\n",
      "2/2 - 0s - loss: 0.0085 - mse: 0.0085 - val_loss: 0.0086 - val_mse: 0.0086 - 19ms/epoch - 9ms/step\n",
      "Epoch 184/400\n",
      "2/2 - 0s - loss: 0.0084 - mse: 0.0084 - val_loss: 0.0085 - val_mse: 0.0085 - 22ms/epoch - 11ms/step\n",
      "Epoch 185/400\n",
      "2/2 - 0s - loss: 0.0082 - mse: 0.0082 - val_loss: 0.0084 - val_mse: 0.0084 - 17ms/epoch - 9ms/step\n",
      "Epoch 186/400\n",
      "2/2 - 0s - loss: 0.0082 - mse: 0.0082 - val_loss: 0.0083 - val_mse: 0.0083 - 17ms/epoch - 9ms/step\n",
      "Epoch 187/400\n",
      "2/2 - 0s - loss: 0.0081 - mse: 0.0081 - val_loss: 0.0082 - val_mse: 0.0082 - 16ms/epoch - 8ms/step\n",
      "Epoch 188/400\n",
      "2/2 - 0s - loss: 0.0080 - mse: 0.0080 - val_loss: 0.0081 - val_mse: 0.0081 - 18ms/epoch - 9ms/step\n",
      "Epoch 189/400\n",
      "2/2 - 0s - loss: 0.0079 - mse: 0.0079 - val_loss: 0.0080 - val_mse: 0.0080 - 18ms/epoch - 9ms/step\n",
      "Epoch 190/400\n",
      "2/2 - 0s - loss: 0.0078 - mse: 0.0078 - val_loss: 0.0079 - val_mse: 0.0079 - 22ms/epoch - 11ms/step\n",
      "Epoch 191/400\n",
      "2/2 - 0s - loss: 0.0077 - mse: 0.0077 - val_loss: 0.0078 - val_mse: 0.0078 - 21ms/epoch - 10ms/step\n",
      "Epoch 192/400\n",
      "2/2 - 0s - loss: 0.0076 - mse: 0.0076 - val_loss: 0.0077 - val_mse: 0.0077 - 17ms/epoch - 8ms/step\n",
      "Epoch 193/400\n",
      "2/2 - 0s - loss: 0.0075 - mse: 0.0075 - val_loss: 0.0077 - val_mse: 0.0077 - 18ms/epoch - 9ms/step\n",
      "Epoch 194/400\n",
      "2/2 - 0s - loss: 0.0074 - mse: 0.0074 - val_loss: 0.0076 - val_mse: 0.0076 - 18ms/epoch - 9ms/step\n",
      "Epoch 195/400\n",
      "2/2 - 0s - loss: 0.0074 - mse: 0.0074 - val_loss: 0.0075 - val_mse: 0.0075 - 20ms/epoch - 10ms/step\n",
      "Epoch 196/400\n",
      "2/2 - 0s - loss: 0.0073 - mse: 0.0073 - val_loss: 0.0074 - val_mse: 0.0074 - 20ms/epoch - 10ms/step\n",
      "Epoch 197/400\n",
      "2/2 - 0s - loss: 0.0072 - mse: 0.0072 - val_loss: 0.0073 - val_mse: 0.0073 - 17ms/epoch - 8ms/step\n",
      "Epoch 198/400\n",
      "2/2 - 0s - loss: 0.0071 - mse: 0.0071 - val_loss: 0.0072 - val_mse: 0.0072 - 17ms/epoch - 8ms/step\n",
      "Epoch 199/400\n",
      "2/2 - 0s - loss: 0.0071 - mse: 0.0071 - val_loss: 0.0072 - val_mse: 0.0072 - 17ms/epoch - 9ms/step\n",
      "Epoch 200/400\n",
      "2/2 - 0s - loss: 0.0070 - mse: 0.0070 - val_loss: 0.0071 - val_mse: 0.0071 - 18ms/epoch - 9ms/step\n",
      "Epoch 201/400\n",
      "2/2 - 0s - loss: 0.0069 - mse: 0.0069 - val_loss: 0.0070 - val_mse: 0.0070 - 22ms/epoch - 11ms/step\n",
      "Epoch 202/400\n",
      "2/2 - 0s - loss: 0.0068 - mse: 0.0068 - val_loss: 0.0069 - val_mse: 0.0069 - 19ms/epoch - 10ms/step\n",
      "Epoch 203/400\n",
      "2/2 - 0s - loss: 0.0068 - mse: 0.0068 - val_loss: 0.0069 - val_mse: 0.0069 - 17ms/epoch - 9ms/step\n",
      "Epoch 204/400\n",
      "2/2 - 0s - loss: 0.0067 - mse: 0.0067 - val_loss: 0.0068 - val_mse: 0.0068 - 17ms/epoch - 9ms/step\n",
      "Epoch 205/400\n",
      "2/2 - 0s - loss: 0.0066 - mse: 0.0066 - val_loss: 0.0067 - val_mse: 0.0067 - 18ms/epoch - 9ms/step\n",
      "Epoch 206/400\n",
      "2/2 - 0s - loss: 0.0066 - mse: 0.0066 - val_loss: 0.0067 - val_mse: 0.0067 - 18ms/epoch - 9ms/step\n",
      "Epoch 207/400\n",
      "2/2 - 0s - loss: 0.0065 - mse: 0.0065 - val_loss: 0.0066 - val_mse: 0.0066 - 23ms/epoch - 11ms/step\n",
      "Epoch 208/400\n",
      "2/2 - 0s - loss: 0.0064 - mse: 0.0064 - val_loss: 0.0065 - val_mse: 0.0065 - 17ms/epoch - 9ms/step\n",
      "Epoch 209/400\n",
      "2/2 - 0s - loss: 0.0064 - mse: 0.0064 - val_loss: 0.0065 - val_mse: 0.0065 - 17ms/epoch - 9ms/step\n",
      "Epoch 210/400\n",
      "2/2 - 0s - loss: 0.0063 - mse: 0.0063 - val_loss: 0.0064 - val_mse: 0.0064 - 20ms/epoch - 10ms/step\n",
      "Epoch 211/400\n",
      "2/2 - 0s - loss: 0.0063 - mse: 0.0063 - val_loss: 0.0064 - val_mse: 0.0064 - 21ms/epoch - 11ms/step\n",
      "Epoch 212/400\n",
      "2/2 - 0s - loss: 0.0062 - mse: 0.0062 - val_loss: 0.0063 - val_mse: 0.0063 - 20ms/epoch - 10ms/step\n",
      "Epoch 213/400\n",
      "2/2 - 0s - loss: 0.0061 - mse: 0.0061 - val_loss: 0.0062 - val_mse: 0.0062 - 17ms/epoch - 9ms/step\n",
      "Epoch 214/400\n",
      "2/2 - 0s - loss: 0.0061 - mse: 0.0061 - val_loss: 0.0061 - val_mse: 0.0061 - 18ms/epoch - 9ms/step\n",
      "Epoch 215/400\n",
      "2/2 - 0s - loss: 0.0060 - mse: 0.0060 - val_loss: 0.0061 - val_mse: 0.0061 - 16ms/epoch - 8ms/step\n",
      "Epoch 216/400\n",
      "2/2 - 0s - loss: 0.0060 - mse: 0.0060 - val_loss: 0.0060 - val_mse: 0.0060 - 17ms/epoch - 9ms/step\n",
      "Epoch 217/400\n",
      "2/2 - 0s - loss: 0.0059 - mse: 0.0059 - val_loss: 0.0060 - val_mse: 0.0060 - 16ms/epoch - 8ms/step\n",
      "Epoch 218/400\n",
      "2/2 - 0s - loss: 0.0059 - mse: 0.0059 - val_loss: 0.0059 - val_mse: 0.0059 - 18ms/epoch - 9ms/step\n",
      "Epoch 219/400\n",
      "2/2 - 0s - loss: 0.0058 - mse: 0.0058 - val_loss: 0.0059 - val_mse: 0.0059 - 22ms/epoch - 11ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 220/400\n",
      "2/2 - 0s - loss: 0.0057 - mse: 0.0057 - val_loss: 0.0058 - val_mse: 0.0058 - 18ms/epoch - 9ms/step\n",
      "Epoch 221/400\n",
      "2/2 - 0s - loss: 0.0057 - mse: 0.0057 - val_loss: 0.0058 - val_mse: 0.0058 - 17ms/epoch - 9ms/step\n",
      "Epoch 222/400\n",
      "2/2 - 0s - loss: 0.0057 - mse: 0.0057 - val_loss: 0.0057 - val_mse: 0.0057 - 17ms/epoch - 8ms/step\n",
      "Epoch 223/400\n",
      "2/2 - 0s - loss: 0.0056 - mse: 0.0056 - val_loss: 0.0056 - val_mse: 0.0056 - 18ms/epoch - 9ms/step\n",
      "Epoch 224/400\n",
      "2/2 - 0s - loss: 0.0055 - mse: 0.0055 - val_loss: 0.0056 - val_mse: 0.0056 - 15ms/epoch - 8ms/step\n",
      "Epoch 225/400\n",
      "2/2 - 0s - loss: 0.0055 - mse: 0.0055 - val_loss: 0.0056 - val_mse: 0.0056 - 23ms/epoch - 11ms/step\n",
      "Epoch 226/400\n",
      "2/2 - 0s - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0055 - val_mse: 0.0055 - 18ms/epoch - 9ms/step\n",
      "Epoch 227/400\n",
      "2/2 - 0s - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0055 - val_mse: 0.0055 - 15ms/epoch - 7ms/step\n",
      "Epoch 228/400\n",
      "2/2 - 0s - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0054 - val_mse: 0.0054 - 15ms/epoch - 8ms/step\n",
      "Epoch 229/400\n",
      "2/2 - 0s - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0054 - val_mse: 0.0054 - 15ms/epoch - 8ms/step\n",
      "Epoch 230/400\n",
      "2/2 - 0s - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0053 - val_mse: 0.0053 - 15ms/epoch - 8ms/step\n",
      "Epoch 231/400\n",
      "2/2 - 0s - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0053 - val_mse: 0.0053 - 16ms/epoch - 8ms/step\n",
      "Epoch 232/400\n",
      "2/2 - 0s - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0052 - val_mse: 0.0052 - 16ms/epoch - 8ms/step\n",
      "Epoch 233/400\n",
      "2/2 - 0s - loss: 0.0051 - mse: 0.0051 - val_loss: 0.0052 - val_mse: 0.0052 - 15ms/epoch - 8ms/step\n",
      "Epoch 234/400\n",
      "2/2 - 0s - loss: 0.0051 - mse: 0.0051 - val_loss: 0.0051 - val_mse: 0.0051 - 15ms/epoch - 8ms/step\n",
      "Epoch 235/400\n",
      "2/2 - 0s - loss: 0.0051 - mse: 0.0051 - val_loss: 0.0051 - val_mse: 0.0051 - 16ms/epoch - 8ms/step\n",
      "Epoch 236/400\n",
      "2/2 - 0s - loss: 0.0050 - mse: 0.0050 - val_loss: 0.0051 - val_mse: 0.0051 - 15ms/epoch - 8ms/step\n",
      "Epoch 237/400\n",
      "2/2 - 0s - loss: 0.0050 - mse: 0.0050 - val_loss: 0.0050 - val_mse: 0.0050 - 16ms/epoch - 8ms/step\n",
      "Epoch 238/400\n",
      "2/2 - 0s - loss: 0.0049 - mse: 0.0049 - val_loss: 0.0050 - val_mse: 0.0050 - 15ms/epoch - 8ms/step\n",
      "Epoch 239/400\n",
      "2/2 - 0s - loss: 0.0049 - mse: 0.0049 - val_loss: 0.0049 - val_mse: 0.0049 - 16ms/epoch - 8ms/step\n",
      "Epoch 240/400\n",
      "2/2 - 0s - loss: 0.0048 - mse: 0.0048 - val_loss: 0.0049 - val_mse: 0.0049 - 16ms/epoch - 8ms/step\n",
      "Epoch 241/400\n",
      "2/2 - 0s - loss: 0.0048 - mse: 0.0048 - val_loss: 0.0049 - val_mse: 0.0049 - 16ms/epoch - 8ms/step\n",
      "Epoch 242/400\n",
      "2/2 - 0s - loss: 0.0048 - mse: 0.0048 - val_loss: 0.0048 - val_mse: 0.0048 - 16ms/epoch - 8ms/step\n",
      "Epoch 243/400\n",
      "2/2 - 0s - loss: 0.0047 - mse: 0.0047 - val_loss: 0.0048 - val_mse: 0.0048 - 16ms/epoch - 8ms/step\n",
      "Epoch 244/400\n",
      "2/2 - 0s - loss: 0.0047 - mse: 0.0047 - val_loss: 0.0047 - val_mse: 0.0047 - 16ms/epoch - 8ms/step\n",
      "Epoch 245/400\n",
      "2/2 - 0s - loss: 0.0047 - mse: 0.0047 - val_loss: 0.0047 - val_mse: 0.0047 - 15ms/epoch - 8ms/step\n",
      "Epoch 246/400\n",
      "2/2 - 0s - loss: 0.0046 - mse: 0.0046 - val_loss: 0.0047 - val_mse: 0.0047 - 15ms/epoch - 8ms/step\n",
      "Epoch 247/400\n",
      "2/2 - 0s - loss: 0.0046 - mse: 0.0046 - val_loss: 0.0046 - val_mse: 0.0046 - 16ms/epoch - 8ms/step\n",
      "Epoch 248/400\n",
      "2/2 - 0s - loss: 0.0046 - mse: 0.0046 - val_loss: 0.0046 - val_mse: 0.0046 - 16ms/epoch - 8ms/step\n",
      "Epoch 249/400\n",
      "2/2 - 0s - loss: 0.0045 - mse: 0.0045 - val_loss: 0.0046 - val_mse: 0.0046 - 16ms/epoch - 8ms/step\n",
      "Epoch 250/400\n",
      "2/2 - 0s - loss: 0.0045 - mse: 0.0045 - val_loss: 0.0045 - val_mse: 0.0045 - 16ms/epoch - 8ms/step\n",
      "Epoch 251/400\n",
      "2/2 - 0s - loss: 0.0045 - mse: 0.0045 - val_loss: 0.0045 - val_mse: 0.0045 - 16ms/epoch - 8ms/step\n",
      "Epoch 252/400\n",
      "2/2 - 0s - loss: 0.0044 - mse: 0.0044 - val_loss: 0.0045 - val_mse: 0.0045 - 17ms/epoch - 8ms/step\n",
      "Epoch 253/400\n",
      "2/2 - 0s - loss: 0.0044 - mse: 0.0044 - val_loss: 0.0044 - val_mse: 0.0044 - 16ms/epoch - 8ms/step\n",
      "Epoch 254/400\n",
      "2/2 - 0s - loss: 0.0044 - mse: 0.0044 - val_loss: 0.0044 - val_mse: 0.0044 - 16ms/epoch - 8ms/step\n",
      "Epoch 255/400\n",
      "2/2 - 0s - loss: 0.0043 - mse: 0.0043 - val_loss: 0.0044 - val_mse: 0.0044 - 16ms/epoch - 8ms/step\n",
      "Epoch 256/400\n",
      "2/2 - 0s - loss: 0.0043 - mse: 0.0043 - val_loss: 0.0043 - val_mse: 0.0043 - 16ms/epoch - 8ms/step\n",
      "Epoch 257/400\n",
      "2/2 - 0s - loss: 0.0043 - mse: 0.0043 - val_loss: 0.0043 - val_mse: 0.0043 - 16ms/epoch - 8ms/step\n",
      "Epoch 258/400\n",
      "2/2 - 0s - loss: 0.0043 - mse: 0.0043 - val_loss: 0.0043 - val_mse: 0.0043 - 16ms/epoch - 8ms/step\n",
      "Epoch 259/400\n",
      "2/2 - 0s - loss: 0.0042 - mse: 0.0042 - val_loss: 0.0043 - val_mse: 0.0043 - 16ms/epoch - 8ms/step\n",
      "Epoch 260/400\n",
      "2/2 - 0s - loss: 0.0042 - mse: 0.0042 - val_loss: 0.0042 - val_mse: 0.0042 - 16ms/epoch - 8ms/step\n",
      "Epoch 261/400\n",
      "2/2 - 0s - loss: 0.0042 - mse: 0.0042 - val_loss: 0.0042 - val_mse: 0.0042 - 16ms/epoch - 8ms/step\n",
      "Epoch 262/400\n",
      "2/2 - 0s - loss: 0.0041 - mse: 0.0041 - val_loss: 0.0042 - val_mse: 0.0042 - 15ms/epoch - 8ms/step\n",
      "Epoch 263/400\n",
      "2/2 - 0s - loss: 0.0041 - mse: 0.0041 - val_loss: 0.0041 - val_mse: 0.0041 - 16ms/epoch - 8ms/step\n",
      "Epoch 264/400\n",
      "2/2 - 0s - loss: 0.0041 - mse: 0.0041 - val_loss: 0.0041 - val_mse: 0.0041 - 16ms/epoch - 8ms/step\n",
      "Epoch 265/400\n",
      "2/2 - 0s - loss: 0.0041 - mse: 0.0041 - val_loss: 0.0041 - val_mse: 0.0041 - 16ms/epoch - 8ms/step\n",
      "Epoch 266/400\n",
      "2/2 - 0s - loss: 0.0040 - mse: 0.0040 - val_loss: 0.0040 - val_mse: 0.0040 - 16ms/epoch - 8ms/step\n",
      "Epoch 267/400\n",
      "2/2 - 0s - loss: 0.0040 - mse: 0.0040 - val_loss: 0.0040 - val_mse: 0.0040 - 15ms/epoch - 8ms/step\n",
      "Epoch 268/400\n",
      "2/2 - 0s - loss: 0.0040 - mse: 0.0040 - val_loss: 0.0040 - val_mse: 0.0040 - 15ms/epoch - 8ms/step\n",
      "Epoch 269/400\n",
      "2/2 - 0s - loss: 0.0039 - mse: 0.0039 - val_loss: 0.0040 - val_mse: 0.0040 - 15ms/epoch - 8ms/step\n",
      "Epoch 270/400\n",
      "2/2 - 0s - loss: 0.0039 - mse: 0.0039 - val_loss: 0.0039 - val_mse: 0.0039 - 16ms/epoch - 8ms/step\n",
      "Epoch 271/400\n",
      "2/2 - 0s - loss: 0.0039 - mse: 0.0039 - val_loss: 0.0039 - val_mse: 0.0039 - 16ms/epoch - 8ms/step\n",
      "Epoch 272/400\n",
      "2/2 - 0s - loss: 0.0039 - mse: 0.0039 - val_loss: 0.0039 - val_mse: 0.0039 - 15ms/epoch - 8ms/step\n",
      "Epoch 273/400\n",
      "2/2 - 0s - loss: 0.0039 - mse: 0.0039 - val_loss: 0.0038 - val_mse: 0.0038 - 16ms/epoch - 8ms/step\n",
      "Epoch 274/400\n",
      "2/2 - 0s - loss: 0.0038 - mse: 0.0038 - val_loss: 0.0038 - val_mse: 0.0038 - 16ms/epoch - 8ms/step\n",
      "Epoch 275/400\n",
      "2/2 - 0s - loss: 0.0038 - mse: 0.0038 - val_loss: 0.0038 - val_mse: 0.0038 - 16ms/epoch - 8ms/step\n",
      "Epoch 276/400\n",
      "2/2 - 0s - loss: 0.0038 - mse: 0.0038 - val_loss: 0.0038 - val_mse: 0.0038 - 16ms/epoch - 8ms/step\n",
      "Epoch 277/400\n",
      "2/2 - 0s - loss: 0.0038 - mse: 0.0038 - val_loss: 0.0038 - val_mse: 0.0038 - 16ms/epoch - 8ms/step\n",
      "Epoch 278/400\n",
      "2/2 - 0s - loss: 0.0037 - mse: 0.0037 - val_loss: 0.0037 - val_mse: 0.0037 - 15ms/epoch - 8ms/step\n",
      "Epoch 279/400\n",
      "2/2 - 0s - loss: 0.0037 - mse: 0.0037 - val_loss: 0.0037 - val_mse: 0.0037 - 16ms/epoch - 8ms/step\n",
      "Epoch 280/400\n",
      "2/2 - 0s - loss: 0.0037 - mse: 0.0037 - val_loss: 0.0037 - val_mse: 0.0037 - 15ms/epoch - 8ms/step\n",
      "Epoch 281/400\n",
      "2/2 - 0s - loss: 0.0037 - mse: 0.0037 - val_loss: 0.0037 - val_mse: 0.0037 - 26ms/epoch - 13ms/step\n",
      "Epoch 282/400\n",
      "2/2 - 0s - loss: 0.0036 - mse: 0.0036 - val_loss: 0.0036 - val_mse: 0.0036 - 20ms/epoch - 10ms/step\n",
      "Epoch 283/400\n",
      "2/2 - 0s - loss: 0.0036 - mse: 0.0036 - val_loss: 0.0036 - val_mse: 0.0036 - 15ms/epoch - 7ms/step\n",
      "Epoch 284/400\n",
      "2/2 - 0s - loss: 0.0036 - mse: 0.0036 - val_loss: 0.0036 - val_mse: 0.0036 - 15ms/epoch - 8ms/step\n",
      "Epoch 285/400\n",
      "2/2 - 0s - loss: 0.0036 - mse: 0.0036 - val_loss: 0.0036 - val_mse: 0.0036 - 16ms/epoch - 8ms/step\n",
      "Epoch 286/400\n",
      "2/2 - 0s - loss: 0.0036 - mse: 0.0036 - val_loss: 0.0036 - val_mse: 0.0036 - 16ms/epoch - 8ms/step\n",
      "Epoch 287/400\n",
      "2/2 - 0s - loss: 0.0035 - mse: 0.0035 - val_loss: 0.0035 - val_mse: 0.0035 - 16ms/epoch - 8ms/step\n",
      "Epoch 288/400\n",
      "2/2 - 0s - loss: 0.0035 - mse: 0.0035 - val_loss: 0.0035 - val_mse: 0.0035 - 16ms/epoch - 8ms/step\n",
      "Epoch 289/400\n",
      "2/2 - 0s - loss: 0.0035 - mse: 0.0035 - val_loss: 0.0035 - val_mse: 0.0035 - 16ms/epoch - 8ms/step\n",
      "Epoch 290/400\n",
      "2/2 - 0s - loss: 0.0035 - mse: 0.0035 - val_loss: 0.0035 - val_mse: 0.0035 - 16ms/epoch - 8ms/step\n",
      "Epoch 291/400\n",
      "2/2 - 0s - loss: 0.0035 - mse: 0.0035 - val_loss: 0.0035 - val_mse: 0.0035 - 15ms/epoch - 8ms/step\n",
      "Epoch 292/400\n",
      "2/2 - 0s - loss: 0.0034 - mse: 0.0034 - val_loss: 0.0034 - val_mse: 0.0034 - 16ms/epoch - 8ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 293/400\n",
      "2/2 - 0s - loss: 0.0034 - mse: 0.0034 - val_loss: 0.0034 - val_mse: 0.0034 - 17ms/epoch - 8ms/step\n",
      "Epoch 294/400\n",
      "2/2 - 0s - loss: 0.0034 - mse: 0.0034 - val_loss: 0.0034 - val_mse: 0.0034 - 15ms/epoch - 7ms/step\n",
      "Epoch 295/400\n",
      "2/2 - 0s - loss: 0.0034 - mse: 0.0034 - val_loss: 0.0034 - val_mse: 0.0034 - 15ms/epoch - 8ms/step\n",
      "Epoch 296/400\n",
      "2/2 - 0s - loss: 0.0034 - mse: 0.0034 - val_loss: 0.0034 - val_mse: 0.0034 - 16ms/epoch - 8ms/step\n",
      "Epoch 297/400\n",
      "2/2 - 0s - loss: 0.0034 - mse: 0.0034 - val_loss: 0.0033 - val_mse: 0.0033 - 16ms/epoch - 8ms/step\n",
      "Epoch 298/400\n",
      "2/2 - 0s - loss: 0.0033 - mse: 0.0033 - val_loss: 0.0033 - val_mse: 0.0033 - 15ms/epoch - 7ms/step\n",
      "Epoch 299/400\n",
      "2/2 - 0s - loss: 0.0033 - mse: 0.0033 - val_loss: 0.0033 - val_mse: 0.0033 - 15ms/epoch - 8ms/step\n",
      "Epoch 300/400\n",
      "2/2 - 0s - loss: 0.0033 - mse: 0.0033 - val_loss: 0.0033 - val_mse: 0.0033 - 16ms/epoch - 8ms/step\n",
      "Epoch 301/400\n",
      "2/2 - 0s - loss: 0.0033 - mse: 0.0033 - val_loss: 0.0033 - val_mse: 0.0033 - 16ms/epoch - 8ms/step\n",
      "Epoch 302/400\n",
      "2/2 - 0s - loss: 0.0033 - mse: 0.0033 - val_loss: 0.0033 - val_mse: 0.0033 - 16ms/epoch - 8ms/step\n",
      "Epoch 303/400\n",
      "2/2 - 0s - loss: 0.0032 - mse: 0.0032 - val_loss: 0.0032 - val_mse: 0.0032 - 15ms/epoch - 7ms/step\n",
      "Epoch 304/400\n",
      "2/2 - 0s - loss: 0.0032 - mse: 0.0032 - val_loss: 0.0032 - val_mse: 0.0032 - 16ms/epoch - 8ms/step\n",
      "Epoch 305/400\n",
      "2/2 - 0s - loss: 0.0032 - mse: 0.0032 - val_loss: 0.0032 - val_mse: 0.0032 - 16ms/epoch - 8ms/step\n",
      "Epoch 306/400\n",
      "2/2 - 0s - loss: 0.0032 - mse: 0.0032 - val_loss: 0.0032 - val_mse: 0.0032 - 16ms/epoch - 8ms/step\n",
      "Epoch 307/400\n",
      "2/2 - 0s - loss: 0.0032 - mse: 0.0032 - val_loss: 0.0032 - val_mse: 0.0032 - 15ms/epoch - 7ms/step\n",
      "Epoch 308/400\n",
      "2/2 - 0s - loss: 0.0032 - mse: 0.0032 - val_loss: 0.0031 - val_mse: 0.0031 - 15ms/epoch - 8ms/step\n",
      "Epoch 309/400\n",
      "2/2 - 0s - loss: 0.0031 - mse: 0.0031 - val_loss: 0.0031 - val_mse: 0.0031 - 15ms/epoch - 7ms/step\n",
      "Epoch 310/400\n",
      "2/2 - 0s - loss: 0.0031 - mse: 0.0031 - val_loss: 0.0031 - val_mse: 0.0031 - 16ms/epoch - 8ms/step\n",
      "Epoch 311/400\n",
      "2/2 - 0s - loss: 0.0031 - mse: 0.0031 - val_loss: 0.0031 - val_mse: 0.0031 - 16ms/epoch - 8ms/step\n",
      "Epoch 312/400\n",
      "2/2 - 0s - loss: 0.0031 - mse: 0.0031 - val_loss: 0.0031 - val_mse: 0.0031 - 16ms/epoch - 8ms/step\n",
      "Epoch 313/400\n",
      "2/2 - 0s - loss: 0.0031 - mse: 0.0031 - val_loss: 0.0031 - val_mse: 0.0031 - 15ms/epoch - 8ms/step\n",
      "Epoch 314/400\n",
      "2/2 - 0s - loss: 0.0031 - mse: 0.0031 - val_loss: 0.0031 - val_mse: 0.0031 - 16ms/epoch - 8ms/step\n",
      "Epoch 315/400\n",
      "2/2 - 0s - loss: 0.0031 - mse: 0.0031 - val_loss: 0.0030 - val_mse: 0.0030 - 16ms/epoch - 8ms/step\n",
      "Epoch 316/400\n",
      "2/2 - 0s - loss: 0.0030 - mse: 0.0030 - val_loss: 0.0030 - val_mse: 0.0030 - 16ms/epoch - 8ms/step\n",
      "Epoch 317/400\n",
      "2/2 - 0s - loss: 0.0030 - mse: 0.0030 - val_loss: 0.0030 - val_mse: 0.0030 - 15ms/epoch - 8ms/step\n",
      "Epoch 318/400\n",
      "2/2 - 0s - loss: 0.0030 - mse: 0.0030 - val_loss: 0.0030 - val_mse: 0.0030 - 17ms/epoch - 8ms/step\n",
      "Epoch 319/400\n",
      "2/2 - 0s - loss: 0.0030 - mse: 0.0030 - val_loss: 0.0030 - val_mse: 0.0030 - 16ms/epoch - 8ms/step\n",
      "Epoch 320/400\n",
      "2/2 - 0s - loss: 0.0030 - mse: 0.0030 - val_loss: 0.0030 - val_mse: 0.0030 - 15ms/epoch - 8ms/step\n",
      "Epoch 321/400\n",
      "2/2 - 0s - loss: 0.0030 - mse: 0.0030 - val_loss: 0.0030 - val_mse: 0.0030 - 17ms/epoch - 8ms/step\n",
      "Epoch 322/400\n",
      "2/2 - 0s - loss: 0.0030 - mse: 0.0030 - val_loss: 0.0029 - val_mse: 0.0029 - 15ms/epoch - 7ms/step\n",
      "Epoch 323/400\n",
      "2/2 - 0s - loss: 0.0029 - mse: 0.0029 - val_loss: 0.0029 - val_mse: 0.0029 - 15ms/epoch - 8ms/step\n",
      "Epoch 324/400\n",
      "2/2 - 0s - loss: 0.0029 - mse: 0.0029 - val_loss: 0.0029 - val_mse: 0.0029 - 15ms/epoch - 8ms/step\n",
      "Epoch 325/400\n",
      "2/2 - 0s - loss: 0.0029 - mse: 0.0029 - val_loss: 0.0029 - val_mse: 0.0029 - 16ms/epoch - 8ms/step\n",
      "Epoch 326/400\n",
      "2/2 - 0s - loss: 0.0029 - mse: 0.0029 - val_loss: 0.0029 - val_mse: 0.0029 - 16ms/epoch - 8ms/step\n",
      "Epoch 327/400\n",
      "2/2 - 0s - loss: 0.0029 - mse: 0.0029 - val_loss: 0.0029 - val_mse: 0.0029 - 16ms/epoch - 8ms/step\n",
      "Epoch 328/400\n",
      "2/2 - 0s - loss: 0.0029 - mse: 0.0029 - val_loss: 0.0028 - val_mse: 0.0028 - 16ms/epoch - 8ms/step\n",
      "Epoch 329/400\n",
      "2/2 - 0s - loss: 0.0029 - mse: 0.0029 - val_loss: 0.0028 - val_mse: 0.0028 - 16ms/epoch - 8ms/step\n",
      "Epoch 330/400\n",
      "2/2 - 0s - loss: 0.0028 - mse: 0.0028 - val_loss: 0.0028 - val_mse: 0.0028 - 16ms/epoch - 8ms/step\n",
      "Epoch 331/400\n",
      "2/2 - 0s - loss: 0.0028 - mse: 0.0028 - val_loss: 0.0028 - val_mse: 0.0028 - 16ms/epoch - 8ms/step\n",
      "Epoch 332/400\n",
      "2/2 - 0s - loss: 0.0028 - mse: 0.0028 - val_loss: 0.0028 - val_mse: 0.0028 - 16ms/epoch - 8ms/step\n",
      "Epoch 333/400\n",
      "2/2 - 0s - loss: 0.0028 - mse: 0.0028 - val_loss: 0.0028 - val_mse: 0.0028 - 15ms/epoch - 8ms/step\n",
      "Epoch 334/400\n",
      "2/2 - 0s - loss: 0.0028 - mse: 0.0028 - val_loss: 0.0028 - val_mse: 0.0028 - 16ms/epoch - 8ms/step\n",
      "Epoch 335/400\n",
      "2/2 - 0s - loss: 0.0028 - mse: 0.0028 - val_loss: 0.0028 - val_mse: 0.0028 - 16ms/epoch - 8ms/step\n",
      "Epoch 336/400\n",
      "2/2 - 0s - loss: 0.0028 - mse: 0.0028 - val_loss: 0.0027 - val_mse: 0.0027 - 16ms/epoch - 8ms/step\n",
      "Epoch 337/400\n",
      "2/2 - 0s - loss: 0.0028 - mse: 0.0028 - val_loss: 0.0027 - val_mse: 0.0027 - 15ms/epoch - 8ms/step\n",
      "Epoch 338/400\n",
      "2/2 - 0s - loss: 0.0027 - mse: 0.0027 - val_loss: 0.0027 - val_mse: 0.0027 - 16ms/epoch - 8ms/step\n",
      "Epoch 339/400\n",
      "2/2 - 0s - loss: 0.0027 - mse: 0.0027 - val_loss: 0.0027 - val_mse: 0.0027 - 16ms/epoch - 8ms/step\n",
      "Epoch 340/400\n",
      "2/2 - 0s - loss: 0.0027 - mse: 0.0027 - val_loss: 0.0027 - val_mse: 0.0027 - 16ms/epoch - 8ms/step\n",
      "Epoch 341/400\n",
      "2/2 - 0s - loss: 0.0027 - mse: 0.0027 - val_loss: 0.0027 - val_mse: 0.0027 - 16ms/epoch - 8ms/step\n",
      "Epoch 342/400\n",
      "2/2 - 0s - loss: 0.0027 - mse: 0.0027 - val_loss: 0.0027 - val_mse: 0.0027 - 16ms/epoch - 8ms/step\n",
      "Epoch 343/400\n",
      "2/2 - 0s - loss: 0.0027 - mse: 0.0027 - val_loss: 0.0027 - val_mse: 0.0027 - 16ms/epoch - 8ms/step\n",
      "Epoch 344/400\n",
      "2/2 - 0s - loss: 0.0027 - mse: 0.0027 - val_loss: 0.0027 - val_mse: 0.0027 - 16ms/epoch - 8ms/step\n",
      "Epoch 345/400\n",
      "2/2 - 0s - loss: 0.0027 - mse: 0.0027 - val_loss: 0.0026 - val_mse: 0.0026 - 16ms/epoch - 8ms/step\n",
      "Epoch 346/400\n",
      "2/2 - 0s - loss: 0.0026 - mse: 0.0026 - val_loss: 0.0026 - val_mse: 0.0026 - 16ms/epoch - 8ms/step\n",
      "Epoch 347/400\n",
      "2/2 - 0s - loss: 0.0026 - mse: 0.0026 - val_loss: 0.0026 - val_mse: 0.0026 - 16ms/epoch - 8ms/step\n",
      "Epoch 348/400\n",
      "2/2 - 0s - loss: 0.0026 - mse: 0.0026 - val_loss: 0.0026 - val_mse: 0.0026 - 15ms/epoch - 8ms/step\n",
      "Epoch 349/400\n",
      "2/2 - 0s - loss: 0.0026 - mse: 0.0026 - val_loss: 0.0026 - val_mse: 0.0026 - 16ms/epoch - 8ms/step\n",
      "Epoch 350/400\n",
      "2/2 - 0s - loss: 0.0026 - mse: 0.0026 - val_loss: 0.0026 - val_mse: 0.0026 - 16ms/epoch - 8ms/step\n",
      "Epoch 351/400\n",
      "2/2 - 0s - loss: 0.0026 - mse: 0.0026 - val_loss: 0.0026 - val_mse: 0.0026 - 16ms/epoch - 8ms/step\n",
      "Epoch 352/400\n",
      "2/2 - 0s - loss: 0.0026 - mse: 0.0026 - val_loss: 0.0025 - val_mse: 0.0025 - 16ms/epoch - 8ms/step\n",
      "Epoch 353/400\n",
      "2/2 - 0s - loss: 0.0026 - mse: 0.0026 - val_loss: 0.0025 - val_mse: 0.0025 - 16ms/epoch - 8ms/step\n",
      "Epoch 354/400\n",
      "2/2 - 0s - loss: 0.0026 - mse: 0.0026 - val_loss: 0.0025 - val_mse: 0.0025 - 16ms/epoch - 8ms/step\n",
      "Epoch 355/400\n",
      "2/2 - 0s - loss: 0.0025 - mse: 0.0025 - val_loss: 0.0025 - val_mse: 0.0025 - 17ms/epoch - 8ms/step\n",
      "Epoch 356/400\n",
      "2/2 - 0s - loss: 0.0025 - mse: 0.0025 - val_loss: 0.0025 - val_mse: 0.0025 - 16ms/epoch - 8ms/step\n",
      "Epoch 357/400\n",
      "2/2 - 0s - loss: 0.0025 - mse: 0.0025 - val_loss: 0.0025 - val_mse: 0.0025 - 16ms/epoch - 8ms/step\n",
      "Epoch 358/400\n",
      "2/2 - 0s - loss: 0.0025 - mse: 0.0025 - val_loss: 0.0025 - val_mse: 0.0025 - 16ms/epoch - 8ms/step\n",
      "Epoch 359/400\n",
      "2/2 - 0s - loss: 0.0025 - mse: 0.0025 - val_loss: 0.0025 - val_mse: 0.0025 - 16ms/epoch - 8ms/step\n",
      "Epoch 360/400\n",
      "2/2 - 0s - loss: 0.0025 - mse: 0.0025 - val_loss: 0.0025 - val_mse: 0.0025 - 15ms/epoch - 8ms/step\n",
      "Epoch 361/400\n",
      "2/2 - 0s - loss: 0.0025 - mse: 0.0025 - val_loss: 0.0025 - val_mse: 0.0025 - 16ms/epoch - 8ms/step\n",
      "Epoch 362/400\n",
      "2/2 - 0s - loss: 0.0025 - mse: 0.0025 - val_loss: 0.0025 - val_mse: 0.0025 - 15ms/epoch - 8ms/step\n",
      "Epoch 363/400\n",
      "2/2 - 0s - loss: 0.0025 - mse: 0.0025 - val_loss: 0.0024 - val_mse: 0.0024 - 15ms/epoch - 8ms/step\n",
      "Epoch 364/400\n",
      "2/2 - 0s - loss: 0.0025 - mse: 0.0025 - val_loss: 0.0024 - val_mse: 0.0024 - 16ms/epoch - 8ms/step\n",
      "Epoch 365/400\n",
      "2/2 - 0s - loss: 0.0024 - mse: 0.0024 - val_loss: 0.0024 - val_mse: 0.0024 - 16ms/epoch - 8ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 366/400\n",
      "2/2 - 0s - loss: 0.0024 - mse: 0.0024 - val_loss: 0.0024 - val_mse: 0.0024 - 16ms/epoch - 8ms/step\n",
      "Epoch 367/400\n",
      "2/2 - 0s - loss: 0.0024 - mse: 0.0024 - val_loss: 0.0024 - val_mse: 0.0024 - 16ms/epoch - 8ms/step\n",
      "Epoch 368/400\n",
      "2/2 - 0s - loss: 0.0024 - mse: 0.0024 - val_loss: 0.0024 - val_mse: 0.0024 - 15ms/epoch - 8ms/step\n",
      "Epoch 369/400\n",
      "2/2 - 0s - loss: 0.0024 - mse: 0.0024 - val_loss: 0.0024 - val_mse: 0.0024 - 16ms/epoch - 8ms/step\n",
      "Epoch 370/400\n",
      "2/2 - 0s - loss: 0.0024 - mse: 0.0024 - val_loss: 0.0024 - val_mse: 0.0024 - 16ms/epoch - 8ms/step\n",
      "Epoch 371/400\n",
      "2/2 - 0s - loss: 0.0024 - mse: 0.0024 - val_loss: 0.0024 - val_mse: 0.0024 - 15ms/epoch - 8ms/step\n",
      "Epoch 372/400\n",
      "2/2 - 0s - loss: 0.0024 - mse: 0.0024 - val_loss: 0.0024 - val_mse: 0.0024 - 15ms/epoch - 8ms/step\n",
      "Epoch 373/400\n",
      "2/2 - 0s - loss: 0.0024 - mse: 0.0024 - val_loss: 0.0024 - val_mse: 0.0024 - 15ms/epoch - 7ms/step\n",
      "Epoch 374/400\n",
      "2/2 - 0s - loss: 0.0024 - mse: 0.0024 - val_loss: 0.0023 - val_mse: 0.0023 - 15ms/epoch - 8ms/step\n",
      "Epoch 375/400\n",
      "2/2 - 0s - loss: 0.0024 - mse: 0.0024 - val_loss: 0.0023 - val_mse: 0.0023 - 16ms/epoch - 8ms/step\n",
      "Epoch 376/400\n",
      "2/2 - 0s - loss: 0.0023 - mse: 0.0023 - val_loss: 0.0023 - val_mse: 0.0023 - 15ms/epoch - 8ms/step\n",
      "Epoch 377/400\n",
      "2/2 - 0s - loss: 0.0023 - mse: 0.0023 - val_loss: 0.0023 - val_mse: 0.0023 - 16ms/epoch - 8ms/step\n",
      "Epoch 378/400\n",
      "2/2 - 0s - loss: 0.0023 - mse: 0.0023 - val_loss: 0.0023 - val_mse: 0.0023 - 17ms/epoch - 8ms/step\n",
      "Epoch 379/400\n",
      "2/2 - 0s - loss: 0.0023 - mse: 0.0023 - val_loss: 0.0023 - val_mse: 0.0023 - 16ms/epoch - 8ms/step\n",
      "Epoch 380/400\n",
      "2/2 - 0s - loss: 0.0023 - mse: 0.0023 - val_loss: 0.0023 - val_mse: 0.0023 - 15ms/epoch - 8ms/step\n",
      "Epoch 381/400\n",
      "2/2 - 0s - loss: 0.0023 - mse: 0.0023 - val_loss: 0.0023 - val_mse: 0.0023 - 15ms/epoch - 8ms/step\n",
      "Epoch 382/400\n",
      "2/2 - 0s - loss: 0.0023 - mse: 0.0023 - val_loss: 0.0023 - val_mse: 0.0023 - 16ms/epoch - 8ms/step\n",
      "Epoch 383/400\n",
      "2/2 - 0s - loss: 0.0023 - mse: 0.0023 - val_loss: 0.0023 - val_mse: 0.0023 - 16ms/epoch - 8ms/step\n",
      "Epoch 384/400\n",
      "2/2 - 0s - loss: 0.0023 - mse: 0.0023 - val_loss: 0.0023 - val_mse: 0.0023 - 16ms/epoch - 8ms/step\n",
      "Epoch 385/400\n",
      "2/2 - 0s - loss: 0.0023 - mse: 0.0023 - val_loss: 0.0022 - val_mse: 0.0022 - 16ms/epoch - 8ms/step\n",
      "Epoch 386/400\n",
      "2/2 - 0s - loss: 0.0023 - mse: 0.0023 - val_loss: 0.0022 - val_mse: 0.0022 - 16ms/epoch - 8ms/step\n",
      "Epoch 387/400\n",
      "2/2 - 0s - loss: 0.0023 - mse: 0.0023 - val_loss: 0.0022 - val_mse: 0.0022 - 16ms/epoch - 8ms/step\n",
      "Epoch 388/400\n",
      "2/2 - 0s - loss: 0.0022 - mse: 0.0022 - val_loss: 0.0022 - val_mse: 0.0022 - 16ms/epoch - 8ms/step\n",
      "Epoch 389/400\n",
      "2/2 - 0s - loss: 0.0022 - mse: 0.0022 - val_loss: 0.0022 - val_mse: 0.0022 - 16ms/epoch - 8ms/step\n",
      "Epoch 390/400\n",
      "2/2 - 0s - loss: 0.0022 - mse: 0.0022 - val_loss: 0.0022 - val_mse: 0.0022 - 17ms/epoch - 8ms/step\n",
      "Epoch 391/400\n",
      "2/2 - 0s - loss: 0.0022 - mse: 0.0022 - val_loss: 0.0022 - val_mse: 0.0022 - 15ms/epoch - 8ms/step\n",
      "Epoch 392/400\n",
      "2/2 - 0s - loss: 0.0022 - mse: 0.0022 - val_loss: 0.0022 - val_mse: 0.0022 - 15ms/epoch - 8ms/step\n",
      "Epoch 393/400\n",
      "2/2 - 0s - loss: 0.0022 - mse: 0.0022 - val_loss: 0.0022 - val_mse: 0.0022 - 15ms/epoch - 8ms/step\n",
      "Epoch 394/400\n",
      "2/2 - 0s - loss: 0.0022 - mse: 0.0022 - val_loss: 0.0022 - val_mse: 0.0022 - 16ms/epoch - 8ms/step\n",
      "Epoch 395/400\n",
      "2/2 - 0s - loss: 0.0022 - mse: 0.0022 - val_loss: 0.0022 - val_mse: 0.0022 - 16ms/epoch - 8ms/step\n",
      "Epoch 396/400\n",
      "2/2 - 0s - loss: 0.0022 - mse: 0.0022 - val_loss: 0.0022 - val_mse: 0.0022 - 16ms/epoch - 8ms/step\n",
      "Epoch 397/400\n",
      "2/2 - 0s - loss: 0.0022 - mse: 0.0022 - val_loss: 0.0021 - val_mse: 0.0021 - 16ms/epoch - 8ms/step\n",
      "Epoch 398/400\n",
      "2/2 - 0s - loss: 0.0022 - mse: 0.0022 - val_loss: 0.0021 - val_mse: 0.0021 - 15ms/epoch - 8ms/step\n",
      "Epoch 399/400\n",
      "2/2 - 0s - loss: 0.0022 - mse: 0.0022 - val_loss: 0.0021 - val_mse: 0.0021 - 16ms/epoch - 8ms/step\n",
      "Epoch 400/400\n",
      "2/2 - 0s - loss: 0.0022 - mse: 0.0022 - val_loss: 0.0021 - val_mse: 0.0021 - 16ms/epoch - 8ms/step\n",
      "7.80341100692749\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "history = final_model_PQ.fit(X_train, y_train, epochs=400, batch_size=256, validation_split=0.25, verbose=2)\n",
    "t1 = time.time()\n",
    "print(t1-t0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "388b3084",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2/2 [==============================] - 0s 2ms/step - loss: 0.0021 - mse: 0.0021\n",
      "38/38 [==============================] - 0s 616us/step - loss: 0.0023 - mse: 0.0023\n"
     ]
    }
   ],
   "source": [
    "#performance of the pre-trained model on train data\n",
    "loss_and_metrics = final_model_PQ.evaluate(X_train, y_train, batch_size=256)\n",
    "loss_and_metrics = final_model_PQ.evaluate(X_test, y_test, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9ed7589",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
