{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load pandas and numpy\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import re\n",
    "import os\n",
    "import sys\n",
    "from typing import Sequence, Dict\n",
    "\n",
    "sys.path.append(os.path.abspath('./paper_results'))\n",
    "from parser import avg_results"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Table 1: accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = ['iglu', 'colas', 'dubosson', 'hall', 'weinstock'] # iglu is Broll in the paper, otherwise alphabetical order\n",
    "models = ['arima', 'linreg', 'xgboost', 'gluformer', 'latentode',  'nhits', 'tft', 'transformer']\n",
    "\n",
    "for model in models:\n",
    "    if model in ['arima', 'gluformer', 'latentode']: # no covariates\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = None\n",
    "        (dict_errors, _), (dict_errors_std, _) = avg_results(model_names, model_names_with_covs)\n",
    "    else:\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = [f'../output/{model}_covariates_{dataset}.txt' for dataset in datasets]\n",
    "        (dict_errors, _), (dict_errors_std, _) = avg_results(model_names, model_names_with_covs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "arima\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+11.59\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+12.01\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+2.02\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+1.49\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+38.56\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+31.81\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-4.79\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-5.08\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+18.47\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+18.56\\%}}\n",
      "linreg\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+2.51\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+1.29\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+1.27\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+1.38\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+30.01\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+19.41\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+6.46\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+4.58\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+14.5\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+14.22\\%}}\n",
      "xgboost\n",
      " \\multicolumn{2}{c}{\\textcolor{blue}{-23.72\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-24.22\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-3.76\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-3.17\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-17.01\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-19.22\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-1.14\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-0.55\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+12.69\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+12.32\\%}}\n",
      "gluformer\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+17.62\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+18.07\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-15.07\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-15.26\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+8.0\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+6.73\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+5.48\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+5.29\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+13.26\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+12.9\\%}}\n",
      "latentode\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+4.12\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+5.93\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-10.05\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-9.83\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-13.7\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-15.47\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+8.07\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+8.18\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+11.17\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+11.08\\%}}\n",
      "nhits\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+6.14\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+5.8\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-4.33\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-4.2\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+4.27\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+5.45\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+0.8\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+0.75\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+9.29\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+9.22\\%}}\n",
      "tft\n",
      " \\multicolumn{2}{c}{\\textcolor{blue}{-9.51\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-7.63\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-1.9\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-1.58\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-4.73\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-6.24\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+2.59\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+2.22\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+6.48\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+6.28\\%}}\n",
      "transformer\n",
      " \\multicolumn{2}{c}{\\textcolor{blue}{-7.16\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-6.96\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-7.78\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-7.23\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-5.52\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-7.52\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+3.69\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+4.15\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+7.0\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+6.16\\%}}\n"
     ]
    }
   ],
   "source": [
    "datasets = ['iglu', 'colas', 'dubosson', 'hall', 'weinstock'] # iglu is Broll in the paper, otherwise alphabetical order\n",
    "models = ['arima', 'linreg', 'xgboost', 'gluformer', 'latentode',  'nhits', 'tft', 'transformer']\n",
    "\n",
    "def color(x):\n",
    "    return r'\\multicolumn{2}{c}{\\textcolor{red}{+' + str(round(x, 2)) + '\\%}}' if x > 0 else r' \\multicolumn{2}{c}{\\textcolor{blue}{' + str(round(x, 2)) + '\\%}}'\n",
    "\n",
    "for model in models:\n",
    "    if model in ['arima', 'gluformer', 'latentode']: # no covariates\n",
    "        print(model)\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = None\n",
    "        (dict_errors, _), _ = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        diff_errors_no_covs = (dict_errors['ood']['no_covs'] - dict_errors['id']['no_covs']) / dict_errors['id']['no_covs'] * 100\n",
    "        print(' & '.join([color(x) for x in diff_errors_no_covs.reshape(-1).tolist()]))\n",
    "    \n",
    "    else:\n",
    "        print(model)\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = [f'../output/{model}_covariates_{dataset}.txt' for dataset in datasets]\n",
    "        (dict_errors, _), _ = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        diff_errors_no_covs = (dict_errors['ood']['no_covs'] - dict_errors['id']['no_covs']) / dict_errors['id']['no_covs'] * 100\n",
    "        print(' & '.join([color(x) for x in diff_errors_no_covs.reshape(-1).tolist()]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "linreg\n",
      "\\textcolor{blue}{-14.82\\%} & \\textcolor{blue}{-13.34\\%} & \\textcolor{red}{+5.54\\%} & \\textcolor{red}{+5.75\\%} & \\textcolor{red}{+2.84\\%} & \\textcolor{red}{+0.61\\%} & \\textcolor{red}{+6.17\\%} & \\textcolor{red}{+5.09\\%} & \\textcolor{blue}{-1.54\\%} & \\textcolor{blue}{-1.08\\%}\n",
      "xgboost\n",
      "\\textcolor{red}{+8.53\\%} & \\textcolor{red}{+3.22\\%} & \\textcolor{blue}{-0.72\\%} & \\textcolor{blue}{-0.5\\%} & \\textcolor{blue}{-1.36\\%} & \\textcolor{blue}{-2.8\\%} & \\textcolor{red}{+6.22\\%} & \\textcolor{red}{+7.11\\%} & \\textcolor{red}{+0.99\\%} & \\textcolor{red}{+1.36\\%}\n",
      "nhits\n",
      "\\textcolor{red}{+17.43\\%} & \\textcolor{red}{+21.29\\%} & \\textcolor{red}{+53.14\\%} & \\textcolor{red}{+59.48\\%} & \\textcolor{red}{+74.38\\%} & \\textcolor{red}{+89.1\\%} & \\textcolor{red}{+6.21\\%} & \\textcolor{red}{+8.05\\%} & \\textcolor{red}{+0.88\\%} & \\textcolor{red}{+0.93\\%}\n",
      "tft\n",
      "\\textcolor{red}{+6.86\\%} & \\textcolor{red}{+12.21\\%} & \\textcolor{red}{+15.86\\%} & \\textcolor{red}{+15.93\\%} & \\textcolor{red}{+0.35\\%} & \\textcolor{red}{+0.09\\%} & \\textcolor{red}{+6.32\\%} & \\textcolor{red}{+6.72\\%} & \\textcolor{red}{+4.52\\%} & \\textcolor{red}{+4.59\\%}\n",
      "transformer\n",
      "\\textcolor{blue}{-15.14\\%} & \\textcolor{blue}{-14.64\\%} & \\textcolor{red}{+30.31\\%} & \\textcolor{red}{+37.56\\%} & \\textcolor{red}{+64.99\\%} & \\textcolor{red}{+73.82\\%} & \\textcolor{blue}{-5.06\\%} & \\textcolor{blue}{-5.4\\%} & \\textcolor{red}{+9.33\\%} & \\textcolor{red}{+12.41\\%}\n"
     ]
    }
   ],
   "source": [
    "datasets = ['iglu', 'colas', 'dubosson', 'hall', 'weinstock'] # iglu is Broll in the paper, otherwise alphabetical order\n",
    "models = ['arima', 'linreg', 'xgboost', 'gluformer', 'latentode',  'nhits', 'tft', 'transformer']\n",
    "\n",
    "def color(x):\n",
    "    return r'\\textcolor{red}{+' + str(round(x, 2)) + '\\%}' if x > 0 else r'\\textcolor{blue}{' + str(round(x, 2)) + '\\%}'\n",
    "\n",
    "for model in models:\n",
    "    if model in ['arima', 'gluformer', 'latentode']: # no covariates\n",
    "        pass   \n",
    "    else:\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = [f'../output/{model}_covariates_{dataset}.txt' for dataset in datasets]\n",
    "        (dict_errors, _), _ = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(model)\n",
    "        diff_errors = (dict_errors['id']['covs'] - dict_errors['id']['no_covs']) / dict_errors['id']['no_covs'] * 100\n",
    "        diff_errors = diff_errors.tolist()\n",
    "        print(' & '.join([color(y) for x in diff_errors for y in x]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\multirow{2}{*}{\\rotatebox{90}{ARI}} & \\crossmark & \n",
      "ID & 10.53&8.67&5.80&4.80&13.53&11.06&8.63&7.34&13.40&11.25\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OOD & 11.75&9.71&5.91&4.87&18.75&14.58&8.22&6.97&15.87&13.34\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+11.8\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+1.75\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+35.18\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-4.94\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+18.51\\%}}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{LIN}} & \\crossmark & \n",
      "ID & 11.68&9.71&5.26&4.35&12.07&9.97&7.38&6.33&13.60&11.46\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 9.95&8.41&5.56&4.60&12.41&10.03&7.84&6.66&13.39&11.34\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      " \\multicolumn{2}{c}{\\textcolor{blue}{-14.08\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+5.65\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+1.73\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+5.63\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-1.31\\%}}\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 11.98&9.83&5.33&4.41&15.69&11.90&7.86&6.62&15.58&13.09\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 23.30&16.80&5.54&4.57&203114.47&67548.59&14.22&10.02&15.66&13.16\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+82.7\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+3.8\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+930929.99\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+66.07\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+0.54\\%}}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+1.9\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-0.45\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+24.71\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+5.52\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+14.36\\%}}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{XGB}} & \\crossmark & \n",
      "ID & 12.80&11.50&6.42&5.49&21.18&19.09&7.58&6.55&13.63&11.61\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 13.89&11.87&6.37&5.46&20.89&18.55&8.05&7.02&13.77&11.77\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+5.88\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-0.61\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-2.08\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+6.67\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+1.18\\%}}\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 9.76&8.72&6.18&5.32&17.57&15.42&7.49&6.52&15.36&13.04\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 9.67&8.56&6.36&5.47&17.44&15.46&8.20&7.11&15.55&13.43\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      " \\multicolumn{2}{c}{\\textcolor{blue}{-1.37\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+2.91\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-0.26\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+9.25\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+2.09\\%}}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      " \\multicolumn{2}{c}{\\textcolor{blue}{-29.14\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-3.47\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-18.11\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-0.84\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+12.51\\%}}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{2}{*}{\\rotatebox{90}{GLU}} & \\crossmark & \n",
      "ID & 14.19&12.55&8.17&7.12&21.74&19.40&7.74&6.69&14.07&12.09\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OOD & 16.70&14.82&6.94&6.03&23.48&20.70&8.17&7.04&15.94&13.65\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+17.85\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-15.17\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+7.37\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+5.39\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+13.08\\%}}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{2}{*}{\\rotatebox{90}{LAT}} & \\crossmark & \n",
      "ID & 14.37&12.32&6.28&5.37&20.14&17.88&7.13&6.11&13.54&11.45\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OOD & 14.96&13.05&5.64&4.84&17.38&15.12&7.71&6.61&15.06&12.72\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+5.03\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-9.94\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-14.59\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+8.12\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+11.12\\%}}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{NHI}} & \\crossmark & \n",
      "ID & 13.79&12.07&5.93&5.04&17.45&14.79&7.68&6.57&13.29&11.21\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 16.20&14.64&9.09&8.03&30.43&27.97&8.16&7.10&13.41&11.31\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+19.36\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+56.31\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+81.74\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+7.13\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+0.9\\%}}\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 14.64&12.77&5.68&4.83&18.20&15.59&7.74&6.62&14.52&12.24\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 15.66&14.01&7.56&6.65&37.35&33.52&8.59&7.53&14.40&12.12\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+8.35\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+35.49\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+110.09\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+12.34\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-0.91\\%}}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      " \\multicolumn{2}{c}{\\textcolor{blue}{-3.8\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-17.01\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+4.86\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+0.78\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+7.29\\%}}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{TFT}} & \\crossmark & \n",
      "ID & 13.73&11.07&5.62&4.54&18.37&15.49&7.92&6.61&14.32&11.76\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 14.68&12.43&6.51&5.27&18.43&15.51&8.42&7.06&14.97&12.30\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+9.53\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+15.9\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+0.22\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+6.52\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+4.55\\%}}\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 12.43&10.23&5.51&4.47&17.50&14.53&8.12&6.76&15.25&12.50\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 13.25&11.17&5.79&4.68&17.19&14.43&8.93&7.44&15.47&12.67\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\multicolumn{2}{c}{\\textcolor{red}{+7.91\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+4.84\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-1.22\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+10.01\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+1.41\\%}}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      " \\multicolumn{2}{c}{\\textcolor{blue}{-9.91\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-11.11\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-6.84\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+2.41\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+3.18\\%}}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{TRA}} & \\crossmark & \n",
      "ID & 15.12&13.20&6.47&5.65&16.62&14.04&7.89&6.78&13.22&11.22\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 12.83&11.27&8.44&7.77&27.43&24.40&7.49&6.42&14.46&12.61\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      " \\multicolumn{2}{c}{\\textcolor{blue}{-14.89\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+33.93\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+69.41\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-5.23\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+10.87\\%}}\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 14.04&12.28&5.97&5.24&15.71&12.98&8.18&7.07&14.15&11.91\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 13.76&12.13&7.26&6.59&34.11&28.21&7.40&6.29&15.59&13.58\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      " \\multicolumn{2}{c}{\\textcolor{blue}{-1.59\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+23.61\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+117.27\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-10.23\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+12.14\\%}}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      " \\multicolumn{2}{c}{\\textcolor{blue}{-7.06\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-14.62\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-6.52\\%}} &  \\multicolumn{2}{c}{\\textcolor{blue}{-1.57\\%}} & \\multicolumn{2}{c}{\\textcolor{red}{+6.58\\%}}\n",
      "\\\\\n",
      "\\midrule\n"
     ]
    }
   ],
   "source": [
    "datasets = ['iglu', 'colas', 'dubosson', 'hall', 'weinstock'] # iglu is Broll in the paper, otherwise alphabetical order\n",
    "models = ['arima', 'linreg', 'xgboost', 'gluformer', 'latentode',  'nhits', 'tft', 'transformer']\n",
    "\n",
    "def color(x):\n",
    "    return r'\\multicolumn{2}{c}{\\textcolor{red}{+' + str(round(x, 2)) + '\\%}}' if x > 0 else r' \\multicolumn{2}{c}{\\textcolor{blue}{' + str(round(x, 2)) + '\\%}}'\n",
    "\n",
    "for model in models:\n",
    "    if model in ['arima', 'gluformer', 'latentode']: # no covariates\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = None\n",
    "        (dict_errors, _), _ = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(r'\\multirow{2}{*}{\\rotatebox{90}{'+ model[:3].upper() + r'}} & \\crossmark & ')\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f}' for x in dict_errors['id']['no_covs'].reshape(-1).tolist()]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "\n",
    "        print('& \\crossmark & ')\n",
    "        print(r'OOD & ' + '&'.join([f'{x:.2f}' for x in dict_errors['ood']['no_covs'].reshape(-1).tolist()]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')\n",
    "        print(r'\\rowcolor{lightgray}')\n",
    "\n",
    "        print(r'\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&')\n",
    "        diff_errors = np.mean((dict_errors['ood']['no_covs'] - dict_errors['id']['no_covs']) / \n",
    "                              dict_errors['id']['no_covs'] * 100, \n",
    "                              axis=1)\n",
    "        print(' & '.join([color(x) for x in diff_errors.reshape(-1).tolist()]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')\n",
    "    \n",
    "    else:\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = [f'../output/{model}_covariates_{dataset}.txt' for dataset in datasets]\n",
    "        (dict_errors, _), _ = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(r'\\multirow{6}{*}{\\rotatebox{90}{'+ model[:3].upper() + r'}} & \\crossmark & ')\n",
    "\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f}' for x in dict_errors['id']['no_covs'].reshape(-1).tolist()]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        \n",
    "        print('& \\checkmark & ')\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f}' for x in dict_errors['id']['covs'].reshape(-1).tolist()]))\n",
    "        \n",
    "        print(r'\\\\')\n",
    "        print('\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}')\n",
    "        \n",
    "        print('& \\multicolumn{2}{c|}{Improv.} &')  \n",
    "        diff_errors = np.mean((dict_errors['id']['covs'] - dict_errors['id']['no_covs']) / \n",
    "                              dict_errors['id']['no_covs'] * 100,\n",
    "                                axis=1)\n",
    "        print(' & '.join([color(x) for x in diff_errors.reshape(-1).tolist()]))\n",
    "        \n",
    "        print(r'\\\\')\n",
    "        print('\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}')\n",
    "        \n",
    "        print('& \\crossmark &')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f}' for x in dict_errors['ood']['no_covs'].reshape(-1).tolist()]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        \n",
    "        print('& \\checkmark & ')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f}' for x in dict_errors['ood']['covs'].reshape(-1).tolist()]))\n",
    "        \n",
    "        print(r'\\\\')\n",
    "        print('\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}')\n",
    "        \n",
    "        print('& \\multicolumn{2}{c|}{Improv.} &')  \n",
    "        diff_errors = np.mean((dict_errors['ood']['covs'] - dict_errors['ood']['no_covs']) / \n",
    "                              dict_errors['ood']['no_covs'] * 100,\n",
    "                                axis=1)\n",
    "        print(' & '.join([color(x) for x in diff_errors.reshape(-1).tolist()]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')\n",
    "        print(r'\\rowcolor{lightgray}')\n",
    "\n",
    "        print(r'\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&')\n",
    "        diff_errors_no_covs = np.mean((dict_errors['ood']['no_covs'] - dict_errors['id']['no_covs']) / \n",
    "                                       dict_errors['id']['no_covs'] * 100,\n",
    "                                       axis=1)\n",
    "        diff_errors_covs = np.mean((dict_errors['ood']['covs'] - dict_errors['id']['covs']) /\n",
    "                                    dict_errors['id']['covs'] * 100,\n",
    "                                    axis=1)\n",
    "        diff_errors = np.minimum(diff_errors_no_covs, diff_errors_covs)\n",
    "        print(' & '.join([color(x) for x in diff_errors.reshape(-1).tolist()]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Table 2: probabilistic fit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "linreg\n",
      "\\textcolor{red}{-0.65\\%} & \\textcolor{red}{+24.98\\%} & \\textcolor{blue}{+0.33\\%} & \\textcolor{red}{+2.93\\%} & \\textcolor{red}{-0.02\\%} & \\textcolor{blue}{-7.62\\%} & \\textcolor{blue}{+0.33\\%} & \\textcolor{red}{+4.43\\%} & \\textcolor{red}{-0.85\\%} & \\textcolor{blue}{-3.1\\%}\n",
      "xgboost\n",
      "\\textcolor{red}{-0.88\\%} & \\textcolor{red}{+67.42\\%} & \\textcolor{blue}{+0.59\\%} & \\textcolor{blue}{-4.55\\%} & \\textcolor{blue}{+3.16\\%} & \\textcolor{red}{+5.02\\%} & \\textcolor{blue}{+1.16\\%} & \\textcolor{blue}{-14.37\\%} & \\textcolor{red}{-0.83\\%} & \\textcolor{red}{+2.45\\%}\n",
      "gluformer\n",
      "\\textcolor{blue}{+6.72\\%} & \\textcolor{red}{+106.76\\%} & \\textcolor{red}{-50.33\\%} & \\textcolor{blue}{-29.83\\%} & \\textcolor{blue}{+45.64\\%} & \\textcolor{red}{+83.63\\%} & \\textcolor{blue}{+7.69\\%} & \\textcolor{red}{+9.24\\%} & \\textcolor{blue}{+3.33\\%} & \\textcolor{red}{+6.23\\%}\n",
      "latentode\n",
      "\\textcolor{red}{-13.67\\%} & \\textcolor{red}{+6.5\\%} & \\textcolor{blue}{+15.89\\%} & \\textcolor{blue}{-4.01\\%} & \\textcolor{blue}{+42.14\\%} & \\textcolor{red}{+4.59\\%} & \\textcolor{blue}{+10.12\\%} & \\textcolor{red}{+20.58\\%} & \\textcolor{red}{-15.03\\%} & \\textcolor{red}{+20.72\\%}\n",
      "nhits\n",
      "\\textcolor{red}{-0.72\\%} & \\textcolor{blue}{-16.3\\%} & \\textcolor{blue}{+0.72\\%} & \\textcolor{red}{+1.17\\%} & \\textcolor{blue}{+1.76\\%} & \\textcolor{red}{+15.16\\%} & \\textcolor{blue}{+1.38\\%} & \\textcolor{blue}{-14.79\\%} & \\textcolor{red}{-0.7\\%} & \\textcolor{red}{+7.16\\%}\n",
      "tft\n",
      "\\textcolor{red}{nan\\%} & \\textcolor{blue}{-4.8\\%} & \\textcolor{red}{nan\\%} & \\textcolor{red}{+15.18\\%} & \\textcolor{red}{nan\\%} & \\textcolor{red}{+10.56\\%} & \\textcolor{red}{nan\\%} & \\textcolor{red}{+30.52\\%} & \\textcolor{red}{nan\\%} & \\textcolor{red}{+9.94\\%}\n",
      "transformer\n",
      "\\textcolor{blue}{+0.14\\%} & \\textcolor{blue}{-14.52\\%} & \\textcolor{blue}{+0.78\\%} & \\textcolor{red}{+3.16\\%} & \\textcolor{blue}{+2.59\\%} & \\textcolor{red}{+17.77\\%} & \\textcolor{blue}{+1.34\\%} & \\textcolor{red}{+14.08\\%} & \\textcolor{red}{-0.54\\%} & \\textcolor{red}{+8.99\\%}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_65873/1381547546.py:27: RuntimeWarning: invalid value encountered in divide\n",
      "  diff_errors_no_covs = (dict_errors['ood']['no_covs'] - dict_errors['id']['no_covs']) / np.abs(dict_errors['id']['no_covs']) * 100\n"
     ]
    }
   ],
   "source": [
    "datasets = ['iglu', 'colas', 'dubosson', 'hall', 'weinstock'] # iglu is Broll in the paper, otherwise alphabetical order\n",
    "models = ['linreg', 'xgboost', 'gluformer', 'latentode',  'nhits', 'tft', 'transformer']\n",
    "\n",
    "def color_min(x):\n",
    "    return r'\\textcolor{red}{+' + str(round(x, 2)) + '\\%}' if x > 0 else r'\\textcolor{blue}{' + str(round(x, 2)) + '\\%}'\n",
    "def color_max(x):\n",
    "    return r'\\textcolor{blue}{+' + str(round(x, 2)) + '\\%}' if x > 0 else r'\\textcolor{red}{' + str(round(x, 2)) + '\\%}'\n",
    "\n",
    "for model in models:\n",
    "    if model in ['arima', 'gluformer', 'latentode']: # no covariates\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = None\n",
    "        (_, dict_errors), _ = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(model)\n",
    "        diff_errors = (dict_errors['ood']['no_covs'] - dict_errors['id']['no_covs']) / np.abs(dict_errors['id']['no_covs']) * 100\n",
    "        string = [[color_max(x[0]), color_min(x[1])] for x in diff_errors.tolist()]\n",
    "        string = [item for sublist in string for item in sublist]\n",
    "        print(' & '.join(string))\n",
    "    \n",
    "    else:\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = [f'../output/{model}_covariates_{dataset}.txt' for dataset in datasets]\n",
    "        (_, dict_errors), _ = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(model)\n",
    "        diff_errors_no_covs = (dict_errors['ood']['no_covs'] - dict_errors['id']['no_covs']) / np.abs(dict_errors['id']['no_covs']) * 100\n",
    "        string = [[color_max(x[0]), color_min(x[1])] for x in diff_errors_no_covs.tolist()]\n",
    "        string = [item for sublist in string for item in sublist]\n",
    "        print(' & '.join(string))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "linreg\n",
      "\\textcolor{blue}{+0.15\\%} & \\textcolor{red}{+5.99\\%} & \\textcolor{blue}{+0.25\\%} & \\textcolor{red}{+24.47\\%} & \\textcolor{red}{-0.41\\%} & \\textcolor{red}{+11.39\\%} & \\textcolor{red}{-7.67\\%} & \\textcolor{red}{+97.36\\%} & \\textcolor{blue}{+0.13\\%} & \\textcolor{blue}{-1.67\\%}\n",
      "xgboost\n",
      "\\textcolor{red}{-1.22\\%} & \\textcolor{red}{+0.59\\%} & \\textcolor{blue}{+0.12\\%} & \\textcolor{blue}{-7.2\\%} & \\textcolor{blue}{+0.13\\%} & \\textcolor{blue}{-6.98\\%} & \\textcolor{red}{-0.31\\%} & \\textcolor{blue}{-1.3\\%} & \\textcolor{red}{-0.15\\%} & \\textcolor{blue}{-4.62\\%}\n",
      "nhits\n",
      "\\textcolor{red}{-3.63\\%} & \\textcolor{blue}{-37.79\\%} & \\textcolor{red}{-1.68\\%} & \\textcolor{red}{+91.12\\%} & \\textcolor{red}{-4.19\\%} & \\textcolor{blue}{-21.68\\%} & \\textcolor{red}{-0.07\\%} & \\textcolor{blue}{-24.77\\%} & \\textcolor{red}{-0.01\\%} & \\textcolor{blue}{-5.46\\%}\n",
      "tft\n",
      "\\textcolor{red}{nan\\%} & \\textcolor{red}{+94.6\\%} & \\textcolor{red}{nan\\%} & \\textcolor{red}{+114.61\\%} & \\textcolor{red}{nan\\%} & \\textcolor{red}{+7.57\\%} & \\textcolor{red}{nan\\%} & \\textcolor{red}{+16.84\\%} & \\textcolor{red}{nan\\%} & \\textcolor{blue}{-21.55\\%}\n",
      "transformer\n",
      "\\textcolor{red}{-1.21\\%} & \\textcolor{blue}{-6.84\\%} & \\textcolor{red}{-0.79\\%} & \\textcolor{red}{+45.69\\%} & \\textcolor{red}{-3.05\\%} & \\textcolor{red}{+48.29\\%} & \\textcolor{blue}{+0.05\\%} & \\textcolor{blue}{-27.4\\%} & \\textcolor{red}{-0.34\\%} & \\textcolor{red}{+0.16\\%}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_65873/1688311718.py:18: RuntimeWarning: invalid value encountered in divide\n",
      "  diff_errors = (dict_errors['id']['covs'] - dict_errors['id']['no_covs']) / np.abs(dict_errors['id']['no_covs']) * 100\n"
     ]
    }
   ],
   "source": [
    "datasets = ['iglu', 'colas', 'dubosson', 'hall', 'weinstock'] # iglu is Broll in the paper, otherwise alphabetical order\n",
    "models = ['linreg', 'xgboost', 'gluformer', 'latentode',  'nhits', 'tft', 'transformer']\n",
    "\n",
    "def color_min(x):\n",
    "    return r'\\textcolor{red}{+' + str(round(x, 2)) + '\\%}' if x > 0 else r'\\textcolor{blue}{' + str(round(x, 2)) + '\\%}'\n",
    "def color_max(x):\n",
    "    return r'\\textcolor{blue}{+' + str(round(x, 2)) + '\\%}' if x > 0 else r'\\textcolor{red}{' + str(round(x, 2)) + '\\%}'\n",
    "\n",
    "for model in models:\n",
    "    if model in ['arima', 'gluformer', 'latentode']: # no covariates\n",
    "        pass\n",
    "    else:\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = [f'../output/{model}_covariates_{dataset}.txt' for dataset in datasets]\n",
    "        (_, dict_errors), _ = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(model)\n",
    "        diff_errors = (dict_errors['id']['covs'] - dict_errors['id']['no_covs']) / np.abs(dict_errors['id']['no_covs']) * 100\n",
    "        string = [[color_max(x[0]), color_min(x[1])] for x in diff_errors.tolist()]\n",
    "        string = [item for sublist in string for item in sublist]\n",
    "        print(' & '.join(string))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\multirow{6}{*}{\\rotatebox{90}{LIN}} & \\crossmark & \n",
      "ID & -9.89&0.12&-9.19&0.15&-10.10&0.18&-9.56&0.10&-10.14&0.11\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & -9.87&0.13&-9.17&0.19&-10.15&0.21&-10.30&0.19&-10.12&0.11\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\textcolor{blue}{+0.15\\%} & \\textcolor{red}{+5.99\\%} & \\textcolor{blue}{+0.25\\%} & \\textcolor{red}{+24.47\\%} & \\textcolor{red}{-0.41\\%} & \\textcolor{red}{+11.39\\%} & \\textcolor{red}{-7.67\\%} & \\textcolor{red}{+97.36\\%} & \\textcolor{blue}{+0.13\\%} & \\textcolor{blue}{-1.67\\%}\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & -9.95&0.15&-9.16&0.15&-10.11&0.17&-9.53&0.10&-10.22&0.11\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & -10.24&0.55&-9.16&0.17&-12.08&0.48&-10.42&0.23&-11.13&0.21\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\textcolor{red}{-2.88\\%} & \\textcolor{red}{+256.79\\%} & \\textcolor{blue}{+0.01\\%} & \\textcolor{red}{+10.19\\%} & \\textcolor{red}{-19.52\\%} & \\textcolor{red}{+181.06\\%} & \\textcolor{red}{-9.26\\%} & \\textcolor{red}{+130.96\\%} & \\textcolor{red}{-8.92\\%} & \\textcolor{red}{+87.0\\%}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      "\\textcolor{red}{-0.65\\%} & \\textcolor{red}{+24.98\\%} & \\textcolor{blue}{+0.33\\%} & \\textcolor{blue}{-8.88\\%} & \\textcolor{red}{-0.02\\%} & \\textcolor{blue}{-7.62\\%} & \\textcolor{blue}{+0.33\\%} & \\textcolor{red}{+4.43\\%} & \\textcolor{red}{-0.85\\%} & \\textcolor{blue}{-3.1\\%}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{XGB}} & \\crossmark & \n",
      "ID & -9.94&0.07&-9.42&0.10&-10.55&0.07&-9.68&0.09&-10.20&0.11\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & -10.06&0.07&-9.40&0.09&-10.54&0.06&-9.70&0.09&-10.21&0.10\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\textcolor{red}{-1.22\\%} & \\textcolor{red}{+0.59\\%} & \\textcolor{blue}{+0.12\\%} & \\textcolor{blue}{-7.2\\%} & \\textcolor{blue}{+0.13\\%} & \\textcolor{blue}{-6.98\\%} & \\textcolor{red}{-0.31\\%} & \\textcolor{blue}{-1.3\\%} & \\textcolor{red}{-0.15\\%} & \\textcolor{blue}{-4.62\\%}\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & -10.03&0.11&-9.36&0.09&-10.22&0.07&-9.56&0.08&-10.28&0.11\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & -10.03&0.11&-9.38&0.08&-10.20&0.07&-9.53&0.10&-10.31&0.10\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\textcolor{red}{-0.04\\%} & \\textcolor{red}{+1.75\\%} & \\textcolor{red}{-0.2\\%} & \\textcolor{blue}{-7.0\\%} & \\textcolor{blue}{+0.13\\%} & \\textcolor{blue}{-1.62\\%} & \\textcolor{blue}{+0.31\\%} & \\textcolor{red}{+21.93\\%} & \\textcolor{red}{-0.34\\%} & \\textcolor{blue}{-4.89\\%}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      "\\textcolor{blue}{+0.29\\%} & \\textcolor{red}{+67.42\\%} & \\textcolor{blue}{+0.59\\%} & \\textcolor{blue}{-4.55\\%} & \\textcolor{blue}{+3.17\\%} & \\textcolor{red}{+5.02\\%} & \\textcolor{blue}{+1.77\\%} & \\textcolor{blue}{-14.37\\%} & \\textcolor{red}{-0.83\\%} & \\textcolor{red}{+2.16\\%}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{2}{*}{\\rotatebox{90}{GLU}} & \\crossmark & \n",
      "ID & -2.11&0.05&-1.07&0.14&-2.15&0.06&-1.56&0.05&-2.50&0.08\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OOD & -1.96&0.11&-1.61&0.10&-1.17&0.12&-1.44&0.06&-2.41&0.09\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      "\\textcolor{blue}{+6.72\\%} & \\textcolor{red}{+106.76\\%} & \\textcolor{red}{-50.33\\%} & \\textcolor{blue}{-29.83\\%} & \\textcolor{blue}{+45.64\\%} & \\textcolor{red}{+83.63\\%} & \\textcolor{blue}{+7.69\\%} & \\textcolor{red}{+9.24\\%} & \\textcolor{blue}{+3.33\\%} & \\textcolor{red}{+6.23\\%}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{2}{*}{\\rotatebox{90}{LAT}} & \\crossmark & \n",
      "ID & -25.29&0.36&-10.47&0.25&-52.18&0.42&-20.24&0.30&-26.15&0.33\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OOD & -28.75&0.38&-8.80&0.24&-30.19&0.44&-18.19&0.36&-30.08&0.40\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      "\\textcolor{red}{-13.67\\%} & \\textcolor{red}{+6.5\\%} & \\textcolor{blue}{+15.89\\%} & \\textcolor{blue}{-4.01\\%} & \\textcolor{blue}{+42.14\\%} & \\textcolor{red}{+4.59\\%} & \\textcolor{blue}{+10.12\\%} & \\textcolor{red}{+20.58\\%} & \\textcolor{red}{-15.03\\%} & \\textcolor{red}{+20.72\\%}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{NHI}} & \\crossmark & \n",
      "ID & -10.01&0.12&-9.32&0.11&-10.37&0.10&-9.62&0.09&-10.13&0.11\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & -10.37&0.07&-9.48&0.21&-10.80&0.08&-9.63&0.07&-10.13&0.11\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\textcolor{red}{-3.63\\%} & \\textcolor{blue}{-37.79\\%} & \\textcolor{red}{-1.68\\%} & \\textcolor{red}{+91.12\\%} & \\textcolor{red}{-4.19\\%} & \\textcolor{blue}{-21.68\\%} & \\textcolor{red}{-0.07\\%} & \\textcolor{blue}{-24.77\\%} & \\textcolor{red}{-0.01\\%} & \\textcolor{blue}{-5.46\\%}\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & -10.08&0.10&-9.26&0.11&-10.18&0.12&-9.49&0.08&-10.20&0.12\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & -10.21&0.06&-9.36&0.14&-11.10&0.20&-9.58&0.06&-10.19&0.11\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\textcolor{red}{-1.3\\%} & \\textcolor{blue}{-34.64\\%} & \\textcolor{red}{-1.17\\%} & \\textcolor{red}{+24.57\\%} & \\textcolor{red}{-9.0\\%} & \\textcolor{red}{+66.46\\%} & \\textcolor{red}{-0.94\\%} & \\textcolor{blue}{-14.44\\%} & \\textcolor{blue}{+0.14\\%} & \\textcolor{blue}{-8.1\\%}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      "\\textcolor{blue}{+1.55\\%} & \\textcolor{blue}{-16.3\\%} & \\textcolor{blue}{+1.23\\%} & \\textcolor{blue}{-34.06\\%} & \\textcolor{blue}{+1.76\\%} & \\textcolor{red}{+15.16\\%} & \\textcolor{blue}{+1.38\\%} & \\textcolor{blue}{-14.79\\%} & \\textcolor{red}{-0.55\\%} & \\textcolor{red}{+4.17\\%}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{TFT}} & \\crossmark & \n",
      "ID & 0.00&0.16&0.00&0.07&0.00&0.23&0.00&0.07&0.00&0.07\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 0.00&0.30&0.00&0.16&0.00&0.25&0.00&0.08&0.00&0.06\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\textcolor{red}{nan\\%} & \\textcolor{red}{+94.6\\%} & \\textcolor{red}{nan\\%} & \\textcolor{red}{+114.61\\%} & \\textcolor{red}{nan\\%} & \\textcolor{red}{+7.57\\%} & \\textcolor{red}{nan\\%} & \\textcolor{red}{+16.84\\%} & \\textcolor{red}{nan\\%} & \\textcolor{blue}{-21.55\\%}\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 0.00&0.15&0.00&0.09&0.00&0.26&0.00&0.08&0.00&0.08\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 0.00&0.23&0.00&0.09&0.00&0.35&0.00&0.08&0.00&0.05\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\textcolor{red}{nan\\%} & \\textcolor{red}{+57.74\\%} & \\textcolor{red}{nan\\%} & \\textcolor{red}{+0.35\\%} & \\textcolor{red}{nan\\%} & \\textcolor{red}{+37.5\\%} & \\textcolor{red}{nan\\%} & \\textcolor{blue}{-1.64\\%} & \\textcolor{red}{nan\\%} & \\textcolor{blue}{-35.43\\%}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      "\\textcolor{red}{nan\\%} & \\textcolor{blue}{-22.83\\%} & \\textcolor{red}{nan\\%} & \\textcolor{blue}{-46.14\\%} & \\textcolor{red}{nan\\%} & \\textcolor{red}{+10.56\\%} & \\textcolor{red}{nan\\%} & \\textcolor{red}{+9.88\\%} & \\textcolor{red}{nan\\%} & \\textcolor{blue}{-9.51\\%}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{TRA}} & \\crossmark & \n",
      "ID & -9.99&0.23&-9.37&0.21&-10.36&0.12&-9.60&0.13&-10.12&0.11\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & -10.11&0.21&-9.45&0.31&-10.68&0.18&-9.60&0.10&-10.15&0.11\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\textcolor{red}{-1.21\\%} & \\textcolor{blue}{-6.84\\%} & \\textcolor{red}{-0.79\\%} & \\textcolor{red}{+45.69\\%} & \\textcolor{red}{-3.05\\%} & \\textcolor{red}{+48.29\\%} & \\textcolor{blue}{+0.05\\%} & \\textcolor{blue}{-27.4\\%} & \\textcolor{red}{-0.34\\%} & \\textcolor{red}{+0.16\\%}\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & -9.98&0.19&-9.30&0.22&-10.09&0.14&-9.47&0.15&-10.17&0.12\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & -10.02&0.11&-9.36&0.22&-10.63&0.25&-9.49&0.08&-10.20&0.12\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\multicolumn{2}{c|}{Improv.} &\n",
      "\\textcolor{red}{-0.41\\%} & \\textcolor{blue}{-43.28\\%} & \\textcolor{red}{-0.65\\%} & \\textcolor{red}{+1.0\\%} & \\textcolor{red}{-5.32\\%} & \\textcolor{red}{+73.94\\%} & \\textcolor{red}{-0.16\\%} & \\textcolor{blue}{-45.49\\%} & \\textcolor{red}{-0.33\\%} & \\textcolor{red}{+2.63\\%}\n",
      "\\\\\n",
      "\\midrule\n",
      "\\rowcolor{lightgray}\n",
      "\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&\n",
      "\\textcolor{blue}{+0.93\\%} & \\textcolor{blue}{-47.95\\%} & \\textcolor{blue}{+0.92\\%} & \\textcolor{blue}{-28.49\\%} & \\textcolor{blue}{+2.59\\%} & \\textcolor{red}{+17.77\\%} & \\textcolor{blue}{+1.34\\%} & \\textcolor{blue}{-14.35\\%} & \\textcolor{red}{-0.53\\%} & \\textcolor{red}{+8.99\\%}\n",
      "\\\\\n",
      "\\midrule\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_65873/1608005624.py:54: RuntimeWarning: invalid value encountered in divide\n",
      "  diff_errors = (dict_errors['id']['covs'] - dict_errors['id']['no_covs']) / np.abs(dict_errors['id']['no_covs']) * 100\n",
      "/tmp/ipykernel_65873/1608005624.py:74: RuntimeWarning: invalid value encountered in divide\n",
      "  diff_errors = (dict_errors['ood']['covs'] - dict_errors['ood']['no_covs']) / np.abs(dict_errors['ood']['no_covs']) * 100\n",
      "/tmp/ipykernel_65873/1608005624.py:84: RuntimeWarning: invalid value encountered in divide\n",
      "  diff_errors_no_covs = (dict_errors['ood']['no_covs'] - dict_errors['id']['no_covs']) / np.abs(dict_errors['id']['no_covs']) * 100\n",
      "/tmp/ipykernel_65873/1608005624.py:85: RuntimeWarning: invalid value encountered in divide\n",
      "  diff_errors_covs = (dict_errors['ood']['covs'] - dict_errors['id']['covs']) / np.abs(dict_errors['id']['covs']) * 100\n"
     ]
    }
   ],
   "source": [
    "datasets = ['iglu', 'colas', 'dubosson', 'hall', 'weinstock'] # iglu is Broll in the paper, otherwise alphabetical order\n",
    "models = ['linreg', 'xgboost', 'gluformer', 'latentode',  'nhits', 'tft', 'transformer']\n",
    "\n",
    "def color_min(x):\n",
    "    return r'\\textcolor{red}{+' + str(round(x, 2)) + '\\%}' if x > 0 else r'\\textcolor{blue}{' + str(round(x, 2)) + '\\%}'\n",
    "def color_max(x):\n",
    "    return r'\\textcolor{blue}{+' + str(round(x, 2)) + '\\%}' if x > 0 else r'\\textcolor{red}{' + str(round(x, 2)) + '\\%}'\n",
    "\n",
    "for model in models:\n",
    "    if model in ['arima', 'gluformer', 'latentode']: # no covariates\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = None\n",
    "        (_, dict_errors), _ = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(r'\\multirow{2}{*}{\\rotatebox{90}{'+ model[:3].upper() + r'}} & \\crossmark & ')\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f}' for x in dict_errors['id']['no_covs'].reshape(-1).tolist()]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "\n",
    "        print('& \\crossmark & ')\n",
    "        print(r'OOD & ' + '&'.join([f'{x:.2f}' for x in dict_errors['ood']['no_covs'].reshape(-1).tolist()]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')\n",
    "        print(r'\\rowcolor{lightgray}')\n",
    "\n",
    "        print(r'\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&')\n",
    "        diff_errors = (dict_errors['ood']['no_covs'] - dict_errors['id']['no_covs']) / np.abs(dict_errors['id']['no_covs']) * 100\n",
    "        string = [[color_max(x[0]), color_min(x[1])] for x in diff_errors.tolist()]\n",
    "        string = [item for sublist in string for item in sublist]\n",
    "        print(' & '.join(string))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')\n",
    "    \n",
    "    else:\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = [f'../output/{model}_covariates_{dataset}.txt' for dataset in datasets]\n",
    "        (_, dict_errors), _ = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(r'\\multirow{6}{*}{\\rotatebox{90}{'+ model[:3].upper() + r'}} & \\crossmark & ')\n",
    "\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f}' for x in dict_errors['id']['no_covs'].reshape(-1).tolist()]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        \n",
    "        print('& \\checkmark & ')\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f}' for x in dict_errors['id']['covs'].reshape(-1).tolist()]))\n",
    "        \n",
    "        print(r'\\\\')\n",
    "        print('\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}')\n",
    "        \n",
    "        print('& \\multicolumn{2}{c|}{Improv.} &')  \n",
    "        diff_errors = (dict_errors['id']['covs'] - dict_errors['id']['no_covs']) / np.abs(dict_errors['id']['no_covs']) * 100\n",
    "        string = [[color_max(x[0]), color_min(x[1])] for x in diff_errors.tolist()]\n",
    "        string = [item for sublist in string for item in sublist]\n",
    "        print(' & '.join(string))\n",
    "        \n",
    "        print(r'\\\\')\n",
    "        print('\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}')\n",
    "        \n",
    "        print('& \\crossmark &')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f}' for x in dict_errors['ood']['no_covs'].reshape(-1).tolist()]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        \n",
    "        print('& \\checkmark & ')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f}' for x in dict_errors['ood']['covs'].reshape(-1).tolist()]))\n",
    "        \n",
    "        print(r'\\\\')\n",
    "        print('\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}')\n",
    "        \n",
    "        print('& \\multicolumn{2}{c|}{Improv.} &')  \n",
    "        diff_errors = (dict_errors['ood']['covs'] - dict_errors['ood']['no_covs']) / np.abs(dict_errors['ood']['no_covs']) * 100\n",
    "        string = [[color_max(x[0]), color_min(x[1])] for x in diff_errors.tolist()]\n",
    "        string = [item for sublist in string for item in sublist]\n",
    "        print(' & '.join(string))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')\n",
    "        print(r'\\rowcolor{lightgray}')\n",
    "\n",
    "        print(r'\\multicolumn{3}{c|}{$\\min \\Delta$(ID, OD)\\%}&')\n",
    "        diff_errors_no_covs = (dict_errors['ood']['no_covs'] - dict_errors['id']['no_covs']) / np.abs(dict_errors['id']['no_covs']) * 100\n",
    "        diff_errors_covs = (dict_errors['ood']['covs'] - dict_errors['id']['covs']) / np.abs(dict_errors['id']['covs']) * 100\n",
    "        diff_errors = np.empty_like(diff_errors_no_covs)\n",
    "        diff_errors[:, 1] = np.maximum(diff_errors_no_covs[:, 0], diff_errors_covs[:, 0])\n",
    "        diff_errors[:, 1] = np.minimum(diff_errors_no_covs[:, 1], diff_errors_covs[:, 1])\n",
    "        string = [[color_max(x[0]), color_min(x[1])] for x in diff_errors.tolist()]\n",
    "        string = [item for sublist in string for item in sublist]\n",
    "        print(' & '.join(string))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Table 1: short version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = ['iglu', 'colas', 'dubosson', 'hall', 'weinstock'] # iglu is Broll in the paper, otherwise alphabetical order\n",
    "models = ['arima', 'linreg', 'xgboost', 'gluformer', 'latentode',  'nhits', 'tft', 'transformer']\n",
    "\n",
    "model_errors = []\n",
    "for model in models:\n",
    "    model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "    model_names_with_covs = None\n",
    "    (dict_errors, _), _ = avg_results(model_names, model_names_with_covs)\n",
    "    model_errors.append(dict_errors['id']['no_covs'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iglu: arima\n",
      "colas: linreg\n",
      "dubosson: linreg\n",
      "hall: latentode\n",
      "weinstock: transformer\n",
      "{'arima', 'transformer', 'latentode', 'linreg'}\n"
     ]
    }
   ],
   "source": [
    "model_errors = np.mean(np.array(model_errors), axis=-1)\n",
    "# find best models with lowest error\n",
    "best_models = np.argsort(model_errors, axis=0)\n",
    "best_models_set = []\n",
    "for i, dataset in enumerate(datasets):\n",
    "    bm_dataset = [models[j] for j in best_models[:, i]]\n",
    "    print(f'{dataset}: {bm_dataset[0]}')\n",
    "    best_models_set.append(bm_dataset[0])\n",
    "print(set(best_models_set))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Table 2: short version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = ['iglu', 'colas', 'dubosson', 'hall', 'weinstock'] # iglu is Broll in the paper, otherwise alphabetical order\n",
    "models = ['arima', 'linreg', 'xgboost', 'gluformer', 'latentode',  'nhits', 'tft', 'transformer']\n",
    "\n",
    "model_errors = []\n",
    "for model in models:\n",
    "    model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "    model_names_with_covs = None\n",
    "    (_, dict_errors), _ = avg_results(model_names, model_names_with_covs)\n",
    "    model_errors.append(dict_errors['id']['no_covs'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iglu: gluformer\n",
      "colas: gluformer\n",
      "dubosson: gluformer\n",
      "hall: gluformer\n",
      "weinstock: gluformer\n",
      "{'gluformer'}\n",
      "iglu: gluformer\n",
      "colas: tft\n",
      "dubosson: gluformer\n",
      "hall: gluformer\n",
      "weinstock: tft\n",
      "{'gluformer', 'tft'}\n"
     ]
    }
   ],
   "source": [
    "model_likelihood = np.array(model_errors)[:,:,0]\n",
    "model_cal = np.array(model_errors)[:,:,1]\n",
    "# find best models with highest likelihood\n",
    "model_likelihood[np.isnan(model_likelihood)] = -np.inf\n",
    "model_likelihood[model_likelihood == 0] = -np.inf # models that have exactly 0 likelihood do not support likelihood\n",
    "best_models = np.argsort(model_likelihood, axis=0)\n",
    "best_models_set = []\n",
    "for i, dataset in enumerate(datasets):\n",
    "    bm_dataset = [models[j] for j in best_models[:, i]]\n",
    "    print(f'{dataset}: {bm_dataset[-1]}')\n",
    "    best_models_set.append(bm_dataset[-1])\n",
    "print(set(best_models_set))\n",
    "\n",
    "# find best models with lowest cal error\n",
    "model_cal[np.isnan(model_cal)] = np.inf\n",
    "best_models = np.argsort(model_cal, axis=0)\n",
    "best_models_set = []\n",
    "for i, dataset in enumerate(datasets):\n",
    "    bm_dataset = [models[j] for j in best_models[:, i]]\n",
    "    print(f'{dataset}: {bm_dataset[0]}')\n",
    "    best_models_set.append(bm_dataset[0])\n",
    "print(set(best_models_set))\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Clearing the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import os\n",
    "\n",
    "# # Set the directory path to the folder containing the output files\n",
    "# folder_path = './output'\n",
    "\n",
    "# # Loop through each file in the folder\n",
    "# for filename in os.listdir(folder_path):\n",
    "#     file_path = os.path.join(folder_path, filename)\n",
    "#     # Open the file in read mode if the file starts with transformer\n",
    "#     if filename.startswith('transformer') or \\\n",
    "#             filename.startswith('tft') or \\\n",
    "#                 filename.startswith('linreg') or \\\n",
    "#                     filename.startswith('xgboost') or \\\n",
    "#                         filename.startswith('nhits'):\n",
    "#         with open(file_path, 'r') as f:\n",
    "#             lines = f.readlines()\n",
    "            \n",
    "#         # Loop through the lines in reverse order\n",
    "#         for i in range(len(lines)-1, -1, -1):\n",
    "#             if lines[i].startswith('Best value: '):\n",
    "#                 # Delete all lines after the line starting with \"Best value: \"\n",
    "#                 del lines[i+1:]\n",
    "#                 break\n",
    "        \n",
    "#         # Open the file in write mode and write the modified lines back to the file\n",
    "#         with open(file_path, 'w') as f:\n",
    "#             f.writelines(lines)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tables 3-6 (std. error in metric estimates)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = ['tft', 'transformer', 'xgboost', 'nhits']\n",
    "# 'gluformer', 'latentode'\n",
    "datasets = ['iglu', 'hall', 'colas', 'weinstock', 'dubosson']\n",
    "metrics = ['ID mean of (MSE, MAE)', 'OOD mean of (MSE, MAE)',\n",
    "           'ID median of (MSE, MAE)', 'OOD median of (MSE, MAE)',\n",
    "           'ID likelihoods', 'OOD likelihoods', \n",
    "           'ID calibration errors', 'OOD calibration errors']\n",
    "\n",
    "for model in models:\n",
    "    for dataset in datasets:\n",
    "        # open model file\n",
    "        with open(f'../output/{model}_covariates_{dataset}.txt', 'rb') as f:\n",
    "            lines = f.readlines()\n",
    "        results = {}\n",
    "        for metric in metrics:\n",
    "            results[metric] = {}\n",
    "            for seed in [1, 2]:\n",
    "                results[metric][seed] = []\n",
    "                for model_seed in range(10, 20):\n",
    "                    # find lines starting with Model Seed: model_seed Seed: seed metric: \n",
    "                    prev_line_has_calibration = False\n",
    "                    for line in lines:\n",
    "                        line = line.decode()\n",
    "                        if f'Model Seed: {model_seed} Seed: {seed} {metric}' in str(line):\n",
    "                            vals = re.findall(r'-?\\d+\\.\\d+(?:e-\\d+)?', line)\n",
    "                            results[metric][seed].append([float(x) for x in vals])\n",
    "                            if 'calibration' in metric:\n",
    "                                prev_line_has_calibration = True\n",
    "                        # check that line has no text just numbers and symbols\n",
    "                        check = re.findall(r'[a-zA-Z]', line)\n",
    "                        if prev_line_has_calibration and len(check) == 0:\n",
    "                            prev_line_has_calibration = False\n",
    "                            vals = re.findall(r'-?\\d+\\.\\d+(?:e-\\d+)?', line)\n",
    "                            results[metric][seed][-1].extend([float(x) for x in vals])\n",
    "        for metric in metrics:\n",
    "            for seed in [1, 2]:\n",
    "                results[metric][seed] = np.array(results[metric][seed])\n",
    "                results[metric][seed] = np.mean(results[metric][seed], axis=0)\n",
    "            results[metric] = np.std([results[metric][1], results[metric][2]], axis=0)\n",
    "        \n",
    "        new_lines = []\n",
    "        for line in lines:\n",
    "            line = line.decode()\n",
    "            for metric in metrics:\n",
    "                if ('Seed' not in line) and (metric in line):\n",
    "                    if len(results[metric]) > 0:\n",
    "                        line = line.strip('\\n') + f' +- {results[metric]} \\n' if len(results[metric]) > 1 else line.strip('\\n') + f' +- {results[metric][0]} \\n'\n",
    "            new_lines.append(line)\n",
    "        # write new lines to file\n",
    "        with open(f'../output/{model}_covariates_{dataset}.txt', 'wb') as f:\n",
    "            for line in new_lines:\n",
    "                f.write(line.encode())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\multirow{2}{*}{\\rotatebox{90}{ARI}} & \\crossmark & \n",
      "ID & 10.53 +- 4.69&5.80 +- 0.83&13.53 +- 6.37&8.63 +- 1.50&13.40 +- 1.77\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OD & 11.75 +- 7.26&5.91 +- 2.38&18.75 +- 15.10&8.22 +- 5.02&15.87 +- 4.31\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{LIN}} & \\crossmark & \n",
      "ID & 11.68 +- 3.61&5.26 +- 0.57&12.07 +- 5.49&7.38 +- 1.29&13.60 +- 1.59\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 9.95 +- 2.73&5.56 +- 0.84&12.41 +- 5.71&7.84 +- 1.85&13.39 +- 1.60\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 11.98 +- 7.08&5.33 +- 1.60&15.69 +- 12.17&7.86 +- 4.23&15.58 +- 4.64\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 23.30 +- 24.85&5.54 +- 1.60&203114.47 +- 287247.24&14.22 +- 17.87&15.66 +- 4.48\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{XGB}} & \\crossmark & \n",
      "ID & 12.80 +- 2.62&6.42 +- 1.04&21.18 +- 5.38&7.58 +- 1.23&13.63 +- 1.67\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 13.89 +- 3.04&6.37 +- 1.75&20.89 +- 5.62&8.05 +- 1.41&13.77 +- 1.67\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 9.76 +- 2.72&6.18 +- 0.89&17.57 +- 4.68&7.49 +- 1.16&15.36 +- 1.13\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 9.67 +- 2.27&6.36 +- 1.76&17.44 +- 3.01&8.20 +- 2.24&15.55 +- 1.77\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{2}{*}{\\rotatebox{90}{GLU}} & \\crossmark & \n",
      "ID & 14.19 +- 4.49&8.17 +- 2.42&21.74 +- 7.53&7.74 +- 1.39&14.07 +- 2.24\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OD & 16.70 +- 6.34&6.94 +- 1.77&23.48 +- 7.52&8.17 +- 1.84&15.94 +- 3.29\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{2}{*}{\\rotatebox{90}{LAT}} & \\crossmark & \n",
      "ID & 14.37 +- 8.72&6.28 +- 1.17&20.14 +- 8.29&7.13 +- 1.07&13.54 +- 2.20\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OD & 14.96 +- 8.19&5.64 +- 1.19&17.38 +- 8.48&7.71 +- 1.25&15.06 +- 2.72\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{NHI}} & \\crossmark & \n",
      "ID & 13.79 +- 3.01&5.93 +- 0.51&17.45 +- 1.08&7.68 +- 0.52&13.29 +- 1.40\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 16.20 +- 5.15&9.09 +- 2.50&30.43 +- 5.65&8.16 +- 1.19&13.41 +- 0.63\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 14.64 +- 3.29&5.68 +- 0.38&18.20 +- 1.53&7.74 +- 0.46&14.52 +- 1.47\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 15.66 +- 6.16&7.56 +- 2.01&37.35 +- 8.22&8.59 +- 1.40&14.40 +- 0.98\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{TFT}} & \\crossmark & \n",
      "ID & 13.73 +- 11.22&5.62 +- 1.22&18.37 +- 6.62&7.92 +- 1.90&14.32 +- 3.70\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 14.68 +- 5.67&6.51 +- 0.00&18.43 +- 5.92&8.42 +- 1.80&14.97 +- 2.43\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 12.43 +- 8.14&5.51 +- 1.36&17.50 +- 6.82&8.12 +- 2.01&15.25 +- 3.83\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 13.25 +- 4.52&5.79 +- 0.00&17.19 +- 4.88&8.93 +- 2.36&15.47 +- 2.42\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{TRA}} & \\crossmark & \n",
      "ID & 15.12 +- 0.00&6.47 +- 0.00&16.62 +- 0.00&7.89 +- 0.00&13.22 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 12.83 +- 0.00&8.44 +- 0.00&27.43 +- 0.01&7.49 +- 0.00&14.46 +- 0.01\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 14.04 +- 0.00&5.97 +- 0.00&15.71 +- 0.00&8.18 +- 0.00&14.15 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 13.76 +- 0.00&7.26 +- 0.00&34.11 +- 0.01&7.40 +- 0.00&15.59 +- 0.00\n",
      "\\\\\n",
      "\\midrule\n"
     ]
    }
   ],
   "source": [
    "# RMSE table\n",
    "\n",
    "for model in models:\n",
    "    if model in ['arima', 'gluformer', 'latentode']: # no covariates\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = None\n",
    "        (dict_errors, _), (dict_errors_std, _) = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(r'\\multirow{2}{*}{\\rotatebox{90}{'+ model[:3].upper() + r'}} & \\crossmark & ')\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['id']['no_covs'][:, 0].tolist(), \\\n",
    "          dict_errors_std['id']['no_covs'][:, 0].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "\n",
    "        print('& \\crossmark & ')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['ood']['no_covs'][:, 0].tolist(), \\\n",
    "          dict_errors_std['ood']['no_covs'][:, 0].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')\n",
    "    \n",
    "    else:\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = [f'../output/{model}_covariates_{dataset}.txt' for dataset in datasets]\n",
    "        (dict_errors, _), (dict_errors_std, _) = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(r'\\multirow{4}{*}{\\rotatebox{90}{'+ model[:3].upper() + r'}} & \\crossmark & ')\n",
    "\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['id']['no_covs'][:, 0].tolist(), \\\n",
    "          dict_errors_std['id']['no_covs'][:, 0].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        \n",
    "        print('& \\checkmark & ')\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['id']['covs'][:, 0].tolist(), \\\n",
    "          dict_errors_std['id']['covs'][:, 0].tolist())]))\n",
    "        \n",
    "        print(r'\\\\')\n",
    "        print('\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}')\n",
    "        \n",
    "        print('& \\crossmark &')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['ood']['no_covs'][:, 0].tolist(), \\\n",
    "          dict_errors_std['ood']['no_covs'][:, 0].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        \n",
    "        print('& \\checkmark & ')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['ood']['covs'][:, 0].tolist(), \\\n",
    "          dict_errors_std['ood']['covs'][:, 0].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\multirow{2}{*}{\\rotatebox{90}{ARI}} & \\crossmark & \n",
      "ID & 8.67 +- 0.74&4.80 +- 0.04&11.06 +- 0.98&7.34 +- 0.16&11.25 +- 0.10\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OD & 9.71 +- 1.93&4.87 +- 0.38&14.58 +- 4.75&6.97 +- 1.19&13.34 +- 0.52\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{LIN}} & \\crossmark & \n",
      "ID & 9.71 +- 0.37&4.35 +- 0.03&9.97 +- 1.00&6.33 +- 0.09&11.46 +- 0.11\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 8.41 +- 0.24&4.60 +- 0.04&10.03 +- 1.11&6.66 +- 0.18&11.34 +- 0.10\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 9.83 +- 1.58&4.41 +- 0.20&11.90 +- 3.51&6.62 +- 0.91&13.09 +- 0.61\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 16.80 +- 9.45&4.57 +- 0.19&67548.59 +- 135072.57&10.02 +- 6.68&13.16 +- 0.57\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{XGB}} & \\crossmark & \n",
      "ID & 11.50 +- 0.31&5.49 +- 0.08&19.09 +- 0.32&6.55 +- 0.09&11.61 +- 0.08\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 11.87 +- 0.24&5.46 +- 0.21&18.55 +- 0.82&7.02 +- 0.12&11.77 +- 0.16\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 8.72 +- 0.45&5.32 +- 0.07&15.42 +- 0.68&6.52 +- 0.11&13.04 +- 0.04\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 8.56 +- 0.30&5.47 +- 0.18&15.46 +- 0.35&7.11 +- 0.26&13.43 +- 0.13\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{2}{*}{\\rotatebox{90}{GLU}} & \\crossmark & \n",
      "ID & 12.55 +- 0.67&7.12 +- 0.37&19.40 +- 1.27&6.69 +- 0.11&12.09 +- 0.18\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OD & 14.82 +- 1.09&6.03 +- 0.21&20.70 +- 1.05&7.04 +- 0.18&13.65 +- 0.32\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{2}{*}{\\rotatebox{90}{LAT}} & \\crossmark & \n",
      "ID & 12.32 +- 2.15&5.37 +- 0.12&17.88 +- 1.59&6.11 +- 0.08&11.45 +- 0.19\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OD & 13.05 +- 1.87&4.84 +- 0.13&15.12 +- 1.92&6.61 +- 0.10&12.72 +- 0.25\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{NHI}} & \\crossmark & \n",
      "ID & 12.07 +- 0.31&5.04 +- 0.02&14.79 +- 0.05&6.57 +- 0.02&11.21 +- 0.07\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 14.64 +- 0.81&8.03 +- 0.36&27.97 +- 0.57&7.10 +- 0.07&11.31 +- 0.02\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 12.77 +- 0.35&4.83 +- 0.02&15.59 +- 0.04&6.62 +- 0.02&12.24 +- 0.07\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 14.01 +- 1.03&6.65 +- 0.27&33.52 +- 0.83&7.53 +- 0.11&12.12 +- 0.03\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{TFT}} & \\crossmark & \n",
      "ID & 11.07 +- 2.85&4.54 +- 0.12&15.49 +- 1.23&6.61 +- 0.20&11.76 +- 0.43\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 12.43 +- 0.89&5.27 +- 0.00&15.51 +- 0.83&7.06 +- 0.19&12.30 +- 0.17\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 10.23 +- 1.91&4.47 +- 0.15&14.53 +- 1.11&6.76 +- 0.21&12.50 +- 0.45\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 11.17 +- 0.59&4.68 +- 0.00&14.43 +- 0.63&7.44 +- 0.26&12.67 +- 0.17\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{TRA}} & \\crossmark & \n",
      "ID & 13.20 +- 0.00&5.65 +- 0.00&14.04 +- 0.00&6.78 +- 0.00&11.22 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 11.27 +- 0.00&7.77 +- 0.00&24.40 +- 0.00&6.42 +- 0.00&12.61 +- 0.00\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 12.28 +- 0.00&5.24 +- 0.00&12.98 +- 0.00&7.07 +- 0.00&11.91 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 12.13 +- 0.00&6.59 +- 0.00&28.21 +- 0.00&6.29 +- 0.00&13.58 +- 0.00\n",
      "\\\\\n",
      "\\midrule\n"
     ]
    }
   ],
   "source": [
    "# MAE table\n",
    "\n",
    "for model in models:\n",
    "    if model in ['arima', 'gluformer', 'latentode']: # no covariates\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = None\n",
    "        (dict_errors, _), (dict_errors_std, _) = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(r'\\multirow{2}{*}{\\rotatebox{90}{'+ model[:3].upper() + r'}} & \\crossmark & ')\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['id']['no_covs'][:, 1].tolist(), \\\n",
    "          dict_errors_std['id']['no_covs'][:, 1].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "\n",
    "        print('& \\crossmark & ')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['ood']['no_covs'][:, 1].tolist(), \\\n",
    "          dict_errors_std['ood']['no_covs'][:, 1].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')\n",
    "    \n",
    "    else:\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = [f'../output/{model}_covariates_{dataset}.txt' for dataset in datasets]\n",
    "        (dict_errors, _), (dict_errors_std, _) = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(r'\\multirow{4}{*}{\\rotatebox{90}{'+ model[:3].upper() + r'}} & \\crossmark & ')\n",
    "\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['id']['no_covs'][:, 1].tolist(), \\\n",
    "          dict_errors_std['id']['no_covs'][:, 1].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        \n",
    "        print('& \\checkmark & ')\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['id']['covs'][:, 1].tolist(), \\\n",
    "          dict_errors_std['id']['covs'][:, 1].tolist())]))\n",
    "        \n",
    "        print(r'\\\\')\n",
    "        print('\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}')\n",
    "        \n",
    "        print('& \\crossmark &')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['ood']['no_covs'][:, 1].tolist(), \\\n",
    "          dict_errors_std['ood']['no_covs'][:, 1].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        \n",
    "        print('& \\checkmark & ')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['ood']['covs'][:, 1].tolist(), \\\n",
    "          dict_errors_std['ood']['covs'][:, 1].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\multirow{2}{*}{\\rotatebox{90}{ARI}} & \\crossmark & \n",
      "ID & nan +- nan&nan +- nan&nan +- nan&nan +- nan&nan +- nan\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OD & nan +- nan&nan +- nan&nan +- nan&nan +- nan&nan +- nan\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{LIN}} & \\crossmark & \n",
      "ID & -9.89 +- 0.01&-9.19 +- 0.01&-10.10 +- 0.16&-9.56 +- 0.03&-10.14 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & -9.87 +- 0.03&-9.17 +- 0.01&-10.15 +- 0.17&-10.30 +- 1.47&-10.12 +- 0.00\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & -9.95 +- 0.14&-9.16 +- 0.06&-10.11 +- 0.28&-9.53 +- 0.17&-10.22 +- 0.03\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & -10.24 +- 0.30&-9.16 +- 0.06&-12.08 +- 3.94&-10.42 +- 1.49&-11.13 +- 1.83\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{XGB}} & \\crossmark & \n",
      "ID & -9.94 +- 0.02&-9.42 +- 0.01&-10.55 +- 0.02&-9.68 +- 0.01&-10.20 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & -10.06 +- 0.06&-9.40 +- 0.02&-10.54 +- 0.01&-9.70 +- 0.00&-10.21 +- 0.00\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & -10.03 +- 0.01&-9.36 +- 0.01&-10.22 +- 0.02&-9.56 +- 0.01&-10.28 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & -10.03 +- 0.01&-9.38 +- 0.02&-10.20 +- 0.01&-9.53 +- 0.01&-10.31 +- 0.00\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{2}{*}{\\rotatebox{90}{GLU}} & \\crossmark & \n",
      "ID & -2.11 +- 0.24&-1.07 +- 0.19&-2.15 +- 0.22&-1.56 +- 0.10&-2.50 +- 0.05\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OD & -1.96 +- 0.27&-1.61 +- 0.12&-1.17 +- 1.69&-1.44 +- 0.11&-2.41 +- 0.05\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{2}{*}{\\rotatebox{90}{LAT}} & \\crossmark & \n",
      "ID & -25.29 +- 5.68&-10.47 +- 0.16&-52.18 +- 9.54&-20.24 +- 0.19&-26.15 +- 0.07\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OD & -28.75 +- 6.38&-8.80 +- 0.12&-30.19 +- 5.20&-18.19 +- 0.16&-30.08 +- 0.14\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{NHI}} & \\crossmark & \n",
      "ID & -10.01 +- 0.01&-9.32 +- 0.00&-10.37 +- 0.00&-9.62 +- 0.00&-10.13 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & -10.37 +- 0.05&-9.48 +- 0.02&-10.80 +- 0.01&-9.63 +- 0.01&-10.13 +- 0.00\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & -10.08 +- 0.01&-9.26 +- 0.00&-10.18 +- 0.00&-9.49 +- 0.00&-10.20 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & -10.21 +- 0.04&-9.36 +- 0.01&-11.10 +- 0.04&-9.58 +- 0.01&-10.19 +- 0.00\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{TFT}} & \\crossmark & \n",
      "ID & 0.00 +- 0.00&0.00 +- 0.00&0.00 +- 0.00&0.00 +- 0.00&0.00 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 0.00 +- 0.00&0.00 +- 0.00&0.00 +- 0.00&0.00 +- 0.00&0.00 +- 0.00\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 0.00 +- 0.00&0.00 +- 0.00&0.00 +- 0.00&0.00 +- 0.00&0.00 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 0.00 +- 0.00&0.00 +- 0.00&0.00 +- 0.00&0.00 +- 0.00&0.00 +- 0.00\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{TRA}} & \\crossmark & \n",
      "ID & -9.99 +- 0.00&-9.37 +- 0.00&-10.36 +- 0.00&-9.60 +- 0.00&-10.12 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & -10.11 +- 0.00&-9.45 +- 0.00&-10.68 +- 0.00&-9.60 +- 0.00&-10.15 +- 0.00\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & -9.98 +- 0.00&-9.30 +- 0.00&-10.09 +- 0.00&-9.47 +- 0.00&-10.17 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & -10.02 +- 0.00&-9.36 +- 0.00&-10.63 +- 0.00&-9.49 +- 0.00&-10.20 +- 0.00\n",
      "\\\\\n",
      "\\midrule\n"
     ]
    }
   ],
   "source": [
    "# likelihood table\n",
    "\n",
    "for model in models:\n",
    "    if model in ['arima', 'gluformer', 'latentode']: # no covariates\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = None\n",
    "        (_, dict_errors), (_, dict_errors_std) = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(r'\\multirow{2}{*}{\\rotatebox{90}{'+ model[:3].upper() + r'}} & \\crossmark & ')\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['id']['no_covs'][:, 0].tolist(), \\\n",
    "          dict_errors_std['id']['no_covs'][:, 0].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "\n",
    "        print('& \\crossmark & ')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['ood']['no_covs'][:, 0].tolist(), \\\n",
    "          dict_errors_std['ood']['no_covs'][:, 0].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')\n",
    "    \n",
    "    else:\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = [f'../output/{model}_covariates_{dataset}.txt' for dataset in datasets]\n",
    "        (_, dict_errors), (_, dict_errors_std) = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(r'\\multirow{4}{*}{\\rotatebox{90}{'+ model[:3].upper() + r'}} & \\crossmark & ')\n",
    "\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['id']['no_covs'][:, 0].tolist(), \\\n",
    "          dict_errors_std['id']['no_covs'][:, 0].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        \n",
    "        print('& \\checkmark & ')\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['id']['covs'][:, 0].tolist(), \\\n",
    "          dict_errors_std['id']['covs'][:, 0].tolist())]))\n",
    "        \n",
    "        print(r'\\\\')\n",
    "        print('\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}')\n",
    "        \n",
    "        print('& \\crossmark &')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['ood']['no_covs'][:, 0].tolist(), \\\n",
    "          dict_errors_std['ood']['no_covs'][:, 0].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        \n",
    "        print('& \\checkmark & ')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['ood']['covs'][:, 0].tolist(), \\\n",
    "          dict_errors_std['ood']['covs'][:, 0].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\multirow{2}{*}{\\rotatebox{90}{ARI}} & \\crossmark & \n",
      "ID & nan +- nan&nan +- nan&nan +- nan&nan +- nan&nan +- nan\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OD & nan +- nan&nan +- nan&nan +- nan&nan +- nan&nan +- nan\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{LIN}} & \\crossmark & \n",
      "ID & 0.12 +- 0.01&0.15 +- 0.00&0.18 +- 0.02&0.10 +- 0.00&0.11 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 0.13 +- 0.02&0.19 +- 0.01&0.21 +- 0.02&0.19 +- 0.20&0.11 +- 0.00\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 0.15 +- 0.04&0.15 +- 0.01&0.17 +- 0.02&0.10 +- 0.02&0.11 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 0.55 +- 0.39&0.17 +- 0.02&0.48 +- 0.58&0.23 +- 0.18&0.21 +- 0.20\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{XGB}} & \\crossmark & \n",
      "ID & 0.07 +- 0.01&0.10 +- 0.01&0.07 +- 0.01&0.09 +- 0.00&0.11 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 0.07 +- 0.01&0.09 +- 0.01&0.06 +- 0.01&0.09 +- 0.00&0.10 +- 0.00\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 0.11 +- 0.01&0.09 +- 0.01&0.07 +- 0.01&0.08 +- 0.00&0.11 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 0.11 +- 0.01&0.08 +- 0.01&0.07 +- 0.01&0.10 +- 0.01&0.10 +- 0.00\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{2}{*}{\\rotatebox{90}{GLU}} & \\crossmark & \n",
      "ID & 0.05 +- 0.01&0.14 +- 0.04&0.06 +- 0.02&0.05 +- 0.01&0.08 +- 0.01\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OD & 0.11 +- 0.04&0.10 +- 0.02&0.12 +- 0.06&0.06 +- 0.01&0.09 +- 0.01\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{2}{*}{\\rotatebox{90}{LAT}} & \\crossmark & \n",
      "ID & 0.36 +- 0.06&0.25 +- 0.03&0.42 +- 0.04&0.30 +- 0.02&0.33 +- 0.03\n",
      "\\\\\n",
      "& \\crossmark & \n",
      "OD & 0.38 +- 0.06&0.24 +- 0.04&0.44 +- 0.07&0.36 +- 0.03&0.40 +- 0.04\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{NHI}} & \\crossmark & \n",
      "ID & 0.12 +- 0.01&0.11 +- 0.00&0.10 +- 0.00&0.09 +- 0.00&0.11 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 0.07 +- 0.01&0.21 +- 0.04&0.08 +- 0.01&0.07 +- 0.00&0.11 +- 0.00\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 0.10 +- 0.01&0.11 +- 0.00&0.12 +- 0.00&0.08 +- 0.00&0.12 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 0.06 +- 0.02&0.14 +- 0.04&0.20 +- 0.03&0.06 +- 0.01&0.11 +- 0.00\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{TFT}} & \\crossmark & \n",
      "ID & 0.16 +- 0.09&0.07 +- 0.03&0.23 +- 0.12&0.07 +- 0.03&0.07 +- 0.04\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 0.30 +- 0.13&0.16 +- 0.00&0.25 +- 0.12&0.08 +- 0.02&0.06 +- 0.02\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 0.15 +- 0.10&0.09 +- 0.04&0.26 +- 0.18&0.08 +- 0.03&0.08 +- 0.05\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 0.23 +- 0.13&0.09 +- 0.00&0.35 +- 0.13&0.08 +- 0.03&0.05 +- 0.02\n",
      "\\\\\n",
      "\\midrule\n",
      "\\multirow{4}{*}{\\rotatebox{90}{TRA}} & \\crossmark & \n",
      "ID & 0.23 +- 0.00&0.21 +- 0.00&0.12 +- 0.00&0.13 +- 0.00&0.11 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "ID & 0.21 +- 0.00&0.31 +- 0.00&0.18 +- 0.00&0.10 +- 0.00&0.11 +- 0.00\n",
      "\\\\\n",
      "\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}\n",
      "& \\crossmark &\n",
      "OD & 0.19 +- 0.00&0.22 +- 0.00&0.14 +- 0.00&0.15 +- 0.00&0.12 +- 0.00\n",
      "\\\\\n",
      "& \\checkmark & \n",
      "OD & 0.11 +- 0.00&0.22 +- 0.00&0.25 +- 0.00&0.08 +- 0.00&0.12 +- 0.00\n",
      "\\\\\n",
      "\\midrule\n"
     ]
    }
   ],
   "source": [
    "# calibration table\n",
    "\n",
    "for model in models:\n",
    "    if model in ['arima', 'gluformer', 'latentode']: # no covariates\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = None\n",
    "        (_, dict_errors), (_, dict_errors_std) = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(r'\\multirow{2}{*}{\\rotatebox{90}{'+ model[:3].upper() + r'}} & \\crossmark & ')\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['id']['no_covs'][:, 1].tolist(), \\\n",
    "          dict_errors_std['id']['no_covs'][:, 1].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "\n",
    "        print('& \\crossmark & ')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['ood']['no_covs'][:, 1].tolist(), \\\n",
    "          dict_errors_std['ood']['no_covs'][:, 1].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')\n",
    "    \n",
    "    else:\n",
    "        model_names = [f'../output/{model}_{dataset}.txt' for dataset in datasets]\n",
    "        model_names_with_covs = [f'../output/{model}_covariates_{dataset}.txt' for dataset in datasets]\n",
    "        (_, dict_errors), (_, dict_errors_std) = avg_results(model_names, model_names_with_covs)\n",
    "\n",
    "        print(r'\\multirow{4}{*}{\\rotatebox{90}{'+ model[:3].upper() + r'}} & \\crossmark & ')\n",
    "\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['id']['no_covs'][:, 1].tolist(), \\\n",
    "          dict_errors_std['id']['no_covs'][:, 1].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        \n",
    "        print('& \\checkmark & ')\n",
    "        print(r'ID & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['id']['covs'][:, 1].tolist(), \\\n",
    "          dict_errors_std['id']['covs'][:, 1].tolist())]))\n",
    "        \n",
    "        print(r'\\\\')\n",
    "        print('\\cmidrule(lr){4-5} \\cmidrule(lr){6-7} \\cmidrule(lr){8-9} \\cmidrule(lr){10-11} \\cmidrule(lr){12-13}')\n",
    "        \n",
    "        print('& \\crossmark &')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['ood']['no_covs'][:, 1].tolist(), \\\n",
    "          dict_errors_std['ood']['no_covs'][:, 1].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        \n",
    "        print('& \\checkmark & ')\n",
    "        print(r'OD & ' + '&'.join([f'{x:.2f} +- {y:.2f}' for (x, y) in zip(dict_errors['ood']['covs'][:, 1].tolist(), \\\n",
    "          dict_errors_std['ood']['covs'][:, 1].tolist())]))\n",
    "\n",
    "        print(r'\\\\')\n",
    "        print('\\midrule')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "glunet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "b9af0babfa4fcc32151d0f9cd96f26ee8eefb724c47cfe9b27c84c1db30f6822"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
