{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e400615e-3e2b-4410-9f56-03461c6900bc",
   "metadata": {},
   "source": [
    "# PASHA: Efficient HPO with Progressive Resource Allocation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f3b015b-603e-4c19-863f-b91d1291a613",
   "metadata": {},
   "source": [
    "Hyperparameter optimization and neural architecture search are important for obtaining\n",
    "well-performing models, but they are costly in practice, especially for large datasets.\n",
    "To decrease the cost, practitioners adopt heuristics with mixed results. We propose an approach \n",
    "to tackle the challenge: start with a small amount of resources and progressively increase them\n",
    "as needed. Our approach named PASHA measures the stability of ranking of different hyperparameter\n",
    "configurations and stops increasing the resources if the ranking becomes stable, returning\n",
    "the best configuration. Our experiments show PASHA significantly accelerates multi-fidelity methods\n",
    "and obtains similarly well-performing hyperparameters."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2112eca-63c9-4b13-8cb0-fd3864675191",
   "metadata": {},
   "source": [
    "Outline:\n",
    "* Initial pre-processing\n",
    "* Main experiments on NASBench201- with PASHA, ASHA and the baselines\n",
    "* Alternative ranking functions\n",
    "* Changes to the reduction factor\n",
    "* Combination with Bayesian Optimization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3718f715-01a9-420b-be42-92ef138ae7a5",
   "metadata": {},
   "source": [
    "The libraries required to run the notebook are the same as the ones required for SyneTune."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa2c80fa-5268-49d7-9ebc-0bd4c52f2dbf",
   "metadata": {},
   "source": [
    "Start by importing the relevant libraries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ba934ed4-2655-4ec2-97f7-1106c29171a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import logging\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "\n",
    "from benchmarking.blackbox_repository import load\n",
    "from benchmarking.blackbox_repository.tabulated_benchmark import BlackboxRepositoryBackend\n",
    "from benchmarking.definitions.definition_nasbench201 import nasbench201_benchmark, nasbench201_default_params\n",
    "from syne_tune.backend.simulator_backend.simulator_callback import SimulatorCallback\n",
    "from syne_tune.optimizer.schedulers.hyperband import HyperbandScheduler\n",
    "from syne_tune.optimizer.baselines import baselines\n",
    "from syne_tune.tuner import Tuner\n",
    "from syne_tune.stopping_criterion import StoppingCriterion\n",
    "from syne_tune.experiments import load_experiment\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17b685b1-824d-477e-a9e3-ecbc06d19adf",
   "metadata": {},
   "source": [
    "Define our settings:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "409df06e-280d-4a87-8afe-40a7db28e13c",
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_valid_error_dim = 0\n",
    "metric_runtime_dim = 2\n",
    "dataset_names = ['cifar10', 'cifar100', 'ImageNet16-120']\n",
    "epoch_names = ['val_acc_epoch_' + str(e) for e in range(200)]\n",
    "random_seeds = [31415927, 0, 1234, 3458, 7685]\n",
    "nb201_random_seeds = [0, 1, 2]\n",
    "n_workers = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "810f0609-63a8-4714-b328-9489c601ef9a",
   "metadata": {},
   "source": [
    "# Initial pre-processing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6f2cd7e-c425-47eb-88c1-3fd5a248d8bd",
   "metadata": {},
   "source": [
    "Load NASBench201 benchmark so that we can analyse the performance of various approaches:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f33be501-3945-4603-9521-4d65a4404145",
   "metadata": {},
   "outputs": [],
   "source": [
    "bb_dict = load('nasbench201')\n",
    "df_dict = {}\n",
    "\n",
    "for seed in nb201_random_seeds:\n",
    "    df_dict[seed] = {}\n",
    "    for dataset in dataset_names:\n",
    "        # create a dataframe with the validation accuracies for various epochs\n",
    "        df_val_acc = pd.DataFrame((1.0-bb_dict[dataset].objectives_evaluations[:, seed, :, metric_valid_error_dim])\n",
    "                                  * 100, columns=['val_acc_epoch_' + str(e) for e in range(200)])\n",
    "        # add a new column with the best validation accuracy\n",
    "        df_val_acc['val_acc_best'] = df_val_acc[epoch_names].max(axis=1)\n",
    "        # create a dataframe with the hyperparameter values\n",
    "        df_hp = bb_dict[dataset].hyperparameters\n",
    "        # create a dataframe with the times it takes to run an epoch\n",
    "        df_time = pd.DataFrame(bb_dict[dataset].objectives_evaluations[:, seed, :, metric_runtime_dim][:, -1], columns=['eval_time_epoch'])    \n",
    "        # combine all smaller dataframes into one dataframe for each NASBench201 random seed and dataset\n",
    "        df_dict[seed][dataset] = pd.concat([df_hp, df_val_acc, df_time], axis=1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f7cc4d8-0c4c-420b-9b2c-a1332c720b39",
   "metadata": {},
   "source": [
    "Motivation to measure best validation accuracy: NASBench201 provides validation and test errors in an inconsistent format and in fact we can only get the errors for each epoch on their combined validation and test sets for CIFAR-100 and ImageNet16-120. As a tradeoff, we use the combined validation and test sets as the validation set. Consequently, there is no test set which we can use for additional evaluation and so we use the best validation accuracy as the final evaluation metric."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "153714ad-7ac2-403d-943b-495d9f7b9287",
   "metadata": {},
   "source": [
    "## Main experiments\n",
    "We perform experiments on NASBench201 - CIFAR-10, CIFAR-100 and ImageNet16-120 datasets. We use PASHA, ASHA (promotion type) and the relevant baselines - one epoch and random.\n",
    "\n",
    "Grace period is the name for minimum resources in SyneTune."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f5dfa59-0b03-466d-b91d-95a85c3e7a75",
   "metadata": {},
   "source": [
    "Define functions for running the experiments and analysing them:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ae6d2cfe-bdc7-4ba5-b143-1de2a5def523",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_experiment(dataset_name, random_seed, nb201_random_seed, hpo_approach, reduction_factor=None, rung_system_kwargs={'ranking_criterion': 'soft_ranking', 'epsilon': 0.025}):\n",
    "    \"\"\"\n",
    "    Function to run a NASBench201 experiment. It is similar to the NASBench201 example script\n",
    "    in syne-tune but extended to make it simple to run our experiments.\n",
    "    \n",
    "    When describing the following parameters we say what values we use, but feel free to also use other values.\n",
    "    \n",
    "    :param dataset_name: one of 'cifar10', 'cifar100', 'ImageNet16-120'\n",
    "    :param random_seed: one of 31415927, 0, 1234, 3458, 7685\n",
    "    :param nb201_random_seed: one of 0, 1, 2\n",
    "    :param hpo_approach: one of 'pasha', 'asha', 'pasha-bo', 'asha-bo'\n",
    "    :param reduction_factor: by default None (resulting in using the default value 3) or 2, 4\n",
    "    :param rung_system_kwargs: dictionary of ranking criterion (str) and epsilon or epsilon scaling (both float)\n",
    "    :return: tuner.name\n",
    "    \n",
    "    \"\"\"\n",
    "    \n",
    "    # this function is similar to the NASBench201 example script\n",
    "    logging.getLogger().setLevel(logging.WARNING)\n",
    "\n",
    "    default_params = nasbench201_default_params({'backend': 'simulated'})\n",
    "    benchmark = nasbench201_benchmark(default_params)\n",
    "    # benchmark must be tabulated to support simulation\n",
    "    assert benchmark.get('supports_simulated', False)\n",
    "    mode = benchmark['mode']\n",
    "    metric = benchmark['metric']\n",
    "    blackbox_name = benchmark.get('blackbox_name')\n",
    "    # NASBench201 is a blackbox from the repository\n",
    "    assert blackbox_name is not None\n",
    "\n",
    "    config_space = benchmark['config_space']\n",
    "\n",
    "    # simulator back-end specialized to tabulated blackboxes\n",
    "    backend = BlackboxRepositoryBackend(\n",
    "        blackbox_name=blackbox_name,\n",
    "        elapsed_time_attr=benchmark['elapsed_time_attr'],\n",
    "        time_this_resource_attr=benchmark.get('time_this_resource_attr'),\n",
    "        dataset=dataset_name,\n",
    "        seed=nb201_random_seed)\n",
    "\n",
    "    # set logging of the simulator backend to WARNING level\n",
    "    logging.getLogger('syne_tune.backend.simulator_backend.simulator_backend').setLevel(logging.WARNING)\n",
    "    \n",
    "    if not reduction_factor:\n",
    "        reduction_factor = default_params['reduction_factor']\n",
    "\n",
    "    # we support various schedulers within the function\n",
    "    if hpo_approach == 'pasha':\n",
    "        scheduler = baselines['PASHA'](\n",
    "            config_space,\n",
    "            max_t=default_params['max_resource_level'],\n",
    "            grace_period=default_params['grace_period'],\n",
    "            reduction_factor=reduction_factor,\n",
    "            resource_attr=benchmark['resource_attr'],\n",
    "            mode=mode,\n",
    "            metric=metric,\n",
    "            random_seed=random_seed,\n",
    "            rung_system_kwargs=rung_system_kwargs)\n",
    "    elif hpo_approach == 'asha':\n",
    "        scheduler = baselines['ASHA'](\n",
    "            config_space,\n",
    "            max_t=default_params['max_resource_level'],\n",
    "            grace_period=default_params['grace_period'],\n",
    "            reduction_factor=reduction_factor,\n",
    "            resource_attr=benchmark['resource_attr'],\n",
    "            mode=mode,\n",
    "            type='promotion',\n",
    "            metric=metric,\n",
    "            random_seed=random_seed)\n",
    "    elif hpo_approach == 'pasha-bo':\n",
    "        scheduler = HyperbandScheduler(\n",
    "            config_space,\n",
    "            max_t=default_params['max_resource_level'],\n",
    "            grace_period=default_params['grace_period'],\n",
    "            reduction_factor=reduction_factor,\n",
    "            resource_attr=benchmark['resource_attr'],\n",
    "            mode=mode,\n",
    "            searcher='bayesopt',\n",
    "            type='pasha',\n",
    "            metric=metric,\n",
    "            random_seed=random_seed,\n",
    "            rung_system_kwargs=rung_system_kwargs)\n",
    "    elif hpo_approach == 'asha-bo':\n",
    "        scheduler = HyperbandScheduler(\n",
    "            config_space,\n",
    "            max_t=default_params['max_resource_level'],\n",
    "            grace_period=default_params['grace_period'],\n",
    "            reduction_factor=reduction_factor,\n",
    "            resource_attr=benchmark['resource_attr'],\n",
    "            mode=mode,\n",
    "            searcher='bayesopt',\n",
    "            type='promotion',\n",
    "            metric=metric,\n",
    "            random_seed=random_seed)\n",
    "    else:\n",
    "        raise ValueError('The selected scheduler is not implemented')\n",
    "\n",
    "    stop_criterion = StoppingCriterion(max_num_trials_started=256)\n",
    "    # printing the status during tuning takes a lot of time, and so does\n",
    "    # storing results\n",
    "    print_update_interval = 700\n",
    "    results_update_interval = 300\n",
    "    # it is important to set `sleep_time` to 0 here (mandatory for simulator\n",
    "    # backend)\n",
    "\n",
    "    tuner = Tuner(\n",
    "        backend=backend,\n",
    "        scheduler=scheduler,\n",
    "        stop_criterion=stop_criterion,\n",
    "        n_workers=n_workers,\n",
    "        sleep_time=0,\n",
    "        results_update_interval=results_update_interval,\n",
    "        print_update_interval=print_update_interval,\n",
    "        # this callback is required in order to make things work with the\n",
    "        # simulator callback. It makes sure that results are stored with\n",
    "        # simulated time (rather than real time), and that the time_keeper\n",
    "        # is advanced properly whenever the tuner loop sleeps\n",
    "        callbacks=[SimulatorCallback()],\n",
    "    )\n",
    "    \n",
    "    tuner.run()\n",
    "    \n",
    "    return tuner.name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0ea790f4-1791-4f98-97e9-bbe67b13b3b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyse_experiments(experiment_names_dict, reference_time=None):\n",
    "    \"\"\"\n",
    "    Function to analyse the experiments that we run with run_experiment function.\n",
    "    \n",
    "    :param experiment_names_dict: dictionary mapping the dataset names to tuples of\n",
    "        experiment names and NASBench201 random seeds\n",
    "    :reference_time: optional argument with the time it takes to run the standard method - e.g. ASHA\n",
    "    :return: tuple of a line to display (string reporting the experiment results) and \n",
    "        the mean of the runtimes that can be used as reference time for other approaches\n",
    "    \"\"\"\n",
    "    val_acc_best_list = []\n",
    "    max_rsc_list = []\n",
    "    runtime_list = []\n",
    "    \n",
    "    for experiment_name, nb201_random_seed in experiment_names_dict[dataset_name]:\n",
    "        experiment_results = load_experiment(experiment_name)\n",
    "        best_cfg = experiment_results.results['metric_valid_error'].argmin()\n",
    "        \n",
    "        # find the best validation accuracy of the corresponding entry in NASBench201\n",
    "        table_hp_names = ['hp_x' + str(hp_idx) for hp_idx in range(6)]\n",
    "        results_hp_names = ['config_hp_x' + str(hp_idx) for hp_idx in range(6)]\n",
    "        condition = (df_dict[nb201_random_seed][dataset_name][table_hp_names] == experiment_results.results[results_hp_names].iloc[best_cfg].tolist()).all(axis=1)\n",
    "        val_acc_best = df_dict[nb201_random_seed][dataset_name][condition]['val_acc_best'].values[0]  # there is only one item in the list\n",
    "        val_acc_best_list.append(val_acc_best)\n",
    "        max_rsc_list.append(experiment_results.results['hp_epoch'].max())\n",
    "        runtime_list.append(experiment_results.results['st_tuner_time'].max())\n",
    "        \n",
    "    line = ' & {:.2f} $\\pm$ {:.2f}'.format(np.mean(val_acc_best_list), np.std(val_acc_best_list))\n",
    "    line += ' & {:.1f}h $\\pm$ {:.1f}h'.format(np.mean(runtime_list)/3600, np.std(runtime_list)/3600)\n",
    "    if reference_time:\n",
    "        line += ' & {:.1f}x'.format(reference_time/np.mean(runtime_list))\n",
    "    else:\n",
    "        line += ' & {:.1f}x'.format(np.mean(runtime_list)/np.mean(runtime_list))\n",
    "    line += ' & {:.1f} $\\pm$ {:.1f}'.format(np.mean(max_rsc_list), np.std(max_rsc_list))\n",
    "    \n",
    "    return line, np.mean(runtime_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7a556ca5-1021-4b67-b579-707aaeb6e387",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_one_epoch_baseline():\n",
    "    \"\"\"\n",
    "    Function to compute the performance of a simple one epoch baseline.\n",
    "    :return: a line to display (string reporting the experiment results)\n",
    "    \"\"\"\n",
    "    best_val_obj_list = []\n",
    "    total_time_list = []\n",
    "    \n",
    "    for nb201_random_seed in nb201_random_seeds:\n",
    "        for random_seed in random_seeds:\n",
    "            # randomly sample 256 configurations for the given dataset and NASBench201 seed\n",
    "            # use the same seeds as for our other experiments\n",
    "            random.seed(random_seed)\n",
    "            cfg_list = random.sample(range(len(df_dict[nb201_random_seed][dataset_name])), 256)\n",
    "            selected_subset = df_dict[nb201_random_seed][dataset_name].iloc[cfg_list]\n",
    "            # find configuration with the best performance after doing one epoch\n",
    "            max_idx = selected_subset['val_acc_epoch_0'].argmax()\n",
    "            best_configuration = selected_subset.iloc[max_idx]\n",
    "            # find the best validation accuracy of the selected configuration\n",
    "            # as that is the metric that we compare \n",
    "            best_val_obj = best_configuration[epoch_names].max()\n",
    "\n",
    "            # we also need to calculate the time it took for this\n",
    "            # taking into account the number of workers\n",
    "            total_time = selected_subset['eval_time_epoch'].sum() / n_workers\n",
    "\n",
    "            best_val_obj_list.append(best_val_obj)\n",
    "            total_time_list.append(total_time)\n",
    "\n",
    "    line = ' & {:.2f} $\\pm$ {:.2f}'.format(np.mean(best_val_obj_list), np.std(best_val_obj_list))\n",
    "    line += ' & {:.1f}h $\\pm$ {:.1f}h'.format(np.mean(total_time_list)/3600, np.std(total_time_list)/3600)\n",
    "    line += ' & {:.1f}x'.format(reference_time/np.mean(total_time_list))\n",
    "    line += ' & 1.0 $\\pm$ 0.0'\n",
    "\n",
    "    return line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "12640ea4-4932-4c87-8577-bad254eacd1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_random_baseline():\n",
    "    \"\"\"\n",
    "    Function to compute the performance of a simple random configuration baseline.\n",
    "    \n",
    "    We consider a ten times larger number of configurations in this case to get a better\n",
    "    estimate of the performance of a random configuration.\n",
    "\n",
    "    :return: a line to display (string reporting the experiment results)\n",
    "    \"\"\"\n",
    "    random.seed(0)\n",
    "    random_seeds_rb = random.sample(range(999999), 256 * 10)\n",
    "\n",
    "    best_val_obj_list = []\n",
    "    total_time_list = []\n",
    "\n",
    "    for nb201_random_seed in nb201_random_seeds:\n",
    "        for random_seed in random_seeds_rb:\n",
    "            random.seed(random_seed)\n",
    "            # select the random configurations\n",
    "            cfg_list = random.sample(range(len(df_dict[nb201_random_seed][dataset_name])), 1)\n",
    "            selected_configuration = df_dict[nb201_random_seed][dataset_name].iloc[cfg_list]\n",
    "            # find the best validation accuracy of the selected configuration\n",
    "            # as that is the metric that we compare \n",
    "            best_val_obj = selected_configuration[epoch_names].max()\n",
    "\n",
    "            # we also need to calculate the time it took for this\n",
    "            total_time = 0.0\n",
    "\n",
    "            best_val_obj_list.append(best_val_obj)\n",
    "            total_time_list.append(total_time)\n",
    "\n",
    "    line = ' & {:.2f} $\\pm$ {:.2f}'.format(np.mean(best_val_obj_list), np.std(best_val_obj_list))\n",
    "    line += ' & {:.1f}h $\\pm$ {:.1f}h'.format(np.mean(total_time_list)/3600, np.std(total_time_list)/3600)\n",
    "    line += ' & NA'\n",
    "    line += ' & 0.0 $\\pm$ 0.0'\n",
    "\n",
    "    return line"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e3b3def-b25e-42d2-9939-990f67d60e56",
   "metadata": {},
   "source": [
    "Run the main experiments with PASHA, ASHA and the baselines:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "add9dbfc-c2da-47d9-9369-6c4ad8776538",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "experiment_names_pasha = {dataset: [] for dataset in dataset_names}\n",
    "experiment_names_asha = {dataset: [] for dataset in dataset_names}\n",
    "\n",
    "for dataset_name in dataset_names:\n",
    "    for nb201_random_seed in nb201_random_seeds:\n",
    "        for random_seed in random_seeds:\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha')\n",
    "            experiment_names_pasha[dataset_name].append((experiment_name, nb201_random_seed))\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'asha')\n",
    "            experiment_names_asha[dataset_name].append((experiment_name, nb201_random_seed))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f35aded-e866-4d43-b4e4-9046fed69370",
   "metadata": {},
   "source": [
    "Analyse the experiments:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "490bdbb7-5ee1-47ae-814f-e3af61471058",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cifar10\n",
      "ASHA & 93.85 $\\pm$ 0.25 & 3.0h $\\pm$ 0.6h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 93.78 $\\pm$ 0.31 & 2.3h $\\pm$ 0.5h & 1.3x & 144.5 $\\pm$ 59.4\n",
      "One epoch baseline  & 93.30 $\\pm$ 0.61 & 0.3h $\\pm$ 0.0h & 8.5x & 1.0 $\\pm$ 0.0\n",
      "Random baseline  & 72.93 $\\pm$ 19.55 & 0.0h $\\pm$ 0.0h & NA & 0.0 $\\pm$ 0.0\n",
      "cifar100\n",
      "ASHA & 71.69 $\\pm$ 1.05 & 3.2h $\\pm$ 0.9h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 71.41 $\\pm$ 1.15 & 1.5h $\\pm$ 0.7h & 2.1x & 88.3 $\\pm$ 74.4\n",
      "One epoch baseline  & 65.57 $\\pm$ 5.53 & 0.3h $\\pm$ 0.0h & 9.2x & 1.0 $\\pm$ 0.0\n",
      "Random baseline  & 42.98 $\\pm$ 18.34 & 0.0h $\\pm$ 0.0h & NA & 0.0 $\\pm$ 0.0\n",
      "ImageNet16-120\n",
      "ASHA & 45.63 $\\pm$ 0.81 & 8.8h $\\pm$ 2.2h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA & 46.01 $\\pm$ 1.00 & 3.2h $\\pm$ 1.0h & 2.8x & 28.6 $\\pm$ 27.7\n",
      "One epoch baseline  & 41.42 $\\pm$ 4.98 & 1.0h $\\pm$ 0.0h & 8.8x & 1.0 $\\pm$ 0.0\n",
      "Random baseline  & 20.97 $\\pm$ 10.01 & 0.0h $\\pm$ 0.0h & NA & 0.0 $\\pm$ 0.0\n"
     ]
    }
   ],
   "source": [
    "for dataset_name in dataset_names:\n",
    "    print(dataset_name)\n",
    "    result_summary, reference_time = analyse_experiments(experiment_names_asha)\n",
    "    print('ASHA' + result_summary)\n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha, reference_time)\n",
    "    print('PASHA' + result_summary)\n",
    "    result_summary = compute_one_epoch_baseline()\n",
    "    print('One epoch baseline', result_summary)\n",
    "    result_summary = compute_random_baseline()\n",
    "    print('Random baseline', result_summary)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b844aa4-3af0-479f-8b3e-40de7198a4be",
   "metadata": {},
   "source": [
    "We see PASHA obtains a similar accuracy as ASHA, but it can find a well-performing configuration much faster.\n",
    "\n",
    "The configurations found by one epoch baseline and random baseline usually obtain significantly lower accuracies, making them unsuitable for finding well-performing configurations."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df522ee5-c373-47c5-b5d3-34c859f83ae1",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Alternative ranking functions\n",
    "We show how to run experiments using an alternative ranking function, more specifically soft ranking with $\\epsilon=2\\sigma$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "84ed0937-782d-4902-b108-4907eb49a6bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "experiment_names_pasha_std2 = {dataset: [] for dataset in dataset_names}\n",
    "\n",
    "for dataset_name in dataset_names:\n",
    "    for nb201_random_seed in nb201_random_seeds:\n",
    "        for random_seed in random_seeds:\n",
    "            experiment_name = run_experiment(dataset_name, random_seed, nb201_random_seed, 'pasha', rung_system_kwargs={'ranking_criterion': 'soft_ranking_std', 'epsilon_scaling': 2.0})\n",
    "            experiment_names_pasha_std2[dataset_name].append((experiment_name, nb201_random_seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "437f0bb5-b331-4cc2-939d-ac0941ffb780",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cifar10\n",
      "ASHA & 93.85 $\\pm$ 0.25 & 3.0h $\\pm$ 0.6h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA soft ranking $\\epsilon=0.025$ & 93.78 $\\pm$ 0.31 & 2.3h $\\pm$ 0.5h & 1.3x & 144.5 $\\pm$ 59.4\n",
      "PASHA soft ranking $2\\sigma$ & 93.88 $\\pm$ 0.28 & 1.9h $\\pm$ 0.5h & 1.5x & 132.7 $\\pm$ 68.7\n",
      "cifar100\n",
      "ASHA & 71.69 $\\pm$ 1.05 & 3.2h $\\pm$ 0.9h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA soft ranking $\\epsilon=0.025$ & 71.41 $\\pm$ 1.15 & 1.5h $\\pm$ 0.7h & 2.1x & 88.3 $\\pm$ 74.4\n",
      "PASHA soft ranking $2\\sigma$ & 71.14 $\\pm$ 0.97 & 1.9h $\\pm$ 0.7h & 1.7x & 136.4 $\\pm$ 75.8\n",
      "ImageNet16-120\n",
      "ASHA & 45.63 $\\pm$ 0.81 & 8.8h $\\pm$ 2.2h & 1.0x & 200.0 $\\pm$ 0.0\n",
      "PASHA soft ranking $\\epsilon=0.025$ & 46.01 $\\pm$ 1.00 & 3.2h $\\pm$ 1.0h & 2.8x & 28.6 $\\pm$ 27.7\n",
      "PASHA soft ranking $2\\sigma$ & 45.39 $\\pm$ 1.22 & 4.5h $\\pm$ 1.4h & 1.9x & 91.2 $\\pm$ 58.0\n"
     ]
    }
   ],
   "source": [
    "for dataset_name in dataset_names:\n",
    "    print(dataset_name)\n",
    "    result_summary, reference_time = analyse_experiments(experiment_names_asha)\n",
    "    print('ASHA' + result_summary)\n",
    "    \n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha, reference_time)\n",
    "    print('PASHA soft ranking $\\epsilon=0.025$' + result_summary)\n",
    "    \n",
    "    result_summary, _ = analyse_experiments(experiment_names_pasha_std2, reference_time)\n",
    "    print('PASHA soft ranking $2\\sigma$' + result_summary)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3217caa4-eefe-4ffe-9816-8d486a651d4f",
   "metadata": {},
   "source": [
    "## Changes to the reduction factor\n",
    "To run experiments with a different reduction factor, it is enough to specify the value for `reduction_factor` argument provided to `run_experiment` function."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef25fa85-ee79-4bf7-a049-2a1494d86549",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Combination with Bayesian Optimization\n",
    "To run experiments with a Bayesian Optimization search strategy, you need to select `'pasha-bo'` or `'asha-bo'` for `hpo_approach` argument provided to `run_experiment` function. Note these experiments take longer to run because Gaussian processes are used."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
