%\documentclass{uai2024} % for initial submission
\documentclass[accepted]{uai2024} % after acceptance, for a revised version; 
% also before submission to see how the non-anonymous paper would look like 
                        
%% There is a class option to choose the math font
% \documentclass[mathfont=ptmx]{uai2024} % ptmx math instead of Computer
                                         % Modern (has noticeable issues)
% \documentclass[mathfont=newtx]{uai2024} % newtx fonts (improves upon
                                          % ptmx; less tested, no support)
% NOTE: Only keep *one* line above as appropriate, as it will be replaced
%       automatically for papers to be published. Do not make any other
%       change above this note for an accepted version.

%% Choose your variant of English; be consistent
\usepackage[american]{babel}
% \usepackage[british]{babel}

%% Some suggested packages, as needed:
\usepackage{natbib} % has a nice set of citation styles and commands
    \bibliographystyle{plainnat}
    \renewcommand{\bibsection}{\subsubsection*{References}}
%\usepackage{mathtools} % amsmath with fixes and additions
%% \usepackage{siunitx} % for proper typesetting of numbers and units
%\usepackage{booktabs} % commands to create good-looking tables
%\usepackage{tikz} % nice language for creating drawings and diagrams

\usepackage{hyperref}
\usepackage{url}

\usepackage{graphicx,psfrag,amsmath,amsthm,amssymb}
\usepackage{url,color,booktabs}
\usepackage[ruled,noline]{algorithm2e}

\usepackage{enumitem}

\input{MACPdef}
\input{MACPdef-ams}

\graphicspath{{grf/}}

%% Provided macros
% \smaller: Because the class footnote size is essentially LaTeX's \small,
%           redefining \footnotesize, we provide the original \footnotesize
%           using this macro.
%           (Use only sparingly, e.g., in drawings, as it is quite small.)

\title{Adaptive Softmax Trees for Many-Class Classification}

% The standard author block has changed for UAI 2024 to provide
% more space for long author lists and allow for complex affiliations
%
% All author information is authomatically removed by the class for the
% anonymous submission version of your paper, so you can already add your
% information below.
%
% Add authors
\author{Rasul Kairgeldin}
\author{Magzhan Gabidolla}
\author{Miguel {\'A}.\ Carreira-Perpi{\~n}{\'a}n}
% Add affiliations after the authors
\affil{Department of Computer Science and Engineering \\
  University of California \\
  Merced, CA, USA}
%  \texttt{mcarreira-perpinan@ucmerced.edu} \\

\begin{document}

\maketitle

\begin{abstract}
  NLP tasks such as language models or document classification involve classification problems with thousands of classes. In these situations, it is difficult to get high predictive accuracy and the resulting model can be huge in number of parameters and inference time. A recent, successful approach is the softmax tree (ST): a decision tree having sparse hyperplane splits at the decision nodes (which make hard, not soft, decisions) and small softmax classifiers at the leaves. Inference here is very fast because only a small subset of class probabilities need to be computed, yet the model is quite accurate. However, a significant drawback is that it assumes a complete tree, whose size grows exponentially with depth. We propose a new algorithm to train a ST of arbitrary structure. The tree structure itself is learned optimally by interleaving steps that grow the structure with steps that optimize the parameters of the current structure. This makes it possible to learn STs that can grow much deeper but in an irregular way, adapting to the data distribution. The resulting STs improve considerably the predictive accuracy while reducing the model size and inference time even further, as demonstrated in datasets with thousands of classes. In addition, they are interpretable to some extent.
  
\end{abstract}

\vspace*{-0.5ex}
\section{Introduction}
\vspace*{-0.5ex}

Classification problems involving thousands to millions of classes occur naturally in many real-world applications. Examples include predicting the next word in a sentence where the vocabulary size can be in the order of hundreds of thousands, and categorizing products for e-commerce systems where the number of distinct labels can be in the order of millions. Designing fast yet accurate methods for these types of problems remains an active area of research.

A linear softmax model, either standalone or as the last layer in a neural network, is widely used for general classification problems. Its inference time, however, is proportional to the number of classes $K$, as it \emph{needs to evaluate the score for every class no matter the input}, which makes it very slow for large-$K$ classification problems. A natural way to speed it up would be through conditional computation during inference, so that only a small subset of classes needs consideration. Decision trees do this: they \emph{follow a single, instance-dependent root-leaf path during prediction}, and their inference time can potentially be logarithmic on the number of classes. However, traditional axis-aligned trees with constant-label leaves do not produce accurate results for problems with many classes \citep{ChoromLangfor15a}.

Recently, \citet{Zharmag_21d} proposed a novel \emph{Softmax Tree (ST)} model that strikes a good balance between linear methods and decision trees: the model takes the form of a (hard) decision tree with sparse oblique (linear) decision nodes and small softmaxes at the leaves. To learn these more complex forms of trees the authors adapt a recent Tree Alternating Optimization (TAO) algorithm \citep{CarreirTavall18a}, which can optimize various types of tree-based models but only of a fixed structure and size. Experimentally, STs demonstrate much faster inference than the linear classifier and other baselines, as well as being very accurate for these large classification tasks. However, \emph{a significant drawback is that it assumes a complete tree structure, whose size grows exponentially with depth, and this limits their power in both accuracy and inference time}. We discuss this in detail in section~\ref{s:search-struct}, where we show that \emph{the key to achieve fast inference time is to decrease the size of the leaf softmaxes by increasing the depth of the leaf path}. Thus, we propose a new model, \emph{Adaptive Softmax Trees (ASTs)}, where we learn jointly the structure and parameters of the tree, by interleaving steps that grow the structure optimally with steps that optimize the parameters of the current structure. This makes it possible to learn ASTs that can grow much deeper but in an irregular way, adapting to the data distribution. As we show experimentally, the resulting ASTs improve considerably the predictive accuracy while reducing the number of parameters and inference time even further.

We now review related work (section~\ref{s:related}), discuss the difficulty of searching over tree structures (section~\ref{s:search-struct}), and describe the original softmax tree (ST) model and TAO-based optimization (section~\ref{s:ST}) and our proposed adaptive softmax trees (AST) (section~\ref{s:AST}). Then (section~\ref{s:expts}) we experimentally show the superiority of ASTs over STs and other baselines for several multi-class classification problems with a large number of classes and for language modeling.

\vspace*{-1ex}
\section{Related work}
\label{s:related}
\vspace*{-0.5ex}

\subsection{Softmax approximation}
\vspace*{-0.5ex}

While a softmax linear classifier defines a convex problem with the cross-entropy, it has long been recognized that training it with many classes is a huge computational bottleneck, so that one-vs-all can often be the only affordable option, in part due to its inherent parallelism \citep{Deng_10b}. Indeed, even the widely used, extremely efficient LIBLINEAR \citep{Fan_08a} implements one-vs-all but not the cross-entropy softmax. And, once trained, inference time in a large softmax is also very large---for example, in a language model having a large vocabulary. Hence, much work has been devoted to approximating the softmax classifier. The Hierarchical Softmax (HSM) \citep{Goodman01a} addresses this by using a predetermined tree structure with linear decision nodes and fixed leaf labels (corresponding to the words in vocabulary) to speed up the training of language models. Originally developed for a two-level tree, it has been extended to deeper architectures \citep{MorinBengio05a, MnihHinton09a}. The structure of the tree can be random, or based on word similarities \citep{Brown92a, Le_11a, Mikolov_13a}, or on word frequencies \citep{Mikolov_13b, Le_13a}, or based on speed-optimal dynamic programming \citep{ZweigMakary13a}, and optimized for GPUs \citep{Grave_17a}. Training HSM-based language models is efficient (usually logarithmic in vocabulary size), but it leads to no speedup at inference time: during prediction, although some pruning is possible, an input instance is propagated to nearly all the leaves. Apart from HSMs, other methods of softmax approximation are possible, such as singular value decomposition \citep{Shim_17a} and model compression techniques such as pruning or quantization \citep{Deng_20a}.

\subsection{Decision tree methods}
\vspace*{-0.5ex}

Decision trees enjoy fast prediction and interpretability, but traditional methods such as CART \citep{Breiman_84a} have low accuracy for problems with many classes \citep{ChoromLangfor15a}. This is due to two reasons: firstly, a suboptimal training based on greedy recursive partitioning, where the tree parameters are fixed using a local heuristic as one grows the tree (so the result is not optimal in any sense); and secondly, a limited modeling ability because of using trees with axis-aligned splits (which are poorly suited for high-dimensional data) and constant-label leaves. The trees can be made more complex by allowing for oblique (hyperplane) splits \citep{Breiman_84a} and small linear classifiers at the leaves \citep{Daume_17a}. However, this leads to a more difficult optimization problem for which various heuristics have been proposed within the many-class setting \citep{Jernit_17a, Daume_17a}. \citet{Zharmag_21d} use Tree Alternating Optimization to learn these models, but it is limited to trees of fixed structure. Decision trees are usually ensembled to boost accuracy but traditional implementations are not suitable for problems with many classes. \citet{Si_17a} adapt gradient boosting trees to output $\ell_0$-regularized sparse prediction and apply this to many-class problems. Besides tree-based techniques, other methods exist such as sampling \citep{Joshi_17a} and hashing \citep{Medini_19a}.

Instead, we build on the line of work initiated by the Tree Alternating Optimization (TAO) algorithm \citep{CarreirTavall18a} (described in section~\ref{s:ST}) and follow-up works. TAO is able to optimize a very general objective function over the parameters of a fixed-structure tree by repeatedly optimizing each node given the rest are fixed, and it achieves trees that are smaller but more accurate than traditional ones \citep{Zharmag_21b}. It makes it possible to train new types of trees, such as having sparse oblique splits \citep{CarreirTavall18a}, bivariate splits \citep{KairgelCarreir24a}, or having neural nets in the leaves \citep{ZharmagCarreir21a}. The ability to use different loss functions makes it possible to learn trees optimally for tasks where they were never or rarely used before, such as clustering \citep{GabidolCarreir22b}, dimensionality reduction \citep{ZharmagCarreir22a}, semi-supervised learning \citep{ZharmagCarreir22b}, imbalanced classification \citep{Gabidol_24a}, or for probing the meaning of individual neurons in neural nets via model distillation \citep{Hada_23a}, among others. Finally, TAO also makes it possible to ensemble trees into forests for classification or regression using bagging \citep{CarreirZharmag20a,ZharmagCarreir20a}, boosting \citep{Gabidol_22a,GabidolCarreir22a} or even a joint global optimization over all trees \citep{Carreir_23a}, which results in forests that are smaller but more accurate than traditional, axis-aligned ones, such as XGBoost \citep{ChenGuestr16a} or LightGBM \citep{Ke_17a}.

\subsection{Conditional computation}
\vspace*{-0.5ex}

There is growing interest in having neural networks use only a small portion of their computational graph to enable fast prediction. Although several works \citep{Shazeer_17a, Hazimeh_20a, VeitBelong18a} have shown promising results in terms of runtime and accuracy tradeoff, the non-differentiability of the conditional computation makes it difficult to apply gradient-based optimization \citep{Hazimeh_21a}. One way to achieve this is to train a continuous model, such as a soft tree, and harden its decisions a posteriori, but this leads to degradation in accuracy, as observed by \citet{Zharmag_21d}. In our adaptive softmax trees, conditional computation is built in by design during training and inference.

\subsection{Growing neural nets and neural architecture search}
\vspace*{-0.5ex}

The idea of growing the neural architecture by adding more neurons during training has a long history \citep{FahlmanLebier90a,Gallan93b,Fritzk94a,BruskeSommer95a,Evci_22a}. A related, recently very active area that aims to learn an optimal neural net structure is Neural Architecture Search (reviewed by \citet{Elsken_19a,Ren_21a}). A major issue in learning/growing neural architectures is the vast number of choices: layerwise or depthwise growth, how to connect neurons, etc. With trees the search space is more directed: either one expands leaves or prunes nodes. \citet{Tanno_19a} adaptively grow and train neural trees using backpropagation, but the potential gains in inference speed are limited: firstly, their trees are soft, so an input instance has to follow all root-leaf paths (each with a positive probability), which is proportional to the number of leaves (typically exponential on the depth); and secondly, their trees are very small, having just a few leaves, which forces each node to use a relatively large neural net, so even if we follow a single path (as a fast approximation) it will still be computationally costly. In contrast, a deeper tree, having many, deep, lighter leaves, is faster at inference, and can still have high accuracy, as we show here with our adaptive softmax trees.

\section{Optimizing trees over parameters and structures: deep paths, thin softmaxes}
\label{s:search-struct}

Learning a tree-based model has two important difficulties. One is that the space of tree structures is huge: with $n$ nodes (in total), there are $\frac{1}{n+1}\binom{2n}{n}$ ordered trees \citep{Knuth97a}, which already exceeds one million for $n=14$. The other is that a (hard) tree defines a non-differentiable, highly non-convex optimization problem.

The traditional, widely used approach for learning axis-aligned trees is based on \emph{greedy top-down induction} \citep{Breiman_84a,Quinlan93a}: starting from the root node, splits are recursively fixed (to optimize a local purity criterion) until the tree is fully grown. This is usually followed by a form of pruning to reduce overfitting. While suboptimal, this two-step process does effect a form of local search over tree structures and can produce adequate results with simple axis-aligned constant-leaf trees, but it works poorly with more complex trees, e.g.\ with oblique or neural nodes.

The \emph{Tree Alternating Optimization (TAO)} algorithm \citep{CarreirTavall18a}, reviewed in section~\ref{s:ST} for Softmax Trees, works by optimizing the parameters of each node in alternation, for a tree of a given structure. It does a much better job at optimizing a complex tree, as it can monotonically decrease a loss function, regularization term, and node models of general form. It also does a restricted form of structure search: an $\ell_1$ penalty sparsifies the node weight vectors, which can make nodes redundant and thus pruned, resulting in a learned structure that is a subtree of the initial tree. But, beyond that, TAO does not search over tree structures, and in particular, it cannot learn a bigger tree than the initial one.

The original Softmax Tree \citep{Zharmag_21d}, consisting of a tree with oblique (hard) splits and softmax leaves, relied directly on TAO to optimize the cross-entropy. As an initial tree, it used a complete tree of depth $\Delta$ and $2^{\Delta}$ softmaxes each having $k$ classes. By tuning these two hyperparameters $\Delta$ and $k$, it achieved good results on large, many-class datasets. But it has a major limitation: \emph{the number of nodes grows exponentially with the depth, which is thus computationally limited in memory and time (to $\Delta \approx 14$ in that paper), which in turn forces the softmaxes to use many classes ($k$ up to 800 in that paper)}. If the tree was deeper, the softmaxes could be smaller, accelerating the inference. Crucially, depending on the data distribution, the tree may need to be quite deep in some parts and shallow in others, i.e., an unbalanced structure. If we could guess the right structure, we could have TAO use that from the beginning, but guessing it is far from simple. Using, say, a structure from a CART tree does not work at all. This calls for searching over structures properly as proposed in our \emph{Adaptive Softmax Trees (ASTs)}, described in section~\ref{s:AST}. And, as it turns out, we find in our experiments that ASTs achieve higher test accuracy than using a complete ST of the same depth (which is far more costly).

At the heart of the improvement of ASTs is the interplay between tree depth $\Delta$ and leaf softmax width $k$. Let $D \in \bbN$ be the feature dimensionality. In a complete ST, the inference time is $\calO(D (\Delta + k))$ (actually less if the weight vectors and softmaxes are sparse and some tree paths are shallower than $\Delta$), and typically $\Delta \ll k$. This already improves greatly over a single softmax, $\calO(D K)$, if $\Delta + k \ll K$. In ASTs, an irregular tree structure makes it possible to reduce $k$ further by increasing $\Delta$ selectively for each branch. Besides, our ASTs learn the number of classes $k_j$ for each leaf $j$ automatically, so that some leaves specialize on a few select classes while others handle more, which affords more speedups. The inference time is then $\calO(D (\Delta_j + k_j))$ for each leaf $j$, and usually larger $\Delta_j$ are associated with smaller $k_j$.

\section{Softmax Trees (ST\lowercase{s}) and Tree Alternating Optimization (TAO)}
\label{s:ST}

We now describe the Softmax Tree (ST) model and the extension of TAO to train them over a fixed tree structure \citep{Zharmag_21d}. Let $\{(\x_n,y_n)\}^N_{n=1} \subset \bbR^D \times \{1,\dots,K\}$ be our training set of size $N$ of $D$-dimensional input features and $K$ classes. Write the \emph{Softmax Tree} as $\btau(\x; \bTheta)$, a rooted binary tree with a set of decision (internal) nodes $\calN_\text{dec}$ and a set of leaf nodes $\calN_\text{leaf}$. Each decision node $i \in \calN_\text{dec}$ has a decision function $g_i(\x; \btheta_i)\mathpunct{:}\ \bbR^D \to \{\text{\small\texttt{left}}_i,\text{\small\texttt{right}}_i\} \subset \{\calN_\text{dec} \cup \calN_\text{leaf}\}$ that sends an instance $\x$ to its left or to its right child. We use oblique (linear) decision nodes: ``$\text{if} \,\, \w_i^T \x + w_{i0} \ge 0 \,\, \text{then} \, g_i(\x) = \text{\small\texttt{right}}_i, \,\, \text{otherwise} \, g_i(\x) = \text{\small\texttt{left}}_i$'' where the learnable parameters are $\btheta_i = \{\w_i, w_{i0}\}$. Note how the decision function makes hard decisions, unlike in soft trees, where an instance $\x$ is propagated to both children with a positive probability. Each leaf $j \in \calN_\text{leaf}$ contains a predictive function $\f_j(\x; \btheta_j)\mathpunct{:}\ \bbR^D \to \bbS^K$ that produces the actual output of the tree $\btau(\x; \bTheta)$ for an instance \x, where $\bbS^K = \{ \x\in [0,1]^K \mathpunct{:}\ \1^T \x = 1 \}$. In Softmax Trees, $\f_j(\x; \btheta_j)$ takes the form of a small softmax linear classifier: $\f_j(\x; \btheta_j) = \sigma(\W_j \x + \w_{j0})$ where $\btheta_j = \{\W_j \in \smash{\bbR^{k \times D}}, \, \w_{j0} \in \smash{\bbR^{k}}\}$ are the learnable parameters, and $\sigma(\cdot)$ is the softmax function. The leaf predictor function $\f_j(\x; \btheta_j)$ can output only $k$ nonzero probabilities, with $k \le K$, for a set of $k$ classes (this set is learned); for all the other $K - k$ classes $\f_j(\x; \btheta_j)$ assigns exactly zero probability. For problems with a large number of classes we want $k \ll K$ to allow for fast inference. The predictive function of the whole Softmax Tree $\btau(\x; \bTheta)$ then works by routing an instance \x\ to exactly one leaf through a root-leaf path of (oblique) decision nodes and applying that leaf's small softmax predictor function. Overall, a ST can be seen as a hierarchical collection of local softmax classifiers each operating on a small subset of classes.

Now we describe how the TAO algorithm applies in learning a ST. TAO is a general method for optimizing a given objective function over a given decision tree model. Unlike CART-type methods, TAO works similarly to how one would optimize a (say) neural network: by taking an initial tree structure (cf.\ network architecture) and parameters (cf.\ network weights) it performs alternating optimization over the nodes (cf.\ gradient descent in a neural net) to monotonically decrease the objective function. Unlike with neural nets and soft decision trees, gradient-based optimization is not applicable because hard decision trees are non-differentiable functions. Given a Softmax Tree $\btau(\x; \bTheta)$ of fixed structure (e.g.\ a complete tree of depth $\Delta$) and initial parameters (e.g.\ random), the goal of TAO is to minimize the following objective:
\begin{multline}
  \label{e:objfcn}
  E(\bTheta) = \sum^N_{n=1}{ L(\y_n,\btau(\x_n)) } + \\
  \lambda \sum_{i \in \calN_\text{dec}}{ \norm{\w_i}_1} + \mu \sum_{j \in \calN_\text{leaf}}{ \norm{\W_j}_1}
\end{multline}
where $L(\cdot,\cdot)$ is the cross-entropy loss, $\bTheta = \{\w_i, w_{i0}\}_{i \in \calN_\text{dec}} \cup \{\W_j, \w_{j0}\}_{j \in \calN_\text{leaf}}$ are the set of all learnable model parameters, and there is an $\ell_1$ penalty over the weight vectors to promote sparsity via hyperparameters $\lambda, \mu \ge 0$. In general, we use the same regularization value for both decision nodes and leaves $\lambda = \mu$, but in some experiments we explore the effect of the leaf sparsity $\mu$.

The TAO algorithm is based on two theorems. First, the \emph{separability condition} states that eq.~\eqref{e:objfcn} separates over a set of non-descendant nodes, e.g.\ all the nodes at a given depth. This is a consequence of the tree making hard decisions. All such non-descendant nodes can be optimized independently and in parallel. Second, the \emph{reduced problem over a node} states that optimizing the top-level problem of eq.~\eqref{e:objfcn} over the parameters of a given node $i \in \{\calN_\text{dec} \cup \calN_\text{leaf}\}$ reduces to a simpler, well-defined problem involving only the training instances that currently reach that node $i$ (the \emph{reduced set} $\calR_i \subset \{1,\dots,N\}$). The exact form of the reduced problem differs for leaves and for decision nodes:
\begin{itemize}[leftmargin=2ex]
  \item For a decision node $i \in \calN_\text{dec}$, the top-level problem of eq.~\eqref{e:objfcn} reduces to a \emph{weighted 0/1 loss binary classification problem}:
  \begin{equation}
    \label{e:rp-decision-nodes}
    \hspace{-1ex}E_i(\w_i, w_{i0})=\sum_{n \in \calR_i}{c_n \, \smash{\overline{L}}(\overline{y}_n, g_i(\x_n)) } + \lambda \, \norm{\w_i}_1
  \end{equation}
  where $\smash{\overline{L}}(\cdot,\cdot)$ is the 0/1 loss, $\overline{y}_n \in \{\text{\small\texttt{left}}_i, \text{\small\texttt{right}}_i\}$ is a pseudolabel indicating the ``best'' child (i.e., the child that gives the lower value of the loss down its subtree) for the instance $\x_n$, and $c_n \ge 0$ is the loss difference between the ``other'' child and the ``best'' child for the instance $\x_n$. This problem over an oblique node is in general NP-hard, but it can be approximated well with a surrogate loss such as the cross-entropy (i.e., solving a logistic regression). We can ensure a monotonic decrease of the top-level objective~\eqref{e:objfcn} by accepting the update only if it improves~\eqref{e:rp-decision-nodes} (in practice we find this unnecessary).
  \item For leaf node $j \in \calN_\text{leaf}$, the top-level problem of eq.~\eqref{e:objfcn} reduces to a form involving the original loss but only over the parameters of the leaf predictor function $\f_j(\cdot)$ and its reduced set $\calR_j$:
  \begin{equation}
    \label{e:objfcn-leaf}
    \hspace{-1ex}E_j(\W_j, \w_{j0}) = \sum_{n \in \calR_j}{ L(\y_n,\f_j(\x_n)) } + \mu \, \norm{\W_j}_1
  \end{equation}
  where $L(\cdot,\cdot)$ is the same cross-entropy loss of eq.~\eqref{e:objfcn}. Exactly solving this problem would require enumerating all $\binom{K}{k}$ class subsets, but we can approximate this well by picking the top $k$ majority classes in the reduced set $\calR_j$ and training a $k$-class softmax classifier $\f_j(\cdot)$ on them. We solve the resulting $\ell_1$-regularized convex problem using SAG \citep{Schmid_17a}.
\end{itemize}
While these theorems do not prescribe the order in which the nodes should be optimized, \citet{Zharmag_21d} follow a reverse breadth-first search order: all the nodes at a given depth are optimized in parallel, starting from the deepest ones until the root. Each optimization subproblem involves solving either an $\ell_1$-regularized logistic regression or an $\ell_1$-regularized $k$-class softmax classifier. As an initial tree, a complete tree of a given depth $\Delta$ is used with initial parameters set either randomly or based on a $k$-means clustering assignment of training points to the leaves. The hyperparameters of the model are the depth $\Delta$ of the tree and the number of classes $k$ in each of the leaf softmaxes. Fig.~\ref{f:pseudocode} (left) outlines the pseudocode of TAO for Softmax Trees. By ensuring that the (approximate) solution of the reduced problem of a decision node improves upon the previous node parameter values, TAO is guaranteed to decrease the objective function~\eqref{e:objfcn} monotonically.

Finally, node pruning occurs automatically because the $\ell_1$ penalty can drive a node's entire weight vector to zero. This makes the node redundant (it sends all instances to the same child) and it can be removed at the end. Thus, the final ST is a subset of the initial (complete) ST.

\section{Adaptive Softmax Trees (AST\lowercase{s})}
\label{s:AST}

\begin{figure*}[t]
  \centering
  \setlength{\fboxsep}{1ex}
  \begin{tabular}{@{}c@{\hspace{7ex}}c@{}}
    \framebox{
      \begin{minipage}[c]{0.95\linewidth}
        \begin{tabbing}
          n \= n \= n \= n \= n \= \kill
          \underline{\textbf{input}} training set $\{\x_n, y_n\}_{n=1}^N$, \+ \\
          Softmax Tree $\btau(\cdot; \bTheta)$ of depth $\Delta$. \-\\[0.5ex]
          %        regularization parameter $\alpha \geq 0$\\
          \underline{\textbf{repeat}} \+ \\
          update reduced sets $\calR_i$ for all nodes $i$; \\
          \underline{\textbf{for}} $d = \Delta$ \textbf{downto} $0$ \textbf{do} \+ \\
          \underline{\textbf{for}} $i \in \text{nodes at depth $d$}$ \+ \\
          \underline{\textbf{if}} $i$ is a leaf \+ \\
          fit a $k_i$-class linear softmax $\f_i(\cdot;\btheta_i)$\\ 
          on the top-$k_i$ majority class points \\
          in $\calR_i$ to optimize eq.~\eqref{e:objfcn-leaf} \- \\
          \underline{\textbf{else}} \+ \\
          %        setup the reduced problem of $i$, \\
          fit a weighted 0/1 loss binary \\
          classifier $g(\cdot; \btheta_i)$ to optimize eq.~\eqref{e:rp-decision-nodes} \- \\
          \underline{\textbf{end if}} \- \\
          \underline{\textbf{end for}} \- \\
          \underline{\textbf{end for}} \- \\
          \underline{\textbf{until}} stopping criterion \\
          \underline{\textbf{return}} trained $\btau(\cdot; \bTheta)$
        \end{tabbing}
      \end{minipage}
    } &
    \framebox{
      \begin{minipage}[c]{0.95\linewidth}
        \begin{tabbing}
          n \= n \= n \= n \= n \= \kill
          \underline{\textbf{input}} training set $\{\x_n, y_n\}_{n=1}^N$, initial depth $\Delta_0$, \+ \\
          softmax contraction coefficient $\alpha \in (0,1)$, \\
          tolerance ratio for node expansion $\rho > 1$\-.\\[0.5ex]
          %        $k \gets \frac{K}{\alpha \, 2^{\Delta_0}}$ \\
          %        $k \gets K / (\alpha \, 2^{\Delta_0})$ \\
          $k_0 \gets \alpha K$ \\
          initialize $\btau(\cdot; \bTheta)$ of depth $\Delta_0$ and $k_0$-class leaves; \\
          fit $\btau(\cdot; \bTheta)$ using TAO; \\
          \underline{\textbf{repeat}} \+ \\
          update reduced sets $\calR_j$ for all $j \in \calN_\text{leaf}$; \\
          \underline{\textbf{for}} $j \in \calN_\text{leaf}$ \+ \\
          %        $\hat{k}_j \gets \frac{k_j}{2} \cdot \alpha$ \\
          %        $\hat{k}_j \gets \alpha \cdot k_j$ \\
          initialize ST $\hat{\btau}_j(\cdot; \hat{\bTheta}_j)$ of depth $\Delta = 1 \text{ or } 2$ \\
          and with $(\alpha k_j)$-class softmax leaves; \\
          fit $\hat{\btau}_j(\cdot; \hat{\bTheta}_j)$ using TAO on $\{\x_n, y_n\}_{n \in \calR_j}$; \\
          \underline{\textbf{if}} \, $\frac{\mathtt{loss}(\hat{\btau}_j(\cdot; \hat{\bTheta}_j))}{\mathtt{loss}(\f_j(\cdot; \btheta_j))} < \rho$ \, \underline{\textbf{then}} accept the expansion of leaf $j$ \- \\
          \underline{\textbf{end for}} \\
          update the tree $\btau(\cdot; \bTheta)$ and reoptimize with TAO; \-\\
          \underline{\textbf{until}} no changes to the tree structure \\
          \underline{\textbf{return}} adaptively grown  $\btau(\cdot; \bTheta)$
        \end{tabbing}
      \end{minipage}
    }
  \end{tabular}
  \caption{\emph{Left}: pseudocode of TAO for learning Softmax Trees. \emph{Right}: pseudocode of the proposed adaptive growth method for ASTs; this uses TAO (left part) as a subroutine.}
  \label{f:pseudocode}
\end{figure*}

In the previous section, TAO was used on a complete tree of depth $\Delta$. Now, we improve this to explore structures. The basic idea is to use two types of steps. One is a \emph{regular} TAO optimization of a ST of fixed structure (not necessarily complete); this guarantees improvement of the objective defined \emph{globally} over this ST. The other is an \emph{expansion step} on a current leaf, which tries to replace it with a shallow ST (having narrower softmaxes at its leaves). This \emph{local move} can improve the loss function but at the cost of additional decision nodes; we actually expand the leaf if overall we improve, else we do not expand it, and try another leaf. We interleave regular and expansion steps until convergence. Let us see this in more detail.

We first train a shallow (e.g.\ depth $\Delta = 2$) complete Softmax Tree $\btau(\cdot;\bTheta)$ with relatively large $k_0$-class softmaxes in the leaves. The number of classes $k_0$ is set such that the total number of predictable classes by the model is at least the total number of classes $K$ in the dataset: $k_0 2^\Delta \ge K$. We then attempt to replace each leaf $j \in \calN_\text{leaf}$ softmax predictor function $\f_j(\cdot;\btheta_j)$ by yet another shallow Softmax Tree $\smash{\hat{\btau}_j(\cdot;\hat{\bTheta}_j)}$ of depth $\smash{\hat{\Delta} = 1}$~or~2, whose leaves contain smaller $\smash{\hat{k}_j}$-class softmaxes, $\smash{\hat{k}_j < k_0}$. To control by how much these large softmaxes are reduced we use the following simple heuristic: $\smash{\hat{k}_j = \alpha \, k_0}$, where $\alpha \in (0,1)$ is the \emph{softmax contraction coefficient hyperparameter}. We obtain this small tree $\smash{\hat{\btau}_j(\cdot;\hat{\bTheta}_j)}$ by fitting it using the TAO algorithm on the training instances that reach the leaf $j$, i.e., on the reduced set $\calR_j$. This step can be considered as a recursive application of the Softmax Tree method with the goal of replacing large, flat softmaxes with faster ``softmax subtrees''. But instead of directly substituting the leaf softmax $\f_j(\cdot;\btheta_j)$ with the tree $\smash{\hat{\btau}_j(\cdot;\hat{\bTheta}_j)}$, we first ensure that the accuracy of $\smash{\hat{\btau}_j(\cdot;\hat{\bTheta}_j)}$ is at least as good as the original softmax $\f_j(\cdot;\btheta_j)$ or within a reasonable \emph{tolerance ratio hyperparameter} $\rho > 1$. If this is not the case, the leaf predictor function $\f_j(\cdot;\btheta_j)$ remains unchanged. Otherwise, the substitution happens, and this results in the structure change of the original tree model $\btau(\cdot;\bTheta)$ where it is expanded through the leaf $j$ (the \emph{expansion step}). In this way, after attempting to expand all the leaves $j \in \calN_\text{leaf}$, and assuming some or all of them are expanded, we obtain a deeper, irregular Softmax Tree $\btau_\text{exp}(\cdot;\bTheta_\text{exp})$ with smaller leaf softmaxes which has comparable or better training accuracy and faster inference. Now, importantly, we retrain the whole model $\btau_\text{exp}(\cdot;\bTheta_\text{exp})$ globally using TAO (the \emph{regular step}), which will further improve the model accuracy and possibly sparsify nodes. We repeat these local expansion and global optimization steps until the model converges or some predetermined stopping criterion is reached. Note that if a given leaf $j$ could not grow at one expansion step, it can still grow in the next iteration because of the in-between optimization step which can change the parameters of the whole model. Fig.~\ref{f:pseudocode} outlines the proposed adaptive learning algorithm.

Note that the expansion move allows us to compare the objective function before and after the expansion in order to decide whether or not we should pursue a new architecture. This is possible by expanding a leaf subtree and optimizing it separately, which in turn is possible because of the separability condition that trees satisfy: the objective function separates additively over the leaf subtrees, because each leaf subtree operates only on its own parameters, its own region of the input space and its own reduced set (training instances reaching the leaf). Thus, the contribution to the overall (tree-wide) objective function of optimizing over an expanded leaf subtree is a separable term. Comparing the loss on the leaf reduced set before (softmax) and after (softmax subtree optimized only on that reduced set), together with the regularization term on the parameters, gives the exact improvement of the overall objective function. This makes it possible to decide locally whether to accept the expansion or not. The subsequent, global optimization with TAO of the expanded tree may, of course, undo some of the expansions, as well as update all the parameters and reduced sets. Also, the local expansion moves are fast thanks to using the existing weight matrix to warm-start the optimization in the leaves.

This algorithm can be motivated as performing a search through the vast space of different tree structures and parameters. Each leaf-wise local expansion step tries to improve the tree structure, and the subsequent optimization step of the whole current tree tries to refine its parameters. This process leads to a better structure, often highly irregular and far from complete, and better parameters than the one produced by TAO on a random or heuristic complete tree initialization. The hyperparameters $\alpha$ and $\rho$ are designed to control how fine the search over structures is: the smaller $\alpha$ the faster the softmaxes contract (so shallower trees), and the smaller $\rho$ the more accurate the expanded subtree must be to be accepted. They also help to control overfitting.

\subsection{Computational complexity of AST\lowercase{s}}

\paragraph{Training}

It is difficult to estimate the training time precisely because of the changing tree structure and softmax sizes. A coarse upper bound results from taking the largest structure and softmax size $k_\text{max}$ that occur during training. If we assume that fitting softmax classifiers is linearly proportional to the training set size, then \emph{sequential} optimization of all the leaves is upper-bounded by fitting a single $k_\text{max}$-class softmax for the whole training set. But after several expansion steps, the softmax sizes are usually much smaller. Regarding the oblique decision nodes, optimizing \emph{sequentially} all of them at a given depth is asymptotically equivalent to fitting a single logistic regression on the whole training set. However, from TAO's separability condition, optimizing all the leaves and all the decision nodes at the same depth can be done \emph{in parallel}, which can bring huge speedups.

\paragraph{Inference}

For the original Softmax Tree (assumed complete), the inference time is $\calO(D(\Delta + k))$. Compared to a single flat softmax on all $K$ classes, the speed-up is dramatic: $\calO(\frac{K}{\Delta + k}) \approx \calO(\frac{K}{k})$ if $k \approx \Delta + k \ll K$. For our AST, the inference time for a leaf $j$ is $\calO(D(\Delta_j + k_j))$. The improvement is that this results in quite smaller values of $k_j$ at the expense of slightly large values of $\Delta_j$ (thin softmaxes in deep leaves).

\section{Experiments}
\label{s:expts}

Our experimental results consistently demonstrate the benefit of our proposed adaptive learning method in learning better Softmax Trees in terms of accuracy, inference time, and model size for several benchmarks in classification tasks with a large number of classes and in language modeling. After describing our setup, we first show a detailed comparison of the proposed adaptive growth method against the previous fixed tree approach. We then report benchmark results for document classification and language modeling tasks. Finally, we analyze the produced tree structure and attempt to interpret the model by visualizing it. In this section, ``AST'' refers to our proposed adaptive learning method, and ``ST'' refers to the previous fixed tree approach.

\subsection{Setup}

Unless otherwise stated, we use the following fixed values for these hyperparameters: the initial tree depth $\Delta_0 = 2$ and the depth of expanding subtrees $\smash{\hat{\Delta}=1}$. For all other hyperparameters (the sparsity of decision nodes and leaves $\lambda = \mu$, tolerance ratio for node expansion $\rho$ and softmax contraction coefficient $\alpha$) we set them in accordance with cross-validation on a holdout set. All other implementation details including hyperparameter tuning are provided in appendix~\ref{a:hyperparameters}.

We compare our results with other baselines specifically developed for problems with a large number of classes. These include RecallTree \citep{Daume_17a}, $(\pi, \kappa)$-DS \citep{Joshi_17a} and MACH \citep{Medini_19a}. We use available open-source implementations of the above methods or cite their results, where applicable. As noted in the related work section, training a softmax classifier by optimizing the cross-entropy is very time-consuming, so in some of our comparisons we use one-vs-all classifiers because training a linear softmax classifier was infeasible even in our 256GB server. For the linear one-versus-all classifier we use \texttt{scikit-learn}'s implementation \citep{Pedreg_11a}. In contrast, our ASTs can be trained with fewer resources and are much faster at inference. Compared to the original ST, ASTs can grow much deeper with very narrow softmax layers, even using just one class ($k = 1$) in some leaves. We report misclassification errors on train and test sets, average inference time per sample on test set, and tree parameters (tree depth $\Delta$, average leaf softmax sizes $\bar{k}$, and the number of leaves). We time the inference of each sample on a single CPU and average it over the whole test set. All experiments are conducted on the machine Intel Xeon CPU E5-2699 v3 @ 2.30GHz, 256 GB RAM.

\subsection{The benefit of adaptive growth}
\label{s:ast_st_comp}

\begin{table*}[t]
  \caption{AST vs ST. We report: train/test errors; depth $\Delta$, number of leaves $L$, average leaf softmax size $\bar{k}$ of the tree; and average inference time and FLOPs per test instance. For ST we specify its leaf softmax size $k$, for AST the softmax contraction coefficient $\alpha$ and tolerance ratio of expansion $\rho$. ASTs are trained with $\mu=0.01$ or (if marked with $*$) $\mu=0.1$.}
  \centering
  \renewcommand{\arraystretch}{0.8}
  \begin{tabular}{@{}l|@{\hspace{2ex}}lrrrrrrr@{}}
    \toprule
    & Method & \multicolumn{1}{c}{$E_{\text{train}} \%$} & \multicolumn{1}{c}{$E_{\text{test}} \%$} & \multicolumn{1}{c}{$\Delta$} & \multicolumn{1}{c}{$L$} & \multicolumn{1}{c}{$\bar{k}$} & inf.(\textmu s) & FLOPs\\
    %    & & & & & & & & & per leaf &\\
    \midrule
     & Softmax & 22.30 & 23.20 & --  & -- & -- & 53 & 416\\
     & ST($k=7$) & 0.52 & 8.33 & 7  & 128 & 5.27 & 142 & 197\\
     & ST($k=5$) & 0.36 & 8.75 & 8  & 256 & 3.53 & 98 & 214\\
     & ST(from AST) & 2.94 & 8.84 & 11 & 373 & 1.77 & 86  & 177\\
     \raisebox{11pt}[0pt][0pt]{\rotatebox{90}{\makebox[0pt][c]{Letter}}}
     & \textbf{AST}($\alpha$=0.85,$\rho$=1.2) & 0.30 & 7.03 & 12  & 153  & 2.13 & 43 & 162\\
     & \textbf{AST}($\alpha$=0.75,$\rho$=1.2) & 2.05 & \textbf{6.35} & 15  & 384 & 1.01 & \textbf{9}  & \textbf{151}\\
     \midrule
    & Softmax & 10.90 & 13.0 & -- & -- & -- & 411 & 128000\\
    & ST$^*$($k=90$) & 2.01 & 12.3 & 7  & 126 & 64.9  & 24 & 1493\\
    \raisebox{3pt}[0pt][0pt]{\rotatebox{90}{\makebox[0pt][c]{\footnotesize ALOI}}}
    & ST($k=75$) & 3.89 & 12.0 & 6 & 64 & 74.9 & 29 & 1871\\
    & ST(from AST) & 2.37 & 12.8 & 8 & 177 & 38.4 & 18 & 1102\\
    & \textbf{AST}$^*$($\alpha$=0.75,$\rho$=1.01) & 1.49 & \textbf{9.9} & 10  & 326 & 23.8 & \textbf{15} & \textbf{1016}\\
    \midrule
    & Softmax & 54.30 & 61.4 & --  & -- & -- & 10680 & 423722\\
    & ST($k=70$) & 14.20 & 62.7 & 7 & 128 & 70.0 & 65 & 12279\\
    \raisebox{3pt}[0pt][0pt]{\rotatebox{90}{\makebox[0pt][c]{\footnotesize LSHTC1}}}
    & ST($k=50$) & 6.15 & 61.2 & 8 & 256 & 49.4 & 55 & 9218\\
    & ST(from AST) & 9.36 & 68.7 & 9 & 511 & 49.7 & 62 & 9388\\
    & \textbf{AST}$^*$($\alpha$=0.9,$\rho$=1.2) & 16.10 & \textbf{60.8} & 10 & 1006 & 11.5 & \textbf{40} & \textbf{3756}\\
    \midrule
    & Softmax & 42.4 & 50.2 & -- & -- & -- & 16500 & 9214\\
    & ST$^*$($k=4$) & 48.7 & 51.5 & 8 & 30 & 4.6  & 36 & 691\\
    & \textbf{AST}$^*$($\alpha$=0.35,$\rho$=1.2) & 46.3 & \textbf{49.5} & 11 & 73 & 4.1 & \textbf{16} & \textbf{586}\\
    & ST$^*$($k=9$) & 44.1 & 48.3 & 8  & 50 & 8.0 & 27 & 918\\
        & \textbf{AST}($\alpha=0.38, \mu=0.1$) & 43.7 & \textbf{46.9} & 11 & \textbf{13}  & 44 & 8.4 & \textbf{791}\\
        & ST($k=13, \mu=0.1$) & 44.1 & 48.3 & 8  & 13  & 40  & 12.1  & 1104\\
    & \textbf{AST}$^*$($\alpha$=0.39,$\rho$=1.2) & 43.6 & \textbf{47.5} & 11 & 34 & 11.7 & \textbf{12} & \textbf{929}\\
    \raisebox{11pt}[0pt][0pt]{\rotatebox{90}{\makebox[0pt][c]{ WIKI-Small subs.}}}
        & ST($k=67, \mu=0.01$) & 29.6 & 48.4 & 8 & 21 & 256 & 8.11 & 2291\\
    & ST($k=95$) & 19.7 & 44.1 & 8 & 256 & 5.7 & 30 & 3065\\
    & ST(from AST) & 21.1 & 44.0 & 8 & 65 & 12.5 & 19 & 3296\\
    & \textbf{AST}($\alpha$=0.69,$\rho$=1.2) & 37.8 &  \textbf{42.7} & 13 & 184 & 2.8 & \textbf{13} & \textbf{1437}\\
    \bottomrule
  \end{tabular}
  \label{t:comparison1}
\end{table*}

We first perform a detailed comparison between the models produced by our adaptive growth method and the previous fixed tree approach. We use the following datasets with a large number of classes: WIKI-Small subs., ALOI, LSHTC1. The details about them are in appendix~\ref{a:datasets}.

For these sets of controlled experiments, we keep node and leaf sparsity parameters $\lambda, \mu$ equal for both ASTs and STs. As stated in previous sections, the AST approach expands leaves unevenly, which produces softmaxes with different number of classes $k$. To ensure that the comparison between resulting models is fair and comprehensive, we train STs with the biggest $k$ from an AST and cross-validated depth $\Delta$. For WIKI-small we provide a pairwise comparison of multiple STs and ASTs of similar $k$ in Table \ref{t:comparison1}. For example, the softmax size of ST$^*$($k=13$) and a maximum softmax size of AST$^*$($\alpha=0.39, \rho=1.2$) are equal. Then, we use the structure of the final tree from the AST to initialize an ST (referred as ``ST(from AST)''). We keep $k$ of leaf softmaxes but reinitialize randomly the weights of linear classifiers in decision nodes and leaves.

Table \ref{t:comparison1} shows that ASTs considerably outperform STs in test error (up to 5\% on WIKI-Small). In many cases, the performance of ST is improved as we lower the depth but lowering it too much leads to an increase in test error. Note that the depth of STs initialized from the corresponding AST differs because of the post-pruning. Importantly, ASTs have much faster inference (up to 15 times) and lower FLOPs. Fig.~\ref{f:err-comp} contains an additional experiment showing the improved accuracy of ASTs over STs as a function of training iterations. These sets of experiments confirm that the progressive growth of a tree results in a better local optimum and justifies our proposed approach.

\begin{figure}[t]
  \psfrag{Iterations}[c][c]{Iterations} 
  \psfrag{PST Train}{PST}
  \psfrag{0/1 Loss}[c][c]{0/1 loss}
  \centering
  \begin{tabular}{@{}c@{}}
    \includegraphics*[width=\linewidth]{LETTER_final/err_plot.eps}
    %    \includegraphics*[width=0.3\linewidth]{ALOI_small_final/err_plot.eps}&
    %    \includegraphics*[width=0.3\linewidth]{Wiki_small_final/err_plot.eps}&
  \end{tabular}
  \caption{0/1 loss of the final AST model for training (dashed line) and test (solid line), compared with the complete Softmax Tree. The arrows point to where expansions of the AST happened. The line colors indicate the performance of the ST (blue), ST(AST) (green) and AST (red). This shows that the adaptive growth gradually enhances the performance of the model on both training and test tests (red solid and dashed lines). On the other hand, a ST initialized randomly (blue line) or on the final structure of AST (green line) is unable to improve after a certain number of iterations.}
  \label{f:err-comp}
\end{figure}

\subsection{Text classification}
\label{s:text_clsfc}

We compare our method with other baselines (including ST) on document categorization benchmark WIKI--Small consisting of more than 36k classes. The full dataset contains roughly 380k features and 800k training samples. Setting the initial depth of AST to small values (2-3) while keeping $\alpha$ relatively high (0.6-0.9) generates extremely big softmaxes in the initial tree, subsequently, causing slow training. Two ways to mitigate this problem: 1) initializing with a bigger initial tree and 2) initializing with smaller $\alpha$ (0.01-0.02) while keeping $\alpha$ in expanding subtrees high (0.7). As a result, as AST expands it covers more and more classes.

Table~\ref{t:comparison_bench_WIKI} shows that AST performs better on the test set than most of the baselines. Moreover, ASTs show 6 times faster inference than STs (ASTs contain on average 44 classes in the leaves). Note increasing the number of TAO iterations during leaf expansion or global optimization (or both) may lead to much better results at a cost of training time.

\subsection{Language Modeling}

Penn Treebank (PTB) is a popular dataset often used for language modeling. We compare the performance of AST models on this task against Hierarchical Softmax (HSM), STs and linear one-vs-all clasifiers. The details about dataset preprocessing, implementation of the baselines and the hyperparameter tuning can be found in Appendix~\ref{a:ptb}.

The perplexity score
\begin{equation*}
  \text{PPL} = \exp\left(-\frac{1}{N} \sum_{n=1}^{N}{\log{ Pr(y_n | \x_n) }}\right)
\end{equation*}
can be undefined for models that can output exactly zero probability. This can happen with STs where an instance $\x$ reaches a leaf whose softmax does not specialize in the true class $y$ and thus gets $Pr(y | \x)=0$. Therefore, in estimating the PPL we only include the instances for which the model outputs nonzero probability. Although a linear classifier provides a positive probability for all the classes, it could not predict correctly 58\% of all $K\approx6k$ classes on both training and test sets, i.e., the outputted score $Pr(y | \x)$, though being positive, was not a maximum, not even in the top-10 for many instances. For our AST models it is possible to control the percentage of points for which the model outputs positive probability by tuning the hyperparameter $\alpha$, which appendix~\ref{a:ptb-leaf-alphas} explores in detail.

Table~\ref{t:comparison_bench_PTB} shows the results on PTB. It is clear that our method outperforms other baselines in both top-1 test error and inference time by a considerable margin. The performance of AST can be even further improved by more optimization iterations.

\begin{table}[t!]
  \caption{Results on the text classification dataset WIKI-Small. We report the test error, depth $\Delta$ of the tree, and the average inference time per test sample in milliseconds. For STs we specify the leaf softmax size $k$, and for ASTs we specify the softmax contraction coefficient ($\alpha$) and the tolerance ratio of node expansion ($\rho$).}
  \label{t:comparison_bench_WIKI}
  \centering
  \begin{tabular}{@{}l@{\hspace{2ex}}l@{\hspace{2ex}}r@{\hspace{2ex}}r@{\hspace{2ex}}r@{\hspace{0ex}}}
    \toprule
    Method & \footnotesize{$E_{\text{test}}(\%)$} & $\Delta$ & inf.(ms) & Train time\\
    \midrule
    RecallTree & 92.64& 15 & 0.97 & 53m\\
    one-vs-all & 85.71& 0 & 10.70& $>$ 7d \\
    MACH & 84.80& -- & 252.64&1445m \\
    ST(k = 200) & 84.70& 8 & 0.18& $\approx$1000m\\
    $(\pi, \kappa)$-DS & 78.50& -- & 10.33& -- \\
    ST(k = 150) & 77.26& 8 & 0.57 &$\approx$1000m\\
    \textbf{AST}($\alpha$=0.6, $\rho$=1.0) & 77.30 & 12 & \textbf{0.03} & $\approx$2000m \\
    \textbf{AST}($\alpha$=0.60, $\rho$=1.1) & \textbf{76.21} & 12 & 0.04 & $\approx$2000m \\
    \bottomrule
  \end{tabular}
\end{table}

\subsection{\mbox{Tree structure and interpretability}}

Fig.~\ref{f:final-tree} shows how the number of classes present in the leaves changes with depth. Theoretically, the number of classes in the leaves should only monotonically decrease with depth. Such deviations are due to two reasons: 1) a number of classes in the reduced set of the given depth is lower than the theoretical upper limit; 2) post pruning brings leaves closer to the root.

The built-in tree structure of our model makes it possible to interpret it by visualizing the tree structure and the tree parameters. To show this, we train an AST on a small subset of Amazon Reviews dataset \citep{HeMcAuley16a} which contains text reviews for the products on the Amazon website. From four high-level product categories (Sports, Toys, Home, Tools) we select 50 subcategories with the highest number of reviews. We select up to 300 reviews from each subcategory and extract tf-idf transformed bag-of-words features. This results in a dataset of size about 60k instances with features of dimensionality 11k and 200 classes. We keep 20\% of the dataset as a test set and train a relatively smaller AST on this problem to be able to visualize it in a figure. An initial tree has depth $\Delta_0 = 2$, and $\alpha = 0.25$, and we limit the expansion steps up to 2. The resulting tree has an accuracy of 53\%, and depth $\Delta = 6$, and is visualized in fig.~\ref{f:amazon-reviews-tree}. At first glance a hierarchical structure is obvious, where we can observe some subtrees specializing in similar groups of classes; for example, decision node 9 specializes mostly in Toy classes. Looking at the decision node weights in any given root-leaf path one can get a local interpretation of why the tree sends a point to that particular leaf. A small and sparse softmax model at the leaf can also be interpreted. Another key observation is that, for the most part, similar classes tend to be grouped within the same leaf, which is quite remarkable given that the tree is initialized randomly and unaware of any class information. 

\begin{table}[t!]
  \caption{Results on the language modeling dataset PTB. We report test error, depth $\Delta$ of the tree, the average inference time per sample in microseconds and the average perplexity (PPL) over the test set instances for which the model outputs nonzero probability. The percentage of such instances is shown in parenthesis. For AST models $\rho=1.0$.}
  \label{t:comparison_bench_PTB}
  \centering
  \begin{tabular}{@{}l@{\hspace{2ex}}l@{\hspace{2ex}}r@{\hspace{2ex}}r@{\hspace{2ex}}r@{\hspace{1ex}}r@{}}
    \toprule
    Method & \hspace{-3ex} \footnotesize{$E_{\text{test}}(\%)$} & $\Delta$ & inf.(\textmu s) & PPL & (\%nnz)\\
    \midrule
    HSM & 91.1 & 18 & 421 & 575 & (100\%)\\
    one-vs-all & 87.5& 0 & 705& 220 & (100\%)\\
    ST(k = 50) & 86.5& 8 & 58 & 17 & (44\%)\\
    ST(k = 100) & 86.5& 7 & 58 & 27 & (51\%)\\
    ST(k = 400) & 86.4& 5 & 64 & 71 & (67\%)\\
    \textbf{AST}($\alpha=0.3$) & 86.4 & 12 & \textbf{17} & 10 & (37\%)\\
    \textbf{AST}($\alpha=0.4$) & \textbf{86.1} & 12 & 18 & 13 & (44\%)\\
    \textbf{AST}($\alpha=0.5$) & 86.2 & 11 & 19 & 24 & (51\%)\\
    \textbf{AST}($\alpha=0.75$) & 86.3 & 12 & 20 & 7 & (33\%)\\  
    \bottomrule
  \end{tabular}
\end{table}

\begin{figure*}[p]
  \centering
  \begin{tabular}{@{}c@{}}
    \includegraphics*[width=0.98\linewidth,clip,bb=-3080 -2070 3300 2880]{processed_graph_Wiki.eps}
  \end{tabular}
  \caption{AST for the Wiki-Small subs.\ dataset. Size of the blue nodes (on the tree) shows the actual number of classes in the leaves after pruning. Green (left column) shows theoretical max.\ values at each aligned depth.}
  \label{f:final-tree}
\end{figure*}

\begin{figure*}[p]
  \centering
  \begin{tabular}{c}
      \includegraphics[width=0.96\linewidth]{pst_interpretable.eps}
    \end{tabular}
  \caption{Visualization of an adaptive softmax tree for a subset of the Amazon Reviews dataset. You may want to zoom in.}
  \label{f:amazon-reviews-tree}
\end{figure*}

\section{Conclusion}

Softmax Trees are effective for many-class problems by capitalizing on the conditional computation of decision trees and the ability to define local softmax classifiers that handle small subsets of classes, both of which make inference very fast. However, the existing algorithm operates on a fixed, complete tree, which computationally limits the depth of any individual leaf and forces the local softmaxes to be wider than necessary. Our Adaptive Softmax Tree solves this by learning the tree structure, so it can have deeper leaves with thinner softmaxes. It achieves this by interleaving local expansion steps that turn a wide softmax into a softmax subtree with thin softmaxes, with a global TAO optimization of the entire tree. Our experiments convincingly show how this results in improved accuracy, inference time and model size, which makes well worth its longer training time.

\paragraph{Limitations}

Although our algorithm is guaranteed a monotonic decrease of the objective function (at both regular and expansion steps), we lack any other theoretical guarantees of optimality (which are difficult to obtain for alternating optimization methods on nonconvex nondifferentiable problems). Also, while trees making hard decisions result in very fast inference, training them end-to-end with neural networks is not straightforward. One approximate approach is to train a neural network with either a regular flat softmax or a hierarchical softmax and then replacing it with an AST in a teacher-student approach to obtain an overall model with faster inference, as we do in our language modeling experiments.

\clearpage

\begin{acknowledgements}

  Work partially supported by NSF award IIS--2007147.

\end{acknowledgements}

\bibliography{macp,macp-xref}

\newpage

\onecolumn

\appendix

\section{Appendix}

\subsection{Datasets}
\label{a:datasets}

\begin{table*}[t]
  \centering
  \begin{tabular}{@{}lrrrr@{}}
    \toprule
    Dataset & $N_{\text{train}}$ & $N_{\text{test}}$ & $D$ & $K$ \\
    \midrule
    Letter & 16\,000  & 4\,000  & 16  & 26 \\
    ALOI & 97\,200  &10\,800  & 128  & 1000 \\
    LSHTC1 & 80\,552  &19\,873  & 271\,022  & 2657 \\
    WIKI--Small (subs.) & 20\,000  &10\,000  & 54\,188  & 200 \\
    WIKI--Small & 796\,617 &199\,155  &  380\,078  & 36\,504\\
    PTB & 400\,097 &34\,633  &  150  & 5\,970\\
    \bottomrule
  \end{tabular}
  \caption{Datasets used in the experiments: number of train and test instances ($N_{\text{train}}$, $N_{\text{test}}$), number of features~$D$, number of classes~$K$.}
  \label{t:datasets}
\end{table*}

To create the subsampled Wiki-Small dataset, we randomly select an equal number of samples from each class to avoid imbalance. This is done with two purposes: 1) a smaller dataset allows training for a much higher number of iterations (to eliminate undertraining); 2) reduces the time of a single experiment which facilitates a more precise hyperparameter search. Further, we remove features that remain constant for all the training and test points. As a result, input features of the subsampled WIKI-Small have $D = 37k$ dimension represented as normalized bag-of-words. For LSHTC1, we eliminate all classes that contain less than 10 samples per class. We used tf-idf feature representations of $D=271k$ dimension and $K=2657$ classes. Table~\ref{t:datasets} summarizes the used dataset statistics.
ALOI can be found here \url{https://aloi.science.uva.nl/}. Wiki-Small and LSHTC1 are both part of Large Scale Hierarchical Text Classification challenge (LSHTC) \citep{Partal_15a}. Preprocessed version of PTB dataset from \citet{Mikolov_10a} was used. Letter dataset can be found in UCI ML dataset repository \citep{Lichman13a}.

\subsection{Implementation details and hyperparameters}
\label{a:hyperparameters}

Both ST and AST were implemented in Python 3.8.10 and parallelized using Ray 2.2.0 \citep{Moritz_18a}. The $l_1$-regularized logistic regression used the implementation available in \texttt{scikit-learn} \citep{Pedreg_11a}: in the decision nodes we used LIBLINEAR \citep{Fan_08a}, and in the leaves we used SAGA \citep{Defazio_14a}.

For ST, a search of hyperparameters was performed on a separate holdout set. We found that $\lambda=0.01$ leads to the best performance for most datasets and $\lambda=1$ - for WIKI-Small. For smaller datasets we set the number of TAO iterations high (up to 100), we report an average of 5 runs, and set the number of LIBLINEAR and SAGA iterations to 100. For larger experiments, TAO iterations are set to 40 with an average of 3 runs, and the number of LIBLINEAR and SAGA iterations is set to 100 and 50 respectively. Trees were initialized using random initialization as well as k-means initialization described in \citet{Zharmag_21d}.

For AST, both leaf and node sparsity parameters were cross-validated separately on a range between 0.01 and 100. It was found that for Letter and subsampled WIKI-Small $\lambda=\mu=0.01$, and for ALOI and LSHTC1 $\lambda=\mu=0.1$ performs best. For large datasets, $\lambda=\mu=1$ produces best results. We initialized the initial tree as well as stumps during expansion using median split. This way nodes have almost the same number of samples and training in parallel becomes faster and generally produces better accuracy. The number of LIBLINEAR and SAGA iterations is similar to one in ST. One way of speeding up the expansion process is to use the weight matrix of the expanding decision node to warm-start optimization in leaves. This way SAGA converges much faster for the same tolerance. The number of TAO iterations during the expansion is set to 10 and to 15 during global optimization, but in many cases, it converges faster.

\subsection{Language modeling experiments}
\label{a:ptb}

The Penn Treebank contains around 1M tokens and a vocabulary size of 10k words. Similar to \citet{Zharmag_21d}, we filter out rare words and obtain word embeddings using pre-trained GloVe \citep{Pennin_14a}. We predict the next word based on the previous 3 words. To form a preprocessed dataset, we simply concatenate word vector representations. As a result preprocessed PTB consists of roughly 400k training samples, 150 features, and 5970 classes. For baselines, we used a one-vs-all classifier from \texttt{scikit-learn} with $\ell_1$ regularization $\lambda=1$ and Hierarchical Softmax from \citet{Mikolov_13b} implemented in PyTorch. We further compare AST and ST of different leaf softmax sized ($k$) to show that AST wins not only in terms of top-1 test error but is up to 4 times faster in inference.

\subsection{Controlling leaf softmax sizes for language modeling}
\label{a:ptb-leaf-alphas}

\begin{figure*}[t]
  \centering
  \psfrag{0-1 Val Loss}[c][c]{$E_{\text{val}}(\%)$}
  \psfrag{Inference time (micro s)}[c][b]{Inference time (\textmu s)}
  \psfrag{Covered (\%)}[c][b]{Covered (\%)}
  \psfrag{ai=2.0, a=1.0}{\small{$\alpha$=(0.5,0.25)}}
  \psfrag{ai=1.2, a=1.0}{\small{$\alpha$=(0.3,0.25)}}
  \psfrag{ai=1.6, a=1.0, best}{\small{$\alpha$=(0.4,0.25)}}
  \psfrag{ai=3.0, a=0.8}{\small{$\alpha$=(0.75,0.2)}}
  \psfrag{ai=3.0, a=1.6}{\small{$\alpha$=(0.75,0.4)}}
  \psfrag{ai=3.0, a=1.0}{\small{$\alpha$=(0.75,0.25)}}
  %  \psfrag{alpha = 3.48}[c][c]{\scriptsize{$\alpha$ = 3.48}}
  \psfrag{Expansion step}[c][c]{Iterations}
  \psfrag{Expansion Step}[c][c]{Iterations} 
  \begin{tabular}{@{}c@{}c@{}c@{}}     
    \includegraphics*[width=0.33\linewidth]{PTB_final/test_error.eps}&
    \includegraphics*[width=0.33\linewidth]{PTB_final/run_time_plot.eps}&
    \includegraphics*[width=0.33\linewidth]{PTB_final/covered_percentage.eps}
  \end{tabular}
  \caption{Top-1 error, average inference time and percentage of covered classes for AST of different $\alpha$=($\alpha_0, \alpha$) on PTB dataset.}
  \label{f:ptb-param}
\end{figure*}

Fig.~\ref{f:ptb-param} shows dependence between the proportion of covered samples of different AST models. We found experimentally that the best validation performance is achieved when $\lambda$, $\mu$, and $\rho$ are set to 1. Fig.~\ref{f:ptb-param} shows that for high values of $\alpha_0$ and $\alpha$ ($\alpha_0 = 0.75$, $\alpha = 0.4$) tree grows extremely deep (high number of expansion steps) while maintaining relatively big softmax in the leaves. Moreover, fig.~\ref{f:ptb-param} highlights that as softmax size decreases with tree depth so does the inference time, however, at some point, it starts to go up again. Since the time it takes to propagate a sample to the leaf overtakes the time of matrix multiplication in softmax there is an optimal depth of the tree for which inference is the fastest. On the other hand, for very small $\alpha$ ($\alpha_0 = 0.75$, $\alpha = 0.2$) softmax size decreases much faster with tree depth resulting in a small tree with a very small number of classes in the leaves. Experimentally we found that such trees do not generalize very well and typically have low class coverage. We can specify the number of expansion steps (maximum depth) of the tree to control the minimum coverage and inference time.

\begin{figure*}[t]
  \centering
  \psfrag{alpha = 1.65}{\small{$\alpha$ = 0.4}}
  \psfrag{alpha = 2.2}{\small{$\alpha$ = 0.55}}
  \psfrag{alpha = 2.76}{\small{$\alpha$ = 0.69}}
  \psfrag{alpha = 3.48}{\small{$\alpha$ = 0.87}}
  \psfrag{Lambda}[c][c]{$\mu$} 
  \begin{tabular}{@{}c@{}c@{}c@{}}
    $\rho=1.0$&$\rho=1.2$&$\rho=1.7$\\
    \psfrag{Tree depth}[c][c]{Mean depth ($\Delta_{av}$)} 
    \includegraphics*[width=0.33\linewidth]{Wiki_small_final/delta_C_betta=10.eps}&
    \psfrag{Tree depth}[c][c]{} 
    \includegraphics*[width=0.33\linewidth]{Wiki_small_final/delta_C_betta=12.eps}&
    \psfrag{Tree depth}[c][c]{} 
    \includegraphics*[width=0.33\linewidth]{Wiki_small_final/delta_C_betta=17.eps}
  \end{tabular}
  \caption{Comparison of average tree depth $\Delta_{av}$ vs softmax regularization parameter $\mu$ for different values of $\rho$ and $\alpha$, $\lambda = 0.01$}
  \label{f:delta-comp}
\end{figure*}

We examine the effect of softmax contraction coefficient ($\alpha$), tolerance ratio for node expansion ($\rho$), and leaf sparsity parameter ($\mu$) on the final tree structure. We conducted this set of experiments 5 times on a subsampled WIKI--Small dataset to eliminate the effect of noise and any inconsistencies. 

We measure average tree depth $\Delta_{av}$ over the depth of each leave in the final AST. Fig.~\ref{f:delta-comp} shows that $\Delta_{av}$ tends to increase as we increase $\mu$. More sparse softmax in the leaves means expanded subtree is more likely to perform better on the reduced set. It subsequently leads to more leaves being expanded on the current depth. Maximum depth, on the other hand, does not grow significantly. Fluctuations of average depth as we increase leaf sparsity can be explained by good local optimum for given $\mu$. In general, it was found that as the number of TAO and SAGA (solver for softmax classifier) iterations increases lines become more smooth.

\end{document}

