{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e400615e-3e2b-4410-9f56-03461c6900bc",
   "metadata": {},
   "source": [
    "# PASHA: Efficient HPO with Progressive Resource Allocation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f3b015b-603e-4c19-863f-b91d1291a613",
   "metadata": {},
   "source": [
    "Hyperparameter optimization and neural architecture search are important for obtaining\n",
    "well-performing models, but they are costly in practice, especially for large datasets.\n",
    "To decrease the cost, practitioners adopt heuristics with mixed results. We propose an approach \n",
    "to tackle the challenge: start with a small amount of resources and progressively increase them\n",
    "as needed. Our approach named PASHA measures the stability of ranking of different hyperparameter\n",
    "configurations and stops increasing the resources if the ranking becomes stable, returning\n",
    "the best configuration. Our experiments show PASHA significantly accelerates multi-fidelity methods\n",
    "and obtains similarly well-performing hyperparameters."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2112eca-63c9-4b13-8cb0-fd3864675191",
   "metadata": {},
   "source": [
    "Outline:\n",
    "* Initial pre-processing and exploration\n",
    "* Main experiments on NASBench201- with PASHA, ASHA and the baselines\n",
    "* Alternative ranking functions\n",
    "* Changes to the reduction factor\n",
    "* Combination with Bayesian Optimization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa2c80fa-5268-49d7-9ebc-0bd4c52f2dbf",
   "metadata": {},
   "source": [
    "Start by importing the relevant libraries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ba934ed4-2655-4ec2-97f7-1106c29171a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import logging\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import random\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "mpl.rcParams['axes.spines.right'] = False\n",
    "mpl.rcParams['axes.spines.top'] = False\n",
    "\n",
    "from benchmarking.blackbox_repository import load\n",
    "from benchmarking.blackbox_repository.tabulated_benchmark import BlackboxRepositoryBackend\n",
    "from benchmarking.definitions.definition_nasbench201 import nasbench201_benchmark, nasbench201_default_params\n",
    "from syne_tune.backend.simulator_backend.simulator_callback import SimulatorCallback\n",
    "from syne_tune.optimizer.schedulers.hyperband import HyperbandScheduler\n",
    "from syne_tune.optimizer.baselines import baselines\n",
    "from syne_tune.tuner import Tuner\n",
    "from syne_tune.stopping_criterion import StoppingCriterion\n",
    "from syne_tune.experiments import load_experiment\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "158e40be-2cf3-4fda-83fa-ad59c919f02a",
   "metadata": {},
   "source": [
    "We will also need to use the `rbo` library for one of the additional ranking functions so install it now:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "905d043b-89ff-4d68-9841-52031c4ee9ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install rbo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17b685b1-824d-477e-a9e3-ecbc06d19adf",
   "metadata": {},
   "source": [
    "Define our settings:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "409df06e-280d-4a87-8afe-40a7db28e13c",
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_valid_error_dim = 0\n",
    "metric_runtime_dim = 2\n",
    "dataset_names = ['cifar10', 'cifar100', 'ImageNet16-120']\n",
    "epoch_names = ['val_acc_epoch_' + str(e) for e in range(200)]\n",
    "random_seeds = [31415927, 0, 1234, 3458, 7685]\n",
    "nb201_random_seeds = [0, 1, 2]\n",
    "n_workers = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "810f0609-63a8-4714-b328-9489c601ef9a",
   "metadata": {},
   "source": [
    "# Initial pre-processing and exploration"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6f2cd7e-c425-47eb-88c1-3fd5a248d8bd",
   "metadata": {},
   "source": [
    "Load NASBench201 benchmark so that we can analyse the performance of various approaches:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f33be501-3945-4603-9521-4d65a4404145",
   "metadata": {},
   "outputs": [],
   "source": [
    "bb_dict = load('nasbench201')\n",
    "df_dict = {}\n",
    "\n",
    "for seed in nb201_random_seeds:\n",
    "    df_dict[seed] = {}\n",
    "    for dataset in dataset_names:\n",
    "        # create a dataframe with the validation accuracies for various epochs\n",
    "        df_val_acc = pd.DataFrame((1.0-bb_dict[dataset].objectives_evaluations[:, seed, :, metric_valid_error_dim])\n",
    "                                  * 100, columns=['val_acc_epoch_' + str(e) for e in range(200)])\n",
    "        # add a new column with the best validation accuracy\n",
    "        df_val_acc['val_acc_best'] = df_val_acc[epoch_names].max(axis=1)\n",
    "        # create a dataframe with the hyperparameter values\n",
    "        df_hp = bb_dict[dataset].hyperparameters\n",
    "        # create a dataframe with the times it takes to run an epoch\n",
    "        df_time = pd.DataFrame(bb_dict[dataset].objectives_evaluations[:, seed, :, metric_runtime_dim][:, -1], columns=['eval_time_epoch'])    \n",
    "        # combine all smaller dataframes into one dataframe for each NASBench201 random seed and dataset\n",
    "        df_dict[seed][dataset] = pd.concat([df_hp, df_val_acc, df_time], axis=1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f7cc4d8-0c4c-420b-9b2c-a1332c720b39",
   "metadata": {},
   "source": [
    "Motivation to measure best validation accuracy: NASBench201 provides validation and test errors in an inconsistent format and in fact we can only get the errors for each epoch on their combined validation and test sets for CIFAR-100 and ImageNet16-120. As a trade-off, we use the combined validation and test sets as the validation set. Consequently, there is no test set which we can use for additional evaluation and so we use the best validation accuracy as the final evaluation metric."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40c32d92-6939-4cba-aa48-405f6df00cd8",
   "metadata": {},
   "source": [
    "Have a look at what the dataframes look like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "10be535e-5ea5-431c-a2fa-6586f0d45a1b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>hp_x0</th>\n",
       "      <th>hp_x1</th>\n",
       "      <th>hp_x2</th>\n",
       "      <th>hp_x3</th>\n",
       "      <th>hp_x4</th>\n",
       "      <th>hp_x5</th>\n",
       "      <th>val_acc_epoch_0</th>\n",
       "      <th>val_acc_epoch_1</th>\n",
       "      <th>val_acc_epoch_2</th>\n",
       "      <th>val_acc_epoch_3</th>\n",
       "      <th>...</th>\n",
       "      <th>val_acc_epoch_192</th>\n",
       "      <th>val_acc_epoch_193</th>\n",
       "      <th>val_acc_epoch_194</th>\n",
       "      <th>val_acc_epoch_195</th>\n",
       "      <th>val_acc_epoch_196</th>\n",
       "      <th>val_acc_epoch_197</th>\n",
       "      <th>val_acc_epoch_198</th>\n",
       "      <th>val_acc_epoch_199</th>\n",
       "      <th>val_acc_best</th>\n",
       "      <th>eval_time_epoch</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>avg_pool_3x3</td>\n",
       "      <td>nor_conv_1x1</td>\n",
       "      <td>skip_connect</td>\n",
       "      <td>nor_conv_1x1</td>\n",
       "      <td>skip_connect</td>\n",
       "      <td>skip_connect</td>\n",
       "      <td>37.980003</td>\n",
       "      <td>36.909996</td>\n",
       "      <td>38.739998</td>\n",
       "      <td>31.059998</td>\n",
       "      <td>...</td>\n",
       "      <td>85.469994</td>\n",
       "      <td>85.479996</td>\n",
       "      <td>85.699997</td>\n",
       "      <td>85.689995</td>\n",
       "      <td>85.570000</td>\n",
       "      <td>85.639999</td>\n",
       "      <td>85.659996</td>\n",
       "      <td>85.619995</td>\n",
       "      <td>85.729996</td>\n",
       "      <td>15.461778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>nor_conv_3x3</td>\n",
       "      <td>nor_conv_3x3</td>\n",
       "      <td>avg_pool_3x3</td>\n",
       "      <td>skip_connect</td>\n",
       "      <td>nor_conv_3x3</td>\n",
       "      <td>skip_connect</td>\n",
       "      <td>34.619999</td>\n",
       "      <td>61.129997</td>\n",
       "      <td>61.039997</td>\n",
       "      <td>70.029999</td>\n",
       "      <td>...</td>\n",
       "      <td>93.739998</td>\n",
       "      <td>93.739998</td>\n",
       "      <td>93.729996</td>\n",
       "      <td>93.709999</td>\n",
       "      <td>93.739998</td>\n",
       "      <td>93.729996</td>\n",
       "      <td>93.639999</td>\n",
       "      <td>93.750000</td>\n",
       "      <td>93.750000</td>\n",
       "      <td>23.198093</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>avg_pool_3x3</td>\n",
       "      <td>nor_conv_3x3</td>\n",
       "      <td>nor_conv_3x3</td>\n",
       "      <td>avg_pool_3x3</td>\n",
       "      <td>avg_pool_3x3</td>\n",
       "      <td>avg_pool_3x3</td>\n",
       "      <td>36.389999</td>\n",
       "      <td>44.279999</td>\n",
       "      <td>42.290001</td>\n",
       "      <td>42.139999</td>\n",
       "      <td>...</td>\n",
       "      <td>86.029999</td>\n",
       "      <td>85.790001</td>\n",
       "      <td>85.820000</td>\n",
       "      <td>85.959999</td>\n",
       "      <td>85.870003</td>\n",
       "      <td>85.699997</td>\n",
       "      <td>85.769997</td>\n",
       "      <td>85.839996</td>\n",
       "      <td>86.029999</td>\n",
       "      <td>24.261475</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>avg_pool_3x3</td>\n",
       "      <td>skip_connect</td>\n",
       "      <td>none</td>\n",
       "      <td>none</td>\n",
       "      <td>none</td>\n",
       "      <td>skip_connect</td>\n",
       "      <td>37.349998</td>\n",
       "      <td>48.919998</td>\n",
       "      <td>56.320000</td>\n",
       "      <td>59.029995</td>\n",
       "      <td>...</td>\n",
       "      <td>86.870003</td>\n",
       "      <td>86.779999</td>\n",
       "      <td>86.790001</td>\n",
       "      <td>86.830002</td>\n",
       "      <td>86.870003</td>\n",
       "      <td>86.720001</td>\n",
       "      <td>86.790001</td>\n",
       "      <td>86.809998</td>\n",
       "      <td>86.989998</td>\n",
       "      <td>9.305114</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>skip_connect</td>\n",
       "      <td>skip_connect</td>\n",
       "      <td>nor_conv_1x1</td>\n",
       "      <td>skip_connect</td>\n",
       "      <td>skip_connect</td>\n",
       "      <td>nor_conv_1x1</td>\n",
       "      <td>35.450001</td>\n",
       "      <td>36.369999</td>\n",
       "      <td>32.860001</td>\n",
       "      <td>34.069996</td>\n",
       "      <td>...</td>\n",
       "      <td>87.900002</td>\n",
       "      <td>88.199997</td>\n",
       "      <td>88.160004</td>\n",
       "      <td>88.099998</td>\n",
       "      <td>87.989998</td>\n",
       "      <td>88.070000</td>\n",
       "      <td>88.209999</td>\n",
       "      <td>88.090004</td>\n",
       "      <td>88.209999</td>\n",
       "      <td>13.933862</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15620</th>\n",
       "      <td>none</td>\n",
       "      <td>avg_pool_3x3</td>\n",
       "      <td>avg_pool_3x3</td>\n",
       "      <td>skip_connect</td>\n",
       "      <td>skip_connect</td>\n",
       "      <td>none</td>\n",
       "      <td>40.399998</td>\n",
       "      <td>47.910000</td>\n",
       "      <td>58.350002</td>\n",
       "      <td>56.389999</td>\n",
       "      <td>...</td>\n",
       "      <td>86.449997</td>\n",
       "      <td>86.529999</td>\n",
       "      <td>86.489998</td>\n",
       "      <td>86.660004</td>\n",
       "      <td>86.639999</td>\n",
       "      <td>86.669998</td>\n",
       "      <td>86.559998</td>\n",
       "      <td>86.619995</td>\n",
       "      <td>86.669998</td>\n",
       "      <td>11.035928</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15621</th>\n",
       "      <td>avg_pool_3x3</td>\n",
       "      <td>nor_conv_3x3</td>\n",
       "      <td>nor_conv_3x3</td>\n",
       "      <td>skip_connect</td>\n",
       "      <td>nor_conv_1x1</td>\n",
       "      <td>none</td>\n",
       "      <td>36.619995</td>\n",
       "      <td>51.030003</td>\n",
       "      <td>60.530003</td>\n",
       "      <td>45.649998</td>\n",
       "      <td>...</td>\n",
       "      <td>88.489998</td>\n",
       "      <td>88.459999</td>\n",
       "      <td>88.700005</td>\n",
       "      <td>88.610001</td>\n",
       "      <td>88.510002</td>\n",
       "      <td>88.510002</td>\n",
       "      <td>88.669998</td>\n",
       "      <td>88.400002</td>\n",
       "      <td>88.700005</td>\n",
       "      <td>22.097097</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15622</th>\n",
       "      <td>skip_connect</td>\n",
       "      <td>nor_conv_3x3</td>\n",
       "      <td>nor_conv_3x3</td>\n",
       "      <td>nor_conv_3x3</td>\n",
       "      <td>none</td>\n",
       "      <td>nor_conv_1x1</td>\n",
       "      <td>37.989998</td>\n",
       "      <td>45.160000</td>\n",
       "      <td>48.240002</td>\n",
       "      <td>55.739998</td>\n",
       "      <td>...</td>\n",
       "      <td>92.779999</td>\n",
       "      <td>92.769997</td>\n",
       "      <td>92.809998</td>\n",
       "      <td>92.750000</td>\n",
       "      <td>92.720001</td>\n",
       "      <td>92.729996</td>\n",
       "      <td>92.769997</td>\n",
       "      <td>92.779999</td>\n",
       "      <td>92.849998</td>\n",
       "      <td>24.775816</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15623</th>\n",
       "      <td>none</td>\n",
       "      <td>avg_pool_3x3</td>\n",
       "      <td>avg_pool_3x3</td>\n",
       "      <td>avg_pool_3x3</td>\n",
       "      <td>none</td>\n",
       "      <td>avg_pool_3x3</td>\n",
       "      <td>20.420002</td>\n",
       "      <td>29.619997</td>\n",
       "      <td>21.740002</td>\n",
       "      <td>25.230003</td>\n",
       "      <td>...</td>\n",
       "      <td>69.110001</td>\n",
       "      <td>68.739998</td>\n",
       "      <td>69.089996</td>\n",
       "      <td>69.330002</td>\n",
       "      <td>69.449997</td>\n",
       "      <td>69.459999</td>\n",
       "      <td>69.420006</td>\n",
       "      <td>69.480003</td>\n",
       "      <td>69.480003</td>\n",
       "      <td>14.217638</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15624</th>\n",
       "      <td>nor_conv_1x1</td>\n",
       "      <td>none</td>\n",
       "      <td>nor_conv_1x1</td>\n",
       "      <td>none</td>\n",
       "      <td>nor_conv_1x1</td>\n",
       "      <td>nor_conv_1x1</td>\n",
       "      <td>30.980003</td>\n",
       "      <td>35.140003</td>\n",
       "      <td>38.410004</td>\n",
       "      <td>40.420002</td>\n",
       "      <td>...</td>\n",
       "      <td>85.399994</td>\n",
       "      <td>85.539993</td>\n",
       "      <td>85.430000</td>\n",
       "      <td>85.489998</td>\n",
       "      <td>85.519997</td>\n",
       "      <td>85.659996</td>\n",
       "      <td>85.549995</td>\n",
       "      <td>85.619995</td>\n",
       "      <td>85.659996</td>\n",
       "      <td>22.166279</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>15625 rows × 208 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "              hp_x0         hp_x1         hp_x2         hp_x3         hp_x4  \\\n",
       "0      avg_pool_3x3  nor_conv_1x1  skip_connect  nor_conv_1x1  skip_connect   \n",
       "1      nor_conv_3x3  nor_conv_3x3  avg_pool_3x3  skip_connect  nor_conv_3x3   \n",
       "2      avg_pool_3x3  nor_conv_3x3  nor_conv_3x3  avg_pool_3x3  avg_pool_3x3   \n",
       "3      avg_pool_3x3  skip_connect          none          none          none   \n",
       "4      skip_connect  skip_connect  nor_conv_1x1  skip_connect  skip_connect   \n",
       "...             ...           ...           ...           ...           ...   \n",
       "15620          none  avg_pool_3x3  avg_pool_3x3  skip_connect  skip_connect   \n",
       "15621  avg_pool_3x3  nor_conv_3x3  nor_conv_3x3  skip_connect  nor_conv_1x1   \n",
       "15622  skip_connect  nor_conv_3x3  nor_conv_3x3  nor_conv_3x3          none   \n",
       "15623          none  avg_pool_3x3  avg_pool_3x3  avg_pool_3x3          none   \n",
       "15624  nor_conv_1x1          none  nor_conv_1x1          none  nor_conv_1x1   \n",
       "\n",
       "              hp_x5  val_acc_epoch_0  val_acc_epoch_1  val_acc_epoch_2  \\\n",
       "0      skip_connect        37.980003        36.909996        38.739998   \n",
       "1      skip_connect        34.619999        61.129997        61.039997   \n",
       "2      avg_pool_3x3        36.389999        44.279999        42.290001   \n",
       "3      skip_connect        37.349998        48.919998        56.320000   \n",
       "4      nor_conv_1x1        35.450001        36.369999        32.860001   \n",
       "...             ...              ...              ...              ...   \n",
       "15620          none        40.399998        47.910000        58.350002   \n",
       "15621          none        36.619995        51.030003        60.530003   \n",
       "15622  nor_conv_1x1        37.989998        45.160000        48.240002   \n",
       "15623  avg_pool_3x3        20.420002        29.619997        21.740002   \n",
       "15624  nor_conv_1x1        30.980003        35.140003        38.410004   \n",
       "\n",
       "       val_acc_epoch_3  ...  val_acc_epoch_192  val_acc_epoch_193  \\\n",
       "0            31.059998  ...          85.469994          85.479996   \n",
       "1            70.029999  ...          93.739998          93.739998   \n",
       "2            42.139999  ...          86.029999          85.790001   \n",
       "3            59.029995  ...          86.870003          86.779999   \n",
       "4            34.069996  ...          87.900002          88.199997   \n",
       "...                ...  ...                ...                ...   \n",
       "15620        56.389999  ...          86.449997          86.529999   \n",
       "15621        45.649998  ...          88.489998          88.459999   \n",
       "15622        55.739998  ...          92.779999          92.769997   \n",
       "15623        25.230003  ...          69.110001          68.739998   \n",
       "15624        40.420002  ...          85.399994          85.539993   \n",
       "\n",
       "       val_acc_epoch_194  val_acc_epoch_195  val_acc_epoch_196  \\\n",
       "0              85.699997          85.689995          85.570000   \n",
       "1              93.729996          93.709999          93.739998   \n",
       "2              85.820000          85.959999          85.870003   \n",
       "3              86.790001          86.830002          86.870003   \n",
       "4              88.160004          88.099998          87.989998   \n",
       "...                  ...                ...                ...   \n",
       "15620          86.489998          86.660004          86.639999   \n",
       "15621          88.700005          88.610001          88.510002   \n",
       "15622          92.809998          92.750000          92.720001   \n",
       "15623          69.089996          69.330002          69.449997   \n",
       "15624          85.430000          85.489998          85.519997   \n",
       "\n",
       "       val_acc_epoch_197  val_acc_epoch_198  val_acc_epoch_199  val_acc_best  \\\n",
       "0              85.639999          85.659996          85.619995     85.729996   \n",
       "1              93.729996          93.639999          93.750000     93.750000   \n",
       "2              85.699997          85.769997          85.839996     86.029999   \n",
       "3              86.720001          86.790001          86.809998     86.989998   \n",
       "4              88.070000          88.209999          88.090004     88.209999   \n",
       "...                  ...                ...                ...           ...   \n",
       "15620          86.669998          86.559998          86.619995     86.669998   \n",
       "15621          88.510002          88.669998          88.400002     88.700005   \n",
       "15622          92.729996          92.769997          92.779999     92.849998   \n",
       "15623          69.459999          69.420006          69.480003     69.480003   \n",
       "15624          85.659996          85.549995          85.619995     85.659996   \n",
       "\n",
       "       eval_time_epoch  \n",
       "0            15.461778  \n",
       "1            23.198093  \n",
       "2            24.261475  \n",
       "3             9.305114  \n",
       "4            13.933862  \n",
       "...                ...  \n",
       "15620        11.035928  \n",
       "15621        22.097097  \n",
       "15622        24.775816  \n",
       "15623        14.217638  \n",
       "15624        22.166279  \n",
       "\n",
       "[15625 rows x 208 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_dict[0]['cifar10']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a13c2e2-3410-4329-a9f2-2c0d1d3840a6",
   "metadata": {},
   "source": [
    "Perform brief exploratory analysis to show NAS is helpful:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "42530baf-ed37-4e39-b45c-6668bf291ab3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA1gAAAEYCAYAAABBWFftAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA6QUlEQVR4nO3debhlVX3n//dHEMSRqcIPGVKo5YB2LLQasZ0QVEq0RRMHSKKoKKaDHW3THUszQEwwZFDUjiFBIWAUkThRCgFLhDZ2C1IMIoNKAaVUpYSCQkRRsPD7+2OvC6cu99a9t+rc4Zz7fj3Pfs7Za6+z99r3nPO957v32munqpAkSZIkbb2HzHYDJEmSJGlYmGBJkiRJUp+YYEmSJElSn5hgSZIkSVKfmGBJkiRJUp+YYEmSJElSn5hgzXNJnp3krCT/keTeJLcnWZHkyCTbJDkwSSV5Uc9rjmtlY01P6Km3Q5I7W/nTx9n+RaNef1eS/5vkFZNs/9OS/FOSy1r7x73vQJKdknw8yW1Jfpbkq0n+01T+XpImz/jy4PiS5GFJ/jbJuiQ/T/LNJM+fTHukQZHkjaO/s4OoJ3a8eYxln0yyegvWubDFuceNsey/JjkjyfeT/CrJRROs6w1JLk1yd5IfJ/nGZH7XJHluktOSXJ1k43j7keTVST6X5ActXn0vyV8ledQYdf2N1cMEax5L8k7g/wI7A+8GXgS8Gfg+cBLw8glW8Vzg2aOmm3uWvwp4dHv+hs2s56qe1x8FPAL4fJJnTWI3ngkcCvwQWDlepSQBvgQsBf478FvAQ4ELk+w5ie1ImgLjy7jx5RTgrcCf0f0N1gHnJ1k8ifZImh3HJtmuT+taCBwLPCjBAl4JLAYuBtZsbiVJ3g/8I3AO8DLgd4ALgIdPog0HA88DrgGu20y9/wncB7yXLr6dBPw3YEWS+3MIf2ONoaqc5uEEPB/4FfCRcZY/HvgN4ECggBf1LDuulW07wTbOB26nCxQ/Gqs+cBHwjVFle7a2/eMk9uMhPc//svtIj1nvsNbmF/aUPQbYMN7fwMnJacsm48vY8QV4eqv3pp6ybYHvActn+31zcurXBLyxfdafMNtt2cr9qBZrCvjvo5Z9Eli9Bet8UNzrWdYbc74BXDTOOp7d4tgrt3C/ercz7n4AC8Yoe0Nr/0E9Zf7GGjV5Bmv+ejfdB/+PxlpYVTdU1VVbuvIke9AdsT4T+DiwG3DIZF5bVWuA9cDek6j7q0k26RXAf1TVhT2vvZPuiMthk1yHpMkxvowdX14B/BL4TE+9jXT7cUiS7Se5PWmgtO6630iyNMmVrbvZFUmelWTbJO9v3WY3tK5rjxj1+j9PcnmSn7QuaF9LcsAY23lGkn9v6785yXvba2tUvW2TvCfJd5Pck64b8weSPGyM5l8KfBH44ySbPTs00XqTHAiMxIkVPV0QD4QpxZz/BtxUVV+cZP1NTHY7VbV+jOJL2+MePWX+xhrFBGseSrIN8ELgK1X1i61Y1TYtmIxMvZ+n36X7fH0C+FfgF2y+G09v+x4F7ALcsBVtG+2pwNVjlF8D7J3kkX3cljRvGV82MTq+PJXuR9HdY9TbDhjo61WkCTwB+FvgBOA1wPbAcrpuZ7vTnfV6H11Xt2NHvXYP4ES6H+tvBG4Fvt57jU+SXem6yO0MHEnXVe2QVn+0TwJ/ApxB173ur+i6EH9qnLb/CbAA+IMJ9nGi9V4OHNOe/wEPdF++fIL1jvZc4NtJ/ijJ2nYd1dVJXjPF9WyJF7TH3q6F/sYaZdvZboBmxa7ADsAPtnI9o388fYruhw90we17VXUJQJIvAq9MsmNV/Xj0ipKMfBb3Av6G7uj3iVvZvl47A6vHKN/QHncCftrH7UnzlfHlAaPjy87AHZupt3Mf2yTNNbsA/6WqbgRoB03OBvapqpGBbs5PN+jLa+g5A15Vbxl53g7inEf34/0twDvaonfRXX90SDtTTZLzGfXdTPI84HXAkVX1iVb81SQbgE8mWVxVV/a+pqquSXIG8EdJTmpnZzYx2fUmubYtu66qLp74zzamx9LF2v2A/0V3Vv5o4Kwkr6yqs7dwvZvVeg+8D/hqVfVel+pvrFE8g6WtcQDwn3umPwVI8p+BpwD/0lP3dOBhdMFntOfQdZv5JXAj8F+B3xoJwm2dvUeyPTAgDT/jizRcvt/7vQO+2x7PH1Xvu8CebeAEAJK8KMmFSW4HNtJ9n58IPKnndQcAF48kVwBV9XO6QSB6LQXuBT476nv/lbZ8vFE9jwUeSZfQjGVL17slHgI8ii6WnVFVK+ji3zV0A1KQTt9iWzsLdTbd3/9NW9f84WeCNT/dDvwc+PWtXM9lVbWyZ7qplR/ZHr+UZMckO9L12V3P2N14vk33A+oAulPpdwH/mmQBdEOa8sAPpF8Cv2xlU3EH3RGU0XbuWS5p6xlfHjA6vkxUb8MYy6RhMfr/7L2bKd8W2Aa666qAc+nOgBzFAwdfvk13YGXE7nRdB0e7ZdT8r9F1yf0Zm373R167y1iNb8nhKcA7RuJHP9a7hW4HNlTV/V0L23VVF9CNQghdV75fjpq2SJId6K6nehw9Zwh7+BtrFI/UzUNVtTHdvRVenGT7qrqnX+tON4zpEW3222NUWZBkUVVd31P2055TzZckuQn4Gt1oYscA/0EXTHv9xxSbdg3wkjHK9wV+WFXz6tS1NF2ML5sYHV+uAV6V5OGjrsPal+5H5aopbleaD36L7qzJb1bV/UlCkp2AH/fUW0eX5Iy226j52+m6ID9vnO1t7vv/F3QHed47xrKtWe9UXUPXPXAsIwN6XMaDY9uUJXko8FlgCfDiqvrOOO3xN1YPz2DNXyfQHU35m7EWJtknyW9swXpfTnfE4s/pLnTvnQ5vdTZ7MXobheYLwFuS7FlV9446kr2yqu7d3DrGsBzYI8nIxZkkeTRdd6HlU1yXpM0zvowdX75Ed2+Y1/TU25aua89X+pmMSkPk4XT3Yrp/JMAkB/HgkUAvBp6dnvsutTMvLxtV7zy6M1+PGeO7v7Kqxk2E2rKP0o3iN/r+TpNd78j3fIdJ7Pt4vgDsnGTJSEG7pu3FtFH+ququ0W2Y6kbaOj8FHEQ3JPx414z5G2sUz2DNU1X19STvAj6YZF/gNLqbae5EdwO6twC/DTzoQs4JHEl3Gv/vxjpikeR/AL+b5M+qqh706gccS3fDvXfTjQQ0pnRDph7aZp/cyl7d5lf3BJTlwDfpLjT9X3Snq98DhHF+BEraMsaXseNLVV2R5DPAh9pR4ZvofqjtQzdymqQHOw94J3Bakn+mu/bqT4G1o+p9kO77dH6SP6dLZN7VHu+PB1V1UZJP010r9UHgW3T3lFpI931/d1V9fzPtOYFuQIkX0DOYzxTW+326M3JvbgNg3EM3aM9dSX6dB8467QL8qifmXFpVI9s7he4M/OeS/AlwW2vTkxj7TNImWhfHkWRob+DhPdu5tqpGBuL4KN0BoeOBn2XTofHX9HQV9DfWaDUHbsblNHsT8F/ohjleR9c/dwPdBZkjwyAfyCRvBEo3hOm9wCmb2d5b22sPbPMXMepGoD11z6C7lmP3zaxvYVvfWNNpo+ruDJza9vFuur7KT5/t98DJaVgn48uD4wvdUesP0t0c+RfAJSPtdXIalolRNxoe67vY8/16y6jyB8UAugMhN7Xv7KV098G7iFE34gWeQXeD3l/QJWB/CnwYuGNUvYfQjT747Vb3zvb8b+jOQI3UK+Avx9i/Y9uy1Vu43rfRDbqzcVTMGvm7jTW9cdS2dqcbFn5D29Y3gZdM8v05cDPbOa6n3urJ1Gt1/Y3VM6X9USRJkqSh0YZ0vxy4raoOnu32aP6wi6AkSZIGXpK/oBss5gd0XezeAvwGD3T1lWaECZYkSZKGQQF/Rncj3gKuohuc4d9mtVWad+wiKEmSJEl94jDtkiRJktQnc7qL4NKlS+u8886b7WZI2jqZ7QZMljFHGnjGG0kzacyYM6fPYN12222z3QRJ84gxR9JMMd5Iw2tOJ1iSJEmSNEhMsCRJkiSpT0ywJEmSJKlPTLAkSZIkqU9MsCTNKUlOTXJrkqt7yj6T5Mo2rU5yZStfmOTnPcv+sec1z0zynSSrknwkycCMLiZJkgbXnB6mXdK8dBrw98AnRgqq6nUjz5N8ALizp/4NVbV4jPWcBLwVuAQ4F1gK/Fv/mytJkvQAz2BJmlOq6uvAhrGWtbNQrwU+vbl1JNkdeHRVXVxVRZesvbLPTZUkSXoQEyxJg+R5wC1VdX1P2T5Jrkjyf5I8r5XtAazpqbOmlT1IkqOTrEyycv369dPTakmSNG+YYEkaJEew6dmrdcDeVbUf8C7gjCSPnsoKq+rkqlpSVUsWLFjQx6ZKkqT5yGuwJI1p4bJzAFh9wstmuSWdJNsCvwk8c6Ssqu4B7mnPL0tyA/BEYC2wZ8/L92xlkqbZSOyAuRM/JM0dc+33xXTwDJakQfEi4LtVdX/XvyQLkmzTnj8OWATcWFXrgJ8kOaBdt/UG4OzZaLQkSZpfPIMlaU5J8mngQGDXJGuAY6vqFOBwHjy4xfOB9yX5JfAr4PeqamSAjN+nG5FwB7rRAx1BUJKkOWKYz3ZPmGAleRjwdWD7Vv+zVXVskn2AM4FdgMuA11fVvUm2pxux65nA7cDrqmp1W9d7gKOA+4A/qKrz+79LkgZZVR0xTvkbxyj7HPC5ceqvBJ7W18ZJkiRNYDJdBO8BDqqqpwOLgaVJDgD+Gjixqp4A3EGXONEe72jlJ7Z6JNmX7gj0U+nuR/MPI117JEmSJGkYTJhgVeenbfahbSrgIOCzrfx0HrjHzGFtnrb84HYNxGHAmVV1T1XdBKwC9u/HTkiSJEnSXDCpQS6SbJPkSuBWYAVwA/DjqtrYqvTeY2YP4GaAtvxOum6E95eP8ZrebXlPGkmSJEkDaVIJVlXdV1WL6YY63h948nQ1yHvSSJIkSRpUUxqmvap+DFwIPBvYsd2XBja9x8xaYC+4/741j6Eb7OL+8jFeI0mSJEkDb8IEq91nZsf2fAfgxcB1dInWq1u1I3ngHjPL2zxt+deqqlr54Um2byMQLgK+1af9kCRJ6pskOyb5bJLvJrkuybOT7JxkRZLr2+NOrW6SfCTJqiRXJXnGbLdf0uyZzBms3YELk1wFXAqsqKovA+8G3pVkFd01Vqe0+qcAu7TydwHLAKrqGuAs4FrgPOCYqrqvnzsjSZLUJx8GzquqJwNPpzu4vAy4oKoWARe0eYCX0h04XgQcDZw0882VNFdMeB+sqroK2G+M8hsZYxTAqvoF8Jpx1nU8cPzUmylJkjQzkjyG7kbmbwSoqnuBe5McRncjdOhGTL6I7oDzYcAnWo+di9vZr92rat0MN13SHDCla7AkSZLmgX2A9cA/J7kiyceTPALYrSdp+hGwW3vuSMmS7meCJUmStKltgWcAJ1XVfsDPeKA7INDdJ5TuvqCT5kjJ0tgWLjuHhcvOme1m9I0JliRJ0qbWAGuq6pI2/1m6hOuWJLsDtMdb23JHSpZ0PxMsSZKkHlX1I+DmJE9qRQfTDdLVO1Ly6BGU39BGEzwAuNPrr6T5a8JBLiRJkuah/w58Ksl2wI3Am+gOTJ+V5CjgB8BrW91zgUOBVcDdra6kHsPUBXAiJliSJEmjVNWVwJIxFh08Rt0CjpnuNkkaDHYRlCRJkqQ+McGSJEmSpD4xwZIkSZKkPjHBkiRJkqQ+McGSJEmSpD4xwZIkSZKkPjHBkjSnJDk1ya1Jru4pOy7J2iRXtunQnmXvSbIqyfeSHNJTvrSVrUqybKb3Q5IkzU8mWJLmmtOApWOUn1hVi9t0LkCSfYHDgae21/xDkm2SbAN8FHgpsC9wRKsrSZI0rbzRsKQ5paq+nmThJKsfBpxZVfcANyVZBezflq2qqhsBkpzZ6l7b7/ZKkiT18gyWpEHx9iRXtS6EO7WyPYCbe+qsaWXjlT9IkqOTrEyycv369dPRbkmSNI+YYEkaBCcBjwcWA+uAD/RrxVV1clUtqaolCxYs6NdqJUnSPGUXQUlzXlXdMvI8yceAL7fZtcBePVX3bGVsplySJGnaeAZL0pyXZPee2VcBIyMMLgcOT7J9kn2ARcC3gEuBRUn2SbId3UAYy2eyzZIkaX7yDJakOSXJp4EDgV2TrAGOBQ5MshgoYDXwNoCquibJWXSDV2wEjqmq+9p63g6cD2wDnFpV18zsnkiSpPnIBEvSnFJVR4xRfMpm6h8PHD9G+bnAuX1smiRJ0oTsIihJkiRJfWKCJUmSJEl9YoIlSZIkSX0yYYKVZK8kFya5Nsk1Sd7Ryo9LsjbJlW06tOc170myKsn3khzSU760la1Ksmx6dkmSJEmSZsdkBrnYCPxhVV2e5FHAZUlWtGUnVtXf9VZOsi/dkMhPBR4LfDXJE9vijwIvBtYAlyZZXlXX9mNHJEmSJGm2TZhgVdU6YF17fleS64A9NvOSw4Azq+oe4KYkq4D927JVVXUjQJIzW10TLEmSJElDYUrXYCVZCOwHXNKK3p7kqiSnJtmple0B3NzzsjWtbLzy0ds4OsnKJCvXr18/leZJkiRJ0qyadIKV5JHA54B3VtVPgJOAxwOL6c5wfaAfDaqqk6tqSVUtWbBgQT9WKUmSJEkzYlIJVpKH0iVXn6qqzwNU1S1VdV9V/Qr4GA90A1wL7NXz8j1b2XjlkiRJc0qS1Um+0wbyWtnKdk6yIsn17XGnVp4kH2mDeF2V5Bmz23pJs2kyowgGOAW4rqo+2FO+e0+1VwFXt+fLgcOTbJ9kH2AR8C3gUmBRkn2SbEc3EMby/uyGJElS372wqhZX1ZI2vwy4oKoWARe0eYCX0v3eWQQcTdfLR9I8NZlRBJ8DvB74TpIrW9l7gSOSLAYKWA28DaCqrklyFt3gFRuBY6rqPoAkbwfOB7YBTq2qa/q2J5IkSdPrMODA9vx04CLg3a38E1VVwMVJdkyyexsoTNI8M5lRBL8BZIxF527mNccDx49Rfu7mXidJkjRHFPCVJAX8U1WdDOzWkzT9CNitPR9vIK9NEqwkR9Od4WLvvfeexqZLmk2TOYMlSZI03zy3qtYm+TVgRZLv9i6sqmrJ16S1JO1kgCVLlkzptZIGx5SGaZckSZoPqmpte7wV+ALdYF63jFyD3h5vbdUdyEvS/UywJEmSeiR5RJJHjTwHXkI3mNdy4MhW7Ujg7PZ8OfCGNprgAcCdXn8lzV92EZQkSdrUbsAXuoGU2RY4o6rOS3IpcFaSo4AfAK9t9c8FDgVWAXcDb5r5JkuaK0ywJEmSelTVjcDTxyi/HTh4jPICjpmBpkkaAHYRlCRJkqQ+McGSJEmS5rmFy85h4bJzZrsZQ8EES5IkSZL6xARL0pyS5NQktya5uqfsb5N8N8lVSb6QZMdWvjDJz5Nc2aZ/7HnNM5N8J8mqJB9Ju1pdkiRpOplgSZprTgOWjipbATytqn4D+D7wnp5lN1TV4jb9Xk/5ScBbgUVtGr1OSZKkvjPBkjSnVNXXgQ2jyr5SVRvb7MV0N/EcV7sB6KOr6uI2utcngFdOQ3MlSZI2YYIladC8Gfi3nvl9klyR5P8keV4r2wNY01NnTSt7kCRHJ1mZZOX69eunp8WSJGneMMGSNDCS/DGwEfhUK1oH7F1V+wHvAs5I8uiprLOqTq6qJVW1ZMGCBf1tsCRJmne80bCkgZDkjcDLgYNbtz+q6h7gnvb8siQ3AE8E1rJpN8I9W5kkSdK0MsGSNOclWQr8EfCCqrq7p3wBsKGq7kvyOLrBLG6sqg1JfpLkAOAS4A3A/56NtkuSNJ/Nx3trmWBJ2sRsB8IknwYOBHZNsgY4lm7UwO2BFW209YvbiIHPB96X5JfAr4Dfq6qRATJ+n25Ewh3ortnqvW5LkiRpWphgSZpTquqIMYpPGafu54DPjbNsJfC0PjZNkiSNYeTg7OoTXjbLLZkbTLAkSZKkeWq2e64MI0cRlCRJkqQ+McGSJEmSNGcsXHbOQJ9ZM8GSJEmSpD4xwZIkSZIEDP7Zo7nABEuSJEmS+mTCBCvJXkkuTHJtkmuSvKOV75xkRZLr2+NOrTxJPpJkVZKrkjyjZ11HtvrXJzly+nZLkiRJkmbeZM5gbQT+sKr2BQ4AjkmyL7AMuKCqFgEXtHmAlwKL2nQ0cBJ0CRndDUOfBewPHDuSlEmSJEnSMJgwwaqqdVV1eXt+F3AdsAdwGHB6q3Y68Mr2/DDgE9W5GNgxye7AIcCKqtpQVXcAK4Cl/dwZSZIkSZpNU7oGK8lCYD/gEmC3qlrXFv0I2K093wO4uedla1rZeOWSJEmSNBQmnWAleSTwOeCdVfWT3mVVVUD1o0FJjk6yMsnK9evX92OVkiRJU5ZkmyRXJPlym98nySXtOvPPJNmulW/f5le15QtnteGSZtWkEqwkD6VLrj5VVZ9vxbe0rn+0x1tb+Vpgr56X79nKxivfRFWdXFVLqmrJggULprIvkiRJ/fQOuksjRvw1cGJVPQG4AziqlR8F3NHKT2z1JM1TkxlFMMApwHVV9cGeRcuBkZEAjwTO7il/QxtN8ADgztaV8HzgJUl2aoNbvKSVSZIkzSlJ9gReBny8zQc4CPhsqzL6+vOR69I/Cxzc6kuah7adRJ3nAK8HvpPkylb2XuAE4KwkRwE/AF7blp0LHAqsAu4G3gRQVRuS/AVwaav3vqra0I+dkCRJ6rMPAX8EPKrN7wL8uKo2tvnea8nvv868qjYmubPVv23GWitpzpgwwaqqbwDjHYU5eIz6BRwzzrpOBU6dSgMlSZJmUpKXA7dW1WVJDuzjeo+mu4UNe++9d79WK2mOmdIogpIkSfPAc4BXJFkNnEnXNfDDdLeeGTk43Xst+f3XmbfljwFuH71SrzOX5gcTLEmSpB5V9Z6q2rOqFgKHA1+rqt8BLgRe3aqNvv585Lr0V7f6fRldWZrPFi47h4XLzpntZkyZCZYkSdLkvBt4V5JVdNdYndLKTwF2aeXvApbNUvukeWMuJ1+TGeRCkiRpXqqqi4CL2vMbgf3HqPML4DUz2jBJc5YJliRJmhYjR5dXn/CyWW6JpGExV89a9bKLoCRJmlZzuSuPJPWbZ7AkzSlJTgVGhkh+WivbGfgMsBBYDby2qu5oN/L8MN299+4G3lhVl7fXHAn8SVvtX1bV6UiSpIEzaAdoTLAkzTWnAX8PfKKnbBlwQVWdkGRZm3838FJgUZueBZwEPKslZMcCS4ACLkuyvKrumLG9kOaZQfsBJGl6zeeYYBdBSXNKVX0d2DCq+DBg5AzU6cAre8o/UZ2L6e5RsztwCLCiqja0pGoFsHTaGy9ps+wqKGk+MMGSNAh2q6p17fmPgN3a8z2Am3vqrWll45U/SJKjk6xMsnL9+vX9bbUkSZp37CIoaaBUVSXp2w08q+pk4GSAJUuWeGNQSZJmybCc4fYMlqRBcEvr+kd7vLWVrwX26qm3Zysbr1ySJGlamWBJGgTLgSPb8yOBs3vK35DOAcCdrSvh+cBLkuyUZCfgJa1MkiRpWtlFUNKckuTTwIHArknW0I0GeAJwVpKjgB8Ar23Vz6Ubon0V3TDtbwKoqg1J/gK4tNV7X1WNHjhDkiSp70ywJM0pVXXEOIsOHqNuAceMs55TgVP72DRJkrQZw3IN1dayi6AkSZIk9YkJliRJkiT1iQmWJEnSHOYNmqXB4jVYkiRJ0jxiwj69PIMlSZIkSX1igiVJkqSBZjdKzSV2EZS0Wb3/sFaf8LJZbIkkSdKmRn6nzKXfKJ7BkiRJkqQ+McGSJElTZpcsSRrbhAlWklOT3Jrk6p6y45KsTXJlmw7tWfaeJKuSfC/JIT3lS1vZqiTL+r8rkiRJkjS7JnMG6zRg6RjlJ1bV4jadC5BkX+Bw4KntNf+QZJsk2wAfBV4K7Asc0epKkiRJ0tCYcJCLqvp6koWTXN9hwJlVdQ9wU5JVwP5t2aqquhEgyZmt7rVTb7IkSZKk2WY34bFtzTVYb09yVetCuFMr2wO4uafOmlY2XvmDJDk6ycokK9evX78VzZMkSZq6JA9L8q0k305yTZI/b+X7JLmkXe7wmSTbtfLt2/yqtnzhrO6ANI/NhetDtzTBOgl4PLAYWAd8oF8NqqqTq2pJVS1ZsGBBv1YrSZI0WfcAB1XV0+l+6yxNcgDw13SXSDwBuAM4qtU/CrijlZ/Y6kmap7boPlhVdcvI8yQfA77cZtcCe/VU3bOVsZlySTPMe1tJ0viqqoCfttmHtqmAg4DfbuWnA8fRHXQ+rD0H+Czw90nS1iNpntmiBCvJ7lW1rs2+ChgZYXA5cEaSDwKPBRYB3wICLEqyD11idTgPBChJkqQ5pQ3QdRnwBLqBum4AflxVG1uV3ssd7r8Uoqo2JrkT2AW4bdQ6jwaOBth7772nexekaTPbXfDmugkTrCSfBg4Edk2yBjgWODDJYrqjOauBtwFU1TVJzqIbvGIjcExV3dfW83bgfGAb4NSquqbfOyNJktQP7ffL4iQ7Al8AntyHdZ4MnAywZMkSz25JQ2oyowgeMUbxKZupfzxw/Bjl5wLnTql1kiRJs6iqfpzkQuDZwI5Jtm1nsXovdxi5RGJNkm2BxwC3z0qDJc26rRlFUJIkaegkWdDOXJFkB+DFwHXAhcCrW7UjgbPb8+Vtnrb8a15/Jc1fW3QNliRJ0hDbHTi9XYf1EOCsqvpykmuBM5P8JXAFD/ToOQX4l3b/zw1015pLmqdMsCQNhCRPAj7TU/Q44M+AHYG3AiM3zntv65JMkvfQDZ98H/AHVXX+jDVY0sCqqquA/cYovxHYf4zyXwCvmYGmSRoAJliSBkJVfY/ufjQjo3utpbvw/E1096X5u976SfalO4r8VLpRTb+a5IkjA+9I2jKOHiZJm2eCJWkQHQzcUFU/SDJencOAM6vqHuCm1nVnf+CbM9RGaWiYVEnS5DnIhaRBdDjw6Z75tye5KsmpSXZqZfffl6bpvWfN/ZIcnWRlkpXr168fvViSJGlKTLAkDZQk2wGvAP61FZ0EPJ6u++A64ANTWV9VnVxVS6pqyYIFC/rZVEmSNA+ZYEkaNC8FLq+qWwCq6paquq+qfgV8jAcuQB+5L82I3nvWSJIkTQsTLEmD5gh6ugcm2b1n2auAq9vz5cDhSbZPsg+wCPjWjLVSkiTNSw5yIWlgJHkE3Q0/39ZT/DdJFgMFrB5ZVlXXJDkLuBbYCBzjCIKSJE3OyOA2q0942Sy3ZPCYYEkaGFX1M2CXUWWv30z944Hjp7tdkiRpds2l0U7tIihJkiRJfeIZLEmStMXm0lFjaT6Y6a57vd9xuwtOjgmWJMAfSZIkSf1ggiXNcyZWkiRpMvzNMDkmWJIkSZKGymx2bXSQC0mSJEnqExMsSZIkSeoTEyxJkiRJ6hMTLEmSJEnqExMsSZIkSeoTEyxJkiRJ6hMTLEmSJEnqE++DJUmSxuRNRSVp6iY8g5Xk1CS3Jrm6p2znJCuSXN8ed2rlSfKRJKuSXJXkGT2vObLVvz7JkdOzO5Ikaa5buOwckzdJQ2syXQRPA5aOKlsGXFBVi4AL2jzAS4FFbToaOAm6hAw4FngWsD9w7EhSJkmSJEnDYsIEq6q+DmwYVXwYcHp7fjrwyp7yT1TnYmDHJLsDhwArqmpDVd0BrODBSZskSdKsS7JXkguTXJvkmiTvaOVT7sEjTRfPBM9dW3oN1m5Vta49/xGwW3u+B3BzT701rWy88gdJcjTd2S/23nvvLWyeJEnSFtsI/GFVXZ7kUcBlSVYAb6TrwXNCkmV0PXjezaY9eJ5F14PnWbPScmkzTMhmxlaPIlhVBVQf2jKyvpOraklVLVmwYEG/VitJkjQpVbWuqi5vz+8CrqM7MDzVHjyS5qEtPYN1S5Ldq2pdCyC3tvK1wF499fZsZWuBA0eVX7SF25YkSZoRSRYC+wGXMPUePOt6yuylo63i2afBsaVnsJYDIyMBHgmc3VP+htYX+QDgzhaIzgdekmSn1l/5Ja1MkiTNMV7b0UnySOBzwDur6ie9y7akB4+9dKT5YTLDtH8a+CbwpCRrkhwFnAC8OMn1wIvaPMC5wI3AKuBjwO8DVNUG4C+AS9v0vlYmSZOWZHWS7yS5MsnKVuZF55L6LslD6ZKrT1XV51vxLSNd/ybZg0dDwoMOmooJuwhW1RHjLDp4jLoFHDPOek4FTp1S6yTpwV5YVbf1zI/cNsKLziX1RZIApwDXVdUHexaN9OA5gQf34Hl7kjPp4sydPV0JJc0zWz3IhSTNMi86l9RvzwFeDxzUzphfmeRQptiDR9L8tKWDXEjSbCjgK0kK+KeqOhkvOpfUZ1X1DSDjLJ5SDx5J848JlqRB8tyqWpvk14AVSb7bu7CqqiVfk9aStJMBlixZ0rdbTkiSppfXRGmuMsGSNDCqam17vDXJF4D9mfptIyRJmldMRmeW12BJGghJHpHkUSPP6W73cDVTv22ENK84+pkkzSzPYEkaFLsBX+gG92Jb4IyqOi/JpcBZ7RYSPwBe2+qfCxxKd9H53cCbZr7JkiRpvjHBkjQQqupG4OljlN+OF51LkqQ5wi6CkiRJktQnJliSJEmS1CcmWJIkSZLUJyZYkiRJktQnJliSJEmS1CcmWJIkSZLUJyZYkiRJktQn3gdLGmILl50DwOoTXjbLLZEkSTNt5HeAZpZnsCRJkqRZtHDZOSZDQ8QzWJIkSdIc0pts2Qtl8JhgSfOIR8ckSZKml10EJUmSJKlPPIMlSZIkDRF7rMwuEyxJ2gKO0Ki5zh9YkjQ7TLCkecAfWpImw1ghSVvPBEuSJEkacB4gmTtMsCRJkqQ5YEuSJBOruWerEqwkq4G7gPuAjVW1JMnOwGeAhcBq4LVVdUeSAB8GDgXuBt5YVZdvzfYlSZKk6WLyoi3Rj2HaX1hVi6tqSZtfBlxQVYuAC9o8wEuBRW06GjipD9uWJEnquySnJrk1ydU9ZTsnWZHk+va4UytPko8kWZXkqiTPmL2WS5pt03EfrMOA09vz04FX9pR/ojoXAzsm2X0ati9JkrS1TgOWjirzILJm3MJl53gmbSvN9N9waxOsAr6S5LIkR7ey3apqXXv+I2C39nwP4Oae165pZZtIcnSSlUlWrl+/fiubJ2lYJNkryYVJrk1yTZJ3tPLjkqxNcmWbDu15zXvaEeXvJTlk9lovadBU1deBDaOKPYgsaUJbO8jFc6tqbZJfA1Yk+W7vwqqqJDWVFVbVycDJAEuWLJnSayUNtY3AH1bV5UkeBVyWZEVbdmJV/V1v5ST7AocDTwUeC3w1yROr6r4ZbbU0R3jvtr6Y6kHkdT1ltIPRRwPsvffe09tSSbNmq85gVdXa9ngr8AVgf+CWkaM27fHWVn0tsFfPy/dsZZI0oapaNzIwTlXdBVzHGGfBexwGnFlV91TVTcAquhglzRi79gyvqiq6njxTec3JVbWkqpYsWLBgmlomabZtcYKV5BHtKDJJHgG8BLgaWA4c2aodCZzdni8H3tAuBD0AuLPnKJAkTVqShcB+wCWt6O3twvJTRy46Z5LdkiVpCjyILGlCW3MGazfgG0m+DXwLOKeqzgNOAF6c5HrgRW0e4FzgRrqjyB8Dfn8rti1pnkrySOBzwDur6id0F5M/HlhM1x3nA1Ncn9d9SposDyLPAXPhzPBk2jAX2qlNzdR7ssXXYFXVjcDTxyi/HTh4jPICjtnS7UlSkofSJVefqqrPA1TVLT3LPwZ8uc1O6oiy131KGkuSTwMHArsmWQMcS3fQ+KwkRwE/AF7bqp9Ld5/PVXT3+nzTjDd4CHndoAbV1g5yIUkzot2s/BTguqr6YE/57j1Hil9F11UZuiPKZyT5IN0gF4vozrZL0oSq6ohxFnkQWZvlWSuZYEkaFM8BXg98J8mVrey9wBFJFtNdbL4aeBtAVV2T5CzgWroRCI9xBEFpbun9IepZCknDwgRL0kCoqm8AGWPRuZt5zfHA8dPWKEnS0PEMlLbW1t5oWJIkSZLUDM0ZLLsZSJK0ZTxiL0n9MzQJliRJkua3mR550AP8GotdBCVJkiSpTzyDJUmSpDljPnVZnU/7Op+YYEmaNG/6KEkaJP3+v2VCpMkwwZJmmP21peE3mz/CJtq2MUjzgYmQZpMJliRJc8BUjrR7NlkabCaAw80ESxoSHpWWJGn2mDRphAmWJEmzaGt+lHkmS+qf6UqQ/J7OPdP9nphgSTNkJo9seRRNkiRpdphgSZI04Lb0oIoHYzRf2I1eM8kES5qD7E4gzX1jJSfT/Z01IZL6z++V+s0ES5rDPOImzT8eYNF8Mhufd79jmm4mWJIkSZqzpnvwiZnimbL5wwRLkqRpNtYRc28ILEnD6SGz3QBJkqSFy87xCL+koeAZLEmSmrl6bYaJh4adn3ENExMsaQ7Z3D+Y8X74+U9JkjQo/J+luWS6DqqZYEmTsCVfwOn60vrPaX7yepzpsSXfJ7+DkqTNMcGS+swfX9LgGyuhncx32+//1pur3TQ1eb6Hmu9mPMFKshT4MLAN8PGqOmGm2yDNNf4omx7Gm/lhKj/mtuZsdD/4XR9uxpzx+dnXfDKjCVaSbYCPAi8G1gCXJlleVdfOZDukLbW5fxDj/WDzn8rsMN4MnrG+K+OdPdraZMrv5dw3aGdBjDkP/l75PdOg6Hc3/Jk+g7U/sKqqbgRIciZwGDBvgs98NReD7FS6/UyG13LMOUMRb8b6jIz+4TloP0RHbE2Xu8kMCDOV7Uh9MKdiztb+YNySM8OSOqmqmdtY8mpgaVW9pc2/HnhWVb29p87RwNFt9knA92asgZOzK3DbbDeij9yfuW0Y9ue2qlo60xudTLxp5VOJOcPwfkyF+zv8hm2fZyXeQN9/4wzy+2LbZ8cgtx0Gt/1jxpw5N8hFVZ0MnDzb7RhPkpVVtWS229Ev7s/cNmz7MxdNJebMt/fD/R1+83GfZ9Nk480gvy+2fXYMctth8Ns/2kNmeHtrgb165vdsZZLUb8YbSTPJmCMJmPkE61JgUZJ9kmwHHA4sn+E2SJofjDeSZpIxRxIww10Eq2pjkrcD59MNYXpqVV0zk23ogznbfXELuT9z27Dtz4yZpngz394P93f4zcd9nhZ9jjmD/L7Y9tkxyG2HwW//JmZ0kAtJkiRJGmYz3UVQkiRJkoaWCZYkSZIk9YkJ1mYk2SvJhUmuTXJNkne08p2TrEhyfXvcabbbOllJtklyRZIvt/l9klySZFWSz7QLcwdCkh2TfDbJd5Ncl+TZA/7e/I/2Obs6yaeTPGyQ359hkmRpku+192HZbLdnOgxjvJuMYYqJExm2mDmsBineJDk1ya1Jru4pG4jP1CDHvPb74FtJvt3a/uetfGDi17DHXhOszdsI/GFV7QscAByTZF9gGXBBVS0CLmjzg+IdwHU9838NnFhVTwDuAI6alVZtmQ8D51XVk4Gn0+3XQL43SfYA/gBYUlVPo7tA+nAG+/0ZCkm2AT4KvBTYFziixYFhM4zxbjKGKSZOZGhi5rAawHhzGjD6JquD8pka5Jh3D3BQVT0dWAwsTXIAgxW/hjr2mmBtRlWtq6rL2/O76D4IewCHAae3aqcDr5yVBk5Rkj2BlwEfb/MBDgI+26oM0r48Bng+cApAVd1bVT9mQN+bZltghyTbAg8H1jGg78+Q2R9YVVU3VtW9wJl0n7OhMmzxbjKGKSZOZEhj5jAaqHhTVV8HNowqHojP1CDHvOr8tM0+tE3FgMSv+RB7TbAmKclCYD/gEmC3qlrXFv0I2G222jVFHwL+CPhVm98F+HFVbWzza+iCyyDYB1gP/HM7xfzxJI9gQN+bqloL/B3wQ7rE6k7gMgb3/RkmewA398wP/fswJPFuMj7E8MTEiQxVzBxiwxBvBu4zNYgxr3WxuxK4FVgB3MDgxK8PMeSx1wRrEpI8Evgc8M6q+knvsurGuZ/zY90neTlwa1VdNttt6ZNtgWcAJ1XVfsDPGHUaf1DeG4DWx/swuh9BjwUewYO7XUjTbhji3WQMYUycyFDFTA2GQfhMDWrMq6r7qmoxsCfdmc8nz26LJme+xF4TrAkkeSjdF+9TVfX5VnxLkt3b8t3pjh7Mdc8BXpFkNV2Xg4Po+uPv2LqkQfclXTs7zZuyNcCaqrqkzX+W7sfDIL43AC8Cbqqq9VX1S+DzdO/ZoL4/w2QtsFfP/NC+D0MU7yZj2GLiRIYtZg6rYYg3A/OZGoaY17r6Xgg8m8GIX/Mi9ppgbUbrE3oKcF1VfbBn0XLgyPb8SODsmW7bVFXVe6pqz6paSDd4wteq6nfovpSvbtUGYl8AqupHwM1JntSKDgauZQDfm+aHwAFJHt4+dyP7M5Dvz5C5FFjURjjaju77s3yW29R3wxTvJmPYYuJEhjBmDqthiDcD8Zka5JiXZEGSHdvzHYAX011DNufj13yJvenOfmosSZ4L/DvwHR7oJ/peuj66ZwF7Az8AXltVoy/ynLOSHAj8z6p6eZLH0R1B2Bm4AvjdqrpnFps3aUkW010guR1wI/AmuoMGA/netGFWX0c3stEVwFvo+iAP5PszTJIcStdnfBvg1Ko6fnZb1H/DGu8mY1hi4kSGLWYOq0GKN0k+DRwI7ArcAhwLfJEB+EwNcsxL8ht0A0FsQ/sOV9X7Bi1+DXPsNcGSJEmSpD6xi6AkSZIk9YkJliRJkiT1iQmWJEmSJPWJCZYkSZIk9YkJliRJkiT1iQnWEEiyMEklWTLW/DivWdLqLOzntjVzkuyU5JYkj9+KdWyf5Ie+f5os4838ZLzRbDDezE/DEG/mZYKV5LT2pRmZbkvy5SRP7uM2ZvOLeTOwO3BlP1ea5KIkfz8T29KkvBc4t6puAEiyc5IvJflpkiuS7NdbOckHkry/t6zdY+Jvgb+esVbPM8abLWO8mXOMNwPAeLNljDdzzsDHm3mZYDVfpfvi7A68BNgB+MKstqhPquq+qvpRVW0cpm3NNUm2TZJZ2vbD6W5EfEpP8R8DjwKeAVwEfKyn/jOBQ4H3jbG6TwHPTfLU6WqvjDeDtq25xnijKTDeDNi25hrjTR9U1bybgNOAL48qezlQwA49ZXvQ3VX6jjadAyzqWb4XcDawAbgb+C5weFtWo6aLxmnL/wM+MKrs0cDPgd9s878LXArcBdwK/CuwR0/9hW0bS8aab2VLW/t+QXfn8t9udRa25bsAnwbWtG1fA7xp1N9s9D4tHGdbz6e7E/ov6O7sfiKwXc/yi4B/AN4P3Nb26e+Ah2zmPdts+1qdAH8IXA/c0+r+Vc/yx9J92W5v79eVwAvbsuOAq0et743AT3vmjwOubuU3APcBj2x/239vn5ENwPnAU0ata8xtt7/fr3r/fq3+W9vfZrtx/h6vbttKT9m5wO+1508BftaebwtcPrKv46zva8BfzvZ3cxgnjDfGG+ON8WaGJow3xhvjzZyIN/P5DNb9kjwKeB3wnar6eSt7OHAh3ZfoBcCzgXXAV9sy6L5ED6f7ID0VeCfw47Zs//a4lO4o0m+Os/lPAocn6X0vfqtt95w2vx1wLPB0ukC5K92XcbL7txfwRWAFsBj438DfjKr2MLoP6cvbvnwY+KckB7fl7wC+CfwzDxwZu3mMbe0B/BtwBbAfcBRwBPBXo6r+DrAR+C/A2+n+dq/bzG5M1D7oAtqftm09FXjNSBuTPAL4P3Rf+FcC/4mxj3ZMZB+64P0auvfjF8AjgA/RvecHAncCX0qy3UTbrqrVdO/Lm0dt583Av1TVveO043nAZdWiR/Nt4KAk2wKHAFe18ncBV1TVhZvZr2/Rfc41zYw3gPFmsow32irGG8B4M1nGm36a6YxuLkx0Rys2Aj9tUwE/BJ7WU+fNdEcKejPobegy9Ne2+auAY8fZxkJGHfkYp94uwL3AwT1lXwVO3sxrntzWvedY2xpj/v3A90fty5/Qc4RnnO2cCXy8Z/4i4O83t5/A8e3v9pCeOm+kO+Ly8J71fHPUelb0bmuS7+P97aM70vIL2hGOMeq+le4I2a7jLD+OyR3h+SWw2wTtegTd0Z/nTnLbr6Y7OvSwNv+U9jd92ma28UXg9FFljwHOAH5AF/D2BR4H3ATsRveP5wa6f2y7j3rtHwA3z9R3cD5NGG/AeDN6+XEYb4w30zBhvAHjzejlx2G8mfF4M5/PYH2d7mjHYrrM/ALgK+1oCMAz6bL5u9pFdT+ly9x3AkZGNfkw8CdJvpnkL1s/0CmpqtuB8+iOeJDksXRHjD45UifJM5KcneQHSe4CVrZFe09yM08BLq72SWu+2VshyTZJ/jjJVUlub/v7m1PYxuht/aqn7Bt0R6me0FN2FZv6D+DXxlvpJNq3L7A93fs4lv2Aq6rqtsnvypjWVNUto9r2+CRnJLkhyU/oug08pKdtE237bLp/QiNHAd8MfKuqrt5MO3agC7j3q6o7q+q3q+rXq+oFVXUt8I/Ae+iOSD2V7v25EvjIqPX9vK1T08N408N4M2nGG20J400P482kGW/6aD4nWHdX1ao2XUp3Qd2jgaPb8ofQvVGLR01PBP4JoKpOoQtS/9zK/1+S47agLZ8EfivJw4DD6U77/jvcf/r1fLp+ra8H/jPdaXnovtT98j/p+vf+LXAw3b5+sc/b6A2Avxxj2eY+j9Pdvl/R9XHu9dAx6v1sjLIvAwuAtwHPogs4Gyfbtqr6JfAJ4M3t9Pfr2fTizrHcRvfPcFxJ3gD8sqrOBA4CPlfdKfkz2nyvnYH1k2mvtojxZlPGG+ON8Wb6GG82Zbwx3sx4vJnPCdZoRfchHOl/fDndEYnbegLVyLTh/hdVramqk6vqtcCf8UAAG+lbus0ktr28Pb6c7kjPGT1HY55M1yf5vVX19ar6Lps5EjKO64BnJZuMCHPAqDrPBb5UVf9SVVfSnWp94qg69zLx/lwHHDCqz/Vz22tvmGK7p9K+6+hO0x88xmuh6zP9G0l2HWf5emC3UX+jxRM1KskudO/R+6vqq1V1Hd1IN9tOYdsAH6c7svf77fVnTrDpK+iOao3XrgV0/aD/Wyt6CA8E1O148Pv4NLrPvGaG8cZ4Y7zRTDHeGG+MNzNsPidY2yf5/9r0FLr+m48EvtSWf4ruVOjZSV6QZJ8kz0831v4igCQfTrI0yeOSLKY78nJte/2tdKclD0myW5LHjNeQqvoF8Dm6fsPPoOf0OV3f6XuAt7ftvAz4iynu6z/S9SX+UJInJXk18Huj6nwfODjJc9PdL+Pv6Y5e9VoN7J/uHhi7jgoyI/6BbkSZf0jylNbeE+j6Nt89xXZPun1VdRddl4a/SvKmdlp7/yQjX8Az6N6Ts5M8r/0tX5HkhW35RXRHOd7bXnsUXd/hidxBd7TlrUmekOQFdH/v3mFdJ9o2VfU9uq4Gfwt8tqp+MsF2zwee0gLgWE4ETqyqH7b5bwBHts/6O9t8r+fRdeXQ9DDebMp4Y7wx3kwf482mjDfGm5mPNzXDF33NhYkHD8n5E7pRRn5rVL3d6E6P30oXBG4CTqVdzEcXtK6n6yu6ni4r7x1e9C10AeQ+xhnGtKfuQa0tl4+x7HV0RzR+0dp5SKt7YFu+kImHMX0Z8L22jv9LdyTp/otA6U7Hfp4Hhkr9G7pgclHPOp5I17f57pHXjrOtkWFM7+GBYUy371l+EQ++mPQ0Rg0tO2r5ZNr3EGAZcCPdEaWbgeN7lu8JfIZuJKS76Y6SHNiz/G10F1D+rL2X72CMYUzHee+ubn/bq9v781PgjZPddqvzhva3fP4kP8ffBI4Zo/yQ9jnpvRB3B7pA+JP2usf1LHs2XSDdYTLbdZrahPHGeGO8Md7M0ITxxnhjvJkT8SatAZJmWZJ3A0dV1eiuC+PVX0p3VGvfqrpvK7b7r3TDnL5/S9chabAYbyTNlPkYb+ZzF0FpTkjyyHR3GX8HXUCZlKo6D/go3dGjLd329nQjHp24peuQNDiMN5JmynyON57BkmZZktPobla4HDiiqjZu/hWStGWMN5JmynyONyZYkiRJktQndhGUJEmSpD4xwZIkSZKkPjHBkiRJkqQ+McGSJEmSpD4xwZIkSZKkPvn/AawsenjiSUDGAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 864x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(12, 4))\n",
    "plt.subplot(131)\n",
    "plt.hist(df_dict[0]['cifar10']['val_acc_best'], bins=100)\n",
    "plt.xlabel('Best validation accuracy (%)', fontsize=14)\n",
    "plt.title('CIFAR-10', fontsize=16)\n",
    "\n",
    "plt.subplot(132)\n",
    "plt.hist(df_dict[0]['cifar100']['val_acc_best'], bins=100)\n",
    "plt.xlabel('Best validation accuracy (%)', fontsize=14)\n",
    "plt.title('CIFAR-100', fontsize=16)\n",
    "\n",
    "plt.subplot(133)\n",
    "plt.hist(df_dict[0]['ImageNet16-120']['val_acc_best'], bins=100)\n",
    "plt.xlabel('Best validation accuracy (%)', fontsize=14)\n",
    "plt.title('ImageNet16-120', fontsize=16)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b8b09e9-e9c3-474d-bc3e-7f00e8594e7d",
   "metadata": {},
   "source": [
    "We can clearly see that there is a huge difference between the various configurations, so NAS is important."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "153714ad-7ac2-403d-943b-495d9f7b9287",
   "metadata": {},
   "source": [
    "## Main experiments\n",
    "We perform experiments on NASBench201 - CIFAR-10, CIFAR-100 and ImageNet16-120 datasets. We use PASHA, ASHA (promotion type) and the relevant baselines - one epoch and random."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8f2c1f1-4c5a-4e52-b9e3-05d2cd0c268f",
   "metadata": {},
   "source": [
    "Grace period is the name for minimum resources in SyneTune."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f5dfa59-0b03-466d-b91d-95a85c3e7a75",
   "metadata": {},
   "source": [
    "Define functions for running the experiments and analysing them:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ae6d2cfe-bdc7-4ba5-b143-1de2a5def523",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_experiment(dataset_name, random_seed, nb201_random_seed, hpo_approach, reduction_factor=None, rung_system_kwargs={'ranking_criterion': 'soft_ranking', 'epsilon': 0.025}):\n",
    "    \"\"\"\n",
    "    Function to run a NASBench201 experiment. It is similar to the NASBench201 example script\n",
    "    in syne-tune but extended to make it simple to run our experiments.\n",
    "    \n",
    "    When describing the following parameters we say what values we use, but feel free to also use other values.\n",
    "    \n",
    "    :param dataset_name: one of 'cifar10', 'cifar100', 'ImageNet16-120'\n",
    "    :param random_seed: one of 31415927, 0, 1234, 3458, 7685\n",
    "    :param nb201_random_seed: one of 0, 1, 2\n",
    "    :param hpo_approach: one of 'pasha', 'asha', 'pasha-bo', 'asha-bo'\n",
    "    :param reduction_factor: by default None (resulting in using the default value 3) or 2, 4\n",
    "    :param rung_system_kwargs: dictionary of ranking criterion (str) and epsilon or epsilon scaling (both float)\n",
    "    :return: tuner.name\n",
    "    \n",
    "    \"\"\"\n",
    "    \n",
    "    # this function is similar to the NASBench201 example script\n",
    "    logging.getLogger().setLevel(logging.WARNING)\n",
    "\n",
    "    default_params = nasbench201_default_params({'backend': 'simulated'})\n",
    "    benchmark = nasbench201_benchmark(default_params)\n",
    "    # benchmark must be tabulated to support simulation\n",
    "    assert benchmark.get('supports_simulated', False)\n",
    "    mode = benchmark['mode']\n",
    "    metric = benchmark['metric']\n",
    "    blackbox_name = benchmark.get('blackbox_name')\n",
    "    # NASBench201 is a blackbox from the repository\n",
    "    assert blackbox_name is not None\n",
    "\n",
    "    config_space = benchmark['config_space']\n",
    "\n",
    "    # simulator back-end specialized to tabulated blackboxes\n",
    "    backend = BlackboxRepositoryBackend(\n",
    "        blackbox_name=blackbox_name,\n",
    "        elapsed_time_attr=benchmark['elapsed_time_attr'],\n",
    "        time_this_resource_attr=benchmark.get('time_this_resource_attr'),\n",
    "        dataset=dataset_name,\n",
    "        seed=nb201_random_seed)\n",
    "\n",
    "    # set logging of the simulator backend to WARNING level\n",
    "    logging.getLogger('syne_tune.backend.simulator_backend.simulator_backend').setLevel(logging.WARNING)\n",
    "    \n",
    "    if not reduction_factor:\n",
    "        reduction_factor = default_params['reduction_factor']\n",
    "\n",
    "    # we support various schedulers within the function\n",
    "    if hpo_approach == 'pasha':\n",
    "        scheduler = baselines['PASHA'](\n",
    "            config_space,\n",
    "            max_t=default_params['max_resource_level'],\n",
    "            grace_period=default_params['grace_period'],\n",
    "            reduction_factor=reduction_factor,\n",
    "            resource_attr=benchmark['resource_attr'],\n",
    "            mode=mode,\n",
    "            metric=metric,\n",
    "            random_seed=random_seed,\n",
    "            rung_system_kwargs=rung_system_kwargs)\n",
    "    elif hpo_approach == 'asha':\n",
    "        scheduler = baselines['ASHA'](\n",
    "            config_space,\n",
    "            max_t=default_params['max_resource_level'],\n",
    "            grace_period=default_params['grace_period'],\n",
    "            reduction_factor=reduction_factor,\n",
    "            resource_attr=benchmark['resource_attr'],\n",
    "            mode=mode,\n",
    "            type='promotion',\n",
    "            metric=metric,\n",
    "            random_seed=random_seed)\n",
    "    elif hpo_approach == 'pasha-bo':\n",
    "        scheduler = HyperbandScheduler(\n",
    "            config_space,\n",
    "            max_t=default_params['max_resource_level'],\n",
    "            grace_period=default_params['grace_period'],\n",
    "            reduction_factor=reduction_factor,\n",
    "            resource_attr=benchmark['resource_attr'],\n",
    "            mode=mode,\n",
    "            searcher='bayesopt',\n",
    "            type='pasha',\n",
    "            metric=metric,\n",
    "            random_seed=random_seed,\n",
    "            rung_system_kwargs=rung_system_kwargs)\n",
    "    elif hpo_approach == 'asha-bo':\n",
    "        scheduler = HyperbandScheduler(\n",
    "            config_space,\n",
    "            max_t=default_params['max_resource_level'],\n",
    "            grace_period=default_params['grace_period'],\n",
    "            reduction_factor=reduction_factor,\n",
    "            resource_attr=benchmark['resource_attr'],\n",
    "            mode=mode,\n",
    "            searcher='bayesopt',\n",
    "            type='promotion',\n",
    "            metric=metric,\n",
    "            random_seed=random_seed)\n",
    "    else:\n",
    "        raise ValueError('The selected scheduler is not implemented')\n",
    "\n",
    "    stop_criterion = StoppingCriterion(max_num_trials_started=256)\n",
    "    # printing the status during tuning takes a lot of time, and so does\n",
    "    # storing results\n",
    "    print_update_interval = 700\n",
    "    results_update_interval = 300\n",
    "    # it is important to set `sleep_time` to 0 here (mandatory for simulator\n",
    "    # backend)\n",
    "\n",
    "    tuner = Tuner(\n",
    "        backend=backend,\n",
    "        scheduler=scheduler,\n",
    "        stop_criterion=stop_criterion,\n",
    "        n_workers=n_workers,\n",
    "        sleep_time=0,\n",
    "        results_update_interval=results_update_interval,\n",
    "        print_update_interval=print_update_interval,\n",
    "        # this callback is required in order to make things work with the\n",
    "        # simulator callback. It makes sure that results are stored with\n",
    "        # simulated time (rather than real time), and that the time_keeper\n",
    "        # is advanced properly whenever the tuner loop sleeps\n",
    "        callbacks=[SimulatorCallback()],\n",
    "    )\n",
    "    \n",
    "    tuner.run()\n",
    "    \n",
    "    return tuner.name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0ea790f4-1791-4f98-97e9-bbe67b13b3b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyse_experiments(experiment_names_dict, reference_time=None):\n",
    "    \"\"\"\n",
    "    Function to analyse the experiments that we run with run_experiment function.\n",
    "    \n",
    "    :param experiment_names_dict: dictionary mapping the dataset names to tuples of\n",
    "        experiment names and NASBench201 random seeds\n",
    "    :reference_time: optional argument with the time it takes to run the standard method - e.g. ASHA\n",
    "    :return: tuple of a line to display (string reporting the experiment results) and \n",
    "        the mean of the runtimes that can be used as reference time for other approaches\n",
    "    \"\"\"\n",
    "    val_acc_best_list = []\n",
    "    max_rsc_list = []\n",
    "    runtime_list = []\n",
    "    \n",
    "    for experiment_name, nb201_random_seed in experiment_names_dict[dataset_name]:\n",
    "        experiment_results = load_experiment(experiment_name)\n",
    "        best_cfg = experiment_results.results['metric_valid_error'].argmin()\n",
    "        \n",
    "        # find the best validation accuracy of the corresponding entry in NASBench201\n",
    "        table_hp_names = ['hp_x' + str(hp_idx) for hp_idx in range(6)]\n",
    "        results_hp_names = ['config_hp_x' + str(hp_idx) for hp_idx in range(6)]\n",
    "        condition = (df_dict[nb201_random_seed][dataset_name][table_hp_names] == experiment_results.results[results_hp_names].iloc[best_cfg].tolist()).all(axis=1)\n",
    "        val_acc_best = df_dict[nb201_random_seed][dataset_name][condition]['val_acc_best'].values[0]  # there is only one item in the list\n",
    "        val_acc_best_list.append(val_acc_best)\n",
    "        max_rsc_list.append(experiment_results.results['hp_epoch'].max())\n",
    "        runtime_list.append(experiment_results.results['st_tuner_time'].max())\n",
    "        \n",
    "    line = ' & {:.2f} $\\pm$ {:.2f}'.format(np.mean(val_acc_best_list), np.std(val_acc_best_list))\n",
    "    line += ' & {:.1f}h $\\pm$ {:.1f}h'.format(np.mean(runtime_list)/3600, np.std(runtime_list)/3600)\n",
    "    if reference_time:\n",
    "        line += ' & {:.1f}x'.format(reference_time/np.mean(runtime_list))\n",
    "    else:\n",
    "        line += ' & {:.1f}x'.format(np.mean(runtime_list)/np.mean(runtime_list))\n",
    "    line += ' & {:.1f} $\\pm$ {:.1f}'.format(np.mean(max_rsc_list), np.std(max_rsc_list))\n",
    "    \n",
    "    return line, np.mean(runtime_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7a556ca5-1021-4b67-b579-707aaeb6e387",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_one_epoch_baseline():\n",
    "    \"\"\"\n",
    "    Function to compute the performance of a simple one epoch baseline.\n",
    "    :return: a line to display (string reporting the experiment results)\n",
    "    \"\"\"\n",
    "    best_val_obj_list = []\n",
    "    total_time_list = []\n",
    "    \n",
    "    for nb201_random_seed in nb201_random_seeds:\n",
    "        for random_seed in random_seeds:\n",
    "            # randomly sample 256 configurations for the given dataset and NASBench201 seed\n",
    "            # use the same seeds as for our other experiments\n",
    "            random.seed(random_seed)\n",
    "            cfg_list = random.sample(range(len(df_dict[nb201_random_seed][dataset_name])), 256)\n",
    "            selected_subset = df_dict[nb201_random_seed][dataset_name].iloc[cfg_list]\n",
    "            # find configuration with the best performance after doing one epoch\n",
    "            max_idx = selected_subset['val_acc_epoch_0'].argmax()\n",
    "            best_configuration = selected_subset.iloc[max_idx]\n",
    "            # find the best validation accuracy of the selected configuration\n",
    "            # as that is the metric that we compare \n",
    "            best_val_obj = best_configuration[epoch_names].max()\n",
    "\n",
    "            # we also need to calculate the time it took for this\n",
    "            # taking into account the number of workers\n",
    "            total_time = selected_subset['eval_time_epoch'].sum() / n_workers\n",
    "\n",
    "            best_val_obj_list.append(best_val_obj)\n",
    "            total_time_list.append(total_time)\n",
    "\n",
    "    line = ' & {:.2f} $\\pm$ {:.2f}'.format(np.mean(best_val_obj_list), np.std(best_val_obj_list))\n",
    "    line += ' & {:.1f}h $\\pm$ {:.1f}h'.format(np.mean(total_time_list)/3600, np.std(total_time_list)/3600)\n",
    "    line += ' & {:.1f}x'.format(reference_time/np.mean(total_time_list))\n",
    "    line += ' & 1.0 $\\pm$ 0.0'\n",
    "\n",
    "    return line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "12640ea4-4932-4c87-8577-bad254eacd1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_random_baseline():\n",
    "    \"\"\"\n",
    "    Function to compute the performance of a simple random configuration baseline.\n",
    "    \n",
    "    We consider a ten times larger number of configurations in this case to get a better\n",
    "    estimate of the performance of a random configuration.\n",
    "\n",
    "    :return: a line to display (string reporting the experiment results)\n",
    "    \"\"\"\n",
    "    random.seed(0)\n",
    "    random_seeds_rb = random.sample(range(999999), 256 * 10)\n",
    "\n",
    "    best_val_obj_list = []\n",
    "    total_time_list = []\n",
    "\n",
    "    for nb201_random_seed in nb201_random_seeds:\n",
    "        for random_seed in random_seeds_rb:\n",
    "            random.seed(random_seed)\n",
    "            # select the random configurations\n",
    "            cfg_list = random.sample(range(len(df_dict[nb201_random_seed][dataset_name])), 1)\n",
    "            selected_configuration = df_dict[nb201_random_seed][dataset_name].iloc[cfg_list]\n",
    "            # find the best validation accuracy of the selected configuration\n",
    "            # as that is the metric that we compare \n",
    "            best_val_obj = selected_configuration[epoch_names].max()\n",
    "\n",
    "            # we also need to calculate the time it took for this\n",
    "            total_time = 0.0\n",
    "\n",
    "            best_val_obj_list.append(best_val_obj)\n",
    "            total_time_list.append(total_time)\n",
    "\n",
    "    line = ' & {:.2f} $\\pm$ {:.2f}'.format(np.mean(best_val_obj_list), np.std(best_val_obj_list))\n",
    "    line += ' & {:.1f}h $\\pm$ {:.1f}h'.format(np.mean(total_time_list)/3600, np.std(total_time_list)/3600)\n",
    "    line += ' & NA'\n",
    "    line += ' & 0.0 $\\pm$ 0.0'\n",
    "\n",
    "    return line"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e3b3def-b25e-42d2-9939-990f67d60e56",
   "metadata": {},
   "source": [
    "Run the main experiments with PASHA, ASHA and the baselines:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "add9dbfc-c2da-47d9-9369-6c4ad8776538",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "experiment_names_pasha = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_asha = {dataset: [] for dataset in dataset_names}\n",
    "\n",
    "for dataset_name in dataset_names:\n",
    "    for nb201_random_seed in nb201_random_seeds:\n",
    "        for random_seed in random_seeds:\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha')\n",
    "            experiment_names_pasha[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'asha')\n",
    "            experiment_names_asha[dataset_name].append((experiment_name, nb201_random_seed))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f35aded-e866-4d43-b4e4-9046fed69370",
   "metadata": {},
   "source": [
    "Analyse the experiments:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "490bdbb7-5ee1-47ae-814f-e3af61471058",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cifar10\n",
      "ASHA & 93.85 $\\pm$ 0.25 & 3.0h $\\pm$ 0.6h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 93.78 $\\pm$ 0.31 & 2.3h $\\pm$ 0.5h & 1.3x & 144.5 $\\pm$ 59.4\n",
      "One epoch baseline  & 93.30 $\\pm$ 0.61 & 0.3h $\\pm$ 0.0h & 8.5x & 1.0 $\\pm$ 0.0\n",
      "Random baseline  & 72.93 $\\pm$ 19.55 & 0.0h $\\pm$ 0.0h & NA & 0.0 $\\pm$ 0.0\n",
      "cifar100\n",
      "ASHA & 71.69 $\\pm$ 1.05 & 3.2h $\\pm$ 0.9h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 71.41 $\\pm$ 1.15 & 1.5h $\\pm$ 0.7h & 2.1x & 88.3 $\\pm$ 74.4\n",
      "One epoch baseline  & 65.57 $\\pm$ 5.53 & 0.3h $\\pm$ 0.0h & 9.2x & 1.0 $\\pm$ 0.0\n",
      "Random baseline  & 42.98 $\\pm$ 18.34 & 0.0h $\\pm$ 0.0h & NA & 0.0 $\\pm$ 0.0\n",
      "ImageNet16-120\n",
      "ASHA & 45.63 $\\pm$ 0.81 & 8.8h $\\pm$ 2.2h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 46.01 $\\pm$ 1.00 & 3.2h $\\pm$ 1.0h & 2.8x & 28.6 $\\pm$ 27.7\n",
      "One epoch baseline  & 41.42 $\\pm$ 4.98 & 1.0h $\\pm$ 0.0h & 8.8x & 1.0 $\\pm$ 0.0\n",
      "Random baseline  & 20.97 $\\pm$ 10.01 & 0.0h $\\pm$ 0.0h & NA & 0.0 $\\pm$ 0.0\n"
     ]
    }
   ],
   "source": [
    "for dataset_name in dataset_names:\n",
    "    print(dataset_name)\n",
    "    result_summary, reference_time = analyse_experiments(experiment_names_asha)\n",
    "    print('ASHA' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha, reference_time)\n",
    "    print('PASHA' + result_summary)\n",
    "    result_summary = compute_one_epoch_baseline()\n",
    "    print('One epoch baseline', result_summary)\n",
    "    result_summary = compute_random_baseline()\n",
    "    print('Random baseline', result_summary)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b844aa4-3af0-479f-8b3e-40de7198a4be",
   "metadata": {},
   "source": [
    "We see PASHA obtains a similar accuracy as ASHA, but it can find a well-performing configuration much faster.\n",
    "\n",
    "The configurations found by one epoch baseline and random baseline usually obtain significantly lower accuracies, making them unsuitable for finding well-performing configurations."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df522ee5-c373-47c5-b5d3-34c859f83ae1",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Alternative ranking functions\n",
    "We evaluate a variety of ranking functions, which belong to different families:\n",
    "\n",
    "1) Direct ranking where we strictly look at if the rankings are the same.\n",
    "\n",
    "2) Soft ranking with $\\epsilon$ selected from $[0.01, 0.02, 0.025, 0.03, 0.05]$ and variations of soft ranking with the value of $\\epsilon$ estimated as a multiple of the standard deviation (multiple of 1, 2 and 3), mean distance or median distance between the values of the performance metric in the previous rung.\n",
    "\n",
    "3) RBO score - we use $p=1.0$ and $p=0.5$ priorities for the top of the ranking (the former says all ranks have the same priority and the latter says that the top of the ranking has twice as large priority as the next one (similarly for the other ranks). We use minimum threshold value of $t=0.5$ to say if there is sufficient similarity between the rankings in the top two rungs.\n",
    "\n",
    "4) RRR and ARRR scores (reciprocal rank regret and its absolute value variation) - we use $p=1.0$ and $p=0.5$ priorities for the top of the ranking (the former says all ranks have the same priority and the latter says that the top of the ranking has twice as large priority as the next one (similarly for the other ranks). We use maximum threshold value of $t=0.05$ to say if there is sufficient similarity between the rankings in the top two rungs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "84ed0937-782d-4902-b108-4907eb49a6bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "experiment_names_pasha_ranking = {dataset: [] for dataset in dataset_names}\n",
    "\n",
    "experiment_names_pasha_e001 = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_pasha_e002 = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_pasha_e003 = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_pasha_e005 = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_pasha_std1 = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_pasha_std2 = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_pasha_std3 = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_pasha_mean_dst = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_pasha_med_dst = {dataset: [] for dataset in dataset_names}\n",
    "\n",
    "experiment_names_pasha_rbo_p1_t05 = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_pasha_rbo_p05_t05 = {dataset: [] for dataset in dataset_names}\n",
    "\n",
    "experiment_names_pasha_rrr_p1_t005 = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_pasha_rrr_p05_t005 = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_pasha_arrr_p1_t005 = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_pasha_arrr_p05_t005 = {dataset: [] for dataset in dataset_names}\n",
    "\n",
    "for dataset_name in dataset_names:\n",
    "    for nb201_random_seed in nb201_random_seeds:\n",
    "        for random_seed in random_seeds:\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'ranking'})\n",
    "            experiment_names_pasha_ranking[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            \n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'soft_ranking', 'epsilon': 0.01})\n",
    "            experiment_names_pasha_e001[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'soft_ranking', 'epsilon': 0.02})\n",
    "            experiment_names_pasha_e002[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'soft_ranking', 'epsilon': 0.03})\n",
    "            experiment_names_pasha_e003[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'soft_ranking', 'epsilon': 0.05})\n",
    "            experiment_names_pasha_e005[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'soft_ranking_std', 'epsilon_scaling': 1.0})\n",
    "            experiment_names_pasha_std1[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'soft_ranking_std', 'epsilon_scaling': 2.0})\n",
    "            experiment_names_pasha_std2[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'soft_ranking_std', 'epsilon_scaling': 3.0})\n",
    "            experiment_names_pasha_std3[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'soft_ranking_mean_dst', 'epsilon_scaling': 1.0})\n",
    "            experiment_names_pasha_mean_dst[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'soft_ranking_median_dst', 'epsilon_scaling': 1.0})\n",
    "            experiment_names_pasha_med_dst[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            \n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'rbo', 'epsilon': 0.5, 'epsilon_scaling': 1.0})\n",
    "            experiment_names_pasha_rbo_p1_t05[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'rbo', 'epsilon': 0.5, 'epsilon_scaling': 0.5})\n",
    "            experiment_names_pasha_rbo_p05_t05[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'rrr', 'epsilon': 0.05, 'epsilon_scaling': 1.0})\n",
    "            experiment_names_pasha_rrr_p1_t005[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'rrr', 'epsilon': 0.05, 'epsilon_scaling': 0.5})\n",
    "            experiment_names_pasha_rrr_p05_t005[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'arrr', 'epsilon': 0.05, 'epsilon_scaling': 1.0})\n",
    "            experiment_names_pasha_arrr_p1_t005[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'arrr', 'epsilon': 0.05, 'epsilon_scaling': 0.5})\n",
    "            experiment_names_pasha_arrr_p05_t005[dataset_name].append((experiment_name, nb201_random_seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "437f0bb5-b331-4cc2-939d-ac0941ffb780",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cifar10\n",
      "ASHA & 93.85 $\\pm$ 0.25 & 3.0h $\\pm$ 0.6h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA direct ranking & 93.79 $\\pm$ 0.26 & 2.7h $\\pm$ 0.6h & 1.1x & 198.4 $\\pm$ 6.0\n",
      "PASHA soft ranking $\\epsilon=0.01$ & 93.79 $\\pm$ 0.26 & 2.6h $\\pm$ 0.5h & 1.1x & 194.3 $\\pm$ 21.2\n",
      "PASHA soft ranking $\\epsilon=0.02$ & 93.78 $\\pm$ 0.31 & 2.4h $\\pm$ 0.5h & 1.2x & 152.4 $\\pm$ 58.3\n",
      "PASHA soft ranking $\\epsilon=0.025$ & 93.78 $\\pm$ 0.31 & 2.3h $\\pm$ 0.5h & 1.3x & 144.5 $\\pm$ 59.4\n",
      "PASHA soft ranking $\\epsilon=0.03$ & 93.78 $\\pm$ 0.32 & 2.2h $\\pm$ 0.6h & 1.3x & 128.6 $\\pm$ 58.3\n",
      "PASHA soft ranking $\\epsilon=0.05$ & 93.79 $\\pm$ 0.49 & 1.8h $\\pm$ 0.7h & 1.6x & 76.0 $\\pm$ 66.0\n",
      "PASHA soft ranking $1\\sigma$ & 93.75 $\\pm$ 0.32 & 2.4h $\\pm$ 0.5h & 1.2x & 186.4 $\\pm$ 35.2\n",
      "PASHA soft ranking $2\\sigma$ & 93.88 $\\pm$ 0.28 & 1.9h $\\pm$ 0.5h & 1.5x & 132.7 $\\pm$ 68.7\n",
      "PASHA soft ranking $3\\sigma$ & 93.56 $\\pm$ 0.69 & 0.9h $\\pm$ 0.3h & 3.1x & 16.2 $\\pm$ 19.9\n",
      "PASHA soft ranking mean distance & 93.73 $\\pm$ 0.52 & 2.3h $\\pm$ 0.4h & 1.3x & 184.1 $\\pm$ 40.5\n",
      "PASHA soft ranking median distance & 93.82 $\\pm$ 0.26 & 2.3h $\\pm$ 0.5h & 1.3x & 169.2 $\\pm$ 51.2\n",
      "PASHA RBO p=1.0, t=0.5 & 93.49 $\\pm$ 0.78 & 0.7h $\\pm$ 0.1h & 4.2x & 4.6 $\\pm$ 6.0\n",
      "PASHA RBO p=0.5, t=0.5 & 93.77 $\\pm$ 0.35 & 2.2h $\\pm$ 0.6h & 1.3x & 144.0 $\\pm$ 71.2\n",
      "PASHA RRR p=1.0, t=0.05 & 93.49 $\\pm$ 0.78 & 0.7h $\\pm$ 0.0h & 4.4x & 3.0 $\\pm$ 0.0\n",
      "PASHA RRR p=0.5, t=0.05 & 93.76 $\\pm$ 0.31 & 2.1h $\\pm$ 0.6h & 1.4x & 140.9 $\\pm$ 69.7\n",
      "PASHA ARRR p=1.0, t=0.05 & 93.71 $\\pm$ 0.35 & 2.4h $\\pm$ 0.4h & 1.2x & 179.0 $\\pm$ 42.9\n",
      "PASHA ARRR p=0.5, t=0.05 & 93.81 $\\pm$ 0.30 & 2.5h $\\pm$ 0.4h & 1.2x & 181.0 $\\pm$ 40.9\n",
      "cifar100\n",
      "ASHA & 71.69 $\\pm$ 1.05 & 3.2h $\\pm$ 0.9h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA direct ranking & 71.69 $\\pm$ 1.05 & 2.8h $\\pm$ 0.7h & 1.1x & 200.0 $\\pm$ 0.0\n",
      "PASHA soft ranking $\\epsilon=0.01$ & 71.55 $\\pm$ 1.04 & 2.5h $\\pm$ 0.7h & 1.3x & 198.3 $\\pm$ 6.5\n",
      "PASHA soft ranking $\\epsilon=0.02$ & 70.94 $\\pm$ 0.85 & 2.0h $\\pm$ 0.5h & 1.6x & 160.5 $\\pm$ 62.9\n",
      "PASHA soft ranking $\\epsilon=0.025$ & 71.41 $\\pm$ 1.15 & 1.5h $\\pm$ 0.7h & 2.1x & 88.3 $\\pm$ 74.4\n",
      "PASHA soft ranking $\\epsilon=0.03$ & 71.00 $\\pm$ 1.38 & 1.0h $\\pm$ 0.5h & 3.2x & 39.4 $\\pm$ 63.4\n",
      "PASHA soft ranking $\\epsilon=0.05$ & 70.71 $\\pm$ 1.66 & 0.7h $\\pm$ 0.0h & 4.9x & 3.0 $\\pm$ 0.0\n",
      "PASHA soft ranking $1\\sigma$ & 71.56 $\\pm$ 1.03 & 2.5h $\\pm$ 0.6h & 1.3x & 184.1 $\\pm$ 40.5\n",
      "PASHA soft ranking $2\\sigma$ & 71.14 $\\pm$ 0.97 & 1.9h $\\pm$ 0.7h & 1.7x & 136.4 $\\pm$ 75.8\n",
      "PASHA soft ranking $3\\sigma$ & 71.63 $\\pm$ 1.60 & 1.0h $\\pm$ 0.3h & 3.3x & 20.2 $\\pm$ 25.3\n",
      "PASHA soft ranking mean distance & 71.51 $\\pm$ 0.99 & 2.4h $\\pm$ 0.5h & 1.4x & 189.8 $\\pm$ 30.3\n",
      "PASHA soft ranking median distance & 71.52 $\\pm$ 0.98 & 2.4h $\\pm$ 0.6h & 1.3x & 189.5 $\\pm$ 30.6\n",
      "PASHA RBO p=1.0, t=0.5 & 70.69 $\\pm$ 1.67 & 0.7h $\\pm$ 0.1h & 4.6x & 3.8 $\\pm$ 2.0\n",
      "PASHA RBO p=0.5, t=0.5 & 71.51 $\\pm$ 0.93 & 2.4h $\\pm$ 0.7h & 1.3x & 180.5 $\\pm$ 50.6\n",
      "PASHA RRR p=1.0, t=0.05 & 70.71 $\\pm$ 1.66 & 0.7h $\\pm$ 0.0h & 4.9x & 3.0 $\\pm$ 0.0\n",
      "PASHA RRR p=0.5, t=0.05 & 71.42 $\\pm$ 1.51 & 1.2h $\\pm$ 0.5h & 2.6x & 39.3 $\\pm$ 51.4\n",
      "PASHA ARRR p=1.0, t=0.05 & 70.80 $\\pm$ 1.70 & 0.8h $\\pm$ 0.4h & 3.8x & 22.9 $\\pm$ 51.3\n",
      "PASHA ARRR p=0.5, t=0.05 & 71.41 $\\pm$ 1.05 & 1.8h $\\pm$ 0.6h & 1.7x & 110.0 $\\pm$ 68.7\n",
      "ImageNet16-120\n",
      "ASHA & 45.63 $\\pm$ 0.81 & 8.8h $\\pm$ 2.2h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA direct ranking & 45.63 $\\pm$ 0.81 & 8.3h $\\pm$ 2.5h & 1.1x & 200.0 $\\pm$ 0.0\n",
      "PASHA soft ranking $\\epsilon=0.01$ & 45.52 $\\pm$ 0.89 & 7.0h $\\pm$ 1.5h & 1.3x & 185.7 $\\pm$ 36.1\n",
      "PASHA soft ranking $\\epsilon=0.02$ & 45.79 $\\pm$ 1.16 & 4.4h $\\pm$ 1.4h & 2.0x & 71.4 $\\pm$ 50.8\n",
      "PASHA soft ranking $\\epsilon=0.025$ & 46.01 $\\pm$ 1.00 & 3.2h $\\pm$ 1.0h & 2.8x & 28.6 $\\pm$ 27.7\n",
      "PASHA soft ranking $\\epsilon=0.03$ & 45.62 $\\pm$ 1.48 & 2.4h $\\pm$ 0.7h & 3.6x & 11.0 $\\pm$ 10.0\n",
      "PASHA soft ranking $\\epsilon=0.05$ & 44.90 $\\pm$ 1.42 & 1.8h $\\pm$ 0.0h & 5.0x & 3.0 $\\pm$ 0.0\n",
      "PASHA soft ranking $1\\sigma$ & 45.63 $\\pm$ 0.89 & 6.5h $\\pm$ 1.3h & 1.4x & 177.1 $\\pm$ 44.2\n",
      "PASHA soft ranking $2\\sigma$ & 45.39 $\\pm$ 1.22 & 4.5h $\\pm$ 1.4h & 1.9x & 91.2 $\\pm$ 58.0\n",
      "PASHA soft ranking $3\\sigma$ & 44.90 $\\pm$ 1.42 & 1.8h $\\pm$ 0.0h & 5.0x & 3.0 $\\pm$ 0.0\n",
      "PASHA soft ranking mean distance & 45.50 $\\pm$ 1.12 & 6.2h $\\pm$ 1.5h & 1.4x & 157.7 $\\pm$ 54.7\n",
      "PASHA soft ranking median distance & 45.67 $\\pm$ 0.95 & 6.3h $\\pm$ 1.6h & 1.4x & 156.3 $\\pm$ 52.2\n",
      "PASHA RBO p=1.0, t=0.5 & 44.90 $\\pm$ 1.42 & 1.8h $\\pm$ 0.0h & 5.0x & 3.0 $\\pm$ 0.0\n",
      "PASHA RBO p=0.5, t=0.5 & 45.24 $\\pm$ 1.13 & 6.4h $\\pm$ 1.3h & 1.4x & 148.3 $\\pm$ 56.9\n",
      "PASHA RRR p=1.0, t=0.05 & 44.90 $\\pm$ 1.42 & 1.8h $\\pm$ 0.0h & 5.0x & 3.0 $\\pm$ 0.0\n",
      "PASHA RRR p=0.5, t=0.05 & 44.90 $\\pm$ 1.42 & 1.8h $\\pm$ 0.0h & 5.0x & 3.0 $\\pm$ 0.0\n",
      "PASHA ARRR p=1.0, t=0.05 & 44.90 $\\pm$ 1.42 & 1.8h $\\pm$ 0.0h & 5.0x & 3.0 $\\pm$ 0.0\n",
      "PASHA ARRR p=0.5, t=0.05 & 44.90 $\\pm$ 1.42 & 1.8h $\\pm$ 0.0h & 5.0x & 3.0 $\\pm$ 0.0\n"
     ]
    }
   ],
   "source": [
    "for dataset_name in dataset_names:\n",
    "    print(dataset_name)\n",
    "    result_summary, reference_time = analyse_experiments(experiment_names_asha)\n",
    "    print('ASHA' + result_summary)\n",
    "    \n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_ranking, reference_time)\n",
    "    print('PASHA direct ranking' + result_summary)\n",
    "    \n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_e001, reference_time)\n",
    "    print('PASHA soft ranking $\\epsilon=0.01$' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_e002, reference_time)\n",
    "    print('PASHA soft ranking $\\epsilon=0.02$' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha, reference_time)\n",
    "    print('PASHA soft ranking $\\epsilon=0.025$' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_e003, reference_time)\n",
    "    print('PASHA soft ranking $\\epsilon=0.03$' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_e005, reference_time)\n",
    "    print('PASHA soft ranking $\\epsilon=0.05$' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_std1, reference_time)\n",
    "    print('PASHA soft ranking $1\\sigma$' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_std2, reference_time)\n",
    "    print('PASHA soft ranking $2\\sigma$' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_std3, reference_time)\n",
    "    print('PASHA soft ranking $3\\sigma$' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_mean_dst, reference_time)\n",
    "    print('PASHA soft ranking mean distance' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_med_dst, reference_time)\n",
    "    print('PASHA soft ranking median distance' + result_summary)\n",
    "    \n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_rbo_p1_t05, reference_time)\n",
    "    print('PASHA RBO p=1.0, t=0.5' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_rbo_p05_t05, reference_time)\n",
    "    print('PASHA RBO p=0.5, t=0.5' + result_summary)\n",
    "\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_rrr_p1_t005, reference_time)\n",
    "    print('PASHA RRR p=1.0, t=0.05' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_rrr_p05_t005, reference_time)\n",
    "    print('PASHA RRR p=0.5, t=0.05' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_arrr_p1_t005, reference_time)\n",
    "    print('PASHA ARRR p=1.0, t=0.05' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_arrr_p05_t005, reference_time)\n",
    "    print('PASHA ARRR p=0.5, t=0.05' + result_summary)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5acea9d8-61ca-4968-aea0-78debe67be59",
   "metadata": {},
   "source": [
    "We see there are also other ranking functions that work well."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3217caa4-eefe-4ffe-9816-8d486a651d4f",
   "metadata": {},
   "source": [
    "## Changes to the reduction factor"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0887a909-cdc5-40e8-ba76-e811c7126856",
   "metadata": {},
   "source": [
    "Reduction factor of 2:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "45781164-8e77-4b77-ac8b-fd5e5a3a632b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "experiment_names_pasha_rf2 = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_asha_rf2 = {dataset: [] for dataset in dataset_names}\n",
    "\n",
    "for dataset_name in dataset_names:\n",
    "    for nb201_random_seed in nb201_random_seeds:\n",
    "        for random_seed in random_seeds:\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', reduction_factor=2)\n",
    "            experiment_names_pasha_rf2[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'asha', reduction_factor=2)\n",
    "            experiment_names_asha_rf2[dataset_name].append((experiment_name, nb201_random_seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "bf4a9ae3-57ed-4d0f-826a-34eadc7f7dab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cifar10\n",
      "ASHA & 93.88 $\\pm$ 0.27 & 3.6h $\\pm$ 1.1h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 93.77 $\\pm$ 0.56 & 2.6h $\\pm$ 0.6h & 1.4x & 134.9 $\\pm$ 52.7\n",
      "cifar100\n",
      "ASHA & 71.67 $\\pm$ 0.84 & 3.8h $\\pm$ 1.0h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 71.79 $\\pm$ 1.38 & 2.1h $\\pm$ 0.7h & 1.8x & 101.5 $\\pm$ 55.0\n",
      "ImageNet16-120\n",
      "ASHA & 46.09 $\\pm$ 0.68 & 11.9h $\\pm$ 4.0h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 45.72 $\\pm$ 1.36 & 4.2h $\\pm$ 1.7h & 2.8x & 49.5 $\\pm$ 44.2\n"
     ]
    }
   ],
   "source": [
    "for dataset_name in dataset_names:\n",
    "    print(dataset_name)\n",
    "    result_summary, reference_time = analyse_experiments(experiment_names_asha_rf2)\n",
    "    print('ASHA' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_rf2, reference_time)\n",
    "    print('PASHA' + result_summary)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59e1474b-4700-4970-856c-5cecf67ce054",
   "metadata": {},
   "source": [
    "Reduction factor of 4:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "43590a29-e329-4436-ab50-3df5fb6d426b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "experiment_names_pasha_rf4 = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_asha_rf4 = {dataset: [] for dataset in dataset_names}\n",
    "\n",
    "for dataset_name in dataset_names:\n",
    "    for nb201_random_seed in nb201_random_seeds:\n",
    "        for random_seed in random_seeds:\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', reduction_factor=4)\n",
    "            experiment_names_pasha_rf4[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'asha', reduction_factor=4)\n",
    "            experiment_names_asha_rf4[dataset_name].append((experiment_name, nb201_random_seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "b1b39983-985c-4204-b609-ca68c9e7f034",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cifar10\n",
      "ASHA & 93.75 $\\pm$ 0.28 & 2.4h $\\pm$ 0.6h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 93.72 $\\pm$ 0.30 & 2.1h $\\pm$ 0.5h & 1.2x & 154.7 $\\pm$ 64.1\n",
      "cifar100\n",
      "ASHA & 71.43 $\\pm$ 1.13 & 2.7h $\\pm$ 0.9h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 71.76 $\\pm$ 1.20 & 1.1h $\\pm$ 0.5h & 2.4x & 59.2 $\\pm$ 74.8\n",
      "ImageNet16-120\n",
      "ASHA & 45.43 $\\pm$ 0.98 & 7.9h $\\pm$ 3.0h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 45.48 $\\pm$ 1.36 & 3.2h $\\pm$ 1.7h & 2.4x & 40.3 $\\pm$ 49.8\n"
     ]
    }
   ],
   "source": [
    "for dataset_name in dataset_names:\n",
    "    print(dataset_name)\n",
    "    result_summary, reference_time = analyse_experiments(experiment_names_asha_rf4)\n",
    "    print('ASHA' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_rf4, reference_time)\n",
    "    print('PASHA' + result_summary)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2cfcc7e-70d2-4c55-baf4-6d899ae0d4ff",
   "metadata": {},
   "source": [
    "PASHA leads to significant speedup for both reduction factors."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef25fa85-ee79-4bf7-a049-2a1494d86549",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Combination with Bayesian Optimization\n",
    "We explore if PASHA can be successfully combined with more complex search strategies based on Bayesian Optimization.\n",
    "\n",
    "These experiments take longer to run so we need to run them separately using a script that can be run in a distributed way across many nodes (we still use the simulator backend but Gaussian processes take some time to run to estimate the best parameters). We provide a script called `run_bo_experiments.py` that can run such experiments (the script accepts arguments such as the random seed selected)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ffb727a3-2fa5-45ba-b10a-df0e12ebc8fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# need to load the names of the configurations from the json file where we store it\n",
    "with open('bo_experiment_details.json', 'r') as f:\n",
    "    bo_experiment_details = json.load(f)\n",
    "\n",
    "experiment_names_pasha_bo = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_asha_bo = {dataset: [] for dataset in dataset_names}\n",
    "\n",
    "# load the details of the experiments\n",
    "for detail_dict in bo_experiment_details:\n",
    "    if detail_dict['scheduler'] == 'asha-bo':\n",
    "        experiment_names_asha_bo[detail_dict['dataset_name']].append((detail_dict['experiment_name'], detail_dict['nb201_random_seed']))\n",
    "    elif detail_dict['scheduler'] == 'pasha-bo':\n",
    "        experiment_names_pasha_bo[detail_dict['dataset_name']].append((detail_dict['experiment_name'], detail_dict['nb201_random_seed']))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "f793138f-1889-49bc-8045-f607b8c9aba7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cifar10\n",
      "ASHA & 94.10 $\\pm$ 0.22 & 5.0h $\\pm$ 1.3h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 94.17 $\\pm$ 0.17 & 4.4h $\\pm$ 1.7h & 1.1x & 156.7 $\\pm$ 62.4\n",
      "cifar100\n",
      "ASHA & 72.76 $\\pm$ 0.64 & 5.6h $\\pm$ 2.0h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 72.07 $\\pm$ 1.80 & 3.8h $\\pm$ 1.7h & 1.5x & 157.9 $\\pm$ 72.7\n",
      "ImageNet16-120\n",
      "ASHA & 45.79 $\\pm$ 1.18 & 13.8h $\\pm$ 5.0h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 45.02 $\\pm$ 1.15 & 6.2h $\\pm$ 5.8h & 2.2x & 50.0 $\\pm$ 75.4\n"
     ]
    }
   ],
   "source": [
    "for dataset_name in dataset_names:\n",
    "    print(dataset_name)\n",
    "    result_summary, reference_time = analyse_experiments(experiment_names_asha_bo)\n",
    "    print('ASHA' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_bo, reference_time)\n",
    "    print('PASHA' + result_summary)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44eab0c9-0426-4894-88d7-d0b96a57ced3",
   "metadata": {},
   "source": [
    "Alternative would be to run the following two cells (uncommented) but it would take several hours to run the first one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "9b5fbcfd-9e54-4323-818a-33c5772fff2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%capture\n",
    "\n",
    "# experiment_names_pasha_bo = {dataset: [] for dataset in dataset_names}\n",
    "# experiment_names_asha_bo = {dataset: [] for dataset in dataset_names}\n",
    "\n",
    "# for dataset_name in dataset_names:\n",
    "#     for nb201_random_seed in nb201_random_seeds:\n",
    "#         for random_seed in random_seeds:\n",
    "#             experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha-bo')\n",
    "#             experiment_names_pasha_bo[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "#             experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'asha-bo')\n",
    "#             experiment_names_asha_bo[dataset_name].append((experiment_name, nb201_random_seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "0cf8a243-ee75-41eb-bd9c-26fcb4eadd91",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for dataset_name in dataset_names:\n",
    "#     print(dataset_name)\n",
    "#     result_summary, reference_time = analyse_experiments(experiment_names_asha_bo)\n",
    "#     print('ASHA' + result_summary)\n",
    "#     result_summary, _ = analyse_experiments(experiment_names_pasha_bo, reference_time)\n",
    "#     print('PASHA' + result_summary)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8ca730f-f07c-4f93-811f-ae36b767541a",
   "metadata": {},
   "source": [
    "To generate the lists of configurations for a slurm script and easily use the `run_bo_experiments.py`, you can run the following:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "0bd0936c-0627-451d-b0d2-9784b50efa94",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # we need to generate the lists with values to use\n",
    "# dataset_names = ['cifar10', 'cifar100', 'ImageNet16-120']\n",
    "# approaches = ['pasha-bo', 'asha-bo']\n",
    "# random_seeds = [31415927, 0, 1234, 3458, 7685]\n",
    "# nb201_random_seeds = [0, 1, 2]\n",
    "\n",
    "# dataset_names_list = []\n",
    "# approaches_list = []\n",
    "# random_seeds_list = []\n",
    "# nb201_random_seeds_list = []\n",
    "\n",
    "# for dataset_name in dataset_names:\n",
    "#     for nb201_random_seed in nb201_random_seeds:\n",
    "#         for random_seed in random_seeds:\n",
    "#             for approach in approaches:\n",
    "#                 dataset_names_list.append(dataset_name)\n",
    "#                 approaches_list.append(approach)\n",
    "#                 random_seeds_list.append(random_seed)\n",
    "#                 nb201_random_seeds_list.append(nb201_random_seed)\n",
    "    \n",
    "# print('DATASET_NAME=(')\n",
    "# for dataset_name in dataset_names_list:\n",
    "#     print(dataset_name)\n",
    "# print(')')\n",
    "# print()\n",
    "\n",
    "# print('APPROACH=(')\n",
    "# for approach in approaches_list:\n",
    "#     print(approach)\n",
    "# print(')')\n",
    "# print()\n",
    "\n",
    "# print('RANDOM_SEED=(')\n",
    "# for random_seed in random_seeds_list:\n",
    "#     print(random_seed)\n",
    "# print(')')\n",
    "# print()\n",
    "\n",
    "# print('NB201_RANDOM_SEED=(')\n",
    "# for nb201_random_seed in nb201_random_seeds_list:\n",
    "#     print(nb201_random_seed)\n",
    "# print(')')\n",
    "# print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
