{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "12cffb75",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ab9a2a2",
   "metadata": {},
   "source": [
    "# Load data and convert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d6c0252c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Keys in file: ['measured_angles', 'measured_velocities', 'constrained_torques', 'measured_torques', 'desired_torques']\n",
      "Input shape: (139990, 12)\n",
      "Output shape: (139990, 9)\n"
     ]
    }
   ],
   "source": [
    "# Load the npz file\n",
    "data = np.load('Sim_Sines_training.npz', allow_pickle=True)\n",
    "\n",
    "keys = list(data.keys())\n",
    "print(\"Keys in file:\", keys) #index is not needed\n",
    "\n",
    "# Assume correct order of keys: s1, u1, s2, s3, u2\n",
    "s1 = data[keys[0]]\n",
    "s2 = data[keys[1]]\n",
    "u1 = data[keys[2]]\n",
    "s3 = data[keys[3]]\n",
    "u2 = data[keys[4]]\n",
    "\n",
    "#we should use constrained_torques as input, which is the derised_torques under the actua condition\n",
    "\n",
    "# # Concatenate input at time t: [s1, s2, s3, u1, u2] → shape: (38, 14000, 15)\n",
    "# X_full = np.concatenate([s1, s2, s3, u1, u2], axis=-1)\n",
    "\n",
    "# Concatenate input at time t: [s1, s2, s3, u1] → shape: (38, 14000, 14)\n",
    "X_full = np.concatenate([s1, s2, s3, u1], axis=-1)\n",
    "\n",
    "# Concatenate states for next step (output at t+1): [s1, s2, s3]\n",
    "Y_full = np.concatenate([s1, s2, s3], axis=-1)  # shape: (38, 14000, 9)\n",
    "\n",
    "# Build training data from t and t+1\n",
    "X_list = []\n",
    "Y_list = []\n",
    "\n",
    "#index_list = [0]\n",
    "\n",
    "for i in range(10):  # 38 trajectories\n",
    "    X_t = X_full[i, :-1, :]   # inputs at time t, shape: (13999, 14)\n",
    "    Y_tp1 = Y_full[i, 1:, :]  # next states at time t+1, shape: (13999, 9)\n",
    "\n",
    "    X_list.append(X_t)\n",
    "    Y_list.append(Y_tp1)\n",
    "\n",
    "# Stack across all trajectories\n",
    "X = np.vstack(X_list)  # shape: (38*13999, 14)\n",
    "Y = np.vstack(Y_list)  # shape: (38*13999, 9)\n",
    "\n",
    "print(\"Input shape:\", X.shape)   # → (531962, 14)\n",
    "print(\"Output shape:\", Y.shape)  # → (531962, 9)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1fe899da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[ 3.81302348  0.55230731 -7.5266488 ]\n",
      "  [ 3.92884146  0.54152433 -7.6478612 ]\n",
      "  [ 4.0448998   0.51713602 -7.80837187]\n",
      "  ...\n",
      "  [-2.14634315 -4.76733918  1.33802923]\n",
      "  [-2.02542057 -4.60942459  1.61554924]\n",
      "  [-1.90369532 -4.45156872  1.84981875]]\n",
      "\n",
      " [[-4.08737776  0.55541261  1.99359253]\n",
      "  [-4.12254459  0.5437961   1.88245269]\n",
      "  [-4.15183704  0.54065488  1.8161306 ]\n",
      "  ...\n",
      "  [-2.64704035 -0.42430382 -8.07404414]\n",
      "  [-2.66215526 -0.44181809 -8.10370186]\n",
      "  [-2.68181581 -0.45896841 -8.13110097]]\n",
      "\n",
      " [[-2.60306997 -1.28429582  2.26896105]\n",
      "  [-2.63863587 -1.2913711   2.21451101]\n",
      "  [-2.66717424 -1.33342341  2.0056503 ]\n",
      "  ...\n",
      "  [ 6.94401041  2.03224586  9.01827505]\n",
      "  [ 7.00832266  2.04739129  9.00270724]\n",
      "  [ 7.06917556  2.05518406  8.94988037]]\n",
      "\n",
      " ...\n",
      "\n",
      " [[-4.06145284  2.80815936 -2.07664389]\n",
      "  [-4.13501922  2.82782944 -1.99197465]\n",
      "  [-4.20244374  2.81397913 -2.03754198]\n",
      "  ...\n",
      "  [ 1.8320504   3.01614985 -3.76463294]\n",
      "  [ 1.75331207  3.02121548 -3.73207246]\n",
      "  [ 1.67587402  2.99829565 -3.81785218]]\n",
      "\n",
      " [[-4.34765988  4.02040619  8.80277014]\n",
      "  [-4.31225372  4.04560725  8.9077988 ]\n",
      "  [-4.27081855  4.0578519   8.95835223]\n",
      "  ...\n",
      "  [-4.05939556  3.78106161 10.26983934]\n",
      "  [-4.10561236  3.72050306 10.00781604]\n",
      "  [-4.15231029  3.7079233   9.95356473]]\n",
      "\n",
      " [[ 4.21522984  1.14215715  1.85413865]\n",
      "  [ 4.16271235  1.13315188  1.82503104]\n",
      "  [ 4.10191239  1.11733531  1.75662472]\n",
      "  ...\n",
      "  [-0.03519334 -1.14349966 -0.76394112]\n",
      "  [ 0.08858292 -1.17758533 -0.81524398]\n",
      "  [ 0.21148916 -1.20726343 -0.84966876]]]\n"
     ]
    }
   ],
   "source": [
    "print(data[keys[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ebb449f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/1000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-07-18 19:52:56.006269: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n",
      "2025-07-18 19:52:56.006297: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (wulab2-System-Product-Name): /proc/driver/nvidia/version does not exist\n",
      "2025-07-18 19:52:56.006949: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "443/443 [==============================] - 1s 729us/step - loss: 0.1031 - mse: 0.1031 - val_loss: 0.0119 - val_mse: 0.0119\n",
      "Epoch 2/1000\n",
      "443/443 [==============================] - 0s 615us/step - loss: 0.0098 - mse: 0.0098 - val_loss: 0.0089 - val_mse: 0.0089\n",
      "Epoch 3/1000\n",
      "443/443 [==============================] - 0s 589us/step - loss: 0.0077 - mse: 0.0077 - val_loss: 0.0072 - val_mse: 0.0072\n",
      "Epoch 4/1000\n",
      "443/443 [==============================] - 0s 587us/step - loss: 0.0066 - mse: 0.0066 - val_loss: 0.0063 - val_mse: 0.0063\n",
      "Epoch 5/1000\n",
      "443/443 [==============================] - 0s 588us/step - loss: 0.0059 - mse: 0.0059 - val_loss: 0.0059 - val_mse: 0.0059\n",
      "Epoch 6/1000\n",
      "443/443 [==============================] - 0s 583us/step - loss: 0.0056 - mse: 0.0056 - val_loss: 0.0057 - val_mse: 0.0057\n",
      "Epoch 7/1000\n",
      "443/443 [==============================] - 0s 606us/step - loss: 0.0055 - mse: 0.0055 - val_loss: 0.0056 - val_mse: 0.0056\n",
      "Epoch 8/1000\n",
      "443/443 [==============================] - 0s 609us/step - loss: 0.0055 - mse: 0.0055 - val_loss: 0.0056 - val_mse: 0.0056\n",
      "Epoch 9/1000\n",
      "443/443 [==============================] - 0s 595us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 10/1000\n",
      "443/443 [==============================] - 0s 584us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 11/1000\n",
      "443/443 [==============================] - 0s 664us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0056 - val_mse: 0.0056\n",
      "Epoch 12/1000\n",
      "443/443 [==============================] - 0s 609us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0057 - val_mse: 0.0057\n",
      "Epoch 13/1000\n",
      "443/443 [==============================] - 0s 591us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0056 - val_mse: 0.0056\n",
      "Epoch 14/1000\n",
      "443/443 [==============================] - 0s 600us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 15/1000\n",
      "443/443 [==============================] - 0s 601us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0056 - val_mse: 0.0056\n",
      "Epoch 16/1000\n",
      "443/443 [==============================] - 0s 584us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0056 - val_mse: 0.0056\n",
      "Epoch 17/1000\n",
      "443/443 [==============================] - 0s 593us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0056 - val_mse: 0.0056\n",
      "Epoch 18/1000\n",
      "443/443 [==============================] - 0s 595us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 19/1000\n",
      "443/443 [==============================] - 0s 584us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 20/1000\n",
      "443/443 [==============================] - 0s 586us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 21/1000\n",
      "443/443 [==============================] - 0s 582us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 22/1000\n",
      "443/443 [==============================] - 0s 589us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0056 - val_mse: 0.0056\n",
      "Epoch 23/1000\n",
      "443/443 [==============================] - 0s 599us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0057 - val_mse: 0.0057\n",
      "Epoch 24/1000\n",
      "443/443 [==============================] - 0s 592us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 25/1000\n",
      "443/443 [==============================] - 0s 595us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 26/1000\n",
      "443/443 [==============================] - 0s 592us/step - loss: 0.0054 - mse: 0.0054 - val_loss: 0.0056 - val_mse: 0.0056\n",
      "Epoch 27/1000\n",
      "443/443 [==============================] - 0s 597us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 28/1000\n",
      "443/443 [==============================] - 0s 596us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 29/1000\n",
      "443/443 [==============================] - 0s 596us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 30/1000\n",
      "443/443 [==============================] - 0s 611us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0056 - val_mse: 0.0056\n",
      "Epoch 31/1000\n",
      "443/443 [==============================] - 0s 596us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 32/1000\n",
      "443/443 [==============================] - 0s 597us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 33/1000\n",
      "443/443 [==============================] - 0s 605us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 34/1000\n",
      "443/443 [==============================] - 0s 615us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 35/1000\n",
      "443/443 [==============================] - 0s 599us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0056 - val_mse: 0.0056\n",
      "Epoch 36/1000\n",
      "443/443 [==============================] - 0s 594us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 37/1000\n",
      "443/443 [==============================] - 0s 600us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 38/1000\n",
      "443/443 [==============================] - 0s 590us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 39/1000\n",
      "443/443 [==============================] - 0s 597us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 40/1000\n",
      "443/443 [==============================] - 0s 601us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 41/1000\n",
      "443/443 [==============================] - 0s 584us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0056 - val_mse: 0.0056\n",
      "Epoch 42/1000\n",
      "443/443 [==============================] - 0s 602us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 43/1000\n",
      "443/443 [==============================] - 0s 606us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 44/1000\n",
      "443/443 [==============================] - 0s 582us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 45/1000\n",
      "443/443 [==============================] - 0s 604us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 46/1000\n",
      "443/443 [==============================] - 0s 604us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 47/1000\n",
      "443/443 [==============================] - 0s 603us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 48/1000\n",
      "443/443 [==============================] - 0s 598us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 49/1000\n",
      "443/443 [==============================] - 0s 604us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 50/1000\n",
      "443/443 [==============================] - 0s 604us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 51/1000\n",
      "443/443 [==============================] - 0s 600us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 52/1000\n",
      "443/443 [==============================] - 0s 601us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 53/1000\n",
      "443/443 [==============================] - 0s 607us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 54/1000\n",
      "443/443 [==============================] - 0s 605us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 55/1000\n",
      "443/443 [==============================] - 0s 605us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 56/1000\n",
      "443/443 [==============================] - 0s 583us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 57/1000\n",
      "443/443 [==============================] - 0s 615us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 58/1000\n",
      "443/443 [==============================] - 0s 598us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 59/1000\n",
      "443/443 [==============================] - 0s 602us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 60/1000\n",
      "443/443 [==============================] - 0s 597us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 61/1000\n",
      "443/443 [==============================] - 0s 597us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 62/1000\n",
      "443/443 [==============================] - 0s 594us/step - loss: 0.0053 - mse: 0.0053 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 63/1000\n",
      "443/443 [==============================] - 0s 604us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 64/1000\n",
      "443/443 [==============================] - 0s 593us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 65/1000\n",
      "443/443 [==============================] - 0s 616us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0053 - val_mse: 0.0053\n",
      "Epoch 66/1000\n",
      "443/443 [==============================] - 0s 601us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 67/1000\n",
      "443/443 [==============================] - 0s 585us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 68/1000\n",
      "443/443 [==============================] - 0s 600us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 69/1000\n",
      "443/443 [==============================] - 0s 593us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 70/1000\n",
      "443/443 [==============================] - 0s 590us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 71/1000\n",
      "443/443 [==============================] - 0s 600us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 72/1000\n",
      "443/443 [==============================] - 0s 606us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 73/1000\n",
      "443/443 [==============================] - 0s 599us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 74/1000\n",
      "443/443 [==============================] - 0s 600us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 75/1000\n",
      "443/443 [==============================] - 0s 597us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 76/1000\n",
      "443/443 [==============================] - 0s 609us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 77/1000\n",
      "443/443 [==============================] - 0s 603us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 78/1000\n",
      "443/443 [==============================] - 0s 595us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 79/1000\n",
      "443/443 [==============================] - 0s 600us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 80/1000\n",
      "443/443 [==============================] - 0s 600us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0053 - val_mse: 0.0053\n",
      "Epoch 81/1000\n",
      "443/443 [==============================] - 0s 603us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0053 - val_mse: 0.0053\n",
      "Epoch 82/1000\n",
      "443/443 [==============================] - 0s 611us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 83/1000\n",
      "443/443 [==============================] - 0s 602us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0053 - val_mse: 0.0053\n",
      "Epoch 84/1000\n",
      "443/443 [==============================] - 0s 596us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0053 - val_mse: 0.0053\n",
      "Epoch 85/1000\n",
      "443/443 [==============================] - 0s 600us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 86/1000\n",
      "443/443 [==============================] - 0s 598us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0053 - val_mse: 0.0053\n",
      "Epoch 87/1000\n",
      "443/443 [==============================] - 0s 594us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 88/1000\n",
      "443/443 [==============================] - 0s 605us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 89/1000\n",
      "443/443 [==============================] - 0s 613us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0053 - val_mse: 0.0053\n",
      "Epoch 90/1000\n",
      "443/443 [==============================] - 0s 607us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0055 - val_mse: 0.0055\n",
      "Epoch 91/1000\n",
      "443/443 [==============================] - 0s 606us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 92/1000\n",
      "443/443 [==============================] - 0s 602us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0054 - val_mse: 0.0054\n",
      "Epoch 93/1000\n",
      "443/443 [==============================] - 0s 589us/step - loss: 0.0052 - mse: 0.0052 - val_loss: 0.0053 - val_mse: 0.0053\n",
      "Epoch 94/1000\n",
      "  1/443 [..............................] - ETA: 0s - loss: 0.0060 - mse: 0.0060"
     ]
    }
   ],
   "source": [
    "# Load your preprocessed X and Y from earlier\n",
    "# X: shape (N, 15), Y: shape (N, 9)\n",
    "\n",
    "# Step 1: Normalize using sklearn StandardScaler\n",
    "x_scaler = StandardScaler()\n",
    "y_scaler = StandardScaler()\n",
    "\n",
    "X_norm = x_scaler.fit_transform(X)\n",
    "Y_norm = y_scaler.fit_transform(Y)\n",
    "\n",
    "# Step 2: Split into train and test\n",
    "X_train, X_test, Y_train, Y_test = train_test_split(X_norm, Y_norm, test_size=0.1, random_state=42)\n",
    "\n",
    "# Step 3: Define the neural network\n",
    "model = Sequential([\n",
    "    Dense(64, activation='relu', input_shape=(X.shape[1],)),  # input: 15\n",
    "    Dense(32, activation='relu'),\n",
    "    Dense(Y.shape[1], activation='linear')  # output: 9\n",
    "])\n",
    "\n",
    "model.compile(optimizer=Adam(learning_rate=0.001),loss='mse', metrics=['mse'])\n",
    "\n",
    "# Step 4: Train the model\n",
    "history = model.fit(X_train, Y_train, epochs=1000, batch_size=256, validation_split=0.1, verbose=1)\n",
    "\n",
    "# Step 5: Evaluate on test set\n",
    "loss, mae = model.evaluate(X_test, Y_test, verbose=0)\n",
    "print(f\"\\nTest MSE: {loss:.4f}\")\n",
    "print(f\"Test MAE: {mae:.4f}\")\n",
    "\n",
    "# Step 6: Predict and inverse transform to original scale\n",
    "Y_pred_norm = model.predict(X_test)\n",
    "Y_pred = y_scaler.inverse_transform(Y_pred_norm)\n",
    "Y_test_true = y_scaler.inverse_transform(Y_test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1075aff1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of time steps to visualize\n",
    "T = 1000  # or adjust based on your dataset\n",
    "\n",
    "# Prepare figure\n",
    "plt.figure(figsize=(15, 18))\n",
    "\n",
    "for i in range(9):\n",
    "    plt.subplot(9, 1, i + 1)\n",
    "    plt.plot(Y_test_true[:T, i], label='True', linestyle='--')\n",
    "    plt.plot(Y_pred[:T, i], label='Predicted', alpha=0.7)\n",
    "    plt.ylabel(f'State {i+1}')\n",
    "    plt.grid(True)\n",
    "    if i == 0:\n",
    "        plt.title(\"True vs Predicted for 9 Output States\")\n",
    "    if i == 8:\n",
    "        plt.xlabel('Time step')\n",
    "    if i == 0:\n",
    "        plt.legend(loc='upper right')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab2532a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(model.summary())\n",
    "model.save('Pretrain_sim.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c504ce73",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
