{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "03e626cc",
   "metadata": {},
   "source": [
    "# Table: Model Ranks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "774e8a91",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "import math\n",
    "import pandas as pd\n",
    "from tsbench.config import MODEL_REGISTRY\n",
    "from tsbench.evaluation import ConfigEvaluator\n",
    "from tsbench.evaluation.tracking import SacredMongoClient\n",
    "from tsbench.experiments.tracking import Tracker\n",
    "from tsbench.evaluation.utils import compute_ranks\n",
    "from tsbench.utils import float_formatter\n",
    "\n",
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b32d94c",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2ef557b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "tracker = Tracker.for_experiment(\"ts-bench\")\n",
    "ensemble_client = SacredMongoClient(\"eval-ensembles-iclr-05-10-21\")\n",
    "recommender_client = SacredMongoClient(\"eval-recommenders-iclr-05-10-21\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a3a4f306",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluator = ConfigEvaluator(tracker)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3f8bd6b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "display_names = {\n",
    "    \"arima\": \"ARIMA\",\n",
    "    \"ets\": \"ETS\",\n",
    "    \"npts\": \"NPTS\",\n",
    "    \"prophet\": \"Prophet\",\n",
    "    \"seasonal_naive\": \"Seasonal Naïve\",\n",
    "    \"stlar\": \"STL-AR\",\n",
    "    \"theta\": \"Theta\",\n",
    "    \"deepar\": \"DeepAR\",\n",
    "    \"mqcnn\": \"MQ-CNN\",\n",
    "    \"mqrnn\": \"MQ-RNN\",\n",
    "    \"nbeats\": \"N-BEATS\",\n",
    "    \"simple_feedforward\": \"Simple Feedforward\",\n",
    "    \"tft\": \"TFT\",\n",
    "}\n",
    "deep = {\n",
    "    \"deepar\", \"mqcnn\", \"mqrnn\", \"nbeats\", \"simple_feedforward\", \"tft\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "f065cde7",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = {\n",
    "    \"Latency Median (in ms)\": \"latency_median_values\",\n",
    "    \"nCRPS Rank Avg.\": \"mean_weighted_quantile_loss_mean\",\n",
    "    \"nCRPS Rank Std.\": \"ncrps_mean_std\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "f1ab9648",
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_ms(ms: float) -> str:\n",
    "    ms = ms * 1000\n",
    "    if math.isclose(ms % 1, 0):\n",
    "        return f\"{ms:.0f}\"\n",
    "    return f\"{ms:.1f}\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "896ab558",
   "metadata": {},
   "source": [
    "## Get Performances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "d0cffdf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "performances = {\n",
    "    **{display: evaluator.run(MODEL_REGISTRY[name]()) for name, display in display_names.items()},\n",
    "    **{\n",
    "        f\"{display} Hyper-Ensemble\": ensemble_client.query_one({\n",
    "            \"model_class\": name,\n",
    "            \"weighting\": \"uniform\",\n",
    "            \"size\": 10,\n",
    "            \"surrogate.name\": None,\n",
    "        }).read_parquet(\"results.parquet\") \n",
    "        for name, display in display_names.items()\n",
    "        if name in deep\n",
    "    },\n",
    "    **{\n",
    "        f\"Constrained Ensemble ({format_ms(latency)} ms)\": ensemble_client.query_one({\n",
    "            \"max_latency\": latency,\n",
    "            \"model_class\": None,\n",
    "            \"weighting\": \"uniform\",\n",
    "            \"size\": 10,\n",
    "            \"surrogate.name\": \"mlp\",\n",
    "            \"surrogate.input_flags.use_simple_dataset_features\": False,\n",
    "        }).read_parquet(\"results.parquet\")\n",
    "        for latency in [0.001, 0.005, 0.01, 0.05, 0.1]\n",
    "    },\n",
    "    \"Unconstrained Ensemble\": ensemble_client.query_one({\n",
    "        \"max_latency\": None,\n",
    "        \"model_class\": None,\n",
    "        \"size\": 10,\n",
    "        \"weighting\": \"uniform\",\n",
    "        \"surrogate.name\": \"mlp\",\n",
    "        \"surrogate.input_flags.use_simple_dataset_features\": False,\n",
    "    }).read_parquet(\"results.parquet\")\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dc459b2",
   "metadata": {},
   "source": [
    "## Compute Ranks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "00fc94f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "ranks = compute_ranks(performances)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "31f456ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = ranks.groupby(\"candidate\").mean().reindex(performances.keys()).assign(\n",
    "    ncrps_mean_std=ranks.groupby(\"candidate\").std().reindex(performances.keys()).mean_weighted_quantile_loss_mean,\n",
    "    latency_median_values=pd.DataFrame(\n",
    "        [v.latency_mean.median() * 1000 for k, v in performances.items()], index=performances.keys()\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "81ab9a20",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "index = []\n",
    "rows = []\n",
    "for i, (name, performance) in enumerate(results.iterrows()):\n",
    "    if \"Hyper-Ensemble\" in name:\n",
    "        continue\n",
    "    index.append(name)\n",
    "    if i < len(display_names) and [k for k, v in display_names.items() if v == name][0] in deep:\n",
    "        rows.append({\n",
    "            m: f\"{float_formatter(results[n], performance[n])} / {float_formatter(results[n], results.iloc[i+len(deep)][n])}\"\n",
    "            for m, n in metrics.items()\n",
    "        })\n",
    "    else:\n",
    "        rows.append({m: float_formatter(results[n], performance[n]) for m, n in metrics.items()})\n",
    "\n",
    "final_df = pd.DataFrame(rows, index=index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "7c76b659",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = final_df.to_latex(\n",
    "    index_names=False,\n",
    "    bold_rows=True,\n",
    "    escape=False,\n",
    "    column_format=\"lccc\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "d2ef6dd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = result.split(\"\\n\")\n",
    "result.insert(11, \"\\\\midrule\")\n",
    "result.insert(18, \"\\\\midrule\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "9c040c13",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lccc}\n",
      "\\toprule\n",
      "{} & Latency Median (in ms) & nCRPS Rank Avg. & nCRPS Rank Std. \\\\\n",
      "\\midrule\n",
      "\\textbf{ARIMA                        } &                 513.87 &           13.86 &            5.83 \\\\\n",
      "\\textbf{ETS                          } &                  80.34 &           15.45 &            6.87 \\\\\n",
      "\\textbf{NPTS                         } &                  69.98 &           17.82 &            7.35 \\\\\n",
      "\\textbf{Prophet                      } &                 602.25 &           17.82 &            5.52 \\\\\n",
      "\\textbf{Seasonal Naïve               } &          \\textbf{0.39} &           19.86 &            4.22 \\\\\n",
      "\\textbf{STL-AR                       } &                  31.99 &           15.57 &            6.73 \\\\\n",
      "\\textbf{Theta                        } &                   8.58 &           15.59 &            6.19 \\\\\n",
      "\\midrule\n",
      "\\textbf{DeepAR                       } &         11.51 / 161.51 &     9.91 / 6.73 &     5.79 / 5.24 \\\\\n",
      "\\textbf{MQ-CNN                       } &            0.78 / 7.88 &   15.86 / 13.07 &     5.75 / 5.98 \\\\\n",
      "\\textbf{MQ-RNN                       } &            0.91 / 3.06 &   22.07 / 21.02 &     4.30 / 4.92 \\\\\n",
      "\\textbf{N-BEATS                      } &           1.30 / 13.77 &   18.32 / 16.59 &     3.45 / 3.71 \\\\\n",
      "\\textbf{Simple Feedforward           } &            0.68 / 6.32 &    12.32 / 9.45 &     4.24 / 4.16 \\\\\n",
      "\\textbf{TFT                          } &           2.24 / 21.87 &     8.70 / 6.75 &     4.71 / 4.61 \\\\\n",
      "\\midrule\n",
      "\\textbf{Constrained Ensemble (1 ms)  } &                   0.93 &           13.77 &            5.01 \\\\\n",
      "\\textbf{Constrained Ensemble (5 ms)  } &                   4.90 &            9.45 &            5.20 \\\\\n",
      "\\textbf{Constrained Ensemble (10 ms) } &                   9.91 &            8.30 &            4.97 \\\\\n",
      "\\textbf{Constrained Ensemble (50 ms) } &                  49.66 &            6.57 &            5.04 \\\\\n",
      "\\textbf{Constrained Ensemble (100 ms)} &                  90.25 &            5.55 &            4.28 \\\\\n",
      "\\textbf{Unconstrained Ensemble       } &                 102.27 &   \\textbf{4.36} &   \\textbf{3.37} \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(\"\\n\".join(result))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d655b425",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
