{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Search spaces in NASLib\n",
    "\n",
    "Neural Architecure Search consists of 3 building blocks:\n",
    "- Search space (cell, hierarchical, ...)\n",
    "- Optimization\n",
    "- Evaluation\n",
    "\n",
    "Here we will conver the first part: Search spaces."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is not a general instruction on seach spaces but rather a walk-through how they are realized in NASLib. Therefore it is also not indeted to execute this notebook but rather see it as a extensive comment why whay is happening where.\n",
    "\n",
    "This will be done using the implementation of the darts search space. The complete file is in `naslib/search_spaces/darts/graph.py`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DARTS Search space definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'Graph' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-02ee2b8882ee>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mclass\u001b[0m \u001b[0mDartsSearchSpace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mGraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     OPTIMIZER_SCOPE = [\n\u001b[1;32m      4\u001b[0m         \u001b[0;34m\"n_stage_1\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m         \u001b[0;34m\"n_stage_2\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'Graph' is not defined"
     ]
    }
   ],
   "source": [
    "class DartsSearchSpace(Graph):\n",
    "\n",
    "    OPTIMIZER_SCOPE = [\n",
    "        \"n_stage_1\",\n",
    "        \"n_stage_2\", \n",
    "        \"n_stage_3\", \n",
    "        \"r_stage_1\", \n",
    "        \"r_stage_2\",\n",
    "    ]\n",
    "\n",
    "    QUERYABLE = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "NASLib uses the concept of `scopes` to differentiate instances of the same graph. For example in darts we have one reduction cell, that is used at three stages of the macro graph. To be able to properly set the channels at each stage, we give each copy a different scope later.\n",
    "\n",
    "The scopes in `OPTIMIZER_SCOPE` are later used by optimizers to determine which graphs they can manipulate to perform the search. E.g. this does usually not include the macro graph but all cells.\n",
    "\n",
    "The flag QUERYABLE indicates if there is an interface to one of the tabular NAS benchmarks which could be used to query achitecture metrics such as train accuracy or inference time. The available metrics are defined in `search_spaces/core/query_metrics` and implemented in the function `query()` below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.channels = [16, 32, 64]\n",
    "\n",
    "        self.num_classes = self.NUM_CLASSES if hasattr(self, 'NUM_CLASSES') else 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`__init__` cannot take any parameters due to the way `networkx` is implemented. If we want to change the number of classes set a static attribute `NUM_CLASSES` before initializing the class. Default is 10 as for cifar-10.\n",
    "\n",
    "the channels are used later when setting the primitive operations at the cell edges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'Graph' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-3-829f4c346cc8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# Normal cell first\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mnormal_cell\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0mnormal_cell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"normal_cell\"\u001b[0m    \u001b[0;31m# Use the same name for all cells with shared attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;31m# Input nodes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'Graph' is not defined"
     ]
    }
   ],
   "source": [
    "        # Normal cell first\n",
    "        normal_cell = Graph()\n",
    "        normal_cell.name = \"normal_cell\"    # Use the same name for all cells with shared attributes\n",
    "\n",
    "        # Input nodes\n",
    "        normal_cell.add_node(1)\n",
    "        normal_cell.add_node(2)\n",
    "\n",
    "        # Intermediate nodes\n",
    "        normal_cell.add_node(3)\n",
    "        normal_cell.add_node(4)\n",
    "        normal_cell.add_node(5)\n",
    "        normal_cell.add_node(6)\n",
    "\n",
    "        # Output node\n",
    "        normal_cell.add_node(7)\n",
    "\n",
    "        # Edges\n",
    "        normal_cell.add_edges_from([(1, i) for i in range(3, 7)])   # input 1\n",
    "        normal_cell.add_edges_from([(2, i) for i in range(3, 7)])   # input 2\n",
    "        normal_cell.add_edges_from([(3, 4), (3, 5), (3, 6)])\n",
    "        normal_cell.add_edges_from([(4, 5), (4, 6)])\n",
    "        normal_cell.add_edges_from([(5, 6)])\n",
    "\n",
    "        # Edges connecting to the output are always the identity\n",
    "        normal_cell.add_edges_from([(i, 7, EdgeData().finalize()) for i in range(3, 7)])   # output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the API from networkx to create nodes and edges. The connectivity follows the definition from the darts paper. Note that the node indices must be integers and start with 1.\n",
    "\n",
    "At each edge sits an `EdgeData` object. This is used to store informations at the edge, e.g. the `op` which is used when passing data from node to node. It can be finalized to avoid later manipulation. This is required to the edges connecting to the output nodes in darts, as they are set constant to the Identity operation (skip connection) and will not be part of the architecture search.\n",
    "\n",
    "Up to this point, the op at the edges is set to `Identity` (definition in `search_spaces/core/primitives.py`). The combine operation (`comb_op`) at nodes is set to `sum` by default. This can be changed later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'deepcopy' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-4-ffbec2ad55dc>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# Reduction cell has the same topology\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mreduction_cell\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeepcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnormal_cell\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0mreduction_cell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"reduction_cell\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'deepcopy' is not defined"
     ]
    }
   ],
   "source": [
    "        # Reduction cell has the same topology\n",
    "        reduction_cell = deepcopy(normal_cell)\n",
    "        reduction_cell.name = \"reduction_cell\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As the reduction cell and the normal cell share the topology, we can just deepcopy them. Make sure to set a unique name for unique graphs, as NASLib uses the name (amongst others) to differentiate between copys of the same graph and copies of other graphs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'self' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-98a98a29019c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;31m# Makrograph definition\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;31m#\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"makrograph\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_node\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m    \u001b[0;31m# input node\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'self' is not defined"
     ]
    }
   ],
   "source": [
    "        #\n",
    "        # Makrograph definition\n",
    "        #\n",
    "        self.name = \"makrograph\"\n",
    "\n",
    "        self.add_node(1)    # input node\n",
    "        self.add_node(2)    # preprocessing\n",
    "        self.add_node(3, subgraph=normal_cell.set_scope(\"n_stage_1\").set_input([2, 2]))\n",
    "        self.add_node(4, subgraph=normal_cell.copy().set_scope(\"n_stage_1\").set_input([2, 3]))\n",
    "        self.add_node(5, subgraph=reduction_cell.set_scope(\"r_stage_1\").set_input([3, 4]))\n",
    "        self.add_node(6, subgraph=normal_cell.copy().set_scope(\"n_stage_2\").set_input([4, 5]))\n",
    "        self.add_node(7, subgraph=normal_cell.copy().set_scope(\"n_stage_2\").set_input([5, 6]))\n",
    "        self.add_node(8, subgraph=reduction_cell.copy().set_scope(\"r_stage_2\").set_input([6, 7]))\n",
    "        self.add_node(9, subgraph=normal_cell.copy().set_scope(\"n_stage_3\").set_input([7, 8]))\n",
    "        self.add_node(10, subgraph=normal_cell.copy().set_scope(\"n_stage_3\").set_input([8, 9]))\n",
    "        self.add_node(11)   # output\n",
    "\n",
    "        self.add_edges_from([(i, i+1) for i in range(1, 11)])\n",
    "        self.add_edges_from([(i, i+2) for i in range(2, 9)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here happen several things. Graphs can be at nodes (were they are conceptually similar to the combine operation) or at edges. In the darts search space they sit only at nodes, in other search spaces they sit at the edges.\n",
    "\n",
    "When they sit at nodes, they have to be set as `subgraph`. `copy()` creates a somewhat shallow copy of the graph. This I explain later. `set_scope()` set the scope for the graph and if specified also for all child graphs as graphs can be nested. `set_input()` is used to route the input from the incoming edges at the node to the input node of the subgraph. E.g. in case of macro node 5, the input to its subgraph will come from macro node 3 and 4. This is also the order in which the input will be assigned to the input nodes of the subgraph. In case of macro node 3, the two inputs are defines as `[2, 2]`. This is a hacky way to dublicate the output of macro node 2 although there is only one edge. The input will be dublicated in that case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'self' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-6-a3bdc157193c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;31m# Operations at the makrograph edges\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;31m#\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_in_edges\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0mreduction_cell_indices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m8\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'self' is not defined"
     ]
    }
   ],
   "source": [
    "        #\n",
    "        # Operations at the makrograph edges\n",
    "        #\n",
    "        self.num_in_edges = 4\n",
    "        reduction_cell_indices = [5, 8]\n",
    "\n",
    "        channel_map_from, channel_map_to = channel_maps(reduction_cell_indices, max_index=11)\n",
    "\n",
    "        self._set_makrograph_ops(channel_map_from, channel_map_to, max_index=11, affine=False)\n",
    "\n",
    "        self._set_cell_ops(reduction_cell_indices)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to set the primitives at the cell edges and the operations required for the macro graph edges. This could not be done earlier because currently the primitives have to created with the correct number of channels. And this changes for each stage.\n",
    "\n",
    "`channel_maps` is just a function to get two dics which contain the corrent index of `self.channels` for every macro graph node.\n",
    "\n",
    "Now we'll look at the other functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "    def _set_makrograph_ops(self, channel_map_from, channel_map_to, max_index, affine=True):\n",
    "        # pre-processing\n",
    "        # In darts there is a hardcoded multiplier of 3 for the output of the stem\n",
    "        stem_multiplier = 3\n",
    "        self.edges[1, 2].set('op', ops.Stem(self.channels[0] * stem_multiplier))\n",
    "\n",
    "        # edges connecting cells\n",
    "        for u, v, data in sorted(self.edges(data=True)):\n",
    "            if u > 1 and v < max_index:\n",
    "                C_in = self.channels[channel_map_from[u]]\n",
    "                C_out = self.channels[channel_map_to[v]]\n",
    "                if C_in == C_out:\n",
    "                    C_in = C_in * stem_multiplier if u == 2 else C_in * self.num_in_edges     # handle Stem\n",
    "                    data.set('op', ops.ReLUConvBN(C_in, C_out, kernel_size=1, affine=affine))\n",
    "                else:\n",
    "                    data.set('op', FactorizedReduce(C_in * self.num_in_edges, C_out, affine=affine))\n",
    "        \n",
    "        # post-processing\n",
    "        _, _, data = sorted(self.edges(data=True))[-1]\n",
    "        data.set('op', ops.Sequential(\n",
    "            nn.AdaptiveAvgPool2d(1),\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(self.channels[-1] * self.num_in_edges, self.num_classes))\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This function set the operation at the macro graph edges, as in darts they are not Identity but are used to do some pre-procesing of the output of the previous nodes.\n",
    "\n",
    "The output of the Stem is multiplied with 3 as it is in the darts code.\n",
    "\n",
    "There are two ways to set attributes at edges in NASLib. The first one is to directly access the edge via networkx api. This is done here for setting the stem in `set('op', ops.Stem(), shared=False)` where `'op'` is the name of the attribute and `ops.Stem()` its value. This is similar to a dict, but can later be accessed also as an attribute of the edge like `edges[1, 2].op`.\n",
    "\n",
    "The advantage of this method is it is easy. The disatvantage is it is not recursive, i.e. it must be done manually for each copy of the graph in case of a private attribute.\n",
    "\n",
    "Private attributes of edges are not shared between copies of the graph, i.e. they are *deepcopied* when creating copies of a graph. If we set `shared=True` in `set()`, then the attribute is shared. This means the attribute will be added for the edge at each copy of the graph as a *shallow copy*.\n",
    "\n",
    "The same is done for reduction cells in the next section. Because darts does concatenate the output of the intermediate nodes at the last output node of the cells, the output dimension of each cell is 4 times the channel. This is considered here: `else C_in * self.num_in_edges`. How we actually set the combine op in this node is shown in the next code snipped.\n",
    "\n",
    "The postprocessing at the last edge (connecting to the output node of the macro graph) is set the same way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "    def _set_cell_ops(self, reduction_cell_indices):\n",
    "        # normal cells\n",
    "        stages = [\"n_stage_1\", \"n_stage_2\", \"n_stage_3\"]\n",
    "\n",
    "        for scope, c in zip(stages, self.channels):\n",
    "            self.update_edges(\n",
    "                update_func=lambda current_edge_data: _set_cell_ops(current_edge_data, c, stride=1),\n",
    "                scope=scope,\n",
    "                private_edge_data=True\n",
    "            )\n",
    "\n",
    "        # reduction cells\n",
    "        # stride=2 is only for some edges, that's why we have to do it this way\n",
    "        for n, c in zip(reduction_cell_indices, self.channels[1:]):\n",
    "            reduction_cell = self.nodes[n]['subgraph']\n",
    "            for u, v, data in reduction_cell.edges.data():\n",
    "                stride = 2 if u in (1, 2) else 1\n",
    "                if not data.is_final():\n",
    "                    reduction_cell.edges[u, v].update(_set_cell_ops(data, c, stride))\n",
    "\n",
    "        #\n",
    "        # Combining operations\n",
    "        #\n",
    "        for _, cell in sorted(self.nodes('subgraph')):\n",
    "            if cell:\n",
    "                cell.nodes[7]['comb_op'] = channel_concat"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we are setting the primitives for all cells which are in the macro graph.\n",
    "\n",
    "In order to not do this manually for each cell (`op` must be private, otherwise its weights are shared), NASLib provides an api for that. It is `graph.update_edges(update_func, scope, private_edge_data)`.\n",
    "\n",
    "- `update_func` is a function provided by the user. It is applied to every edge for every graph in the scope. Only if the edge is flagged as final, `update_func` is not applied.\n",
    "- `scope` is a string or a list of strings which must match the scope defined at the graphs we want to update.\n",
    "- `private_edge_data` specifies whether the function given as `update_func` is going to add/change private attributes or shared ones. This is needed in order to not set the same attribute multiple times (shallow copy, mentioned above).\n",
    "\n",
    "We had to wait with setting the primitives at the cells because they require the number of channels they expect (amongst others possibly). Now that we have placed all cells correctly we can update them using their scope. Here we use the function `_set_cell_ops` which is presented in the next snipped (bottom of the file).\n",
    "\n",
    "Unfortunately, we cannot use the same logic for the reduction cells. This is because by definition of the reduction cell, `stride=2` is only set on edges connecting input and intermediate nodes. We plan to add functionality to handle this case also via `update_edges()`, stay tuned.\n",
    "\n",
    "Last but not least we set the combine operation as `cannel_concat` for each cell. As this is not channel dependent, we could have also done it earlier, when defining the cell. However, if `comb_op` is channel dependent, this is currently the way to go. This could also be done using `update_nodes()`, which we will see later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _set_cell_ops(current_edge_data, C, stride):\n",
    "    C_in = C if stride==1 else C//2\n",
    "    current_edge_data.set('op', [\n",
    "        ops.Identity() if stride==1 else FactorizedReduce(C_in, C, affine=False),\n",
    "        ops.Zero(stride=stride),\n",
    "        ops.MaxPool1x1(3, stride, C_in, C, affine=False),\n",
    "        ops.AvgPool1x1(3, stride, C_in, C, affine=False),\n",
    "        ops.SepConv(C_in, C, kernel_size=3, stride=stride, padding=1, affine=False),\n",
    "        ops.SepConv(C_in, C, kernel_size=5, stride=stride, padding=2, affine=False),\n",
    "        ops.DilConv(C_in, C, kernel_size=3, stride=stride, padding=2, dilation=2, affine=False),\n",
    "        ops.DilConv(C_in, C, kernel_size=5, stride=stride, padding=4, dilation=2, affine=False),\n",
    "    ])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pretty straight-forward."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hooks\n",
    "\n",
    "Now that we have the search space defined, one could think we are ready to go. And technically we are. We can use this search space, give it an optimizer and, boom, get the result.\n",
    "\n",
    "But unfortunatly we are not. This is due to the fact that for evaluation we are not quite using that search space. Insead we change it, to get better performance, meet the restrictions of previous work or other reasons.\n",
    "\n",
    "This is why NASLib offerst *hooks* to handle these cases.\n",
    "\n",
    "### Prepare discretization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_discretization(self):\n",
    "    self.update_nodes(_truncate_input_edges, scope=self.OPTIMIZER_SCOPE, single_instances=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In darts some things are done, before the architecture can be discretized, i.e. determine exactly one op for each edge (and not e.g. a list, as it was defined above). First thing is that input edges are truncated to two incoming edges for each intermediate node for each cell. Second, the Zero op can never be picked for discretization.\n",
    "\n",
    "This is done in the next snippet using the NASLib function `update_nodes` which works in a similar fashion as `update_edges` but is much more powerfull."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _truncate_input_edges(node, in_edges, out_edges):\n",
    "    k = 2\n",
    "    if len(in_edges) >= k:\n",
    "        if any(e.has('alpha') or e.is_final() for _, e in in_edges):\n",
    "            # We are in the one-shot case\n",
    "            for _, data in in_edges:\n",
    "                if data.is_final():\n",
    "                    return  # We are looking at an out node\n",
    "                data.alpha[1] = -float(\"Inf\")   # Zero op should never be max alpha\n",
    "            sorted_edge_ids = sorted(in_edges, key=lambda x: max(x[1].alpha), reverse=True)\n",
    "            keep_edges, _ = zip(*sorted_edge_ids[:k])\n",
    "            for edge_id, edge_data in in_edges:\n",
    "                if edge_id not in keep_edges:\n",
    "                    edge_data.delete()\n",
    "        else:\n",
    "            # We are in the discrete case (e.g. random search)\n",
    "            for _, data in in_edges:\n",
    "                assert isinstance(data.op, list)\n",
    "                data.op.pop(1)      # Remove the zero op\n",
    "            if any(e.has('final') and e.final for _, e in in_edges):\n",
    "                return\n",
    "            else:\n",
    "                for _ in range(len(in_edges) - k):\n",
    "                    in_edges[random.randint(0, len(in_edges)-1)][1].delete()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The function given to `update_nodes` must accept three parameters: `node`, `in_edges`, and `out_edges`.\n",
    "- `node` is a tuple (int, dict) containing the index and the attributes of the current node. \n",
    "- `in_edges` is a list of tuples with the index of the tail of the edge and its EdgeData.\n",
    "- `out_edges is a list of tuples with the index of the head of the edge and its EdgeData."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare evaluation\n",
    "\n",
    "The second hook is a way to change the optimal discrete architecture before running the evaluation on it, i.e. deteriminig the final performance of that architecture. In DARTS, this is used to expand the macro graph to 6 cells at each stage instead of only two, and increase the channels by a factor of 2.25."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "    def prepare_evaluation(self):\n",
    "        self._expand()\n",
    "\n",
    "        # Operations at the edges\n",
    "        self.channels = [36, 72, 144]\n",
    "        reduction_cell_indices = [9, 16]\n",
    "\n",
    "        channel_map_from, channel_map_to = channel_maps(reduction_cell_indices, max_index=23)\n",
    "        self._set_makrograph_ops(channel_map_from, channel_map_to, max_index=23, affine=True)\n",
    "\n",
    "        # Taken from DARTS implementation\n",
    "        # assuming input size 8x8\n",
    "        self.edges[22, 23].set('op', ops.Sequential(\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2\n",
    "            nn.Conv2d(self.channels[-1] * self.num_in_edges, 128, 1, bias=False),\n",
    "            nn.BatchNorm2d(128),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(128, 768, 2, bias=False),\n",
    "            nn.BatchNorm2d(768),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(768, self.num_classes))\n",
    "        )\n",
    "\n",
    "        self.update_edges(\n",
    "            update_func=_double_channels,\n",
    "            scope=self.OPTIMIZER_SCOPE,\n",
    "            private_edge_data=True\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`prepare_evaluation()` is always called after the graph was discretized. Here we first expand the graph (see next snippet), then again set the macro graph operations as described already above.\n",
    "\n",
    "The evalutaion in darts is performed using auxiliary towers (adding a second loss to some intermediate output of the network). This is what is set at the edge (22, 23).\n",
    "\n",
    "Last but not least, we double (a little more actually) the channels for each primitive at each edge in the scope (see snippet after the next one)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "    def _expand(self):\n",
    "        # shift the node indices to make space for 4 more nodes at each stage\n",
    "        # and the auxiliary logits\n",
    "        mapping = {\n",
    "            5: 9,\n",
    "            6: 10,\n",
    "            7: 11,\n",
    "            8: 16,\n",
    "            9: 17,\n",
    "            10: 18,\n",
    "            11: 24,     # 23 is auxiliary\n",
    "        }\n",
    "        nx.relabel_nodes(self, mapping, copy=False)\n",
    "        \n",
    "        # fix edges\n",
    "        self.remove_edges_from(list(self.edges()))\n",
    "        self.add_edges_from([(i, i+1) for i in range(1, 22)])\n",
    "        self.add_edges_from([(i, i+2) for i in range(2, 21)])\n",
    "        self.add_edge(22, 23)   # auxiliary output\n",
    "        self.add_edge(22, 24)   # final output\n",
    "        \n",
    "        to_insert = [] + list(range(5, 9)) + list(range(12, 16)) + list(range(19, 23))\n",
    "        for i in to_insert:\n",
    "            normal_cell = self.nodes[i-1]['subgraph']\n",
    "            self.add_node(i, subgraph=normal_cell.copy().set_scope(normal_cell.scope).set_input([i-2, i-1]))\n",
    "        \n",
    "        for i, cell in sorted(self.nodes(data='subgraph')):\n",
    "            if cell:\n",
    "                if i == 3:\n",
    "                    cell.input_node_idxs = [2, 2]\n",
    "                else:\n",
    "                    cell.input_node_idxs = [i-2, i-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use again networkx functionality to expand the graph and make space for additional cells at each stage. Because of that we have to reset the edges. Then we can add the nodes in the gaps using the free indices and set their input indices as described above for the search space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _double_channels(current_edge_data):\n",
    "    init_params = current_edge_data.op.init_params\n",
    "    if 'C_in' in init_params:\n",
    "        init_params['C_in'] = int(init_params['C_in'] * 2.25) \n",
    "    if 'C_out' in init_params:\n",
    "        init_params['C_out'] = int(init_params['C_out'] * 2.25) \n",
    "    if 'affine' in init_params:\n",
    "        init_params['affine'] = True\n",
    "    current_edge_data.set('op', current_edge_data.op.__class__(**init_params))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Unfortunately, in NASLib we cannot use native `torch.nn.Module` as operations. This is because of several reasons:\n",
    "- We need to make sure, that we find any subgraphs nested in an op, as it is the case for the hierarchical search space.\n",
    "- We might need to change some attributes of the op which were stored there earlier (this is the case here)\n",
    "- We might need a dedicated name for each op in care we want to make use of tabular benchmarks and query the performance of an architecture."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### More than one output\n",
    "\n",
    "If the model should be trained with an auxiliary head for evaluation, this head must be added to the architecture. The way this is currently done is by adding an intermediate output to the macro graph (the output from node 23).\n",
    "\n",
    "This is required because when passing data through the graph `logits = awesome_architecture(x)` NASLib is only returning the output of the last node. However, for auxiliary towers (and maybe other applications of NAS) we actually need more than one output. These potential outputs are stored at the graph dictionary as `out_from_` followed by the node index. For auxiliary towers, this is realizes defining an additional function `auxiliary_logits` in this case, which returns this intermediate output and allows the evaluation pipeline to handle it accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "    def auxilary_logits(self):\n",
    "        return self.graph['out_from_23']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Querying\n",
    "\n",
    "NASLib offers an interface to tabular benchmarks like Nas-Bench 201.\n",
    "\n",
    "The optimizer or evaluation pipeline can access performance metrics via `awesome_architecture.query(metric, dataset)`. All possible metrics are defined in `search_spaces/core/query_metrics.py` and listed in the next snippet. The dataset is the datset string which is privided by the tabular benchmark."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "unexpected EOF while parsing (<ipython-input-16-b626a3389080>, line 2)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  File \u001b[0;32m\"<ipython-input-16-b626a3389080>\"\u001b[0;36m, line \u001b[0;32m2\u001b[0m\n\u001b[0;31m    # nasbench 301 query logic\u001b[0m\n\u001b[0m                              ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m unexpected EOF while parsing\n"
     ]
    }
   ],
   "source": [
    "    def query(self, metric=None, dataset=None):\n",
    "        # nasbench 301 query logic"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
