{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "539ebd7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from keras.models import load_model\n",
    "import time "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "069e2e34",
   "metadata": {},
   "source": [
    "# import target dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3a4a2d8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# === Step 1: Load data ===\n",
    "filename = 'dataBenchmark.csv'  # Replace with your filename\n",
    "df = pd.read_csv(filename, header=None)\n",
    "\n",
    "# Columns: assume 0-based indexing, so:\n",
    "# u(t) is in column 1\n",
    "# y(t) is in column 3\n",
    "u = df.iloc[1:, 1].values  # u(t) #delete the first element which shows the label\n",
    "y = df.iloc[1:, 3].values  # y(t)\n",
    "\n",
    "# === Step 2: Construct input-output pairs ===\n",
    "X = []\n",
    "Y = []\n",
    "for t in range(1, len(y)):\n",
    "    X.append([u[t], y[t-1]])  # input: u(t), y(t-1)\n",
    "    Y.append(y[t])            # output: y(t)\n",
    "\n",
    "X = np.array(X)\n",
    "Y = np.array(Y).reshape(-1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a7948aac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[['0.99921' '4.9728']\n",
      " ['1.0172' '4.9722']\n",
      " ['1.0318' '4.9703']\n",
      " ...\n",
      " ['0.88351' '3.7978']\n",
      " ['0.9162' '3.7807']\n",
      " ['0.94805' '3.7151']] [['4.9722']\n",
      " ['4.9703']\n",
      " ['4.988']\n",
      " ...\n",
      " ['3.7807']\n",
      " ['3.7151']\n",
      " ['3.7179']]\n"
     ]
    }
   ],
   "source": [
    "print(X,Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d8e4402d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(204, 2) (819, 2) (204, 1) (819, 1)\n"
     ]
    }
   ],
   "source": [
    "# === Step 3: Normalize data (optional but recommended) ===\n",
    "scaler_X = StandardScaler()\n",
    "scaler_Y = StandardScaler()\n",
    "\n",
    "X_scaled = scaler_X.fit_transform(X)\n",
    "Y_scaled = scaler_Y.fit_transform(Y)\n",
    "\n",
    "# === Step 4: Train-test split ===\n",
    "X_train, X_test, Y_train, Y_test = train_test_split(X_scaled, Y_scaled, test_size=0.8, random_state=42)\n",
    "\n",
    "print(X_train.shape, X_test.shape, Y_train.shape, Y_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d8c5f41",
   "metadata": {},
   "source": [
    "# Develop the benchmark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d0ab6116",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " dense (Dense)               (None, 16)                48        \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 16)                272       \n",
      "                                                                 \n",
      " dense_2 (Dense)             (None, 1)                 17        \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 337\n",
      "Trainable params: 337\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-07-14 15:24:42.294111: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n",
      "2025-07-14 15:24:42.294140: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (wulab2-System-Product-Name): /proc/driver/nvidia/version does not exist\n",
      "2025-07-14 15:24:42.294471: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "# === Step 5: Define and train FNN model ===\n",
    "model = Sequential([\n",
    "    Dense(16, activation='relu', input_shape=(2,)),\n",
    "    Dense(16, activation='relu'),\n",
    "    Dense(1, activation='linear')\n",
    "])\n",
    "\n",
    "model.compile(optimizer='adam', loss='mse', metrics=['mae'])\n",
    "print(model.summary())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2cc77d23",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "7/7 [==============================] - 0s 14ms/step - loss: 0.8969 - mae: 0.8244 - val_loss: 0.7898 - val_mae: 0.7611\n",
      "Epoch 2/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.8118 - mae: 0.7798 - val_loss: 0.7167 - val_mae: 0.7212\n",
      "Epoch 3/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.7345 - mae: 0.7370 - val_loss: 0.6453 - val_mae: 0.6807\n",
      "Epoch 4/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.6588 - mae: 0.6946 - val_loss: 0.5770 - val_mae: 0.6400\n",
      "Epoch 5/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.5821 - mae: 0.6492 - val_loss: 0.5083 - val_mae: 0.5969\n",
      "Epoch 6/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.5064 - mae: 0.6022 - val_loss: 0.4394 - val_mae: 0.5507\n",
      "Epoch 7/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.4314 - mae: 0.5514 - val_loss: 0.3727 - val_mae: 0.5028\n",
      "Epoch 8/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.3592 - mae: 0.4995 - val_loss: 0.3100 - val_mae: 0.4535\n",
      "Epoch 9/100\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.2943 - mae: 0.4453 - val_loss: 0.2492 - val_mae: 0.4008\n",
      "Epoch 10/100\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.2308 - mae: 0.3875 - val_loss: 0.1950 - val_mae: 0.3475\n",
      "Epoch 11/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.1780 - mae: 0.3308 - val_loss: 0.1488 - val_mae: 0.2958\n",
      "Epoch 12/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.1327 - mae: 0.2755 - val_loss: 0.1106 - val_mae: 0.2488\n",
      "Epoch 13/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0971 - mae: 0.2278 - val_loss: 0.0811 - val_mae: 0.2099\n",
      "Epoch 14/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0696 - mae: 0.1919 - val_loss: 0.0593 - val_mae: 0.1785\n",
      "Epoch 15/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0503 - mae: 0.1650 - val_loss: 0.0435 - val_mae: 0.1528\n",
      "Epoch 16/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0365 - mae: 0.1431 - val_loss: 0.0330 - val_mae: 0.1341\n",
      "Epoch 17/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0274 - mae: 0.1273 - val_loss: 0.0252 - val_mae: 0.1191\n",
      "Epoch 18/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0214 - mae: 0.1147 - val_loss: 0.0198 - val_mae: 0.1075\n",
      "Epoch 19/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0169 - mae: 0.1044 - val_loss: 0.0161 - val_mae: 0.0983\n",
      "Epoch 20/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0136 - mae: 0.0948 - val_loss: 0.0138 - val_mae: 0.0915\n",
      "Epoch 21/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0114 - mae: 0.0873 - val_loss: 0.0122 - val_mae: 0.0864\n",
      "Epoch 22/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0096 - mae: 0.0801 - val_loss: 0.0110 - val_mae: 0.0819\n",
      "Epoch 23/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0083 - mae: 0.0742 - val_loss: 0.0101 - val_mae: 0.0779\n",
      "Epoch 24/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0073 - mae: 0.0689 - val_loss: 0.0093 - val_mae: 0.0744\n",
      "Epoch 25/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0065 - mae: 0.0647 - val_loss: 0.0087 - val_mae: 0.0717\n",
      "Epoch 26/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0060 - mae: 0.0613 - val_loss: 0.0082 - val_mae: 0.0699\n",
      "Epoch 27/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0055 - mae: 0.0587 - val_loss: 0.0078 - val_mae: 0.0684\n",
      "Epoch 28/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0051 - mae: 0.0563 - val_loss: 0.0075 - val_mae: 0.0670\n",
      "Epoch 29/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0048 - mae: 0.0546 - val_loss: 0.0072 - val_mae: 0.0660\n",
      "Epoch 30/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0045 - mae: 0.0529 - val_loss: 0.0070 - val_mae: 0.0647\n",
      "Epoch 31/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0043 - mae: 0.0511 - val_loss: 0.0068 - val_mae: 0.0636\n",
      "Epoch 32/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0041 - mae: 0.0498 - val_loss: 0.0066 - val_mae: 0.0627\n",
      "Epoch 33/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0040 - mae: 0.0487 - val_loss: 0.0064 - val_mae: 0.0614\n",
      "Epoch 34/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0038 - mae: 0.0475 - val_loss: 0.0062 - val_mae: 0.0604\n",
      "Epoch 35/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0037 - mae: 0.0468 - val_loss: 0.0061 - val_mae: 0.0596\n",
      "Epoch 36/100\n",
      "7/7 [==============================] - 0s 5ms/step - loss: 0.0036 - mae: 0.0462 - val_loss: 0.0059 - val_mae: 0.0590\n",
      "Epoch 37/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0035 - mae: 0.0455 - val_loss: 0.0057 - val_mae: 0.0580\n",
      "Epoch 38/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0034 - mae: 0.0450 - val_loss: 0.0056 - val_mae: 0.0573\n",
      "Epoch 39/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0032 - mae: 0.0442 - val_loss: 0.0055 - val_mae: 0.0565\n",
      "Epoch 40/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0032 - mae: 0.0434 - val_loss: 0.0054 - val_mae: 0.0558\n",
      "Epoch 41/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0031 - mae: 0.0427 - val_loss: 0.0053 - val_mae: 0.0554\n",
      "Epoch 42/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0030 - mae: 0.0423 - val_loss: 0.0051 - val_mae: 0.0547\n",
      "Epoch 43/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0029 - mae: 0.0416 - val_loss: 0.0050 - val_mae: 0.0541\n",
      "Epoch 44/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0028 - mae: 0.0409 - val_loss: 0.0049 - val_mae: 0.0536\n",
      "Epoch 45/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0028 - mae: 0.0404 - val_loss: 0.0048 - val_mae: 0.0529\n",
      "Epoch 46/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0027 - mae: 0.0399 - val_loss: 0.0047 - val_mae: 0.0523\n",
      "Epoch 47/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0026 - mae: 0.0395 - val_loss: 0.0046 - val_mae: 0.0521\n",
      "Epoch 48/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0026 - mae: 0.0392 - val_loss: 0.0046 - val_mae: 0.0514\n",
      "Epoch 49/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0025 - mae: 0.0385 - val_loss: 0.0044 - val_mae: 0.0505\n",
      "Epoch 50/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0024 - mae: 0.0378 - val_loss: 0.0043 - val_mae: 0.0498\n",
      "Epoch 51/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0024 - mae: 0.0371 - val_loss: 0.0043 - val_mae: 0.0494\n",
      "Epoch 52/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0023 - mae: 0.0365 - val_loss: 0.0042 - val_mae: 0.0491\n",
      "Epoch 53/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0022 - mae: 0.0364 - val_loss: 0.0041 - val_mae: 0.0486\n",
      "Epoch 54/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0022 - mae: 0.0359 - val_loss: 0.0040 - val_mae: 0.0475\n",
      "Epoch 55/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0021 - mae: 0.0352 - val_loss: 0.0039 - val_mae: 0.0468\n",
      "Epoch 56/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0021 - mae: 0.0348 - val_loss: 0.0038 - val_mae: 0.0465\n",
      "Epoch 57/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0020 - mae: 0.0346 - val_loss: 0.0038 - val_mae: 0.0461\n",
      "Epoch 58/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0020 - mae: 0.0343 - val_loss: 0.0037 - val_mae: 0.0458\n",
      "Epoch 59/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0019 - mae: 0.0338 - val_loss: 0.0036 - val_mae: 0.0451\n",
      "Epoch 60/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0019 - mae: 0.0333 - val_loss: 0.0035 - val_mae: 0.0448\n",
      "Epoch 61/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0018 - mae: 0.0332 - val_loss: 0.0035 - val_mae: 0.0443\n",
      "Epoch 62/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0018 - mae: 0.0326 - val_loss: 0.0034 - val_mae: 0.0437\n",
      "Epoch 63/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0017 - mae: 0.0322 - val_loss: 0.0033 - val_mae: 0.0434\n",
      "Epoch 64/100\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0017 - mae: 0.0320 - val_loss: 0.0033 - val_mae: 0.0433\n",
      "Epoch 65/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0017 - mae: 0.0316 - val_loss: 0.0032 - val_mae: 0.0428\n",
      "Epoch 66/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0016 - mae: 0.0312 - val_loss: 0.0032 - val_mae: 0.0423\n",
      "Epoch 67/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0016 - mae: 0.0311 - val_loss: 0.0031 - val_mae: 0.0420\n",
      "Epoch 68/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0016 - mae: 0.0310 - val_loss: 0.0031 - val_mae: 0.0417\n",
      "Epoch 69/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0015 - mae: 0.0306 - val_loss: 0.0030 - val_mae: 0.0413\n",
      "Epoch 70/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0015 - mae: 0.0302 - val_loss: 0.0030 - val_mae: 0.0408\n",
      "Epoch 71/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0015 - mae: 0.0299 - val_loss: 0.0029 - val_mae: 0.0405\n",
      "Epoch 72/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0015 - mae: 0.0302 - val_loss: 0.0029 - val_mae: 0.0406\n",
      "Epoch 73/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0015 - mae: 0.0299 - val_loss: 0.0029 - val_mae: 0.0397\n",
      "Epoch 74/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0014 - mae: 0.0292 - val_loss: 0.0028 - val_mae: 0.0396\n",
      "Epoch 75/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0014 - mae: 0.0292 - val_loss: 0.0028 - val_mae: 0.0396\n",
      "Epoch 76/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0014 - mae: 0.0289 - val_loss: 0.0027 - val_mae: 0.0392\n",
      "Epoch 77/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0014 - mae: 0.0288 - val_loss: 0.0027 - val_mae: 0.0388\n",
      "Epoch 78/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0013 - mae: 0.0284 - val_loss: 0.0027 - val_mae: 0.0383\n",
      "Epoch 79/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0013 - mae: 0.0280 - val_loss: 0.0027 - val_mae: 0.0382\n",
      "Epoch 80/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0013 - mae: 0.0278 - val_loss: 0.0026 - val_mae: 0.0382\n",
      "Epoch 81/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0013 - mae: 0.0278 - val_loss: 0.0026 - val_mae: 0.0381\n",
      "Epoch 82/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0013 - mae: 0.0278 - val_loss: 0.0026 - val_mae: 0.0377\n",
      "Epoch 83/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0012 - mae: 0.0274 - val_loss: 0.0025 - val_mae: 0.0374\n",
      "Epoch 84/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0012 - mae: 0.0271 - val_loss: 0.0025 - val_mae: 0.0372\n",
      "Epoch 85/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0012 - mae: 0.0271 - val_loss: 0.0025 - val_mae: 0.0373\n",
      "Epoch 86/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0012 - mae: 0.0273 - val_loss: 0.0025 - val_mae: 0.0372\n",
      "Epoch 87/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0012 - mae: 0.0269 - val_loss: 0.0025 - val_mae: 0.0369\n",
      "Epoch 88/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0012 - mae: 0.0266 - val_loss: 0.0025 - val_mae: 0.0367\n",
      "Epoch 89/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0012 - mae: 0.0261 - val_loss: 0.0024 - val_mae: 0.0363\n",
      "Epoch 90/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0012 - mae: 0.0262 - val_loss: 0.0024 - val_mae: 0.0365\n",
      "Epoch 91/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0011 - mae: 0.0261 - val_loss: 0.0024 - val_mae: 0.0364\n",
      "Epoch 92/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0011 - mae: 0.0262 - val_loss: 0.0024 - val_mae: 0.0363\n",
      "Epoch 93/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0011 - mae: 0.0259 - val_loss: 0.0024 - val_mae: 0.0358\n",
      "Epoch 94/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0011 - mae: 0.0254 - val_loss: 0.0024 - val_mae: 0.0357\n",
      "Epoch 95/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0011 - mae: 0.0251 - val_loss: 0.0023 - val_mae: 0.0358\n",
      "Epoch 96/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0011 - mae: 0.0256 - val_loss: 0.0023 - val_mae: 0.0359\n",
      "Epoch 97/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0011 - mae: 0.0256 - val_loss: 0.0023 - val_mae: 0.0354\n",
      "Epoch 98/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0011 - mae: 0.0252 - val_loss: 0.0023 - val_mae: 0.0355\n",
      "Epoch 99/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0011 - mae: 0.0252 - val_loss: 0.0023 - val_mae: 0.0352\n",
      "Epoch 100/100\n",
      "7/7 [==============================] - 0s 4ms/step - loss: 0.0011 - mae: 0.0247 - val_loss: 0.0023 - val_mae: 0.0347\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "model.fit(X_train, Y_train, epochs=100, batch_size=32, validation_data=(X_test, Y_test))\n",
    "t1 = time.time()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "81f35c9a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4/4 [==============================] - 0s 599us/step - loss: 0.0023 - mae: 0.0347\n",
      "[0.0022503482177853584, 0.03470602631568909]\n",
      "2.8884756565093994\n"
     ]
    }
   ],
   "source": [
    "#use the SOURCE test data to evaluate the model\n",
    "loss_and_metrics = model.evaluate(X_test, Y_test, batch_size=256)\n",
    "print(loss_and_metrics)\n",
    "\n",
    "print(t1-t0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e569d35c",
   "metadata": {},
   "source": [
    "# construct the NN after feature transformation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d659a2af",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = load_model(\"pretrained_model.h5\")\n",
    "X_train_S = np.load('source_input_data.npy')\n",
    "y_train_S = np.load('source_output_data.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4687c4cf",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "all the input array dimensions for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 2 and the array at index 1 has size 3",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Input \u001b[0;32mIn [11]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m#combine source and target for training\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;66;03m#first 640 target, 3995 source\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m RNN_input_final_Train \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconcatenate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_train_S\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m \n\u001b[1;32m      4\u001b[0m RNN_output_final_Train \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mconcatenate((y_train, y_train_S), axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m      6\u001b[0m NN_S \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mpredict(X_train_S)\n",
      "File \u001b[0;32m<__array_function__ internals>:180\u001b[0m, in \u001b[0;36mconcatenate\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: all the input array dimensions for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 2 and the array at index 1 has size 3"
     ]
    }
   ],
   "source": [
    "#combine source and target for training\n",
    "#first 640 target, 3995 source\n",
    "RNN_input_final_Train = np.concatenate((X_train, X_train_S), axis=0) \n",
    "RNN_output_final_Train = np.concatenate((y_train, y_train_S), axis=0)\n",
    "\n",
    "NN_S = model.predict(X_train_S)\n",
    "\n",
    "#performance of the pre-trained model on train data\n",
    "loss_and_metrics = model.evaluate(X_train_S, y_train_S, batch_size=256)\n",
    "loss_and_metrics = model.evaluate(X_train, y_train, batch_size=256)\n",
    "\n",
    "loss_and_metrics = model.evaluate(RNN_input_final_Train, RNN_output_final_Train, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2c61d08",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 2: Freeze the original model's weights\n",
    "for layer in model.layers:\n",
    "    layer.trainable = False\n",
    "\n",
    "# Step 3: Create new layers\n",
    "input_layer = Input(shape=(None, 4))  # Replace input_dim with your actual input feature dimension\n",
    "new_input_layer = Dense(8, activation='linear', use_bias=False)  # New front layer\n",
    "new_input_layer1 = Dense(8, activation='linear', use_bias=False)  # New front layer\n",
    "new_input_layer2 = Dense(4, activation='linear', use_bias=False)  # New front layer\n",
    "\n",
    "# Connect the frozen original model\n",
    "x = model(new_input_layer2(new_input_layer1(new_input_layer(input_layer))))\n",
    "\n",
    "# Add new output layer\n",
    "new_output_layer = Dense(4, activation='linear', use_bias=False)  # Replace output_dim with your desired final output size\n",
    "new_output_layer1 = Dense(4, activation='linear', use_bias=False)  # Replace output_dim with your desired final output size\n",
    "new_output_layer2 = Dense(3, activation='linear', use_bias=False)  # Replace output_dim with your desired final output size\n",
    "\n",
    "output = new_output_layer2(new_output_layer1(new_output_layer(x)))\n",
    "# Step 4: Assemble final model\n",
    "final_model = Model(inputs=input_layer, outputs=output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9a742b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#assign parameters\n",
    "# Define A: 4x8 matrix [I_4 | 0]\n",
    "A = np.array([\n",
    "    [1, 0, 0, 0, 0, 0, 0, 0],\n",
    "    [0, 1, 0, 0, 0, 0, 0, 0],\n",
    "    [0, 0, 1, 0, 0, 0, 0, 0],\n",
    "    [0, 0, 0, 1, 0, 0, 0, 0]\n",
    "], dtype=np.float32)\n",
    "\n",
    "# Define B: 8x8 identity matrix\n",
    "B = np.array(np.eye(8), dtype=np.float32)\n",
    "\n",
    "# Define C: 8x4 matrix [I_4; 0]\n",
    "C = np.array([\n",
    "    [1, 0, 0, 0],\n",
    "    [0, 1, 0, 0],\n",
    "    [0, 0, 1, 0],\n",
    "    [0, 0, 0, 1],\n",
    "    [0, 0, 0, 0],\n",
    "    [0, 0, 0, 0],\n",
    "    [0, 0, 0, 0],\n",
    "    [0, 0, 0, 0]\n",
    "], dtype=np.float32)\n",
    "\n",
    "new_input_layer.set_weights([A])\n",
    "new_input_layer1.set_weights([B])\n",
    "new_input_layer2.set_weights([C])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5792001",
   "metadata": {},
   "outputs": [],
   "source": [
    "# A: 3x4 matrix [I_2 | 0]\n",
    "Ao = np.array([\n",
    "    [1, 0, 0, 0],\n",
    "    [0, 1, 0, 0],\n",
    "    [0, 0, 1, 0]\n",
    "], dtype=np.float32)\n",
    "\n",
    "# B: 4x4 identity\n",
    "Bo = np.array(np.eye(4), dtype=np.float32)\n",
    "\n",
    "# C: 4x3 matrix [I_2; 0]\n",
    "Co = np.array([\n",
    "    [1, 0, 0],\n",
    "    [0, 1, 0],\n",
    "    [0, 0, 1],\n",
    "    [0, 0, 0]\n",
    "], dtype=np.float32)\n",
    "\n",
    "new_output_layer.set_weights([Ao])\n",
    "new_output_layer1.set_weights([Bo])\n",
    "new_output_layer2.set_weights([Co])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b514281",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Step 5: Custom loss function\n",
    "def custom_loss(y_true, y_pred):\n",
    "    Num = 500 #number of training target data\n",
    "    NN_out = final_model(RNN_input_final_Train) #[[t1,ca1,cb1],[t2,ca2,cb2]]\n",
    "    \n",
    "    #first term\n",
    "    loss11 =  tf.math.reduce_mean((NN_out[:Num,:,:]-RNN_output_final_Train[:Num,:,:])**2) # prediction on target\n",
    "    loss12 =  tf.math.reduce_mean((NN_out[Num:,:,:]-RNN_output_final_Train[Num:,:,:])**2) # prediction on source\n",
    "    \n",
    "    loss1 = abs(loss11 - loss12) \n",
    "    \n",
    "    #second term\n",
    "    \n",
    "    loss2 =  tf.math.reduce_mean((NN_out[Num:,:,:]- NN_S)**2) # prediction error for h and h^* on source\n",
    "    \n",
    "    #last term\n",
    "    max_abs_input = tf.reduce_max(tf.abs(input_layer_weights))\n",
    "    max_abs_output = tf.reduce_max(tf.abs(output_layer_weights))\n",
    "    loss3 =  max_abs_input * max_abs_output\n",
    "    \n",
    "    #weight\n",
    "    a = 0\n",
    "    b = 0.01\n",
    "    c = 0.01\n",
    "    \n",
    "    loss = 1.0*loss11+a*loss1+b*loss2+c*loss3#100*loss11+a*loss1+b*loss2+c*loss3\n",
    "    \n",
    "    return loss  # you can modify this to include regularization if needed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91d741df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 6: Compile\n",
    "#final_model.compile(optimizer='adam', loss=custom_loss, metrics=['mse'])\n",
    "final_model.compile(optimizer='adam', loss=custom_loss, metrics=['mse'])\n",
    "print(final_model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2d89c02",
   "metadata": {},
   "outputs": [],
   "source": [
    "t0 = time.time()\n",
    "for i in range(100):\n",
    "    Num = 500 #number of training target data\n",
    "     # Access the layers\n",
    "    input_layer_weights = final_model.layers[1].trainable_weights[0]  # weight matrix only\n",
    "    output_layer_weights = final_model.layers[-1].trainable_weights[0]\n",
    "    \n",
    "    #prediction performance on the target training set\n",
    "    NN_out = final_model(RNN_input_final_Train) \n",
    "    loss11 =  tf.math.reduce_mean((NN_out[:Num,:,:]-RNN_output_final_Train[:Num,:,:])**2) # prediction on target\n",
    "    print(\"iteration :\" + str(i) + \"     Target train: \" + str(loss11))\n",
    "    \n",
    "    #history = final_model.fit(RNN_input_final_Train, RNN_output_final_Train, epochs=1, batch_size=256, validation_split=0.25, verbose=2)\n",
    "    history = final_model.fit(X_train, y_train, epochs=1, batch_size=256, validation_split=0.25, verbose=2)\n",
    "    i += 1\n",
    "\n",
    "t1 = time.time()\n",
    "print(t1 - t0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f14a2c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#prediction performance on the target training set\n",
    "NN_out = final_model(RNN_input_final_Train) \n",
    "loss11 =  tf.math.reduce_mean((NN_out[:Num,:,:]-RNN_output_final_Train[:Num,:,:])**2) # prediction on target\n",
    "print(\"Target train: \", loss11)\n",
    "\n",
    "#prediction performance on the target testing set\n",
    "NN_out = final_model(X_test) \n",
    "loss11 =  tf.math.reduce_mean((NN_out-y_test)**2) # prediction on target\n",
    "print(\"Target test: \", loss11)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21ca0c99",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8016b92",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "647daff1",
   "metadata": {},
   "source": [
    "# combine the parameters and refine the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "523000d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#calculate the input layers\n",
    "# Get first layer\n",
    "layer_I1 = final_model.layers[1]\n",
    "# Get weights\n",
    "weights_I1 = layer_I1.get_weights()\n",
    "\n",
    "# Get first layer\n",
    "layer_I2 = final_model.layers[2]\n",
    "# Get weights\n",
    "weights_I2 = layer_I2.get_weights()\n",
    "\n",
    "# Get first layer\n",
    "layer_I3 = final_model.layers[3]\n",
    "# Get weights\n",
    "weights_I3 = layer_I3.get_weights()\n",
    "\n",
    "# print(\"Weights (kernel_I1):\\n\", weights_I1)\n",
    "# print(\"Weights (kernel_I2):\\n\", weights_I2)\n",
    "# print(\"Weights (kernel_I3):\\n\", weights_I3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b90b3b96",
   "metadata": {},
   "outputs": [],
   "source": [
    "#calculate the output layers\n",
    "# Get first layer\n",
    "layer_P1 = final_model.layers[-3]\n",
    "# Get weights\n",
    "weights_P1 = layer_P1.get_weights()\n",
    "\n",
    "# Get first layer\n",
    "layer_P2 = final_model.layers[-2]\n",
    "# Get weights\n",
    "weights_P2 = layer_P2.get_weights()\n",
    "\n",
    "# Get first layer\n",
    "layer_P3 = final_model.layers[-1]\n",
    "# Get weights\n",
    "weights_P3 = layer_P3.get_weights()\n",
    "\n",
    "print(\"Weights (kernel_P1):\\n\", weights_P1)\n",
    "print(\"Weights (kernel_P2):\\n\", weights_P2)\n",
    "print(\"Weights (kernel_P3):\\n\", weights_P3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2032c68",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Weights (kernel_I1):\\n\", weights_I1[0])\n",
    "print(\"Weights (kernel_I2):\\n\", weights_I2[0])\n",
    "print(\"Weights (kernel_I3):\\n\", weights_I3[0])\n",
    "\n",
    "print(\"Weights (kernel_P1):\\n\", weights_P1[0])\n",
    "print(\"Weights (kernel_P2):\\n\", weights_P2[0])\n",
    "print(\"Weights (kernel_P3):\\n\", weights_P3[0])\n",
    "\n",
    "#calculate the transfromation matrix\n",
    "P = weights_I1[0]@weights_I2[0]@weights_I3[0]\n",
    "Q = weights_P1[0]@weights_P2[0]@weights_P3[0]\n",
    "\n",
    "print(\"Weights Final input P:\\n\", P)\n",
    "print(\"Weights Final output Q:\\n\", Q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63fb1c5d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "c51eb080",
   "metadata": {},
   "source": [
    "# Refine the model and set the weight parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "409049a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 2: Freeze the original model's weights\n",
    "for layer in model.layers:\n",
    "    layer.trainable = False\n",
    "\n",
    "# Step 3: Create new layers\n",
    "input_layer = Input(shape=(None, 4))  # Replace input_dim with your actual input feature dimension\n",
    "new_input_layer = Dense(4, activation='linear', use_bias=False)  # New front layer\n",
    "\n",
    "# Connect the frozen original model\n",
    "x = model(new_input_layer(input_layer))\n",
    "\n",
    "# Add new output layer\n",
    "new_output_layer = Dense(3, activation='linear', use_bias=False)  # Replace output_dim with your desired final output size\n",
    "\n",
    "output = new_output_layer(x)\n",
    "# Step 4: Assemble final model\n",
    "final_model_PQ = Model(inputs=input_layer, outputs=output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea01d0d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#assign parameters\n",
    "W1 = np.array(P, dtype=np.float32)  # shape (input_dim, output_dim)\n",
    "W2 = np.array(Q, dtype=np.float32)  # shape (input_dim, output_dim)\n",
    "\n",
    "new_input_layer.set_weights([W1])\n",
    "new_output_layer.set_weights([W2])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc27c3d2",
   "metadata": {},
   "source": [
    "# fine tune the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be147352",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: ynFreeze the model's weights\n",
    "for layer in final_model.layers:\n",
    "    layer.trainable = True\n",
    "    \n",
    "# Step 2: Compile with mse as loss function\n",
    "final_model_PQ.compile(optimizer='adam', loss='mse', metrics=['mse'])\n",
    "print(final_model_PQ.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55fbdca2",
   "metadata": {},
   "outputs": [],
   "source": [
    "t0 = time.time()\n",
    "history = final_model_PQ.fit(X_train, y_train, epochs=400, batch_size=256, validation_split=0.25, verbose=2)\n",
    "t1 = time.time()\n",
    "print(t1-t0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "388b3084",
   "metadata": {},
   "outputs": [],
   "source": [
    "#performance of the pre-trained model on train data\n",
    "loss_and_metrics = final_model_PQ.evaluate(X_train, y_train, batch_size=256)\n",
    "loss_and_metrics = final_model_PQ.evaluate(X_test, y_test, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9ed7589",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
