{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d31d54c6",
   "metadata": {},
   "source": [
    "# Table: Surrogate Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f69f31e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from typing import Any, Dict, List, Set\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tsbench.evaluation.tracking import SacredMongoClient\n",
    "\n",
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4996a0bc",
   "metadata": {},
   "source": [
    "## Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "90e55d78",
   "metadata": {},
   "outputs": [],
   "source": [
    "def flatten_dict(data: Dict[str, Any], ignore: Set[str]) -> Dict[str, Any]:\n",
    "    result = {}\n",
    "    for k, v in data.items():\n",
    "        if k in ignore:\n",
    "            continue\n",
    "        if isinstance(v, dict):\n",
    "            result.update({\n",
    "                f\"{k}_{kk}\": vv\n",
    "                for kk, vv in flatten_dict(v, {i if not i.startswith(k) else i[len(k) + 1:] for i in ignore}).items()\n",
    "            })\n",
    "        else:\n",
    "            result[k] = str(v)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8ea25f4",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "080406e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "client = SacredMongoClient(\"eval-surrogates-iclr-05-10-21\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4acfa5b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "ignore_keys = {\n",
    "    \"name\",\n",
    "    \"experiment\",\n",
    "    \"metrics\",\n",
    "    \"seed\",\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76c834d5",
   "metadata": {},
   "source": [
    "## Read All Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2fdf8c83",
   "metadata": {},
   "outputs": [],
   "source": [
    "configs = []\n",
    "metrics = []\n",
    "for experiment in client:\n",
    "    df = experiment.read_parquet(\"results.parquet\")\n",
    "    configs.append(experiment.config)\n",
    "    metrics.append(df.mean().mean_weighted_quantile_loss_mean)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6af1c81f",
   "metadata": {},
   "outputs": [],
   "source": [
    "index_df = pd.DataFrame([flatten_dict(c, ignore=ignore_keys) for c in configs])\n",
    "df = pd.DataFrame(metrics, index=pd.MultiIndex.from_frame(index_df)).sort_index()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d780e531",
   "metadata": {},
   "source": [
    "## Compare Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0eb4a966",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_results = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15960f52",
   "metadata": {},
   "source": [
    "### Random Predictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4183be29",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_results[\"Random\"] = df.query(\"surrogate == 'random'\").mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08d2329d",
   "metadata": {},
   "source": [
    "### Nonparametric Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7866c725",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>mrr</th>\n",
       "      <th>ndcg</th>\n",
       "      <th>nrmse</th>\n",
       "      <th>precision_10</th>\n",
       "      <th>precision_20</th>\n",
       "      <th>precision_5</th>\n",
       "      <th>smape</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>surrogate</th>\n",
       "      <th>input_flags_use_simple_dataset_features</th>\n",
       "      <th>input_flags_use_seasonal_naive_performance</th>\n",
       "      <th>input_flags_use_catch22_features</th>\n",
       "      <th>nonparametric_use_ranks</th>\n",
       "      <th>xgboost_objective</th>\n",
       "      <th>mlp_objective</th>\n",
       "      <th>mlp_discount</th>\n",
       "      <th>mlp_hidden_layer_sizes</th>\n",
       "      <th>mlp_weight_decay</th>\n",
       "      <th>mlp_dropout</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">nonparametric</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>ranking</th>\n",
       "      <th>ranking</th>\n",
       "      <th>linear</th>\n",
       "      <th>[32, 32]</th>\n",
       "      <th>0.01</th>\n",
       "      <th>0.0</th>\n",
       "      <td>0.068830</td>\n",
       "      <td>0.358006</td>\n",
       "      <td>1.458794</td>\n",
       "      <td>0.209091</td>\n",
       "      <td>0.327273</td>\n",
       "      <td>0.131818</td>\n",
       "      <td>77.172819</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>True</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>ranking</th>\n",
       "      <th>ranking</th>\n",
       "      <th>linear</th>\n",
       "      <th>[32, 32]</th>\n",
       "      <th>0.01</th>\n",
       "      <th>0.0</th>\n",
       "      <td>0.090033</td>\n",
       "      <td>0.356390</td>\n",
       "      <td>1.180901</td>\n",
       "      <td>0.186364</td>\n",
       "      <td>0.296591</td>\n",
       "      <td>0.118182</td>\n",
       "      <td>66.594767</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                                                                                                                                                                                 mrr  \\\n",
       "surrogate     input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "nonparametric False                                   False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          0.068830   \n",
       "              True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          0.090033   \n",
       "\n",
       "                                                                                                                                                                                                                                                                ndcg  \\\n",
       "surrogate     input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "nonparametric False                                   False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          0.358006   \n",
       "              True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          0.356390   \n",
       "\n",
       "                                                                                                                                                                                                                                                               nrmse  \\\n",
       "surrogate     input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "nonparametric False                                   False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          1.458794   \n",
       "              True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          1.180901   \n",
       "\n",
       "                                                                                                                                                                                                                                                            precision_10  \\\n",
       "surrogate     input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                 \n",
       "nonparametric False                                   False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0              0.209091   \n",
       "              True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0              0.186364   \n",
       "\n",
       "                                                                                                                                                                                                                                                            precision_20  \\\n",
       "surrogate     input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                 \n",
       "nonparametric False                                   False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0              0.327273   \n",
       "              True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0              0.296591   \n",
       "\n",
       "                                                                                                                                                                                                                                                            precision_5  \\\n",
       "surrogate     input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                \n",
       "nonparametric False                                   False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0             0.131818   \n",
       "              True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0             0.118182   \n",
       "\n",
       "                                                                                                                                                                                                                                                                smape  \n",
       "surrogate     input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "nonparametric False                                   False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          77.172819  \n",
       "              True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          66.594767  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = df.query(\"surrogate == 'nonparametric' and nonparametric_use_ranks == 'False'\")\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "cb7732af",
   "metadata": {},
   "outputs": [],
   "source": [
    "best = results.query(\n",
    "    \"input_flags_use_simple_dataset_features == 'False' and input_flags_use_catch22_features == 'False'\"\n",
    ")\n",
    "best_results[\"Nonparametric Regression\"] = best.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f42fa765",
   "metadata": {},
   "source": [
    "### Nonparametric Ranking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8db5940c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>mrr</th>\n",
       "      <th>ndcg</th>\n",
       "      <th>nrmse</th>\n",
       "      <th>precision_10</th>\n",
       "      <th>precision_20</th>\n",
       "      <th>precision_5</th>\n",
       "      <th>smape</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>surrogate</th>\n",
       "      <th>input_flags_use_simple_dataset_features</th>\n",
       "      <th>input_flags_use_seasonal_naive_performance</th>\n",
       "      <th>input_flags_use_catch22_features</th>\n",
       "      <th>nonparametric_use_ranks</th>\n",
       "      <th>xgboost_objective</th>\n",
       "      <th>mlp_objective</th>\n",
       "      <th>mlp_discount</th>\n",
       "      <th>mlp_hidden_layer_sizes</th>\n",
       "      <th>mlp_weight_decay</th>\n",
       "      <th>mlp_dropout</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">nonparametric</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>True</th>\n",
       "      <th>ranking</th>\n",
       "      <th>ranking</th>\n",
       "      <th>linear</th>\n",
       "      <th>[32, 32]</th>\n",
       "      <th>0.01</th>\n",
       "      <th>0.0</th>\n",
       "      <td>0.067812</td>\n",
       "      <td>0.358657</td>\n",
       "      <td>828.004101</td>\n",
       "      <td>0.209091</td>\n",
       "      <td>0.330682</td>\n",
       "      <td>0.1</td>\n",
       "      <td>199.018165</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>True</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>True</th>\n",
       "      <th>ranking</th>\n",
       "      <th>ranking</th>\n",
       "      <th>linear</th>\n",
       "      <th>[32, 32]</th>\n",
       "      <th>0.01</th>\n",
       "      <th>0.0</th>\n",
       "      <td>0.075944</td>\n",
       "      <td>0.356313</td>\n",
       "      <td>828.004085</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>0.297727</td>\n",
       "      <td>0.1</td>\n",
       "      <td>199.008561</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                                                                                                                                                                                 mrr  \\\n",
       "surrogate     input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "nonparametric False                                   False                                      False                            True                    ranking           ranking       linear       [32, 32]               0.01             0.0          0.067812   \n",
       "              True                                    False                                      False                            True                    ranking           ranking       linear       [32, 32]               0.01             0.0          0.075944   \n",
       "\n",
       "                                                                                                                                                                                                                                                                ndcg  \\\n",
       "surrogate     input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "nonparametric False                                   False                                      False                            True                    ranking           ranking       linear       [32, 32]               0.01             0.0          0.358657   \n",
       "              True                                    False                                      False                            True                    ranking           ranking       linear       [32, 32]               0.01             0.0          0.356313   \n",
       "\n",
       "                                                                                                                                                                                                                                                                 nrmse  \\\n",
       "surrogate     input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout               \n",
       "nonparametric False                                   False                                      False                            True                    ranking           ranking       linear       [32, 32]               0.01             0.0          828.004101   \n",
       "              True                                    False                                      False                            True                    ranking           ranking       linear       [32, 32]               0.01             0.0          828.004085   \n",
       "\n",
       "                                                                                                                                                                                                                                                            precision_10  \\\n",
       "surrogate     input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                 \n",
       "nonparametric False                                   False                                      False                            True                    ranking           ranking       linear       [32, 32]               0.01             0.0              0.209091   \n",
       "              True                                    False                                      False                            True                    ranking           ranking       linear       [32, 32]               0.01             0.0              0.200000   \n",
       "\n",
       "                                                                                                                                                                                                                                                            precision_20  \\\n",
       "surrogate     input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                 \n",
       "nonparametric False                                   False                                      False                            True                    ranking           ranking       linear       [32, 32]               0.01             0.0              0.330682   \n",
       "              True                                    False                                      False                            True                    ranking           ranking       linear       [32, 32]               0.01             0.0              0.297727   \n",
       "\n",
       "                                                                                                                                                                                                                                                            precision_5  \\\n",
       "surrogate     input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                \n",
       "nonparametric False                                   False                                      False                            True                    ranking           ranking       linear       [32, 32]               0.01             0.0                  0.1   \n",
       "              True                                    False                                      False                            True                    ranking           ranking       linear       [32, 32]               0.01             0.0                  0.1   \n",
       "\n",
       "                                                                                                                                                                                                                                                                 smape  \n",
       "surrogate     input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout              \n",
       "nonparametric False                                   False                                      False                            True                    ranking           ranking       linear       [32, 32]               0.01             0.0          199.018165  \n",
       "              True                                    False                                      False                            True                    ranking           ranking       linear       [32, 32]               0.01             0.0          199.008561  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = df.query(\"surrogate == 'nonparametric' and nonparametric_use_ranks == 'True'\")\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e0caeb9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "best = results.query(\n",
    "    \"input_flags_use_simple_dataset_features == 'False' and input_flags_use_catch22_features == 'False'\"\n",
    ")\n",
    "best_results[\"Nonparametric Ranking\"] = best.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b83fb32c",
   "metadata": {},
   "source": [
    "### XGBoost Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3e28528c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>mrr</th>\n",
       "      <th>ndcg</th>\n",
       "      <th>nrmse</th>\n",
       "      <th>precision_10</th>\n",
       "      <th>precision_20</th>\n",
       "      <th>precision_5</th>\n",
       "      <th>smape</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>surrogate</th>\n",
       "      <th>input_flags_use_simple_dataset_features</th>\n",
       "      <th>input_flags_use_seasonal_naive_performance</th>\n",
       "      <th>input_flags_use_catch22_features</th>\n",
       "      <th>nonparametric_use_ranks</th>\n",
       "      <th>xgboost_objective</th>\n",
       "      <th>mlp_objective</th>\n",
       "      <th>mlp_discount</th>\n",
       "      <th>mlp_hidden_layer_sizes</th>\n",
       "      <th>mlp_weight_decay</th>\n",
       "      <th>mlp_dropout</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">xgboost</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>regression</th>\n",
       "      <th>ranking</th>\n",
       "      <th>linear</th>\n",
       "      <th>[32, 32]</th>\n",
       "      <th>0.01</th>\n",
       "      <th>0.0</th>\n",
       "      <td>0.056846</td>\n",
       "      <td>0.353395</td>\n",
       "      <td>1.458776</td>\n",
       "      <td>0.220455</td>\n",
       "      <td>0.326136</td>\n",
       "      <td>0.131818</td>\n",
       "      <td>77.172885</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>True</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>regression</th>\n",
       "      <th>ranking</th>\n",
       "      <th>linear</th>\n",
       "      <th>[32, 32]</th>\n",
       "      <th>0.01</th>\n",
       "      <th>0.0</th>\n",
       "      <td>0.014944</td>\n",
       "      <td>0.284394</td>\n",
       "      <td>1.911431</td>\n",
       "      <td>0.072727</td>\n",
       "      <td>0.161364</td>\n",
       "      <td>0.040909</td>\n",
       "      <td>84.135481</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                                                                                                                                                                             mrr  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "xgboost   False                                   False                                      False                            False                   regression        ranking       linear       [32, 32]               0.01             0.0          0.056846   \n",
       "          True                                    False                                      False                            False                   regression        ranking       linear       [32, 32]               0.01             0.0          0.014944   \n",
       "\n",
       "                                                                                                                                                                                                                                                            ndcg  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "xgboost   False                                   False                                      False                            False                   regression        ranking       linear       [32, 32]               0.01             0.0          0.353395   \n",
       "          True                                    False                                      False                            False                   regression        ranking       linear       [32, 32]               0.01             0.0          0.284394   \n",
       "\n",
       "                                                                                                                                                                                                                                                           nrmse  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "xgboost   False                                   False                                      False                            False                   regression        ranking       linear       [32, 32]               0.01             0.0          1.458776   \n",
       "          True                                    False                                      False                            False                   regression        ranking       linear       [32, 32]               0.01             0.0          1.911431   \n",
       "\n",
       "                                                                                                                                                                                                                                                        precision_10  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                 \n",
       "xgboost   False                                   False                                      False                            False                   regression        ranking       linear       [32, 32]               0.01             0.0              0.220455   \n",
       "          True                                    False                                      False                            False                   regression        ranking       linear       [32, 32]               0.01             0.0              0.072727   \n",
       "\n",
       "                                                                                                                                                                                                                                                        precision_20  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                 \n",
       "xgboost   False                                   False                                      False                            False                   regression        ranking       linear       [32, 32]               0.01             0.0              0.326136   \n",
       "          True                                    False                                      False                            False                   regression        ranking       linear       [32, 32]               0.01             0.0              0.161364   \n",
       "\n",
       "                                                                                                                                                                                                                                                        precision_5  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                \n",
       "xgboost   False                                   False                                      False                            False                   regression        ranking       linear       [32, 32]               0.01             0.0             0.131818   \n",
       "          True                                    False                                      False                            False                   regression        ranking       linear       [32, 32]               0.01             0.0             0.040909   \n",
       "\n",
       "                                                                                                                                                                                                                                                            smape  \n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "xgboost   False                                   False                                      False                            False                   regression        ranking       linear       [32, 32]               0.01             0.0          77.172885  \n",
       "          True                                    False                                      False                            False                   regression        ranking       linear       [32, 32]               0.01             0.0          84.135481  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = df.query(\"surrogate == 'xgboost' and xgboost_objective == 'regression'\")\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "89b33383",
   "metadata": {},
   "outputs": [],
   "source": [
    "best = results.query(\n",
    "    \"input_flags_use_simple_dataset_features == 'False' and input_flags_use_catch22_features == 'False'\"\n",
    ")\n",
    "best_results[\"XGBoost Regression\"] = best.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a70ed9d",
   "metadata": {},
   "source": [
    "### XGBoost Pairwise Ranking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "23a79b1e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>mrr</th>\n",
       "      <th>ndcg</th>\n",
       "      <th>nrmse</th>\n",
       "      <th>precision_10</th>\n",
       "      <th>precision_20</th>\n",
       "      <th>precision_5</th>\n",
       "      <th>smape</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>surrogate</th>\n",
       "      <th>input_flags_use_simple_dataset_features</th>\n",
       "      <th>input_flags_use_seasonal_naive_performance</th>\n",
       "      <th>input_flags_use_catch22_features</th>\n",
       "      <th>nonparametric_use_ranks</th>\n",
       "      <th>xgboost_objective</th>\n",
       "      <th>mlp_objective</th>\n",
       "      <th>mlp_discount</th>\n",
       "      <th>mlp_hidden_layer_sizes</th>\n",
       "      <th>mlp_weight_decay</th>\n",
       "      <th>mlp_dropout</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">xgboost</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>ranking</th>\n",
       "      <th>ranking</th>\n",
       "      <th>linear</th>\n",
       "      <th>[32, 32]</th>\n",
       "      <th>0.01</th>\n",
       "      <th>0.0</th>\n",
       "      <td>0.066369</td>\n",
       "      <td>0.356266</td>\n",
       "      <td>4.123228</td>\n",
       "      <td>0.215909</td>\n",
       "      <td>0.328409</td>\n",
       "      <td>0.122727</td>\n",
       "      <td>132.663563</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>True</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>False</th>\n",
       "      <th>ranking</th>\n",
       "      <th>ranking</th>\n",
       "      <th>linear</th>\n",
       "      <th>[32, 32]</th>\n",
       "      <th>0.01</th>\n",
       "      <th>0.0</th>\n",
       "      <td>0.044221</td>\n",
       "      <td>0.326935</td>\n",
       "      <td>9.929184</td>\n",
       "      <td>0.159091</td>\n",
       "      <td>0.226136</td>\n",
       "      <td>0.118182</td>\n",
       "      <td>161.809045</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                                                                                                                                                                             mrr  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "xgboost   False                                   False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          0.066369   \n",
       "          True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          0.044221   \n",
       "\n",
       "                                                                                                                                                                                                                                                            ndcg  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "xgboost   False                                   False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          0.356266   \n",
       "          True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          0.326935   \n",
       "\n",
       "                                                                                                                                                                                                                                                           nrmse  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "xgboost   False                                   False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          4.123228   \n",
       "          True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          9.929184   \n",
       "\n",
       "                                                                                                                                                                                                                                                        precision_10  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                 \n",
       "xgboost   False                                   False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0              0.215909   \n",
       "          True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0              0.159091   \n",
       "\n",
       "                                                                                                                                                                                                                                                        precision_20  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                 \n",
       "xgboost   False                                   False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0              0.328409   \n",
       "          True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0              0.226136   \n",
       "\n",
       "                                                                                                                                                                                                                                                        precision_5  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                \n",
       "xgboost   False                                   False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0             0.122727   \n",
       "          True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0             0.118182   \n",
       "\n",
       "                                                                                                                                                                                                                                                             smape  \n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout              \n",
       "xgboost   False                                   False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          132.663563  \n",
       "          True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          161.809045  "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = df.query(\"surrogate == 'xgboost' and xgboost_objective == 'ranking'\")\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e01845f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "best = results.query(\n",
    "    \"input_flags_use_simple_dataset_features == 'False' and input_flags_use_catch22_features == 'False'\"\n",
    ")\n",
    "best_results[\"XGBoost Ranking\"] = best.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbd021a9",
   "metadata": {},
   "source": [
    "### MLP Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "5ab9a31c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>mrr</th>\n",
       "      <th>ndcg</th>\n",
       "      <th>nrmse</th>\n",
       "      <th>precision_10</th>\n",
       "      <th>precision_20</th>\n",
       "      <th>precision_5</th>\n",
       "      <th>smape</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>surrogate</th>\n",
       "      <th>input_flags_use_simple_dataset_features</th>\n",
       "      <th>input_flags_use_seasonal_naive_performance</th>\n",
       "      <th>input_flags_use_catch22_features</th>\n",
       "      <th>nonparametric_use_ranks</th>\n",
       "      <th>xgboost_objective</th>\n",
       "      <th>mlp_objective</th>\n",
       "      <th>mlp_discount</th>\n",
       "      <th>mlp_hidden_layer_sizes</th>\n",
       "      <th>mlp_weight_decay</th>\n",
       "      <th>mlp_dropout</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"10\" valign=\"top\">mlp</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">False</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">False</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">False</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">False</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">ranking</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">regression</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">None</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">[32, 32]</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">0.01</th>\n",
       "      <th>0.0</th>\n",
       "      <td>0.050717</td>\n",
       "      <td>0.322365</td>\n",
       "      <td>1.453022</td>\n",
       "      <td>0.131818</td>\n",
       "      <td>0.259091</td>\n",
       "      <td>0.086364</td>\n",
       "      <td>78.241485</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.049411</td>\n",
       "      <td>0.331320</td>\n",
       "      <td>1.452421</td>\n",
       "      <td>0.140909</td>\n",
       "      <td>0.247727</td>\n",
       "      <td>0.095455</td>\n",
       "      <td>78.196484</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.028207</td>\n",
       "      <td>0.317414</td>\n",
       "      <td>1.452144</td>\n",
       "      <td>0.147727</td>\n",
       "      <td>0.260227</td>\n",
       "      <td>0.068182</td>\n",
       "      <td>78.195202</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.035784</td>\n",
       "      <td>0.313911</td>\n",
       "      <td>1.452820</td>\n",
       "      <td>0.138636</td>\n",
       "      <td>0.243182</td>\n",
       "      <td>0.063636</td>\n",
       "      <td>78.301108</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.028865</td>\n",
       "      <td>0.317924</td>\n",
       "      <td>1.453907</td>\n",
       "      <td>0.127273</td>\n",
       "      <td>0.251136</td>\n",
       "      <td>0.072727</td>\n",
       "      <td>78.243142</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">True</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">False</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">False</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">False</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">ranking</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">regression</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">None</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">[32, 32]</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">0.01</th>\n",
       "      <th>0.0</th>\n",
       "      <td>0.055819</td>\n",
       "      <td>0.315527</td>\n",
       "      <td>1.614736</td>\n",
       "      <td>0.136364</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>0.063636</td>\n",
       "      <td>71.872001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.035180</td>\n",
       "      <td>0.307530</td>\n",
       "      <td>1.529296</td>\n",
       "      <td>0.131818</td>\n",
       "      <td>0.197727</td>\n",
       "      <td>0.077273</td>\n",
       "      <td>68.845910</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.045626</td>\n",
       "      <td>0.316765</td>\n",
       "      <td>1.554837</td>\n",
       "      <td>0.127273</td>\n",
       "      <td>0.207955</td>\n",
       "      <td>0.104545</td>\n",
       "      <td>75.792085</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.043655</td>\n",
       "      <td>0.296455</td>\n",
       "      <td>1.598574</td>\n",
       "      <td>0.113636</td>\n",
       "      <td>0.192045</td>\n",
       "      <td>0.040909</td>\n",
       "      <td>74.641335</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.075376</td>\n",
       "      <td>0.322091</td>\n",
       "      <td>1.611170</td>\n",
       "      <td>0.154545</td>\n",
       "      <td>0.237500</td>\n",
       "      <td>0.063636</td>\n",
       "      <td>71.106983</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                                                                                                                                                                             mrr  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "mlp       False                                   False                                      False                            False                   ranking           regression    None         [32, 32]               0.01             0.0          0.050717   \n",
       "                                                                                                                                                                                                                                           0.0          0.049411   \n",
       "                                                                                                                                                                                                                                           0.0          0.028207   \n",
       "                                                                                                                                                                                                                                           0.0          0.035784   \n",
       "                                                                                                                                                                                                                                           0.0          0.028865   \n",
       "          True                                    False                                      False                            False                   ranking           regression    None         [32, 32]               0.01             0.0          0.055819   \n",
       "                                                                                                                                                                                                                                           0.0          0.035180   \n",
       "                                                                                                                                                                                                                                           0.0          0.045626   \n",
       "                                                                                                                                                                                                                                           0.0          0.043655   \n",
       "                                                                                                                                                                                                                                           0.0          0.075376   \n",
       "\n",
       "                                                                                                                                                                                                                                                            ndcg  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "mlp       False                                   False                                      False                            False                   ranking           regression    None         [32, 32]               0.01             0.0          0.322365   \n",
       "                                                                                                                                                                                                                                           0.0          0.331320   \n",
       "                                                                                                                                                                                                                                           0.0          0.317414   \n",
       "                                                                                                                                                                                                                                           0.0          0.313911   \n",
       "                                                                                                                                                                                                                                           0.0          0.317924   \n",
       "          True                                    False                                      False                            False                   ranking           regression    None         [32, 32]               0.01             0.0          0.315527   \n",
       "                                                                                                                                                                                                                                           0.0          0.307530   \n",
       "                                                                                                                                                                                                                                           0.0          0.316765   \n",
       "                                                                                                                                                                                                                                           0.0          0.296455   \n",
       "                                                                                                                                                                                                                                           0.0          0.322091   \n",
       "\n",
       "                                                                                                                                                                                                                                                           nrmse  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "mlp       False                                   False                                      False                            False                   ranking           regression    None         [32, 32]               0.01             0.0          1.453022   \n",
       "                                                                                                                                                                                                                                           0.0          1.452421   \n",
       "                                                                                                                                                                                                                                           0.0          1.452144   \n",
       "                                                                                                                                                                                                                                           0.0          1.452820   \n",
       "                                                                                                                                                                                                                                           0.0          1.453907   \n",
       "          True                                    False                                      False                            False                   ranking           regression    None         [32, 32]               0.01             0.0          1.614736   \n",
       "                                                                                                                                                                                                                                           0.0          1.529296   \n",
       "                                                                                                                                                                                                                                           0.0          1.554837   \n",
       "                                                                                                                                                                                                                                           0.0          1.598574   \n",
       "                                                                                                                                                                                                                                           0.0          1.611170   \n",
       "\n",
       "                                                                                                                                                                                                                                                        precision_10  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                 \n",
       "mlp       False                                   False                                      False                            False                   ranking           regression    None         [32, 32]               0.01             0.0              0.131818   \n",
       "                                                                                                                                                                                                                                           0.0              0.140909   \n",
       "                                                                                                                                                                                                                                           0.0              0.147727   \n",
       "                                                                                                                                                                                                                                           0.0              0.138636   \n",
       "                                                                                                                                                                                                                                           0.0              0.127273   \n",
       "          True                                    False                                      False                            False                   ranking           regression    None         [32, 32]               0.01             0.0              0.136364   \n",
       "                                                                                                                                                                                                                                           0.0              0.131818   \n",
       "                                                                                                                                                                                                                                           0.0              0.127273   \n",
       "                                                                                                                                                                                                                                           0.0              0.113636   \n",
       "                                                                                                                                                                                                                                           0.0              0.154545   \n",
       "\n",
       "                                                                                                                                                                                                                                                        precision_20  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                 \n",
       "mlp       False                                   False                                      False                            False                   ranking           regression    None         [32, 32]               0.01             0.0              0.259091   \n",
       "                                                                                                                                                                                                                                           0.0              0.247727   \n",
       "                                                                                                                                                                                                                                           0.0              0.260227   \n",
       "                                                                                                                                                                                                                                           0.0              0.243182   \n",
       "                                                                                                                                                                                                                                           0.0              0.251136   \n",
       "          True                                    False                                      False                            False                   ranking           regression    None         [32, 32]               0.01             0.0              0.200000   \n",
       "                                                                                                                                                                                                                                           0.0              0.197727   \n",
       "                                                                                                                                                                                                                                           0.0              0.207955   \n",
       "                                                                                                                                                                                                                                           0.0              0.192045   \n",
       "                                                                                                                                                                                                                                           0.0              0.237500   \n",
       "\n",
       "                                                                                                                                                                                                                                                        precision_5  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                \n",
       "mlp       False                                   False                                      False                            False                   ranking           regression    None         [32, 32]               0.01             0.0             0.086364   \n",
       "                                                                                                                                                                                                                                           0.0             0.095455   \n",
       "                                                                                                                                                                                                                                           0.0             0.068182   \n",
       "                                                                                                                                                                                                                                           0.0             0.063636   \n",
       "                                                                                                                                                                                                                                           0.0             0.072727   \n",
       "          True                                    False                                      False                            False                   ranking           regression    None         [32, 32]               0.01             0.0             0.063636   \n",
       "                                                                                                                                                                                                                                           0.0             0.077273   \n",
       "                                                                                                                                                                                                                                           0.0             0.104545   \n",
       "                                                                                                                                                                                                                                           0.0             0.040909   \n",
       "                                                                                                                                                                                                                                           0.0             0.063636   \n",
       "\n",
       "                                                                                                                                                                                                                                                            smape  \n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "mlp       False                                   False                                      False                            False                   ranking           regression    None         [32, 32]               0.01             0.0          78.241485  \n",
       "                                                                                                                                                                                                                                           0.0          78.196484  \n",
       "                                                                                                                                                                                                                                           0.0          78.195202  \n",
       "                                                                                                                                                                                                                                           0.0          78.301108  \n",
       "                                                                                                                                                                                                                                           0.0          78.243142  \n",
       "          True                                    False                                      False                            False                   ranking           regression    None         [32, 32]               0.01             0.0          71.872001  \n",
       "                                                                                                                                                                                                                                           0.0          68.845910  \n",
       "                                                                                                                                                                                                                                           0.0          75.792085  \n",
       "                                                                                                                                                                                                                                           0.0          74.641335  \n",
       "                                                                                                                                                                                                                                           0.0          71.106983  "
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = df.query(\"surrogate == 'mlp' and mlp_objective == 'regression'\")\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7490a3cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "best = results.query(\n",
    "    \"mlp_hidden_layer_sizes == '[32, 32]' and mlp_weight_decay == '0.01' and \"\n",
    "    \"input_flags_use_simple_dataset_features == 'False' and input_flags_use_catch22_features == 'False'\"\n",
    ")\n",
    "best_results[\"MLP Regression\"] = best.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "641f585b",
   "metadata": {},
   "source": [
    "### MLP Ranking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "d43dd004",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>mrr</th>\n",
       "      <th>ndcg</th>\n",
       "      <th>nrmse</th>\n",
       "      <th>precision_10</th>\n",
       "      <th>precision_20</th>\n",
       "      <th>precision_5</th>\n",
       "      <th>smape</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>surrogate</th>\n",
       "      <th>input_flags_use_simple_dataset_features</th>\n",
       "      <th>input_flags_use_seasonal_naive_performance</th>\n",
       "      <th>input_flags_use_catch22_features</th>\n",
       "      <th>nonparametric_use_ranks</th>\n",
       "      <th>xgboost_objective</th>\n",
       "      <th>mlp_objective</th>\n",
       "      <th>mlp_discount</th>\n",
       "      <th>mlp_hidden_layer_sizes</th>\n",
       "      <th>mlp_weight_decay</th>\n",
       "      <th>mlp_dropout</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"15\" valign=\"top\">mlp</th>\n",
       "      <th rowspan=\"10\" valign=\"top\">False</th>\n",
       "      <th rowspan=\"10\" valign=\"top\">False</th>\n",
       "      <th rowspan=\"10\" valign=\"top\">False</th>\n",
       "      <th rowspan=\"10\" valign=\"top\">False</th>\n",
       "      <th rowspan=\"10\" valign=\"top\">ranking</th>\n",
       "      <th rowspan=\"10\" valign=\"top\">ranking</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">None</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">[32, 32]</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">0.01</th>\n",
       "      <th>0.0</th>\n",
       "      <td>0.049984</td>\n",
       "      <td>0.345304</td>\n",
       "      <td>5.650850</td>\n",
       "      <td>0.206818</td>\n",
       "      <td>0.329545</td>\n",
       "      <td>0.104545</td>\n",
       "      <td>153.221199</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.049984</td>\n",
       "      <td>0.344603</td>\n",
       "      <td>7.056557</td>\n",
       "      <td>0.209091</td>\n",
       "      <td>0.330682</td>\n",
       "      <td>0.104545</td>\n",
       "      <td>141.993855</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.049984</td>\n",
       "      <td>0.345157</td>\n",
       "      <td>5.676417</td>\n",
       "      <td>0.211364</td>\n",
       "      <td>0.335227</td>\n",
       "      <td>0.095455</td>\n",
       "      <td>145.619724</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.049864</td>\n",
       "      <td>0.345775</td>\n",
       "      <td>14.896397</td>\n",
       "      <td>0.211364</td>\n",
       "      <td>0.337500</td>\n",
       "      <td>0.104545</td>\n",
       "      <td>176.273566</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.049977</td>\n",
       "      <td>0.348634</td>\n",
       "      <td>5.726610</td>\n",
       "      <td>0.211364</td>\n",
       "      <td>0.334091</td>\n",
       "      <td>0.104545</td>\n",
       "      <td>148.668324</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">linear</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">[32, 32]</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">0.01</th>\n",
       "      <th>0.0</th>\n",
       "      <td>0.164067</td>\n",
       "      <td>0.436737</td>\n",
       "      <td>12.183844</td>\n",
       "      <td>0.284091</td>\n",
       "      <td>0.337500</td>\n",
       "      <td>0.227273</td>\n",
       "      <td>152.031473</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.164067</td>\n",
       "      <td>0.437882</td>\n",
       "      <td>11.652070</td>\n",
       "      <td>0.281818</td>\n",
       "      <td>0.351136</td>\n",
       "      <td>0.240909</td>\n",
       "      <td>150.092414</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.164067</td>\n",
       "      <td>0.438864</td>\n",
       "      <td>10.906156</td>\n",
       "      <td>0.284091</td>\n",
       "      <td>0.347727</td>\n",
       "      <td>0.240909</td>\n",
       "      <td>149.775588</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.164067</td>\n",
       "      <td>0.436906</td>\n",
       "      <td>9.651461</td>\n",
       "      <td>0.281818</td>\n",
       "      <td>0.345455</td>\n",
       "      <td>0.236364</td>\n",
       "      <td>147.468341</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.164067</td>\n",
       "      <td>0.436518</td>\n",
       "      <td>10.088156</td>\n",
       "      <td>0.284091</td>\n",
       "      <td>0.337500</td>\n",
       "      <td>0.245455</td>\n",
       "      <td>149.524794</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">True</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">False</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">False</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">False</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">ranking</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">ranking</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">linear</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">[32, 32]</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">0.01</th>\n",
       "      <th>0.0</th>\n",
       "      <td>0.133345</td>\n",
       "      <td>0.418552</td>\n",
       "      <td>27.519127</td>\n",
       "      <td>0.220455</td>\n",
       "      <td>0.264773</td>\n",
       "      <td>0.195455</td>\n",
       "      <td>172.977323</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.110351</td>\n",
       "      <td>0.389889</td>\n",
       "      <td>19.749552</td>\n",
       "      <td>0.190909</td>\n",
       "      <td>0.253409</td>\n",
       "      <td>0.168182</td>\n",
       "      <td>168.360391</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.129729</td>\n",
       "      <td>0.393172</td>\n",
       "      <td>21.078210</td>\n",
       "      <td>0.206818</td>\n",
       "      <td>0.268182</td>\n",
       "      <td>0.159091</td>\n",
       "      <td>170.430236</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.102149</td>\n",
       "      <td>0.397734</td>\n",
       "      <td>19.361855</td>\n",
       "      <td>0.231818</td>\n",
       "      <td>0.280682</td>\n",
       "      <td>0.168182</td>\n",
       "      <td>166.407624</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.094976</td>\n",
       "      <td>0.399500</td>\n",
       "      <td>27.932787</td>\n",
       "      <td>0.195455</td>\n",
       "      <td>0.282955</td>\n",
       "      <td>0.186364</td>\n",
       "      <td>173.667545</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                                                                                                                                                                             mrr  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "mlp       False                                   False                                      False                            False                   ranking           ranking       None         [32, 32]               0.01             0.0          0.049984   \n",
       "                                                                                                                                                                                                                                           0.0          0.049984   \n",
       "                                                                                                                                                                                                                                           0.0          0.049984   \n",
       "                                                                                                                                                                                                                                           0.0          0.049864   \n",
       "                                                                                                                                                                                                                                           0.0          0.049977   \n",
       "                                                                                                                                                                                      linear       [32, 32]               0.01             0.0          0.164067   \n",
       "                                                                                                                                                                                                                                           0.0          0.164067   \n",
       "                                                                                                                                                                                                                                           0.0          0.164067   \n",
       "                                                                                                                                                                                                                                           0.0          0.164067   \n",
       "                                                                                                                                                                                                                                           0.0          0.164067   \n",
       "          True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          0.133345   \n",
       "                                                                                                                                                                                                                                           0.0          0.110351   \n",
       "                                                                                                                                                                                                                                           0.0          0.129729   \n",
       "                                                                                                                                                                                                                                           0.0          0.102149   \n",
       "                                                                                                                                                                                                                                           0.0          0.094976   \n",
       "\n",
       "                                                                                                                                                                                                                                                            ndcg  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout             \n",
       "mlp       False                                   False                                      False                            False                   ranking           ranking       None         [32, 32]               0.01             0.0          0.345304   \n",
       "                                                                                                                                                                                                                                           0.0          0.344603   \n",
       "                                                                                                                                                                                                                                           0.0          0.345157   \n",
       "                                                                                                                                                                                                                                           0.0          0.345775   \n",
       "                                                                                                                                                                                                                                           0.0          0.348634   \n",
       "                                                                                                                                                                                      linear       [32, 32]               0.01             0.0          0.436737   \n",
       "                                                                                                                                                                                                                                           0.0          0.437882   \n",
       "                                                                                                                                                                                                                                           0.0          0.438864   \n",
       "                                                                                                                                                                                                                                           0.0          0.436906   \n",
       "                                                                                                                                                                                                                                           0.0          0.436518   \n",
       "          True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          0.418552   \n",
       "                                                                                                                                                                                                                                           0.0          0.389889   \n",
       "                                                                                                                                                                                                                                           0.0          0.393172   \n",
       "                                                                                                                                                                                                                                           0.0          0.397734   \n",
       "                                                                                                                                                                                                                                           0.0          0.399500   \n",
       "\n",
       "                                                                                                                                                                                                                                                            nrmse  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout              \n",
       "mlp       False                                   False                                      False                            False                   ranking           ranking       None         [32, 32]               0.01             0.0           5.650850   \n",
       "                                                                                                                                                                                                                                           0.0           7.056557   \n",
       "                                                                                                                                                                                                                                           0.0           5.676417   \n",
       "                                                                                                                                                                                                                                           0.0          14.896397   \n",
       "                                                                                                                                                                                                                                           0.0           5.726610   \n",
       "                                                                                                                                                                                      linear       [32, 32]               0.01             0.0          12.183844   \n",
       "                                                                                                                                                                                                                                           0.0          11.652070   \n",
       "                                                                                                                                                                                                                                           0.0          10.906156   \n",
       "                                                                                                                                                                                                                                           0.0           9.651461   \n",
       "                                                                                                                                                                                                                                           0.0          10.088156   \n",
       "          True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          27.519127   \n",
       "                                                                                                                                                                                                                                           0.0          19.749552   \n",
       "                                                                                                                                                                                                                                           0.0          21.078210   \n",
       "                                                                                                                                                                                                                                           0.0          19.361855   \n",
       "                                                                                                                                                                                                                                           0.0          27.932787   \n",
       "\n",
       "                                                                                                                                                                                                                                                        precision_10  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                 \n",
       "mlp       False                                   False                                      False                            False                   ranking           ranking       None         [32, 32]               0.01             0.0              0.206818   \n",
       "                                                                                                                                                                                                                                           0.0              0.209091   \n",
       "                                                                                                                                                                                                                                           0.0              0.211364   \n",
       "                                                                                                                                                                                                                                           0.0              0.211364   \n",
       "                                                                                                                                                                                                                                           0.0              0.211364   \n",
       "                                                                                                                                                                                      linear       [32, 32]               0.01             0.0              0.284091   \n",
       "                                                                                                                                                                                                                                           0.0              0.281818   \n",
       "                                                                                                                                                                                                                                           0.0              0.284091   \n",
       "                                                                                                                                                                                                                                           0.0              0.281818   \n",
       "                                                                                                                                                                                                                                           0.0              0.284091   \n",
       "          True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0              0.220455   \n",
       "                                                                                                                                                                                                                                           0.0              0.190909   \n",
       "                                                                                                                                                                                                                                           0.0              0.206818   \n",
       "                                                                                                                                                                                                                                           0.0              0.231818   \n",
       "                                                                                                                                                                                                                                           0.0              0.195455   \n",
       "\n",
       "                                                                                                                                                                                                                                                        precision_20  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                 \n",
       "mlp       False                                   False                                      False                            False                   ranking           ranking       None         [32, 32]               0.01             0.0              0.329545   \n",
       "                                                                                                                                                                                                                                           0.0              0.330682   \n",
       "                                                                                                                                                                                                                                           0.0              0.335227   \n",
       "                                                                                                                                                                                                                                           0.0              0.337500   \n",
       "                                                                                                                                                                                                                                           0.0              0.334091   \n",
       "                                                                                                                                                                                      linear       [32, 32]               0.01             0.0              0.337500   \n",
       "                                                                                                                                                                                                                                           0.0              0.351136   \n",
       "                                                                                                                                                                                                                                           0.0              0.347727   \n",
       "                                                                                                                                                                                                                                           0.0              0.345455   \n",
       "                                                                                                                                                                                                                                           0.0              0.337500   \n",
       "          True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0              0.264773   \n",
       "                                                                                                                                                                                                                                           0.0              0.253409   \n",
       "                                                                                                                                                                                                                                           0.0              0.268182   \n",
       "                                                                                                                                                                                                                                           0.0              0.280682   \n",
       "                                                                                                                                                                                                                                           0.0              0.282955   \n",
       "\n",
       "                                                                                                                                                                                                                                                        precision_5  \\\n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout                \n",
       "mlp       False                                   False                                      False                            False                   ranking           ranking       None         [32, 32]               0.01             0.0             0.104545   \n",
       "                                                                                                                                                                                                                                           0.0             0.104545   \n",
       "                                                                                                                                                                                                                                           0.0             0.095455   \n",
       "                                                                                                                                                                                                                                           0.0             0.104545   \n",
       "                                                                                                                                                                                                                                           0.0             0.104545   \n",
       "                                                                                                                                                                                      linear       [32, 32]               0.01             0.0             0.227273   \n",
       "                                                                                                                                                                                                                                           0.0             0.240909   \n",
       "                                                                                                                                                                                                                                           0.0             0.240909   \n",
       "                                                                                                                                                                                                                                           0.0             0.236364   \n",
       "                                                                                                                                                                                                                                           0.0             0.245455   \n",
       "          True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0             0.195455   \n",
       "                                                                                                                                                                                                                                           0.0             0.168182   \n",
       "                                                                                                                                                                                                                                           0.0             0.159091   \n",
       "                                                                                                                                                                                                                                           0.0             0.168182   \n",
       "                                                                                                                                                                                                                                           0.0             0.186364   \n",
       "\n",
       "                                                                                                                                                                                                                                                             smape  \n",
       "surrogate input_flags_use_simple_dataset_features input_flags_use_seasonal_naive_performance input_flags_use_catch22_features nonparametric_use_ranks xgboost_objective mlp_objective mlp_discount mlp_hidden_layer_sizes mlp_weight_decay mlp_dropout              \n",
       "mlp       False                                   False                                      False                            False                   ranking           ranking       None         [32, 32]               0.01             0.0          153.221199  \n",
       "                                                                                                                                                                                                                                           0.0          141.993855  \n",
       "                                                                                                                                                                                                                                           0.0          145.619724  \n",
       "                                                                                                                                                                                                                                           0.0          176.273566  \n",
       "                                                                                                                                                                                                                                           0.0          148.668324  \n",
       "                                                                                                                                                                                      linear       [32, 32]               0.01             0.0          152.031473  \n",
       "                                                                                                                                                                                                                                           0.0          150.092414  \n",
       "                                                                                                                                                                                                                                           0.0          149.775588  \n",
       "                                                                                                                                                                                                                                           0.0          147.468341  \n",
       "                                                                                                                                                                                                                                           0.0          149.524794  \n",
       "          True                                    False                                      False                            False                   ranking           ranking       linear       [32, 32]               0.01             0.0          172.977323  \n",
       "                                                                                                                                                                                                                                           0.0          168.360391  \n",
       "                                                                                                                                                                                                                                           0.0          170.430236  \n",
       "                                                                                                                                                                                                                                           0.0          166.407624  \n",
       "                                                                                                                                                                                                                                           0.0          173.667545  "
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = df.query(\"surrogate == 'mlp' and mlp_objective == 'ranking'\")\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "faabe59e",
   "metadata": {},
   "outputs": [],
   "source": [
    "best = results.query(\n",
    "    \"mlp_discount == 'None' and \"\n",
    "    \"mlp_hidden_layer_sizes == '[32, 32]' and mlp_weight_decay == '0.01' and \"\n",
    "    \"input_flags_use_simple_dataset_features == 'False' and input_flags_use_catch22_features == 'False'\"\n",
    ")\n",
    "best_results[\"MLP Ranking\"] = best.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "410b12d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "best = results.query(\n",
    "    \"mlp_discount == 'linear' and \"\n",
    "    \"mlp_hidden_layer_sizes == '[32, 32]' and mlp_weight_decay == '0.01' and \"\n",
    "    \"input_flags_use_simple_dataset_features == 'False' and input_flags_use_catch22_features == 'False'\"\n",
    ")\n",
    "best_results[\"MLP Ranking + Discounting\"] = best.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "908fea79",
   "metadata": {},
   "source": [
    "## LaTeX Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "456edf19",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(best_results.values(), index=best_results.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "46af0e05",
   "metadata": {},
   "outputs": [],
   "source": [
    "def formatter(values: pd.Series, value: float) -> str:\n",
    "    if values.max() == value:\n",
    "        return f\"\\\\textbf{{{value*100:.2f}}}\"\n",
    "    return f\"{value*100:.2f}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "5c81eb3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = df.to_latex(\n",
    "    columns=[\"mrr\", \"ndcg\", \"precision_5\", \"precision_10\", \"precision_20\"],\n",
    "    header=[\"MRR\", \"NDCG\", \"Precision@5\", \"Precision@10\", \"Precision@20\"],\n",
    "    formatters={\n",
    "        \"mrr\": partial(formatter, df.mrr),\n",
    "        \"ndcg\": partial(formatter, df.ndcg),\n",
    "        \"precision_5\": partial(formatter, df.precision_5),\n",
    "        \"precision_10\": partial(formatter, df.precision_10),\n",
    "        \"precision_20\": partial(formatter, df.precision_20),\n",
    "    },\n",
    "    bold_rows=True,\n",
    "    escape=False,\n",
    "    column_format=\"lccccc\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "0f8446d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = result.split(\"\\n\")\n",
    "result.insert(5, \"\\\\midrule\")\n",
    "result.insert(8, \"\\\\midrule\")\n",
    "result.insert(11, \"\\\\midrule\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "856c0c7f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lccccc}\n",
      "\\toprule\n",
      "{} &            MRR &           NDCG &    Precision@5 &   Precision@10 &   Precision@20 \\\\\n",
      "\\midrule\n",
      "\\textbf{Random                   } &           2.17 &          25.36 &           2.45 &           4.68 &           8.27 \\\\\n",
      "\\midrule\n",
      "\\textbf{Nonparametric Regression } &           6.88 &          35.80 &          13.18 &          20.91 &          32.73 \\\\\n",
      "\\textbf{Nonparametric Ranking    } &           6.78 &          35.87 &          10.00 &          20.91 &          33.07 \\\\\n",
      "\\midrule\n",
      "\\textbf{XGBoost Regression       } &           5.68 &          35.34 &          13.18 &          22.05 &          32.61 \\\\\n",
      "\\textbf{XGBoost Ranking          } &           6.64 &          35.63 &          12.27 &          21.59 &          32.84 \\\\\n",
      "\\midrule\n",
      "\\textbf{MLP Regression           } &           3.86 &          32.06 &           7.73 &          13.73 &          25.23 \\\\\n",
      "\\textbf{MLP Ranking              } &           5.00 &          34.59 &          10.27 &          21.00 &          33.34 \\\\\n",
      "\\textbf{MLP Ranking + Discounting} & \\textbf{16.41} & \\textbf{43.74} & \\textbf{23.82} & \\textbf{28.32} & \\textbf{34.39} \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(\"\\n\".join(result))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33f8caca",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
