{
 "cells": [
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-13T13:00:46.996087Z",
     "start_time": "2025-03-13T13:00:46.649681Z"
    }
   },
   "source": [
    "import pandas as pd\n",
    "\n",
    "\n",
    "MODEL = \"kirnn\"  # can be \"kirnn\" or \"lstm\"\n",
    "horizon = 168  # can be 24, 168 or None"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-13T13:00:47.126713Z",
     "start_time": "2025-03-13T13:00:47.096987Z"
    }
   },
   "source": [
    "results_csv = f\"{MODEL}/grid_results.csv\"\n",
    "full_results = pd.read_csv(results_csv, index_col=\"run_id\")\n",
    "\n",
    "if horizon is None:\n",
    "    results = full_results\n",
    "else:\n",
    "    results = full_results.query(f\"horizon=={horizon}\")\n",
    "\n",
    "print(f\"#successful experiments: {len(results)}.\")\n",
    "results.head()"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#successful experiments: 125.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "        training_time  metric_train  metric_val  metric_test  \\\n",
       "run_id                                                         \n",
       "126          3.205648      3.929205    3.749201     4.533193   \n",
       "127          3.291892      4.168224    4.164307     4.193677   \n",
       "128          3.312314      3.756631    3.816502     3.314845   \n",
       "129          3.356545      3.925787    4.183041     4.147108   \n",
       "130          3.341612      3.533684    3.329933     4.103313   \n",
       "\n",
       "        prediction_time_train  prediction_time_val  prediction_time_test  \\\n",
       "run_id                                                                     \n",
       "126                  3.022673             0.849330              0.409993   \n",
       "127                  3.042498             0.848795              0.409403   \n",
       "128                  3.051481             0.852484              0.410388   \n",
       "129                  3.070180             0.852862              0.410675   \n",
       "130                  3.038604             0.850016              0.409301   \n",
       "\n",
       "        config_id   data_folder        output_folder  save_predictions  \\\n",
       "run_id                                                                   \n",
       "126            25  climate/data  climate/output/swim             False   \n",
       "127            25  climate/data  climate/output/swim             False   \n",
       "128            25  climate/data  climate/output/swim             False   \n",
       "129            25  climate/data  climate/output/swim             False   \n",
       "130            25  climate/data  climate/output/swim             False   \n",
       "\n",
       "          target  horizon  time_delay model_name  width activation  \\\n",
       "run_id                                                               \n",
       "126     T (degC)      168         168       swim     32       tanh   \n",
       "127     T (degC)      168         168       swim     32       tanh   \n",
       "128     T (degC)      168         168       swim     32       tanh   \n",
       "129     T (degC)      168         168       swim     32       tanh   \n",
       "130     T (degC)      168         168       swim     32       tanh   \n",
       "\n",
       "        regularization        seed  \n",
       "run_id                              \n",
       "126       1.000000e-10  1213835294  \n",
       "127       1.000000e-10   170075129  \n",
       "128       1.000000e-10  1642822536  \n",
       "129       1.000000e-10  1233945442  \n",
       "130       1.000000e-10  1363047212  "
      ],
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>training_time</th>\n",
       "      <th>metric_train</th>\n",
       "      <th>metric_val</th>\n",
       "      <th>metric_test</th>\n",
       "      <th>prediction_time_train</th>\n",
       "      <th>prediction_time_val</th>\n",
       "      <th>prediction_time_test</th>\n",
       "      <th>config_id</th>\n",
       "      <th>data_folder</th>\n",
       "      <th>output_folder</th>\n",
       "      <th>save_predictions</th>\n",
       "      <th>target</th>\n",
       "      <th>horizon</th>\n",
       "      <th>time_delay</th>\n",
       "      <th>model_name</th>\n",
       "      <th>width</th>\n",
       "      <th>activation</th>\n",
       "      <th>regularization</th>\n",
       "      <th>seed</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>run_id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>126</th>\n",
       "      <td>3.205648</td>\n",
       "      <td>3.929205</td>\n",
       "      <td>3.749201</td>\n",
       "      <td>4.533193</td>\n",
       "      <td>3.022673</td>\n",
       "      <td>0.849330</td>\n",
       "      <td>0.409993</td>\n",
       "      <td>25</td>\n",
       "      <td>climate/data</td>\n",
       "      <td>climate/output/swim</td>\n",
       "      <td>False</td>\n",
       "      <td>T (degC)</td>\n",
       "      <td>168</td>\n",
       "      <td>168</td>\n",
       "      <td>swim</td>\n",
       "      <td>32</td>\n",
       "      <td>tanh</td>\n",
       "      <td>1.000000e-10</td>\n",
       "      <td>1213835294</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>127</th>\n",
       "      <td>3.291892</td>\n",
       "      <td>4.168224</td>\n",
       "      <td>4.164307</td>\n",
       "      <td>4.193677</td>\n",
       "      <td>3.042498</td>\n",
       "      <td>0.848795</td>\n",
       "      <td>0.409403</td>\n",
       "      <td>25</td>\n",
       "      <td>climate/data</td>\n",
       "      <td>climate/output/swim</td>\n",
       "      <td>False</td>\n",
       "      <td>T (degC)</td>\n",
       "      <td>168</td>\n",
       "      <td>168</td>\n",
       "      <td>swim</td>\n",
       "      <td>32</td>\n",
       "      <td>tanh</td>\n",
       "      <td>1.000000e-10</td>\n",
       "      <td>170075129</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>128</th>\n",
       "      <td>3.312314</td>\n",
       "      <td>3.756631</td>\n",
       "      <td>3.816502</td>\n",
       "      <td>3.314845</td>\n",
       "      <td>3.051481</td>\n",
       "      <td>0.852484</td>\n",
       "      <td>0.410388</td>\n",
       "      <td>25</td>\n",
       "      <td>climate/data</td>\n",
       "      <td>climate/output/swim</td>\n",
       "      <td>False</td>\n",
       "      <td>T (degC)</td>\n",
       "      <td>168</td>\n",
       "      <td>168</td>\n",
       "      <td>swim</td>\n",
       "      <td>32</td>\n",
       "      <td>tanh</td>\n",
       "      <td>1.000000e-10</td>\n",
       "      <td>1642822536</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>129</th>\n",
       "      <td>3.356545</td>\n",
       "      <td>3.925787</td>\n",
       "      <td>4.183041</td>\n",
       "      <td>4.147108</td>\n",
       "      <td>3.070180</td>\n",
       "      <td>0.852862</td>\n",
       "      <td>0.410675</td>\n",
       "      <td>25</td>\n",
       "      <td>climate/data</td>\n",
       "      <td>climate/output/swim</td>\n",
       "      <td>False</td>\n",
       "      <td>T (degC)</td>\n",
       "      <td>168</td>\n",
       "      <td>168</td>\n",
       "      <td>swim</td>\n",
       "      <td>32</td>\n",
       "      <td>tanh</td>\n",
       "      <td>1.000000e-10</td>\n",
       "      <td>1233945442</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>130</th>\n",
       "      <td>3.341612</td>\n",
       "      <td>3.533684</td>\n",
       "      <td>3.329933</td>\n",
       "      <td>4.103313</td>\n",
       "      <td>3.038604</td>\n",
       "      <td>0.850016</td>\n",
       "      <td>0.409301</td>\n",
       "      <td>25</td>\n",
       "      <td>climate/data</td>\n",
       "      <td>climate/output/swim</td>\n",
       "      <td>False</td>\n",
       "      <td>T (degC)</td>\n",
       "      <td>168</td>\n",
       "      <td>168</td>\n",
       "      <td>swim</td>\n",
       "      <td>32</td>\n",
       "      <td>tanh</td>\n",
       "      <td>1.000000e-10</td>\n",
       "      <td>1363047212</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>prediction_time_train</th>\n",
       "      <th>prediction_time_val</th>\n",
       "      <th>prediction_time_test</th>\n",
       "      <th>training_time</th>\n",
       "      <th>metric_train</th>\n",
       "      <th>metric_val</th>\n",
       "      <th>metric_test</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>config_id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>42</th>\n",
       "      <td>5.125874</td>\n",
       "      <td>1.451209</td>\n",
       "      <td>0.691803</td>\n",
       "      <td>4.866589</td>\n",
       "      <td>3.356134</td>\n",
       "      <td>3.537632</td>\n",
       "      <td>4.624142</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>44</th>\n",
       "      <td>5.210057</td>\n",
       "      <td>1.454635</td>\n",
       "      <td>0.697433</td>\n",
       "      <td>4.870636</td>\n",
       "      <td>3.382204</td>\n",
       "      <td>3.667382</td>\n",
       "      <td>4.719246</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>43</th>\n",
       "      <td>5.187989</td>\n",
       "      <td>1.474784</td>\n",
       "      <td>0.704270</td>\n",
       "      <td>4.860293</td>\n",
       "      <td>3.291209</td>\n",
       "      <td>3.690394</td>\n",
       "      <td>4.719264</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>5.190898</td>\n",
       "      <td>1.507469</td>\n",
       "      <td>0.822324</td>\n",
       "      <td>4.939905</td>\n",
       "      <td>3.397083</td>\n",
       "      <td>3.736515</td>\n",
       "      <td>4.810085</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49</th>\n",
       "      <td>5.952818</td>\n",
       "      <td>1.670949</td>\n",
       "      <td>0.815411</td>\n",
       "      <td>7.606666</td>\n",
       "      <td>3.077934</td>\n",
       "      <td>3.756936</td>\n",
       "      <td>5.029673</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>5.167676</td>\n",
       "      <td>1.445722</td>\n",
       "      <td>0.697570</td>\n",
       "      <td>4.887149</td>\n",
       "      <td>3.407341</td>\n",
       "      <td>3.756964</td>\n",
       "      <td>4.479396</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45</th>\n",
       "      <td>6.047703</td>\n",
       "      <td>1.730738</td>\n",
       "      <td>0.821794</td>\n",
       "      <td>7.715208</td>\n",
       "      <td>2.995263</td>\n",
       "      <td>3.766534</td>\n",
       "      <td>4.604857</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>3.042743</td>\n",
       "      <td>0.848663</td>\n",
       "      <td>0.410549</td>\n",
       "      <td>3.330864</td>\n",
       "      <td>4.031215</td>\n",
       "      <td>3.774643</td>\n",
       "      <td>4.436400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>3.688661</td>\n",
       "      <td>1.003946</td>\n",
       "      <td>0.479239</td>\n",
       "      <td>3.776189</td>\n",
       "      <td>3.677532</td>\n",
       "      <td>3.789948</td>\n",
       "      <td>5.313665</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>3.268684</td>\n",
       "      <td>0.923300</td>\n",
       "      <td>0.443392</td>\n",
       "      <td>3.764180</td>\n",
       "      <td>3.812054</td>\n",
       "      <td>3.800527</td>\n",
       "      <td>4.797808</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>3.192297</td>\n",
       "      <td>0.875345</td>\n",
       "      <td>0.423659</td>\n",
       "      <td>3.263579</td>\n",
       "      <td>3.962244</td>\n",
       "      <td>3.826673</td>\n",
       "      <td>4.663975</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>3.045087</td>\n",
       "      <td>0.850698</td>\n",
       "      <td>0.409952</td>\n",
       "      <td>3.301602</td>\n",
       "      <td>3.862706</td>\n",
       "      <td>3.848597</td>\n",
       "      <td>4.058427</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>3.154778</td>\n",
       "      <td>0.882777</td>\n",
       "      <td>0.426098</td>\n",
       "      <td>3.242977</td>\n",
       "      <td>3.886289</td>\n",
       "      <td>3.852386</td>\n",
       "      <td>4.940101</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>3.185697</td>\n",
       "      <td>0.879465</td>\n",
       "      <td>0.426547</td>\n",
       "      <td>3.244781</td>\n",
       "      <td>4.025248</td>\n",
       "      <td>3.859301</td>\n",
       "      <td>4.728481</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>3.053518</td>\n",
       "      <td>0.850120</td>\n",
       "      <td>0.410358</td>\n",
       "      <td>3.331180</td>\n",
       "      <td>3.874152</td>\n",
       "      <td>3.876992</td>\n",
       "      <td>4.181510</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>3.281683</td>\n",
       "      <td>0.917201</td>\n",
       "      <td>0.440514</td>\n",
       "      <td>3.767621</td>\n",
       "      <td>3.764615</td>\n",
       "      <td>3.929705</td>\n",
       "      <td>4.483180</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>3.140706</td>\n",
       "      <td>0.875582</td>\n",
       "      <td>0.423546</td>\n",
       "      <td>3.239284</td>\n",
       "      <td>3.884259</td>\n",
       "      <td>3.946397</td>\n",
       "      <td>4.883127</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>3.043811</td>\n",
       "      <td>0.861566</td>\n",
       "      <td>0.409897</td>\n",
       "      <td>3.320070</td>\n",
       "      <td>4.031792</td>\n",
       "      <td>3.960249</td>\n",
       "      <td>4.257519</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48</th>\n",
       "      <td>5.940648</td>\n",
       "      <td>1.678348</td>\n",
       "      <td>0.803825</td>\n",
       "      <td>7.586175</td>\n",
       "      <td>2.964281</td>\n",
       "      <td>3.967656</td>\n",
       "      <td>7.499914</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>3.288555</td>\n",
       "      <td>0.917029</td>\n",
       "      <td>0.446007</td>\n",
       "      <td>3.781938</td>\n",
       "      <td>3.790199</td>\n",
       "      <td>3.993856</td>\n",
       "      <td>4.883393</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>3.033454</td>\n",
       "      <td>0.858630</td>\n",
       "      <td>0.409399</td>\n",
       "      <td>3.325212</td>\n",
       "      <td>4.009006</td>\n",
       "      <td>3.994831</td>\n",
       "      <td>4.340448</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>47</th>\n",
       "      <td>6.002270</td>\n",
       "      <td>1.681601</td>\n",
       "      <td>0.800309</td>\n",
       "      <td>7.654123</td>\n",
       "      <td>3.096340</td>\n",
       "      <td>4.008155</td>\n",
       "      <td>5.714287</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>3.281867</td>\n",
       "      <td>0.917751</td>\n",
       "      <td>0.444932</td>\n",
       "      <td>3.811103</td>\n",
       "      <td>3.743859</td>\n",
       "      <td>4.023671</td>\n",
       "      <td>4.509429</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>3.144734</td>\n",
       "      <td>0.881830</td>\n",
       "      <td>0.425429</td>\n",
       "      <td>3.231079</td>\n",
       "      <td>4.008342</td>\n",
       "      <td>4.040726</td>\n",
       "      <td>4.839820</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>46</th>\n",
       "      <td>6.591454</td>\n",
       "      <td>1.830842</td>\n",
       "      <td>0.816670</td>\n",
       "      <td>7.744003</td>\n",
       "      <td>3.088658</td>\n",
       "      <td>4.042843</td>\n",
       "      <td>5.338559</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           prediction_time_train  prediction_time_val  prediction_time_test  \\\n",
       "config_id                                                                     \n",
       "42                      5.125874             1.451209              0.691803   \n",
       "44                      5.210057             1.454635              0.697433   \n",
       "43                      5.187989             1.474784              0.704270   \n",
       "40                      5.190898             1.507469              0.822324   \n",
       "49                      5.952818             1.670949              0.815411   \n",
       "41                      5.167676             1.445722              0.697570   \n",
       "45                      6.047703             1.730738              0.821794   \n",
       "28                      3.042743             0.848663              0.410549   \n",
       "36                      3.688661             1.003946              0.479239   \n",
       "38                      3.268684             0.923300              0.443392   \n",
       "33                      3.192297             0.875345              0.423659   \n",
       "25                      3.045087             0.850698              0.409952   \n",
       "30                      3.154778             0.882777              0.426098   \n",
       "34                      3.185697             0.879465              0.426547   \n",
       "29                      3.053518             0.850120              0.410358   \n",
       "37                      3.281683             0.917201              0.440514   \n",
       "32                      3.140706             0.875582              0.423546   \n",
       "26                      3.043811             0.861566              0.409897   \n",
       "48                      5.940648             1.678348              0.803825   \n",
       "35                      3.288555             0.917029              0.446007   \n",
       "27                      3.033454             0.858630              0.409399   \n",
       "47                      6.002270             1.681601              0.800309   \n",
       "39                      3.281867             0.917751              0.444932   \n",
       "31                      3.144734             0.881830              0.425429   \n",
       "46                      6.591454             1.830842              0.816670   \n",
       "\n",
       "           training_time  metric_train  metric_val  metric_test  \n",
       "config_id                                                        \n",
       "42              4.866589      3.356134    3.537632     4.624142  \n",
       "44              4.870636      3.382204    3.667382     4.719246  \n",
       "43              4.860293      3.291209    3.690394     4.719264  \n",
       "40              4.939905      3.397083    3.736515     4.810085  \n",
       "49              7.606666      3.077934    3.756936     5.029673  \n",
       "41              4.887149      3.407341    3.756964     4.479396  \n",
       "45              7.715208      2.995263    3.766534     4.604857  \n",
       "28              3.330864      4.031215    3.774643     4.436400  \n",
       "36              3.776189      3.677532    3.789948     5.313665  \n",
       "38              3.764180      3.812054    3.800527     4.797808  \n",
       "33              3.263579      3.962244    3.826673     4.663975  \n",
       "25              3.301602      3.862706    3.848597     4.058427  \n",
       "30              3.242977      3.886289    3.852386     4.940101  \n",
       "34              3.244781      4.025248    3.859301     4.728481  \n",
       "29              3.331180      3.874152    3.876992     4.181510  \n",
       "37              3.767621      3.764615    3.929705     4.483180  \n",
       "32              3.239284      3.884259    3.946397     4.883127  \n",
       "26              3.320070      4.031792    3.960249     4.257519  \n",
       "48              7.586175      2.964281    3.967656     7.499914  \n",
       "35              3.781938      3.790199    3.993856     4.883393  \n",
       "27              3.325212      4.009006    3.994831     4.340448  \n",
       "47              7.654123      3.096340    4.008155     5.714287  \n",
       "39              3.811103      3.743859    4.023671     4.509429  \n",
       "31              3.231079      4.008342    4.040726     4.839820  \n",
       "46              7.744003      3.088658    4.042843     5.338559  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TARGET = \"metric_val\"\n",
    "METRIC_COLUMNS = [\n",
    "    \"prediction_time_train\",\n",
    "    \"prediction_time_val\",\n",
    "    \"prediction_time_test\",\n",
    "    \"training_time\",\n",
    "    \"metric_train\",\n",
    "    \"metric_val\",\n",
    "    \"metric_test\",\n",
    "]\n",
    "\n",
    "metrics = results.groupby([\"config_id\"])[METRIC_COLUMNS].mean()\n",
    "metrics.sort_values(by=TARGET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best config: #42.\n",
      "Aggregated metrics of the best config:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>prediction_time_train</th>\n",
       "      <th>prediction_time_val</th>\n",
       "      <th>prediction_time_test</th>\n",
       "      <th>training_time</th>\n",
       "      <th>metric_train</th>\n",
       "      <th>metric_val</th>\n",
       "      <th>metric_test</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>config_id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>42</th>\n",
       "      <td>5.125874</td>\n",
       "      <td>1.451209</td>\n",
       "      <td>0.691803</td>\n",
       "      <td>4.866589</td>\n",
       "      <td>3.356134</td>\n",
       "      <td>3.537632</td>\n",
       "      <td>4.624142</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           prediction_time_train  prediction_time_val  prediction_time_test  \\\n",
       "config_id                                                                     \n",
       "42                      5.125874             1.451209              0.691803   \n",
       "\n",
       "           training_time  metric_train  metric_val  metric_test  \n",
       "config_id                                                        \n",
       "42              4.866589      3.356134    3.537632     4.624142  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All runs of the best config:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>training_time</th>\n",
       "      <th>metric_train</th>\n",
       "      <th>metric_val</th>\n",
       "      <th>metric_test</th>\n",
       "      <th>prediction_time_train</th>\n",
       "      <th>prediction_time_val</th>\n",
       "      <th>prediction_time_test</th>\n",
       "      <th>config_id</th>\n",
       "      <th>data_folder</th>\n",
       "      <th>output_folder</th>\n",
       "      <th>save_predictions</th>\n",
       "      <th>target</th>\n",
       "      <th>horizon</th>\n",
       "      <th>time_delay</th>\n",
       "      <th>model_name</th>\n",
       "      <th>width</th>\n",
       "      <th>activation</th>\n",
       "      <th>regularization</th>\n",
       "      <th>seed</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>run_id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>211</th>\n",
       "      <td>4.842715</td>\n",
       "      <td>3.364208</td>\n",
       "      <td>3.579924</td>\n",
       "      <td>4.785945</td>\n",
       "      <td>5.136994</td>\n",
       "      <td>1.487052</td>\n",
       "      <td>0.697315</td>\n",
       "      <td>42</td>\n",
       "      <td>climate/data</td>\n",
       "      <td>climate/output/swim</td>\n",
       "      <td>False</td>\n",
       "      <td>T (degC)</td>\n",
       "      <td>168</td>\n",
       "      <td>168</td>\n",
       "      <td>swim</td>\n",
       "      <td>256</td>\n",
       "      <td>tanh</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>560793881</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>212</th>\n",
       "      <td>4.896584</td>\n",
       "      <td>3.385688</td>\n",
       "      <td>3.617896</td>\n",
       "      <td>4.867459</td>\n",
       "      <td>5.122491</td>\n",
       "      <td>1.444159</td>\n",
       "      <td>0.690018</td>\n",
       "      <td>42</td>\n",
       "      <td>climate/data</td>\n",
       "      <td>climate/output/swim</td>\n",
       "      <td>False</td>\n",
       "      <td>T (degC)</td>\n",
       "      <td>168</td>\n",
       "      <td>168</td>\n",
       "      <td>swim</td>\n",
       "      <td>256</td>\n",
       "      <td>tanh</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>1539509957</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>213</th>\n",
       "      <td>4.893866</td>\n",
       "      <td>3.307062</td>\n",
       "      <td>3.413334</td>\n",
       "      <td>4.622088</td>\n",
       "      <td>5.093190</td>\n",
       "      <td>1.439807</td>\n",
       "      <td>0.684206</td>\n",
       "      <td>42</td>\n",
       "      <td>climate/data</td>\n",
       "      <td>climate/output/swim</td>\n",
       "      <td>False</td>\n",
       "      <td>T (degC)</td>\n",
       "      <td>168</td>\n",
       "      <td>168</td>\n",
       "      <td>swim</td>\n",
       "      <td>256</td>\n",
       "      <td>tanh</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>1694378689</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>214</th>\n",
       "      <td>4.887356</td>\n",
       "      <td>3.232272</td>\n",
       "      <td>3.607804</td>\n",
       "      <td>4.168997</td>\n",
       "      <td>5.138295</td>\n",
       "      <td>1.443243</td>\n",
       "      <td>0.701768</td>\n",
       "      <td>42</td>\n",
       "      <td>climate/data</td>\n",
       "      <td>climate/output/swim</td>\n",
       "      <td>False</td>\n",
       "      <td>T (degC)</td>\n",
       "      <td>168</td>\n",
       "      <td>168</td>\n",
       "      <td>swim</td>\n",
       "      <td>256</td>\n",
       "      <td>tanh</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>964996477</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>215</th>\n",
       "      <td>4.812424</td>\n",
       "      <td>3.491440</td>\n",
       "      <td>3.469202</td>\n",
       "      <td>4.676219</td>\n",
       "      <td>5.138402</td>\n",
       "      <td>1.441782</td>\n",
       "      <td>0.685709</td>\n",
       "      <td>42</td>\n",
       "      <td>climate/data</td>\n",
       "      <td>climate/output/swim</td>\n",
       "      <td>False</td>\n",
       "      <td>T (degC)</td>\n",
       "      <td>168</td>\n",
       "      <td>168</td>\n",
       "      <td>swim</td>\n",
       "      <td>256</td>\n",
       "      <td>tanh</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>1582522787</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        training_time  metric_train  metric_val  metric_test  \\\n",
       "run_id                                                         \n",
       "211          4.842715      3.364208    3.579924     4.785945   \n",
       "212          4.896584      3.385688    3.617896     4.867459   \n",
       "213          4.893866      3.307062    3.413334     4.622088   \n",
       "214          4.887356      3.232272    3.607804     4.168997   \n",
       "215          4.812424      3.491440    3.469202     4.676219   \n",
       "\n",
       "        prediction_time_train  prediction_time_val  prediction_time_test  \\\n",
       "run_id                                                                     \n",
       "211                  5.136994             1.487052              0.697315   \n",
       "212                  5.122491             1.444159              0.690018   \n",
       "213                  5.093190             1.439807              0.684206   \n",
       "214                  5.138295             1.443243              0.701768   \n",
       "215                  5.138402             1.441782              0.685709   \n",
       "\n",
       "        config_id   data_folder        output_folder  save_predictions  \\\n",
       "run_id                                                                   \n",
       "211            42  climate/data  climate/output/swim             False   \n",
       "212            42  climate/data  climate/output/swim             False   \n",
       "213            42  climate/data  climate/output/swim             False   \n",
       "214            42  climate/data  climate/output/swim             False   \n",
       "215            42  climate/data  climate/output/swim             False   \n",
       "\n",
       "          target  horizon  time_delay model_name  width activation  \\\n",
       "run_id                                                               \n",
       "211     T (degC)      168         168       swim    256       tanh   \n",
       "212     T (degC)      168         168       swim    256       tanh   \n",
       "213     T (degC)      168         168       swim    256       tanh   \n",
       "214     T (degC)      168         168       swim    256       tanh   \n",
       "215     T (degC)      168         168       swim    256       tanh   \n",
       "\n",
       "        regularization        seed  \n",
       "run_id                              \n",
       "211           0.000001   560793881  \n",
       "212           0.000001  1539509957  \n",
       "213           0.000001  1694378689  \n",
       "214           0.000001   964996477  \n",
       "215           0.000001  1582522787  "
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_config_id = metrics.index[metrics[TARGET].argmin()]\n",
    "print(f\"Best config: #{best_config_id}.\")\n",
    "print(\"Aggregated metrics of the best config:\")\n",
    "display(metrics.query(f\"config_id=={best_config_id}\"))\n",
    "\n",
    "print(\"All runs of the best config:\")\n",
    "best_runs = results.query(f\"config_id=={best_config_id}\")\n",
    "best_runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "swimrnn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
