{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "36b1924a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import the search space\n",
    "from naslib.search_spaces import NasBench201SearchSpace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8e27a075",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a new search space object. This object doesn't have an architecture\n",
    "# assigned to it yet\n",
    "graph = NasBench201SearchSpace(n_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8c136cd7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, 1, 4, 2, 3, 2)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Sample a random architecture\n",
    "# You can call this method only once.\n",
    "graph.sample_random_architecture()\n",
    "\n",
    "# Get the NASLib representation of this architecture\n",
    "graph.get_hash()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0fbcc3bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Comb_op is ignored if subgraph is defined!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([5, 10])\n"
     ]
    }
   ],
   "source": [
    "# This graph is now a NAS-Bench-201 model, which can be used for training\n",
    "# Forward pass some dummy data through it to see it in action\n",
    "\n",
    "import torch\n",
    "\n",
    "x = torch.randn(5, 3, 32, 32) # (Batch_size, Num_channels, Height, Width)\n",
    "logits = graph(x)\n",
    "\n",
    "print(logits.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "46619456",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'|avg_pool_3x3~0|+|none~0|nor_conv_3x3~1|+|avg_pool_3x3~0|nor_conv_1x1~1|nor_conv_3x3~2|'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Import code to convert NASLib graph to the original NAS-Bench-201 representation\n",
    "from naslib.search_spaces.nasbench201.conversions import convert_naslib_to_str as convert_naslib_nb201_to_str\n",
    "\n",
    "# Get the string representation of this model\n",
    "convert_naslib_nb201_to_str(graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9c46eacf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parent graph: (4, 1, 4, 2, 3, 2)\n",
      "Child graph : (4, 1, 4, 4, 3, 2)\n"
     ]
    }
   ],
   "source": [
    "# Mutating an architecture\n",
    "# First, create a new child_graph\n",
    "child_graph = NasBench201SearchSpace(n_classes=10)\n",
    "\n",
    "# Call mutate on the child graph by passing the parent graph to it\n",
    "child_graph.mutate(parent=graph)\n",
    "\n",
    "# See the parent and child graph representations\n",
    "print(f'Parent graph: {graph.get_hash()}')\n",
    "print(f'Child graph : {child_graph.get_hash()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ac9e9331",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now, let's load the queryable tabular NAS-Bench-201 API\n",
    "# This API has the training metrics of all the 15625 models in the search space\n",
    "# such as train and validation accuracies/losses at every epoch\n",
    "\n",
    "from naslib.utils import get_dataset_api\n",
    "benchmark_api = get_dataset_api(search_space='nasbench201', dataset='cifar10')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8f539ae1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performance of parent model\n",
      "Train accuracy: 99.8800000024414\n",
      "Validation accuracy: 84.07\n"
     ]
    }
   ],
   "source": [
    "# With the NAS-Bench-201 API, we can now query, say, the validation performance of any NB201 model\n",
    "# Without it, we would have to train the model from scratch to get this information\n",
    "\n",
    "# First, import the Metric enum\n",
    "from naslib.search_spaces.core import Metric\n",
    "\n",
    "# Metric has, among others, these values:\n",
    "# Metric.TRAIN_ACCURACY\n",
    "# Metric.VAL_ACCURACY\n",
    "# Metric.TRAIN_LOSS\n",
    "# Metric.TEST_LOSS\n",
    "# Metric.TRAIN_TIME\n",
    "\n",
    "train_acc_parent = graph.query(metric=Metric.TRAIN_ACCURACY, dataset='cifar10', dataset_api=benchmark_api)\n",
    "val_acc_parent = graph.query(metric=Metric.VAL_ACCURACY, dataset='cifar10', dataset_api=benchmark_api)\n",
    "\n",
    "print('Performance of parent model')\n",
    "print(f'Train accuracy: {train_acc_parent}')\n",
    "print(f'Validation accuracy: {val_acc_parent}')\n",
    "\n",
    "# TODO: Query the train and validation performance of the child model\n",
    "# train_acc_parent = ...\n",
    "# val_acc_parent = ...\n",
    "\n",
    "# print('Performance of child model')\n",
    "# print(f'Train accuracy: {train_acc_child}')\n",
    "# print(f'Validation accuracy: {val_acc_child}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "15fb7306",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: \n",
    "# 1. Sample a random NAS-Bench-301 model\n",
    "# 2. Get the NASLib and genotype representations of the model\n",
    "# 3. Query the predicted performance of the model (loading the NB301 benchmark API might take some time)\n",
    "# 4. Mutate the model\n",
    "# 5. Get the NASLib and genotype representations of the model\n",
    "# 6. Query the predicted performance of the child\n",
    "\n",
    "from naslib.search_spaces import NasBench301SearchSpace\n",
    "from naslib.search_spaces.nasbench301.conversions import convert_naslib_to_genotype as convert_naslib_nb301_to_genotype"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c44e30f",
   "metadata": {},
   "source": [
    "## Optimizers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4c6c0f6a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "ename": "IndentationError",
     "evalue": "unexpected indent (trainer.py, line 187)",
     "output_type": "error",
     "traceback": [
      "Traceback \u001b[0;36m(most recent call last)\u001b[0m:\n",
      "\u001b[0m  File \u001b[1;32m/opt/miniconda3/envs/naslib_minimal2/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3378\u001b[0m in \u001b[1;35mrun_code\u001b[0m\n    exec(code_obj, self.user_global_ns, self.user_ns)\u001b[0m\n",
      "\u001b[0;36m  Cell \u001b[0;32mIn [10], line 6\u001b[0;36m\n\u001b[0;31m    from naslib.defaults.trainer import Trainer\u001b[0;36m\n",
      "\u001b[0;36m  File \u001b[0;32m~/Work/re_saga/NASLib/naslib/defaults/trainer.py:187\u001b[0;36m\u001b[0m\n\u001b[0;31m    ...\u001b[0m\n\u001b[0m    ^\u001b[0m\n\u001b[0;31mIndentationError\u001b[0m\u001b[0;31m:\u001b[0m unexpected indent\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import logging\n",
    "import os\n",
    "\n",
    "# import the Trainer used to run the optimizer on a given search space\n",
    "from naslib.defaults.trainer import Trainer\n",
    "# import the optimizers\n",
    "from naslib.optimizers import (\n",
    "    RandomSearch,\n",
    "    RegularizedEvolution\n",
    ")\n",
    "# import the search spaces\n",
    "from naslib.search_spaces import (\n",
    "    NasBench101SearchSpace,\n",
    "    NasBench201SearchSpace,\n",
    "    NasBench301SearchSpace,\n",
    ")\n",
    "\n",
    "from naslib.search_spaces.core.query_metrics import Metric\n",
    "from naslib import utils\n",
    "from naslib.utils import get_dataset_api\n",
    "from naslib.utils.log import setup_logger\n",
    "\n",
    "from fvcore.common.config import CfgNode # Required to read the config\n",
    "###### End of imports ######\n",
    "\n",
    "# The configuration used by the Trainer and Optimizer\n",
    "config = {\n",
    "    'dataset': 'cifar10',\n",
    "    'search': {\n",
    "        'seed': 0, # \n",
    "        'epochs': 5, # Number of epochs (steps) of the optimizer to run\n",
    "        'fidelity': -1, # \n",
    "        'checkpoint_freq': 100,\n",
    "    },\n",
    "    'save': 'runs' # folder to save the results to \n",
    "}\n",
    "\n",
    "config = CfgNode.load_cfg(json.dumps(config))\n",
    "\n",
    "# Make the directories required for search and evaluation\n",
    "os.makedirs(config['save'] + '/search', exist_ok=True)\n",
    "os.makedirs(config['save'] + '/eval', exist_ok=True)\n",
    "\n",
    "# Set up the loggers\n",
    "logger = setup_logger()\n",
    "logger.setLevel(logging.INFO)\n",
    "\n",
    "# See the config\n",
    "logger.info(f'Configuration is \\n{config}')\n",
    "# logger.info(config)\n",
    "\n",
    "# Set up the seeds\n",
    "utils.set_seed(9002)\n",
    "\n",
    "# Instantiate the search space and get its benchmark API\n",
    "search_space = NasBench201SearchSpace()\n",
    "dataset_api = get_dataset_api('nasbench201', 'cifar10')\n",
    "\n",
    "# Instantitate the optimizer and adapt the search space to it\n",
    "optimizer = RandomSearch(config)\n",
    "optimizer.adapt_search_space(search_space, dataset_api=dataset_api)\n",
    "\n",
    "# Create a Trainer\n",
    "trainer = Trainer(optimizer, config)\n",
    "\n",
    "# Perform the search\n",
    "trainer.search(resume_from=\"\", report_incumbent=False)\n",
    "\n",
    "# Get the results of the search\n",
    "search_trajectory = trainer.search_trajectory\n",
    "print('Train accuracies:', search_trajectory.train_acc)\n",
    "print('Validation accuracies:', search_trajectory.valid_acc)\n",
    "\n",
    "# Get the validation performance of the best model found in the search phase\n",
    "best_model_val_acc = trainer.evaluate(dataset_api=dataset_api, metric=Metric.VAL_ACCURACY)\n",
    "best_model_val_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "777f232f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_config(config, optimizer_type, search_space_type, dataset, seed):\n",
    "    # Dataset being used\n",
    "    config.dataset = dataset\n",
    "    \n",
    "    # Directory to which the results/logs will be saved\n",
    "    config.save = f\"runs/{optimizer_type.__name__}/{search_space_type.__name__}/{dataset}/{seed}\"\n",
    "    \n",
    "    # Seed used during search phase of the optimizer\n",
    "    config.search.seed = seed\n",
    "    \n",
    "def run_optimizer(optimizer_type, search_space_type, dataset, dataset_api, config, seed):\n",
    "    # Update the config\n",
    "    update_config(config, optimizer_type, search_space_type, dataset, seed)\n",
    "\n",
    "    # Make the results directories\n",
    "    os.makedirs(config.save + '/search', exist_ok=True)\n",
    "    os.makedirs(config.save + '/eval', exist_ok=True)\n",
    "\n",
    "    # Set up the loggers\n",
    "    logger = setup_logger()\n",
    "    logger.setLevel(logging.INFO)\n",
    "\n",
    "     # See the config\n",
    "    logger.info(f'Configuration is \\n{config}')\n",
    "\n",
    "    # Set up the seed\n",
    "    utils.set_seed(seed)\n",
    "\n",
    "    # Instantiate the search space\n",
    "    n_classes = {\n",
    "        'cifar10': 10,\n",
    "        'cifar100': 100,\n",
    "        'ImageNet16-120': 120\n",
    "    }\n",
    "    search_space = search_space_type(n_classes=n_classes[dataset])\n",
    "\n",
    "    # Get the benchmark API\n",
    "    logger.info('Loading Benchmark API')\n",
    "    dataset_api = get_dataset_api(search_space.get_type(), dataset)\n",
    "    \n",
    "    # Instantiate the optimizer and adapat the search space to the optimizer\n",
    "    optimizer = optimizer_type(config)\n",
    "    optimizer.adapt_search_space(search_space, dataset_api=dataset_api)\n",
    "\n",
    "    # Create a Trainer\n",
    "    trainer = Trainer(optimizer, config)\n",
    "\n",
    "    # Perform the search\n",
    "    trainer.search(report_incumbent=False)\n",
    "\n",
    "    # Get the results of the search\n",
    "    search_trajectory = trainer.search_trajectory\n",
    "    print('Train accuracies:', search_trajectory.train_acc)\n",
    "    print('Validation accuracies:', search_trajectory.valid_acc)\n",
    "\n",
    "    # Get the validation performance of the best model found in the search phase\n",
    "    best_model_val_acc = trainer.evaluate(dataset_api=dataset_api, metric=Metric.VAL_ACCURACY)\n",
    "    best_model_val_acc\n",
    "\n",
    "    best_model = optimizer.get_final_architecture()\n",
    "\n",
    "    return search_trajectory, best_model, best_model_val_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7b15fd7",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Set the optimizer and search space types\n",
    "# They will be instantiated inside run_optimizer\n",
    "optimizer_type = RegularizedEvolution # {RegularizedEvolution, RandomSearch}\n",
    "search_space_type = NasBench201SearchSpace # {NasBench101SearchSpace, NasBench201SearchSpace, NasBench301SearchSpace}\n",
    "\n",
    "# Set the dataset\n",
    "dataset = 'cifar100' # cifar10 for NB101 and NB301, {cifar100, ImageNet16-120} for NB201\n",
    "\n",
    "# The configuration used by the Trainer and Optimizer\n",
    "# The missing information will be populated inside run_optimizer\n",
    "config = {\n",
    "    'search': {\n",
    "        # Required by Trainer\n",
    "        'epochs': 100,\n",
    "        'checkpoint_freq': 100,\n",
    "        \n",
    "        # Required by Random Search optimizer\n",
    "        'fidelity': -1,\n",
    "        \n",
    "        # Required by RegularizedEvolution\n",
    "        'sample_size': 10,\n",
    "        'population_size': 30,\n",
    "    }\n",
    "}\n",
    "config = CfgNode.load_cfg(json.dumps(config))\n",
    "\n",
    "search_trajectory, best_model, best_model_val_acc = run_optimizer(\n",
    "                                                        optimizer_type,\n",
    "                                                        search_space_type,\n",
    "                                                        dataset,\n",
    "                                                        dataset_api,\n",
    "                                                        config,\n",
    "                                                        9001\n",
    "                                                    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ea257c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "514e8686",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mvenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]"
  },
  "vscode": {
   "interpreter": {
    "hash": "3a138bae05203fa8eb4bf07493da4bb9038fdb3e1f2f10b7ab3cd8c9223b9122"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
