{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import pandas as pd\n",
    "\n",
    "\n",
    "MODEL = \"kirnn\"  # can be \"kirnn\" or \"esn\"\n",
    "horizon = 120"
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "results_csv = f\"./grid_results.csv\"\n",
    "results = pd.read_csv(results_csv, index_col=\"run_id\")\n",
    "results= results.loc[results[\"horizon\"]==horizon]\n",
    "print(f\"#successful experiments: {len(results)}.\")\n",
    "results.head()"
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "TARGET = \"metric_val\"\n",
    "METRIC_COLUMNS = [\n",
    "    \"prediction_time_train\",\n",
    "    \"prediction_time_val\",\n",
    "    \"prediction_time_test\",\n",
    "    \"training_time\",\n",
    "    \"metric_train\",\n",
    "    \"metric_val\",\n",
    "    \"metric_test\",\n",
    "]\n",
    "\n",
    "metrics = results.groupby([\"config_id\"])[METRIC_COLUMNS].mean()\n",
    "metrics.sort_values(by=TARGET)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-02-27T17:51:41.862898750Z",
     "start_time": "2026-02-27T17:51:41.396248400Z"
    }
   },
   "source": [
    "best_config_id = metrics.index[metrics[TARGET].argmin()]\n",
    "print(f\"Best config: #{best_config_id}.\")\n",
    "print(\"Aggregated metrics of the best config:\")\n",
    "display(metrics.query(f\"config_id=={best_config_id}\"))\n",
    "\n",
    "print(\"All runs of the best config:\")\n",
    "best_runs = results.query(f\"config_id=={best_config_id}\")\n",
    "best_runs"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best config: #4.\n",
      "Aggregated metrics of the best config:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "           prediction_time_train  prediction_time_val  prediction_time_test  \\\n",
       "config_id                                                                     \n",
       "4                     731.302182           364.169373            360.239861   \n",
       "\n",
       "           training_time  metric_train  metric_val  metric_test  \n",
       "config_id                                                        \n",
       "4               3.174151      1.657872    1.693265     1.660361  "
      ],
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>prediction_time_train</th>\n",
       "      <th>prediction_time_val</th>\n",
       "      <th>prediction_time_test</th>\n",
       "      <th>training_time</th>\n",
       "      <th>metric_train</th>\n",
       "      <th>metric_val</th>\n",
       "      <th>metric_test</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>config_id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>731.302182</td>\n",
       "      <td>364.169373</td>\n",
       "      <td>360.239861</td>\n",
       "      <td>3.174151</td>\n",
       "      <td>1.657872</td>\n",
       "      <td>1.693265</td>\n",
       "      <td>1.660361</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data",
     "jetTransient": {
      "display_id": null
     }
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All runs of the best config:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "        training_time  metric_train  metric_val  metric_test  \\\n",
       "run_id                                                         \n",
       "21           3.197072      1.682173    1.719401     1.650919   \n",
       "22           3.129078      1.640559    1.703160     1.689099   \n",
       "23           3.114329      1.656325    1.610063     1.607259   \n",
       "24           3.279318      1.623709    1.728349     1.658592   \n",
       "25           3.150959      1.686596    1.705353     1.695937   \n",
       "\n",
       "        prediction_time_train  prediction_time_val  prediction_time_test  \\\n",
       "run_id                                                                     \n",
       "21                 726.630813           358.900719            358.371477   \n",
       "22                 727.495643           368.133058            360.117126   \n",
       "23                 732.600434           361.763317            359.104469   \n",
       "24                 733.208539           359.221632            364.143386   \n",
       "25                 736.575482           372.828137            359.462845   \n",
       "\n",
       "        config_id                  data_folder  \\\n",
       "run_id                                           \n",
       "21              4  real_world/data/electricity   \n",
       "22              4  real_world/data/electricity   \n",
       "23              4  real_world/data/electricity   \n",
       "24              4  real_world/data/electricity   \n",
       "25              4  real_world/data/electricity   \n",
       "\n",
       "                                         output_folder  save_predictions  \\\n",
       "run_id                                                                     \n",
       "21      real_world/grid_search/output/electricity/swim              True   \n",
       "22      real_world/grid_search/output/electricity/swim              True   \n",
       "23      real_world/grid_search/output/electricity/swim              True   \n",
       "24      real_world/grid_search/output/electricity/swim              True   \n",
       "25      real_world/grid_search/output/electricity/swim              True   \n",
       "\n",
       "         target  horizon  time_delay model_name  width activation  \\\n",
       "run_id                                                              \n",
       "21      Voltage      120         240       swim     64       tanh   \n",
       "22      Voltage      120         240       swim     64       tanh   \n",
       "23      Voltage      120         240       swim     64       tanh   \n",
       "24      Voltage      120         240       swim     64       tanh   \n",
       "25      Voltage      120         240       swim     64       tanh   \n",
       "\n",
       "        regularization        seed  \n",
       "run_id                              \n",
       "21        1.000000e-08  1074497555  \n",
       "22        1.000000e-08   796282693  \n",
       "23        1.000000e-08   392022359  \n",
       "24        1.000000e-08  1990212658  \n",
       "25        1.000000e-08  1678403330  "
      ],
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>training_time</th>\n",
       "      <th>metric_train</th>\n",
       "      <th>metric_val</th>\n",
       "      <th>metric_test</th>\n",
       "      <th>prediction_time_train</th>\n",
       "      <th>prediction_time_val</th>\n",
       "      <th>prediction_time_test</th>\n",
       "      <th>config_id</th>\n",
       "      <th>data_folder</th>\n",
       "      <th>output_folder</th>\n",
       "      <th>save_predictions</th>\n",
       "      <th>target</th>\n",
       "      <th>horizon</th>\n",
       "      <th>time_delay</th>\n",
       "      <th>model_name</th>\n",
       "      <th>width</th>\n",
       "      <th>activation</th>\n",
       "      <th>regularization</th>\n",
       "      <th>seed</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>run_id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>3.197072</td>\n",
       "      <td>1.682173</td>\n",
       "      <td>1.719401</td>\n",
       "      <td>1.650919</td>\n",
       "      <td>726.630813</td>\n",
       "      <td>358.900719</td>\n",
       "      <td>358.371477</td>\n",
       "      <td>4</td>\n",
       "      <td>real_world/data/electricity</td>\n",
       "      <td>real_world/grid_search/output/electricity/swim</td>\n",
       "      <td>True</td>\n",
       "      <td>Voltage</td>\n",
       "      <td>120</td>\n",
       "      <td>240</td>\n",
       "      <td>swim</td>\n",
       "      <td>64</td>\n",
       "      <td>tanh</td>\n",
       "      <td>1.000000e-08</td>\n",
       "      <td>1074497555</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>3.129078</td>\n",
       "      <td>1.640559</td>\n",
       "      <td>1.703160</td>\n",
       "      <td>1.689099</td>\n",
       "      <td>727.495643</td>\n",
       "      <td>368.133058</td>\n",
       "      <td>360.117126</td>\n",
       "      <td>4</td>\n",
       "      <td>real_world/data/electricity</td>\n",
       "      <td>real_world/grid_search/output/electricity/swim</td>\n",
       "      <td>True</td>\n",
       "      <td>Voltage</td>\n",
       "      <td>120</td>\n",
       "      <td>240</td>\n",
       "      <td>swim</td>\n",
       "      <td>64</td>\n",
       "      <td>tanh</td>\n",
       "      <td>1.000000e-08</td>\n",
       "      <td>796282693</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>3.114329</td>\n",
       "      <td>1.656325</td>\n",
       "      <td>1.610063</td>\n",
       "      <td>1.607259</td>\n",
       "      <td>732.600434</td>\n",
       "      <td>361.763317</td>\n",
       "      <td>359.104469</td>\n",
       "      <td>4</td>\n",
       "      <td>real_world/data/electricity</td>\n",
       "      <td>real_world/grid_search/output/electricity/swim</td>\n",
       "      <td>True</td>\n",
       "      <td>Voltage</td>\n",
       "      <td>120</td>\n",
       "      <td>240</td>\n",
       "      <td>swim</td>\n",
       "      <td>64</td>\n",
       "      <td>tanh</td>\n",
       "      <td>1.000000e-08</td>\n",
       "      <td>392022359</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>3.279318</td>\n",
       "      <td>1.623709</td>\n",
       "      <td>1.728349</td>\n",
       "      <td>1.658592</td>\n",
       "      <td>733.208539</td>\n",
       "      <td>359.221632</td>\n",
       "      <td>364.143386</td>\n",
       "      <td>4</td>\n",
       "      <td>real_world/data/electricity</td>\n",
       "      <td>real_world/grid_search/output/electricity/swim</td>\n",
       "      <td>True</td>\n",
       "      <td>Voltage</td>\n",
       "      <td>120</td>\n",
       "      <td>240</td>\n",
       "      <td>swim</td>\n",
       "      <td>64</td>\n",
       "      <td>tanh</td>\n",
       "      <td>1.000000e-08</td>\n",
       "      <td>1990212658</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>3.150959</td>\n",
       "      <td>1.686596</td>\n",
       "      <td>1.705353</td>\n",
       "      <td>1.695937</td>\n",
       "      <td>736.575482</td>\n",
       "      <td>372.828137</td>\n",
       "      <td>359.462845</td>\n",
       "      <td>4</td>\n",
       "      <td>real_world/data/electricity</td>\n",
       "      <td>real_world/grid_search/output/electricity/swim</td>\n",
       "      <td>True</td>\n",
       "      <td>Voltage</td>\n",
       "      <td>120</td>\n",
       "      <td>240</td>\n",
       "      <td>swim</td>\n",
       "      <td>64</td>\n",
       "      <td>tanh</td>\n",
       "      <td>1.000000e-08</td>\n",
       "      <td>1678403330</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 6
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-02-27T17:51:42.616258842Z",
     "start_time": "2026-02-27T17:51:42.263608568Z"
    }
   },
   "cell_type": "code",
   "source": [
    "print(\"Average fit time\", best_runs[\"training_time\"].mean())\n",
    "print(\"Maximum fit time\", best_runs[\"training_time\"].max())\n",
    "print(\"Minimum fit time\", best_runs[\"training_time\"].min())\n",
    "\n",
    "print(\"Average test MSE\", best_runs[\"metric_test\"].mean())\n",
    "print(\"Maximum test MSE\", best_runs[\"metric_test\"].max())\n",
    "print(\"Minimum test MSE\", best_runs[\"metric_test\"].min())"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average fit time 3.17415132522583\n",
      "Maximum fit time 3.27931809425354\n",
      "Minimum fit time 3.114328622817993\n",
      "Average test MSE 1.6603609884237862\n",
      "Maximum test MSE 1.695936526126446\n",
      "Minimum test MSE 1.6072593507966173\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-02-20T11:24:12.968495155Z",
     "start_time": "2026-02-20T11:24:12.822566582Z"
    }
   },
   "cell_type": "code",
   "source": "",
   "outputs": [],
   "execution_count": 6
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "swimrnn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
