Why Does DARTS Miss the Target, and How Do We Aim to Fix It?

In this blog post, we will dissect and explain ‘Rethinking Architecture Selection in Differentiable NAS’ from ICLR 2021 [Wang et al., 2021]. This paper is closely tied to two prior works: DARTS [Liu et al., 2019] and a response to DARTS [Zela et al., 2020]. We will first establish context by examining these works. Then, we will return focus to [Wang et al., 2021].

Neural Architecture Search (NAS)

NAS automates the process of discovering high performance architectures for Neural Networks. Though the popular state-of-the-art models are handcrafted, NAS aims to develop methods that eliminate this manual process. A few of the common approaches for NAS are Random Search, Evolutionary Algorithms, Reinforcement Learning and One Shot Learning algorithms. A popular One Shot Learning algorithm for NAS, which is discussed in detail in this post, is the Differentiable Architecture Search (DARTS) [Liu et al., 2019].

Differentiable Architecture Search (DARTS)

DARTS represents the neural architecture (or just a cell) as a directed acyclic graph (DAG) and then introduces a continuous relaxation on each edge of this graph. This allows training of both the architecture structure and its weights via gradient descent.

Representing the Neural Architecture as a DAG in DARTS

The final neural architecture is built using smaller repeatable units known as a cell. Rather than finding the optimal structure for the entire model, DARTS only finds the best structure of the cell. The cell can then be stacked or recursively connected to form a convolutional network or recurrent network respectively.

A cell is represented as a directed acyclic graph consisting of an ordered sequence of nodes. Each node $ x^{(i)} $ is a latent representation (e.g. a feature map in a convolutional network) and each directed edge $ (i,j) $ is associated with some operation $ o^{(i,j)} $ (e.g., convolution, max pooling) that transforms $ x^{(i)} $ as:

\[\begin{equation} x^{(j)} = \sum_{i < j} o^{(i,j)}(x^{(i)}) \end{equation}\]

Continuous Relaxation for DARTS

The search space is made continuous by relaxing the categorical choice of a particular operation to a softmax over all possible operations for each edge. The operation mixing weights for an edge $ (i,j) $ is defined as:

\[\begin{equation} \bar{o}^{(i,j)}(x) = \sum_{o\in\mathcal{O}} \frac{\exp(\alpha_o^{(i,j)})}{\sum_{o'\in\mathcal{O}} \exp(\alpha_{o'}^{(i,j)})} o(x) \end{equation}\]

where $ \mathcal{O} $ is the set of all candidate operations and $ \alpha_{o’}^{(i,j)} $ is the weight of a particular operation $ o’ $ in the edge $ (i,j) $.

At the end of search, the cell architecture can be obtained by replacing each mixed operation $ \bar{o}^{(i,j)} $ with the most likely operation, i.e.,

\[\begin{equation} o^{(i,j)} = \arg\max_{o\in\mathcal{O}} \alpha_{o'}^{(i,j)} \end{equation}\]

After the continuous relaxation, the goal is to learn the optimal architecture $ \alpha $ and the weight $ w $ for all the mixed operations. This is formulated as the bilevel optimization problem:

\[\begin{align} \min_\alpha \quad & \mathcal{L}_\text{val} (w^*(\alpha), \alpha) \\ \text{s.t.} \quad & w^*(\alpha) = \arg\min_w \mathcal{L}_\text{train} (w, \alpha) \end{align}\]

where \(\mathcal{L}_\text{val}\) and \(\mathcal{L}_\text{train}\) are the validation and training losses respectively.

DARTS Pseudocode image

The pseudo code of the iterative gradient descent procedure for the bilevel optimization problem is given above. (Image source: [Liu et al., 2019])

DARTS Overview image

An overview of DARTS is given above: (a) Initial unknown operations on the edges. (b) Continuous relaxation of the search space by placing a mixture of candidate operations on each edge. (c) Solving the bilevel optimization problem to get optimal architecture and weights. (d) Deriving the final architecture from the learned probabilities (also known as discretization). (Image source: [Liu et al., 2019])

An analysis of the Failure modes of DARTS

While continuous relaxation of the search space allows gradient based learning of architecture, several researchers ([Li & Talwalkar, 2019]; [Sciuto et al., 2019]) have reported poor performance from DARTS. In [Zela et al., 2020], the authors analyze the reason for the poor performance of DARTS.

Overfitting to the validation set

Improvement on the validation loss is not a sufficient indicator to demonstrate performance of DARTS since it directly optimizes validation performance. Therefore, while DARTS successfully improves performance on validation loss with epochs, the model generalizes very poorly to test sets. This is similar to how training set performance of a Deep Learning model can be a misleading indicator of the ability of the model to generalize and perform well on examples from a held-out test set.

Relating overfitting to curvature of loss in the parameter space

There has been past work correlating the generalization (reduced overfitting) ability of models to the curvature of the loss function in the parameter space. For instance, [Hochreiter & Schmidhuber, 1997] present that flat minima of training loss gives better generalization properties than sharp minima. This phenomenon has also been observed in the overfitting of hyperparameters to validation sets. For instance, [Nguyen et al, 2018] show that when the minima of the validation loss lie in a sharp region of the hyperparameter space, the model tends to generalize poorly.

Curvature of loss function image

In the above figure, $\alpha^{\star}$ is the parameter value that minimizes the validation loss, and $\alpha^{disc}$ is the value that results after discretization of $\alpha^{\star}$. The figure shows the variation in performance before and after discretization for two different regions of validation loss with very different curvature. We see that for sharp curvature regions, the performance after discretization is very different from performance before discretization (Image source: [Zela et al., 2019]).

Failure modes of DARTS

The authors construct different search spaces to demonstrate the poor architecture selection and test set performance of DARTS. We first start with 4 smaller subspaces.

These smaller search spaces are used to demonstrate the degenerate architectures selected by DARTS. The architecture space is composed of the same macro architecture of the standard DARTS paper (normal and reduction cells), but only a subset of the operations are considered on every edge. Three of these search spaces are natural subspaces of the original space and are also strict subspaces. This means that these spaces should be easier to search than the original space. The fourth subspace includes a Noise Operator, which is specifically included to show the failure mode of DARTS, i.e., selecting the harmful operator. Details on the 4 spaces are outlined as follows:

S1 : This space uses only two operators per edge. These operators are selected using an offline process that iteratively drops operations from the original DARTS paper that have least importance. Therefore, this becomes a “pre-optimized” space which comes with the advantage that it is quite small but still supports many strong architectures.

S2 : In this space, the set of candidate operations per edge is {3 × 3 SepConv, SkipConnect}. These operations are chosen since they are the most frequent ones as reported in [Liu et al., 2019].

S3 : This space uses the following operations: {3 × 3 SepConv, SkipConnect, Zero}, where the Zero operation replaces every value in the input feature map with zero.

S4 : This is the only space that is not a strict subspace of the original space. This set consists of only two operations: {3 × 3 SepConv, Noise}, where the Noise operation simply replaces every value from the input feature map by noise drawn from a standard normal distribution. This operation actively harms the distribution and should not be selected by the model.

Further, it is interesting to note that the only operation which includes parameters is the 3 x 3 SepConv operation.

Consider the learnt architecture when DARTS searches these spaces while training on the CIFAR-10 dataset. These architectures are presented in the figure below (Image source: [Zela et al., 2020]).

DARTS Learned Architectures image 1 DARTS Learned Architectures image 2

It is immediately apparent that the parameter less skip connections dominate most edge operations in search spaces S1-S3, and in S4 even the harmful Noise operation has been selected for 5 out of 8 operations.

Having demonstrated that DARTS chooses degenerate architectures, the authors now construct another space to showcase the hypothesis that DARTS overfits to the validation set while performing poorly on the test set.

S5 : Very small search space with known global optimum: In this search space the authors use only one intermediate node in the normal cell as well as reduction cell, with three possible operations choices per edge: 3 × 3 SepConv, SkipConnection, and 3 × 3 MaxPooling. The main idea here is that the total number of possible architectures is only 81. This means that we can a priori evaluate the performance of all 81 architectures to find the global optimum. We would then be able to understand the “regret” (intuitively, the deviation from the optimal path) taken by the DARTS model.

In order to put the performance of DARTS in perspective, it is compared against a baseline of Random Search with weight sharing (RS-ws) by [Li & Talwalkar, 2019]. The authors run DARTS three times on this search space and examine the DARTS test regret, DARTS one-shot-val error, RS-ws test regret. While it seems that DARTS manages to find an architecture close to the global minimum, at around epoch 40 the test performance deteriorates. Yet, the search model validation error does not deteriorate but continues to converge. This indicates that the architectural parameters are overfitting to the validation set. However, the baseline RS-ws stays relatively constant throughout the search, eventually outperforming DARTS (Image source: [Zela et al., 2019]).

Regret Plots Image

Role of dominant eigenvalues of Hessian matrix of validation loss

Having demonstrated that DARTS overfits to the validation set, the authors further extend this idea by considering the growth of the dominant eigenvalue of the Hessian matrix of the validation loss with respect to the architectural parameters $\alpha$ . This is useful because the largest eigenvalue ($\lambda_{max}^{\alpha}$) of said Hessian matrix is a proxy for the curvature of the validation loss in that region of $\alpha$. It is observed that $\lambda_{max}^{\alpha}$ increases in standard DARTS, along with the test error (middle) of the final architectures, while the validation error still decreases (left) (Image source: [Zela et al., 2019]).

Eigenvalue Plot Image

Alternate failure explanations for DARTS

[Wang et al., 2021] also believed that DARTS has flaws, but they believe the main issues to be different than those proposed by [Zela et al., 2020]. While [Zela et al., 2020] were focused on showing the the optimization landscape for DARTS’ supernet lent itself to non-generalizable solutions, [Wang et al., 2021] focused on an implicit assumption in the DARTS concept (one that might have jumped out to you when first reading the breakdown of DARTS): the idea that a high architecture mixing weight indicates an important connection, and vice versa.

Assumptions

It’s not at all uncommon, in the literature, to make the assumption that assigning large weight to a connection means it must represent a useful concept. As a brief aside, we can consider the field of network pruning: the discipline of removing components from already-trained neural networks in order to obtain compressed models that require less storage and computation but produce acceptably comparable results. [Han et al., 2015] prune the smallest-weight connections in a network, while [Li et al., 2017] remove convolutional filters with the smallest L1-norm (that is, the smallest sum of magnitudes of weights). [Li et al., 2017] make the argument that filters with small magnitude weights produce output feature maps with small expected magnitude of activation. While these methods show good empirical results, they do hinge on the critical assumption that magnitude of activation in an intermediate layer indicates relevance of a layer to the overall computation of the final network prediction.

DARTS operates on the similar assumption that learning a large architecture parameter implies that the corresponding network connection must be critical to later calculation, and therefore key to strong network accuracy. But is this assumption really reasonable? And more important, does it lead to the best possible performance?

Evidence

[Wang et al., 2021] present evidence to suggest that the connections with the largest architecture parameters are not necessarily the most important connections. From a pretrained DARTS supernet, they randomly choose three edges and compare the softmaxed magnitude of the architecture parameter assigned to each choice of operation to the final validation network accuracy obtained by using only that operation and removing the others. Results are shown below (Image source: [Wang et al., 2021]).

Comparison between architecture parameters and post-train accuracy

There are a few striking takeaways from these plots. First, the best post-discretization accuracy was never obtained from the operation with the largest architecture parameter. Second, if one were to select the operation with the largest architecture parameter, in 2 out of 3 cases there would exist a different option with at least 1% potential validation accuracy improvement.

Deep dive: skip connections

Earlier, we showed that in several of the failure cases for DARTS, the learned cell architecture is full of skip connections. [Wang et al., 2021] wonder if this is not due to any inherent flaw with the subnet optimization but is instead fully due to the fact that skip connections are naturally very likely to learn large architecture parameters. If this were the case, then these failure modes could be addressed by halting reliance on the architecture-parameter-magnitude heuristic for operation selection.

The justification for why this occurs depends on an argument by [Greff et al, 2017] and [Veit et al., 2016] about the interpretation of networks that depend on skip connections (e.g., ResNets) as a form of ensemble model composed of different network paths. This concept is called unrolled estimation.

Demonstration of residual network as ensemble of network paths

The graphic above, from [Veit et al., 2016], illustrates this concept clearly. Essentially, any time a skip connection exists alongside a parameterized connection (e.g., convolutional layer), these two routes represent different branches in a network path to the end of the network. By taking each distinct combination of branches as a different path, we can view the entire network as a combination or ensemble of different dataflow paths, each with different amounts of skip connection vs. convolutional connection. At the end, some combination of these network paths’ outputs is used as the final output.

We recall that traditional deep networks are believed to train stacked layers that have different ideal feature representations, potentially at different depths of abstraction. [Veit et al., 2016] and [Greff et al., 2017] argue that networks performing unrolled estimation work differently, in that each layer actually approximates the same optimal feature map. In other words, each cell’s ideal input and output feature map are similar, so the input-to-output change becomes quite small.

Imagine that we let a network performing unrolled estimation train to convergence. At convergence, we expect the input and output feature maps for each cell to be similar. We therefore can expect a certain degree of reliance on the skip connection, as it is the best operation for generating an output feature map like the input feature map. The connection that contributes more to the combined feature map (i.e. the feature with a larger architecture parameter) will often be that which produces similar output to input and minimizes introduced variance, which will typically be the skip connection (convolutional layers add in additional variance).

However, if we choose a skip connection every time, we will learn a degenerate network, as evidenced in the [Zela et al., 2020] failure modes. So we can’t rely on only choosing the highest architecture parameter in all circumstances.

Improvements on DARTS

Perturbation

In place of the original DARTS discretization strategy which selected those connections that had the largest architecture parameters, [Wang et al., 2021] propose a perturbation algorithm for discretization. The algorithm focuses on examining which connections, when removed, damage the network accuracy the most–these are deemed to be the most important connections.

Perturbation Algorithm image

The pseudocode above, from [Wang et al., 2021], describes the perturbation algorithm in more detail. We’ll also summarize the algorithm in plain English:

Perturbation Algorithm summary

  1. Optimize the DARTS supernet in the standard bilevel way.

  2. Discretize edges in a random order.

    • When discretizing an edge, examine how much the network validation accuracy drops when that edge is removed.

    • Keep the edge that led to the smallest accuracy decrease when removed and discard the rest of the edges.

Perturbation Algorithm results

[Wang et al., 2021] describe the success of the perturbation algorithm in detail. We list here a small portion of the evaluation results, comparing performance of DARTS and several closely related NAS algorithms to the same algorithms with perturbation applied (perturbation is signified with +PT).

Architecture Test Error (%) Params (M) Search Cost (GPU Days)
DARTS [Liu et al., 2019] 3.00 ± 0.14 3.3 0.4
SDARTS-RS [Chen & Hsieh, 2020] 2.67 ± 0.03 3.4 0.4
SGAS (Cri 1. avg) [Li et al., 2020)] 2.66 ± 0.24 3.7 0.25
DARTS+PT (avg, four seeds) 2.61 ± 0.08 3.0 0.8
DARTS+PT (best) 2.48 3.3 0.8
SDARTS-RS+PT (avg, four seeds) 2.54 ± 0.10 3.3 0.8
SDARTS-RS+PT (best) 2.44 3.2 0.8
SGAS+PT (Crit.1 avg, four seeds) 2.56 ± 0.10 3.9 0.29
SGAS+PT (Crit.1 best) 2.46 3.9 0.29

The takeaway is that the perturbation algorithm improves test error for each of these approaches, while requiring fewer parameters (except for in the case of SGAS, in which the parameter number increases). The cost is some increased GPU computation time.

Reflections on the Perturbation Algorithm

The perturbation algorithm features a few heuristics that might leave us a little uneasy. The most important of these is the greediness of it. We recall that the algorithm selects the operation for a given connection by fixing the rest of the architecture, testing the different options to see which leads to the best accuracy, and then fine-tunes before repeating. Essentially, the algorithm finds locally optimal choices for operations, instead of considering choices of multiple operations in conjunction. There is also some potential suboptimality in the randomized order in which operation choices are made (maybe it might be worthwhile, for example, to make the most clear-cut operation decisions first and get to the tougher choices later, hoping they might by that point become clearer).

Future improvements on DARTS

There may be a fix for this bias towards local greedy optimization: consider combinations of operation selections! We might think of trying to pick 2 or even 3 operations at once by considering joint optimality, at the cost of some computational complexity. If we had compute to spend, we could scale up the number of operations until we’re jointly selecting every operation in the network at once. The trouble, however, is that the further one goes down this path, the closer one gets to the computationally expensive discretized neural architecture search of the past. In this way, we might lose what was truly efficient about DARTS: the continuous relaxation.

Earlier, we drew a connection between pruning the supernet of DARTS and the more general field of network pruning and compression. As a final note, we suggest that future methods to improve DARTS, without reverting too far into expensive discretized NAS, might draw from different approaches to network pruning. For example, in the vein of the work of [Kameyama & Kosugi, 1991], who chose to merge same-layer neurons with highly correlated output activation, we might consider eliminating supernet connections on the basis of statistical redundancy with one another. Alternately, as inspired by ThiNet [Luo et al., 2017], one might try a variation of the perturbation algorithm focused on pruning connections in the supernet in order to minimize reconstruction error from each layer to the next. If one successfully minimizes reconstruction error after the discretized layer, finetuning might be even more computationally efficient–allowing for even further strides in the journey toward computationally tractable NAS.

Bibliography

Xiangning Chen and Cho-Jui Hsieh. Stabilizing differentiable architecture search via perturbation-based regularization. In Hal Daume III and Aarti Singh (eds.), Proceedings of the 37th International Conference on Machine Learning, volume 119 of Proceedings of Machine Learning Research, pp. 1554–1565. PMLR, 13–18 Jul 2020. URL http://proceedings.mlr.press/v119/chen20f.html.

Klaus Greff, Rupesh K. Srivastava, and Jurgen Schmidhuber. Highway and residual networks learn unrolled iterative estimation. In International Conference on Learning Representations (ICLR), 2017.

Song Han, Jeff Pool, John Tran, and William J. Dally. Learning both Weights and Connections for Efficient Neural Networks. In C. Cortes, N. Lawrence, D. Lee, M. Sugiyama, and R. Garnett (eds.), Advances in Neural Information Processing Systems 28. Curran Associates, Inc., 2015.

Sepp Hochreiter and Jurgen Schmidhuber. Flat minima. Neural Comput., 9(1):1–42, January 1997.

Keisuke Kameyama and Yukio Kosugi. Automatic Fusion and Splitting of Artificial Neural Elements in Optimizing the Network Size. In IEEE International Conference on Systems, Man, and Cybernetics, 1991.

Guohao Li, Guocheng Qian, Itzel C. Delgadillo, Matthias Muller, Ali Thabet, and Bernard Ghanem. Sgas: Sequential greedy architecture search. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 1620–1630, 2020.

Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet, and Hans Peter Graf. Pruning Filters for Efficient ConvNets. In International Conference on Learning Representation (ICLR), 2017.

Liam Li and Ameet Talwalkar. Random Search and Reproducibility for Neural Architecture Search, 2019.

Hanxiao Liu, Karen Simonyan, and Yiming Yang. DARTS: Differentiable architecture search. In International Conference on Learning Representations (ICLR), 2019.

Jian-Hao Luo, Jianxin Wu and Weiyao Lin. ThiNet: A Filter Level Pruning Method for Deep Neural Network Compression. In International Conference on Computer Vision (ICCV), 2017.

Thanh Dai Nguyen, Sunil Gupta, Santu Rana, and Svetha Venkatesh. Stable bayesian optimization. International Journal of Data Science and Analytics, 6(4):327–339, Dec 2018.

Christian Sciuto, Kaicheng Yu, Martin Jaggi, Claudiu Musat, and Mathieu Salzmann. Evaluating the search phase of neural architecture search. arXiv preprint, 2019.

Andreas Veit, Michael Wilber, and Serge Belongie. Residual networks behave like ensembles of relatively shallow networks. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. CesaBianchi, and R. Garnett (eds.), Advances in Neural Information Processing Systems 29, pp. 550—-558. Curran Associates, Inc., 2016.

Ruochen Wang, Minhao Cheng, Xiangning Chen, Xiaocheng Tang, and Cho-Jui Hsieh. Rethinking Architecture Selection in Differentiable NAS. In International Conference on Learning Representations (ICLR), 2021.

Arber Zela, Thomas Elsken, Tonmoy Saikia, Yassine Marrakchi, Thomas Brox, and Frank Hutter. Understanding and robustifying differentiable architecture search. In International Conference on Learning Representations (ICLR), 2020.