{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys; sys.path.append('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments.parse_results import parse_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fix_index(all_results):\n",
    "    all_results.reset_index(inplace=True)\n",
    "    all_results['depth'] = all_results['depth'].astype(int)\n",
    "    all_results['step'] = all_results['step'].astype(int)\n",
    "    all_results = all_results.set_index(['depth', 'step'])\n",
    "    all_results = all_results.sort_index(level=['depth', 'step'])\n",
    "    return all_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_mean_std_string(df, col='loss.test', round_num=2):\n",
    "    std = df.groupby(['depth', 'step'], as_index=True).std().round(round_num)\n",
    "    mean = df.groupby(['depth', 'step'], as_index=True).mean().round(round_num)\n",
    "    string_it = mean[col].astype(str) + ' $\\pm$ ' + std[col].astype(str)\n",
    "    return mean, string_it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_multiindex_step(df, steps=[1, 8, 32, 128]):\n",
    "    return df.iloc[[True if x in [1, 8, 32, 128] else False for x in df.index.get_level_values(1)]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# EigenWorms Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get data\n",
    "ew = parse_results('UEA', \n",
    "                   'EigenWorms', \n",
    "                   'main_adjoint', \n",
    "                   sort_key='test',\n",
    "                   average_over=None, \n",
    "                   print_frame=False, \n",
    "                   pretty_std=False)\n",
    "ew['acc.test'] = (ew['acc.test'] * 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "ew_mean, ew_string = convert_to_mean_std_string(ew, col='acc.test', round_num=1)\n",
    "ew_mean['elapsed_time'] = (ew_mean['elapsed_time'] / (60 ** 2)).round(1)\n",
    "ew_mean['memory_usage'] = ew_mean['memory_usage'].round(1)\n",
    "\n",
    "ew_string = pd.concat((ew_string, ew_mean[['memory_usage', 'elapsed_time']]), axis=1)\n",
    "ew_string = fix_index(ew_string)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "ew_all = ew_string[['acc.test', 'elapsed_time', 'memory_usage']]\n",
    "with open('tables/eigenworms_full.tex', 'w') as file:\n",
    "    file.write(ew_all.to_latex(escape=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "ew_table = extract_multiindex_step(ew_all)\n",
    "with open('tables/eigenworms.tex', 'w') as file:\n",
    "    file.write(ew_table.to_latex(escape=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Creating a BIDMC Results table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get data\n",
    "get_frame = lambda name: parse_results('TSR', \n",
    "                                       name, \n",
    "                                       'main_adjoint', \n",
    "                                       sort_key='test',\n",
    "                                       average_over=None, \n",
    "                                       print_frame=False, \n",
    "                                       pretty_std=False)\n",
    "rr = get_frame('BIDMC32RR')\n",
    "hr = get_frame('BIDMC32HR')\n",
    "sp = get_frame('BIDMC32SpO2')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we want just the mean of mem usage, time taken"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "rr_mean, rr_string = convert_to_mean_std_string(rr)\n",
    "hr_mean, hr_string = convert_to_mean_std_string(hr)\n",
    "sp_mean, sp_string = convert_to_mean_std_string(sp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mean mem usage\n",
    "get_mean = lambda col: pd.concat((rr_mean[col], hr_mean[col], sp_mean[col]), axis=1).mean(axis=1)\n",
    "mean_mem = get_mean('memory_usage').astype(int)\n",
    "mean_time = get_mean('elapsed_time').astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Combine\n",
    "all_results = pd.concat((rr_string, hr_string, sp_string, mean_mem, mean_time), axis=1)\n",
    "all_results = fix_index(all_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('tables/mean_bidmctex.tex', 'w') as file:\n",
    "    file.write(all_results.to_latex(escape=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Suppose we just wanted the full results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "# All results\n",
    "memory = pd.concat([x['memory_usage'] for x in [rr_mean, hr_mean, sp_mean]], axis=1).round(1)\n",
    "time = (pd.concat([x['elapsed_time'] for x in [rr_mean, hr_mean, sp_mean]], axis=1) / (60 ** 2)).round(1)\n",
    "\n",
    "# Combine\n",
    "all_results_ = pd.concat((rr_string, hr_string, sp_string, memory, time), axis=1)\n",
    "all_results_.reset_index(inplace=True)\n",
    "all_results_['depth'] = all_results_['depth'].astype(int)\n",
    "all_results_['step'] = all_results_['step'].astype(int)\n",
    "all_results_.set_index(['depth', 'step'], inplace=True)\n",
    "all_results_.sort_index(level=['depth', 'step'], inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mean the memory usage\n",
    "all_results_['mean_memory_usage'] = all_results_['memory_usage'].mean(axis=1).round(1).values\n",
    "all_results_.drop('memory_usage', axis=1, inplace=True);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_results_ = all_results_.iloc[[False if x in [20, 50] else True for x in all_results_.index.get_level_values(1)]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = all_results_.iloc[[True if x in [1, 8, 32, 128] else False for x in all_results_.index.get_level_values(1)]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('full_bidmctex.tex', 'w') as file:\n",
    "    file.write(output.to_latex(escape=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('tables/bidmc_full.tex', 'w') as file:\n",
    "    file.write(all_results_.to_latex(escape=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>loss.test</th>\n",
       "      <th>loss.test</th>\n",
       "      <th>loss.test</th>\n",
       "      <th>elapsed_time</th>\n",
       "      <th>elapsed_time</th>\n",
       "      <th>elapsed_time</th>\n",
       "      <th>mean_memory_usage</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>depth</th>\n",
       "      <th>step</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"12\" valign=\"top\">1</th>\n",
       "      <th>1</th>\n",
       "      <td>2.79 $\\pm$ 0.04</td>\n",
       "      <td>9.82 $\\pm$ 0.34</td>\n",
       "      <td>2.83 $\\pm$ 0.27</td>\n",
       "      <td>23.8</td>\n",
       "      <td>22.1</td>\n",
       "      <td>28.1</td>\n",
       "      <td>56.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2.87 $\\pm$ 0.03</td>\n",
       "      <td>11.69 $\\pm$ 0.38</td>\n",
       "      <td>3.36 $\\pm$ 0.2</td>\n",
       "      <td>19.3</td>\n",
       "      <td>9.6</td>\n",
       "      <td>8.8</td>\n",
       "      <td>32.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2.92 $\\pm$ 0.08</td>\n",
       "      <td>11.15 $\\pm$ 0.49</td>\n",
       "      <td>3.69 $\\pm$ 0.06</td>\n",
       "      <td>5.3</td>\n",
       "      <td>5.7</td>\n",
       "      <td>3.2</td>\n",
       "      <td>20.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>2.8 $\\pm$ 0.06</td>\n",
       "      <td>10.72 $\\pm$ 0.24</td>\n",
       "      <td>3.43 $\\pm$ 0.17</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2.6</td>\n",
       "      <td>4.8</td>\n",
       "      <td>14.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>2.22 $\\pm$ 0.07</td>\n",
       "      <td>7.98 $\\pm$ 0.61</td>\n",
       "      <td>2.9 $\\pm$ 0.11</td>\n",
       "      <td>1.7</td>\n",
       "      <td>1.4</td>\n",
       "      <td>1.8</td>\n",
       "      <td>11.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>2.53 $\\pm$ 0.23</td>\n",
       "      <td>12.23 $\\pm$ 0.43</td>\n",
       "      <td>2.68 $\\pm$ 0.12</td>\n",
       "      <td>1.9</td>\n",
       "      <td>0.9</td>\n",
       "      <td>2.2</td>\n",
       "      <td>9.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64</th>\n",
       "      <td>2.63 $\\pm$ 0.11</td>\n",
       "      <td>12.02 $\\pm$ 0.09</td>\n",
       "      <td>2.88 $\\pm$ 0.06</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0.3</td>\n",
       "      <td>0.4</td>\n",
       "      <td>9.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>128</th>\n",
       "      <td>2.64 $\\pm$ 0.18</td>\n",
       "      <td>11.98 $\\pm$ 0.37</td>\n",
       "      <td>2.86 $\\pm$ 0.04</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0.3</td>\n",
       "      <td>8.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>256</th>\n",
       "      <td>2.53 $\\pm$ 0.04</td>\n",
       "      <td>12.29 $\\pm$ 0.1</td>\n",
       "      <td>3.08 $\\pm$ 0.1</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.1</td>\n",
       "      <td>8.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>512</th>\n",
       "      <td>2.53 $\\pm$ 0.03</td>\n",
       "      <td>12.22 $\\pm$ 0.11</td>\n",
       "      <td>2.98 $\\pm$ 0.04</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.1</td>\n",
       "      <td>8.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1024</th>\n",
       "      <td>2.67 $\\pm$ 0.12</td>\n",
       "      <td>11.55 $\\pm$ 0.03</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>7.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2048</th>\n",
       "      <td>2.48 $\\pm$ 0.03</td>\n",
       "      <td>12.03 $\\pm$ 0.2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>7.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"11\" valign=\"top\">2</th>\n",
       "      <th>2</th>\n",
       "      <td>2.91 $\\pm$ 0.1</td>\n",
       "      <td>11.11 $\\pm$ 0.23</td>\n",
       "      <td>3.89 $\\pm$ 0.44</td>\n",
       "      <td>12.7</td>\n",
       "      <td>9.3</td>\n",
       "      <td>8.2</td>\n",
       "      <td>58.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2.92 $\\pm$ 0.04</td>\n",
       "      <td>11.14 $\\pm$ 0.2</td>\n",
       "      <td>4.23 $\\pm$ 0.57</td>\n",
       "      <td>18.1</td>\n",
       "      <td>5.0</td>\n",
       "      <td>3.4</td>\n",
       "      <td>34.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>2.63 $\\pm$ 0.12</td>\n",
       "      <td>8.63 $\\pm$ 0.24</td>\n",
       "      <td>2.88 $\\pm$ 0.15</td>\n",
       "      <td>2.1</td>\n",
       "      <td>3.4</td>\n",
       "      <td>3.3</td>\n",
       "      <td>21.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>1.8 $\\pm$ 0.07</td>\n",
       "      <td>5.73 $\\pm$ 0.45</td>\n",
       "      <td>1.98 $\\pm$ 0.21</td>\n",
       "      <td>2.2</td>\n",
       "      <td>1.4</td>\n",
       "      <td>2.5</td>\n",
       "      <td>16.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>1.9 $\\pm$ 0.02</td>\n",
       "      <td>7.9 $\\pm$ 1.0</td>\n",
       "      <td>1.69 $\\pm$ 0.2</td>\n",
       "      <td>1.2</td>\n",
       "      <td>1.1</td>\n",
       "      <td>2.0</td>\n",
       "      <td>13.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64</th>\n",
       "      <td>1.89 $\\pm$ 0.04</td>\n",
       "      <td>5.54 $\\pm$ 0.45</td>\n",
       "      <td>2.04 $\\pm$ 0.07</td>\n",
       "      <td>0.3</td>\n",
       "      <td>0.3</td>\n",
       "      <td>1.7</td>\n",
       "      <td>11.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>128</th>\n",
       "      <td>1.86 $\\pm$ 0.03</td>\n",
       "      <td>6.77 $\\pm$ 0.42</td>\n",
       "      <td>1.95 $\\pm$ 0.18</td>\n",
       "      <td>0.3</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.7</td>\n",
       "      <td>10.9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>256</th>\n",
       "      <td>1.86 $\\pm$ 0.09</td>\n",
       "      <td>5.64 $\\pm$ 0.19</td>\n",
       "      <td>2.1 $\\pm$ 0.19</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.5</td>\n",
       "      <td>10.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>512</th>\n",
       "      <td>1.81 $\\pm$ 0.02</td>\n",
       "      <td>5.05 $\\pm$ 0.23</td>\n",
       "      <td>2.17 $\\pm$ 0.18</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0.4</td>\n",
       "      <td>10.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1024</th>\n",
       "      <td>1.93 $\\pm$ 0.11</td>\n",
       "      <td>6.0 $\\pm$ 0.19</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>9.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2048</th>\n",
       "      <td>2.03 $\\pm$ 0.03</td>\n",
       "      <td>7.7 $\\pm$ 1.46</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>9.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"11\" valign=\"top\">3</th>\n",
       "      <th>2</th>\n",
       "      <td>2.82 $\\pm$ 0.08</td>\n",
       "      <td>11.01 $\\pm$ 0.28</td>\n",
       "      <td>4.1 $\\pm$ 0.72</td>\n",
       "      <td>8.8</td>\n",
       "      <td>9.4</td>\n",
       "      <td>6.9</td>\n",
       "      <td>125.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2.97 $\\pm$ 0.23</td>\n",
       "      <td>10.13 $\\pm$ 0.62</td>\n",
       "      <td>3.56 $\\pm$ 0.44</td>\n",
       "      <td>3.2</td>\n",
       "      <td>4.1</td>\n",
       "      <td>2.6</td>\n",
       "      <td>71.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>2.42 $\\pm$ 0.19</td>\n",
       "      <td>7.67 $\\pm$ 0.4</td>\n",
       "      <td>2.55 $\\pm$ 0.13</td>\n",
       "      <td>2.9</td>\n",
       "      <td>3.2</td>\n",
       "      <td>3.1</td>\n",
       "      <td>43.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>1.74 $\\pm$ 0.05</td>\n",
       "      <td>4.11 $\\pm$ 0.61</td>\n",
       "      <td>1.4 $\\pm$ 0.06</td>\n",
       "      <td>1.4</td>\n",
       "      <td>1.4</td>\n",
       "      <td>6.5</td>\n",
       "      <td>29.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>1.67 $\\pm$ 0.01</td>\n",
       "      <td>4.5 $\\pm$ 0.7</td>\n",
       "      <td>1.61 $\\pm$ 0.05</td>\n",
       "      <td>1.3</td>\n",
       "      <td>1.8</td>\n",
       "      <td>7.3</td>\n",
       "      <td>20.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64</th>\n",
       "      <td>1.53 $\\pm$ 0.08</td>\n",
       "      <td>3.05 $\\pm$ 0.36</td>\n",
       "      <td>1.48 $\\pm$ 0.14</td>\n",
       "      <td>0.4</td>\n",
       "      <td>1.9</td>\n",
       "      <td>3.3</td>\n",
       "      <td>17.9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>128</th>\n",
       "      <td>1.51 $\\pm$ 0.08</td>\n",
       "      <td>2.97 $\\pm$ 0.45</td>\n",
       "      <td>1.37 $\\pm$ 0.22</td>\n",
       "      <td>0.5</td>\n",
       "      <td>1.7</td>\n",
       "      <td>1.7</td>\n",
       "      <td>17.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>256</th>\n",
       "      <td>1.51 $\\pm$ 0.06</td>\n",
       "      <td>3.4 $\\pm$ 0.74</td>\n",
       "      <td>1.47 $\\pm$ 0.07</td>\n",
       "      <td>0.3</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.6</td>\n",
       "      <td>16.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>512</th>\n",
       "      <td>1.49 $\\pm$ 0.08</td>\n",
       "      <td>3.46 $\\pm$ 0.13</td>\n",
       "      <td>1.29 $\\pm$ 0.15</td>\n",
       "      <td>0.3</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.4</td>\n",
       "      <td>15.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1024</th>\n",
       "      <td>1.83 $\\pm$ 0.33</td>\n",
       "      <td>5.58 $\\pm$ 2.5</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0.1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>14.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2048</th>\n",
       "      <td>2.31 $\\pm$ 0.27</td>\n",
       "      <td>9.77 $\\pm$ 1.53</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>14.7</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  loss.test         loss.test        loss.test  elapsed_time  \\\n",
       "depth step                                                                     \n",
       "1     1     2.79 $\\pm$ 0.04   9.82 $\\pm$ 0.34  2.83 $\\pm$ 0.27          23.8   \n",
       "      2     2.87 $\\pm$ 0.03  11.69 $\\pm$ 0.38   3.36 $\\pm$ 0.2          19.3   \n",
       "      4     2.92 $\\pm$ 0.08  11.15 $\\pm$ 0.49  3.69 $\\pm$ 0.06           5.3   \n",
       "      8      2.8 $\\pm$ 0.06  10.72 $\\pm$ 0.24  3.43 $\\pm$ 0.17           3.0   \n",
       "      16    2.22 $\\pm$ 0.07   7.98 $\\pm$ 0.61   2.9 $\\pm$ 0.11           1.7   \n",
       "      32    2.53 $\\pm$ 0.23  12.23 $\\pm$ 0.43  2.68 $\\pm$ 0.12           1.9   \n",
       "      64    2.63 $\\pm$ 0.11  12.02 $\\pm$ 0.09  2.88 $\\pm$ 0.06           0.2   \n",
       "      128   2.64 $\\pm$ 0.18  11.98 $\\pm$ 0.37  2.86 $\\pm$ 0.04           0.2   \n",
       "      256   2.53 $\\pm$ 0.04   12.29 $\\pm$ 0.1   3.08 $\\pm$ 0.1           0.1   \n",
       "      512   2.53 $\\pm$ 0.03  12.22 $\\pm$ 0.11  2.98 $\\pm$ 0.04           0.1   \n",
       "      1024  2.67 $\\pm$ 0.12  11.55 $\\pm$ 0.03              NaN           0.1   \n",
       "      2048  2.48 $\\pm$ 0.03   12.03 $\\pm$ 0.2              NaN           0.0   \n",
       "2     2      2.91 $\\pm$ 0.1  11.11 $\\pm$ 0.23  3.89 $\\pm$ 0.44          12.7   \n",
       "      4     2.92 $\\pm$ 0.04   11.14 $\\pm$ 0.2  4.23 $\\pm$ 0.57          18.1   \n",
       "      8     2.63 $\\pm$ 0.12   8.63 $\\pm$ 0.24  2.88 $\\pm$ 0.15           2.1   \n",
       "      16     1.8 $\\pm$ 0.07   5.73 $\\pm$ 0.45  1.98 $\\pm$ 0.21           2.2   \n",
       "      32     1.9 $\\pm$ 0.02     7.9 $\\pm$ 1.0   1.69 $\\pm$ 0.2           1.2   \n",
       "      64    1.89 $\\pm$ 0.04   5.54 $\\pm$ 0.45  2.04 $\\pm$ 0.07           0.3   \n",
       "      128   1.86 $\\pm$ 0.03   6.77 $\\pm$ 0.42  1.95 $\\pm$ 0.18           0.3   \n",
       "      256   1.86 $\\pm$ 0.09   5.64 $\\pm$ 0.19   2.1 $\\pm$ 0.19           0.1   \n",
       "      512   1.81 $\\pm$ 0.02   5.05 $\\pm$ 0.23  2.17 $\\pm$ 0.18           0.1   \n",
       "      1024  1.93 $\\pm$ 0.11    6.0 $\\pm$ 0.19              NaN           0.1   \n",
       "      2048  2.03 $\\pm$ 0.03    7.7 $\\pm$ 1.46              NaN           0.1   \n",
       "3     2     2.82 $\\pm$ 0.08  11.01 $\\pm$ 0.28   4.1 $\\pm$ 0.72           8.8   \n",
       "      4     2.97 $\\pm$ 0.23  10.13 $\\pm$ 0.62  3.56 $\\pm$ 0.44           3.2   \n",
       "      8     2.42 $\\pm$ 0.19    7.67 $\\pm$ 0.4  2.55 $\\pm$ 0.13           2.9   \n",
       "      16    1.74 $\\pm$ 0.05   4.11 $\\pm$ 0.61   1.4 $\\pm$ 0.06           1.4   \n",
       "      32    1.67 $\\pm$ 0.01     4.5 $\\pm$ 0.7  1.61 $\\pm$ 0.05           1.3   \n",
       "      64    1.53 $\\pm$ 0.08   3.05 $\\pm$ 0.36  1.48 $\\pm$ 0.14           0.4   \n",
       "      128   1.51 $\\pm$ 0.08   2.97 $\\pm$ 0.45  1.37 $\\pm$ 0.22           0.5   \n",
       "      256   1.51 $\\pm$ 0.06    3.4 $\\pm$ 0.74  1.47 $\\pm$ 0.07           0.3   \n",
       "      512   1.49 $\\pm$ 0.08   3.46 $\\pm$ 0.13  1.29 $\\pm$ 0.15           0.3   \n",
       "      1024  1.83 $\\pm$ 0.33    5.58 $\\pm$ 2.5              NaN           0.2   \n",
       "      2048  2.31 $\\pm$ 0.27   9.77 $\\pm$ 1.53              NaN           0.1   \n",
       "\n",
       "            elapsed_time  elapsed_time  mean_memory_usage  \n",
       "depth step                                                 \n",
       "1     1             22.1          28.1               56.5  \n",
       "      2              9.6           8.8               32.6  \n",
       "      4              5.7           3.2               20.2  \n",
       "      8              2.6           4.8               14.3  \n",
       "      16             1.4           1.8               11.8  \n",
       "      32             0.9           2.2                9.8  \n",
       "      64             0.3           0.4                9.1  \n",
       "      128            0.2           0.3                8.7  \n",
       "      256            0.1           0.1                8.3  \n",
       "      512            0.0           0.1                8.4  \n",
       "      1024           0.1           NaN                7.8  \n",
       "      2048           0.1           NaN                7.8  \n",
       "2     2              9.3           8.2               58.3  \n",
       "      4              5.0           3.4               34.0  \n",
       "      8              3.4           3.3               21.8  \n",
       "      16             1.4           2.5               16.0  \n",
       "      32             1.1           2.0               13.1  \n",
       "      64             0.3           1.7               11.6  \n",
       "      128            0.4           0.7               10.9  \n",
       "      256            0.1           0.5               10.5  \n",
       "      512            0.2           0.4               10.3  \n",
       "      1024           0.1           NaN                9.6  \n",
       "      2048           0.1           NaN                9.6  \n",
       "3     2              9.4           6.9              125.2  \n",
       "      4              4.1           2.6               71.6  \n",
       "      8              3.2           3.1               43.3  \n",
       "      16             1.4           6.5               29.1  \n",
       "      32             1.8           7.3               20.5  \n",
       "      64             1.9           3.3               17.9  \n",
       "      128            1.7           1.7               17.3  \n",
       "      256            0.7           0.6               16.6  \n",
       "      512            0.4           0.4               15.4  \n",
       "      1024           0.1           NaN               14.8  \n",
       "      2048           0.1           NaN               14.7  "
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_results_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Neural RDEs",
   "language": "python",
   "name": "nrdes"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}