% \documentclass{uai2022} % for initial submission
\documentclass[accepted]{uai2022} % after acceptance, for a revised
                                    % version; also before submission to
                                    % see how the non-anonymous paper
                                    % would look like
%% There is a class option to choose the math font
% \documentclass[mathfont=ptmx]{uai2022} % ptmx math instead of Computer
                                         % Modern (has noticable issues)
% \documentclass[mathfont=newtx]{uai2022} % newtx fonts (improves upon
                                          % ptmx; less tested, no support)
% NOTE: Only keep *one* line above as appropriate, as it will be replaced
%       automatically for papers to be published. Do not make any other
%       change above this note for an accepted version.

%% Choose your variant of English; be consistent
\usepackage[american]{babel}
% \usepackage[british]{babel}

%% Some suggested packages, as needed:
\usepackage{natbib} % has a nice set of citation styles and commands
    \bibliographystyle{abbrvnat}
    \renewcommand{\bibsection}{\subsubsection*{References}}
\usepackage{amsmath}
\usepackage{amsfonts}
\usepackage{makecell}
\usepackage{listings, multicol}
\usepackage{xcolor}
\usepackage{subcaption}

\definecolor{codegreen}{rgb}{0,0.6,0}
\definecolor{codegray}{rgb}{0.5,0.5,0.5}
\definecolor{codepurple}{rgb}{0.58,0,0.82}
\definecolor{backcolour}{rgb}{0.95,0.95,0.92}

\lstdefinestyle{mystyle}{
    backgroundcolor=\color{backcolour},   
    commentstyle=\color{codegreen},
    keywordstyle=\color{magenta},
    numberstyle=\tiny\color{codegray},
    stringstyle=\color{codepurple},
    basicstyle=\ttfamily\scriptsize,
    breakatwhitespace=false,         
    breaklines=true,                 
    captionpos=b,                    
    keepspaces=true,                 
    numbers=left,                    
    numbersep=5pt,                  
    showspaces=false,                
    showstringspaces=false,
    showtabs=false,                  
    tabsize=2
}

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                               PROOF, THEOREM, and FRIENDS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
\newcommand{\BlackBox}{\rule{1.5ex}{1.5ex}}  % end of proof
\newenvironment{proof}{\par\noindent{\bf Proof\ }}{\hfill\BlackBox\\[2mm]}
\newtheorem{example}{Example} 
\newtheorem{theorem}{Theorem}
\newtheorem{lemma}[theorem]{Lemma} 
\newtheorem{proposition}[theorem]{Proposition} 
\newtheorem{remark}[theorem]{Remark}
\newtheorem{corollary}[theorem]{Corollary}
\newtheorem{definition}[theorem]{Definition}
\newtheorem{conjecture}[theorem]{Conjecture}
\newtheorem{axiom}[theorem]{Axiom}
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

\lstset{style=mystyle}

% \usepackage{mathtools} % amsmath with fixes and additions
% \usepackage{siunitx} % for proper typesetting of numbers and units
\usepackage{booktabs} % commands to create good-looking tables
\usepackage{tikz} % nice language for creating drawings and diagrams

%% Provided macros
% \smaller: Because the class footnote size is essentially LaTeX's \small,
%           redefining \footnotesize, we provide the original \footnotesize
%           using this macro.
%           (Use only sparingly, e.g., in drawings, as it is quite small.)

%% Self-defined macros
\newcommand{\swap}[3][-]{#3#1#2} % just an example

\title{Equilibrium Aggregation: Encoding Sets via Optimization}

% Encoding sets via optimization
% Optimization-based aggregation for sets
% An optimization-based aggregation operator for sets
% Optimization-based aggregation for sequences and sets
%


% The standard author block has changed for UAI 2022 to provide
% more space for long author lists and allow for complex affiliations
%
% All author information is authomatically removed by the class for the
% anonymous submission version of your paper, so you can already add your
% information below.
%
% Add authors
\author[1,2,*]{\href{mailto:<sbos.net@gmail.com>?Subject=Equilibrium Aggregation}{Sergey Bartunov}{}}
\author[1,*]{Fabian B. Fuchs}
\author[1]{Timothy P. Lillicrap}
% Add affiliations after the authors
\affil[1]{%
    DeepMind\\
    London, United Kingdom
}
% \affil[3]{%
%     Work done at DeepMind
% }
\affil[2]{%
    Now at CHARM Therapeutics\\
    London, United Kingdom\\
    \vspace{3mm}
}
\affil[*]{%
    Joint first authorship
}
  
\begin{document}
\maketitle

\begin{abstract}
   Processing sets or other unordered, potentially variable-sized inputs in neural networks is usually handled by \emph{aggregating} a number of input tensors into a single representation.
  While a number of aggregation methods already exist from simple sum pooling to multi-head attention, they are limited in their representational power both from theoretical and empirical perspectives.
  On the search of a principally more powerful aggregation strategy, we propose an optimization-based method called Equilibrium Aggregation.
  We show that many existing aggregation methods can be recovered as special cases of Equilibrium Aggregation and that it is provably more efficient in some important cases.
  Equilibrium Aggregation can be used as a drop-in replacement in many existing architectures and applications. We validate its efficiency on three different tasks: median estimation, class counting, and molecular property prediction. 
  In all experiments, Equilibrium Aggregation achieves higher performance than the other aggregation techniques we test.
\end{abstract}

\section{Introduction}\label{sec:introduction}

Early neural networks research focused on processing fixed-dimensional vector inputs. Since then, advanced architectures have been developed for processing fixed-dimensional data efficiently and effectively. 
This format, however, is not natural for applications where inputs do not have a fixed dimensionality, are unordered, or have both of these properties.
A strikingly successful strategy for tackling this issue has been to process such inputs with a series of aggregation $\rightarrow$ transformation operations. 

\begin{figure}
    \centering
    \includegraphics[width=0.9\linewidth]{figures/teaser_annotated.pdf}
    \caption{Global aggregation layers in typical neural networks for sets (top) and graphs (bottom). \textbf{Top}: each input set element $\mathbf{x}_i$ is first processed individually before being pooled into a global representation $\mathbf{y}$. This is followed by a final transformation block. \textbf{Bottom}: for graph data, the first part of the network is replaced by a graph or message passing neural network, but the global aggregation step is similar. In both cases, the global aggregation step drastically reduces the number of embeddings from many to one, rendering the right choice of aggregation technique critical for good model performance. The aggregation layer is typically implemented using sum-, max-, or attention-pooling. We propose a new aggregation mechanism, called Equilibrium Aggregation.}
    \label{fig:set_and_graph_aggregation}
\end{figure}

An \emph{aggregation} operation compresses a set
of input tensors into a single representation of a known, predefined dimensionality that can be then further sent to the downstream \emph{transformation} block. 
Since the latter deals with fixed-dimensional inputs with a defined ordering, it can profit from the variety of techniques available for vector-to-vector computations.

This pattern can be seen in many architectures. 
For instance, Deep Sets~\citep{DeepSets} builds a representation of a set of objects by first transforming each object and then summing their embeddings.
Similarly, Graph Neural Networks~\citep{kipf2016semi, battaglia2018relational} use a message-passing mechanism, which amounts to aggregating the set of input messages received by each node from its neighbours and then transforming the aggregate into a new message on the next layer (local aggregation). In many cases, several message passing layers are then followed by a global aggregation layer, where all node embeddings are aggregated into one global embedding vector describing the entire graph.
Finally, Transformers~\citep{vaswani2017attention} use self-attention, a mechanism that allows each object in the input set to interact with every other object and update its embedding by aggregating value embeddings from the rest of the set.

Mathematically, the aggregation $\phi(X) = \mathbf{y}$ compresses the input set $X = \{ \mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_N \} \in 2^{\mathcal{X}}$ into a $D$-dimensional vector $\mathbf{y} \in \mathbb{R}^D$.
In the case of Deep Sets~\citep{DeepSets} with sum aggregation, this reads
\begin{equation}\label{eq:sum_aggregation}
    \phi(X) =  \rho(\sum_{i=1}^N f(\mathbf{x}_i)),
\end{equation}
where $f$ and $\rho$ are the optional input and output transformations, respectively.

Besides yielding a fixed-dimensional output embedding, \eqref{eq:sum_aggregation} enforces an important inductive bias: permutation invariance. Global properties of sets or graphs (such as the free energy of a molecule) are independent of the ordering of the set elements. Taking advantage of such \textit{task symmetries} \citep{Mallat_2016} can add robustness guarantees with respect to important classes of input transformations, and is known to help generalisation performance~\citep{Worrall2017,WeilerGWBC18,Winkels2018}. Other ways of incorporating permutation invariance are max-pooling, mean-pooling or attention aggregators~\citep{kipf2016semi,battaglia2018relational,vaswani2017attention,Velickovic2018}.\footnote{Interestingly, even though in the case of Transformers for natural language processing the input is an ordered sequence, it appears beneficial to model the data as an order-independent set (or fully connected graph) with the sequential structure added via positional encodings.}

However, it is exactly these aggregation functions which often introduce a bottleneck in the information flow \citep{DeepSets,Wagstaff2019,cai2020note,chen2020measuring, Wagstaff2021}.
It is easy to see that sum aggregation may struggle to selectively extract relevant information from individual inputs or subsets and while methods like multi-head attention (effectively amounting to weighted mean per each head) partially address this issue, we believe there is a fundamental need for more expressive aggregation mechanisms.

Motivated by this need, we develop a method called Equilibrium Aggregation which is a generalization over existing pooling-based aggregation methods and can be obtained as an implicit solution to optimization-based formulation of aggregation. 
We further investigate its theoretical properties and show that not only it is a universal approximator of set functions but that it is also provably more expressive than sum or max aggregation in some cases.
Finally, we validate our insights empirically on a series of experiments where Equilibrium Aggregation demonstrates its practical effectiveness. 

\section{Equilibrium aggregation}\label{sec:equilibrium_aggregation}

\begin{figure}
    \centering
    \includegraphics[width=0.94\linewidth]{figures/eqagg.pdf}
    \caption{Schematic illustration of Equilibrium Aggregation. Each input $\mathbf{x} \in X$ contributes a potential value $F(\mathbf{x}, \mathbf{y})$ which are summed over the set $X$ and, together with the regularizer $R(\mathbf{y})$, form the total energy. Equilibrium Aggregation seeks to minimize this energy and the found minimum serves as the aggregation result.}
    \label{fig:equilibrium_aggregation}
\end{figure}

Our insight for developing better aggregation functions is grounded in the fact that the standard, pooling-based aggregation methods can be recovered as solutions to a certain optimization problem:
\begin{equation}\label{eq:aggregation_optimization}
    \phi(X) = \arg \min_{\mathbf{y}} \sum_{i=1}^N F(\mathbf{x}_i, \mathbf{y}),
\end{equation}
where $F(\mathbf{x}, \mathbf{y})$ is a \emph{potential} function. 

For example, with $F(\mathbf{x}, \mathbf{y}) = (\mathbf{x} - \mathbf{y})^2 $ (and assuming $\mathcal{X} = \mathbb{R}$), one obtains the \emph{mean aggregation} $\phi(X) = {1 \over N} \sum_{i=1}^N \mathbf{x}_i$, more examples can be found in Table~\ref{tab:aggregation}.
A natural question following this observation arises: can a more interesting aggregation strategy be induced by other choices of the potential function $F(\mathbf{x}, \mathbf{y})$?

We propose a method called Equilibrium Aggregation that addresses this question by letting the potential be a learnable neural network $F_{\theta}(\mathbf{x}, \mathbf{y})$ parameterized by $\theta$ which takes a set element $\mathbf{x}$ and the aggregation result $\mathbf{y} \in \mathcal{Y} = \mathbb{R}^M$ as an input and outputs a non-negative real scalar expressing the degree of ``disagreement'' between the two inputs.
By also adding a regularization term, we obtain the energy-minimization equation for Equilibrium Aggregation:
\begin{align}
    \phi_{\theta}(X) &= \arg \min_{\mathbf{y}} E_{\theta}(X, \mathbf{y}), \nonumber \\
    E_{\theta}(X, \mathbf{y}) &= R_{\theta}(\mathbf{y}) + \sum_{i=1}^N F_{\theta}(\mathbf{x}_i, \mathbf{y}), \label{eq:equilibrium_aggregation}
\end{align}
where for the scope of the paper the regularizer is simply $R_{\theta}(\mathbf{y}) = \text{softplus}(\lambda) \cdot || \mathbf{y} ||_2^2$. A graphical illustration for this construction can be found on Figure~\ref{fig:equilibrium_aggregation}.

Interestingly, this makes the result of the aggregation $\mathbf{y}$ be defined \emph{implicitly} and generally not available as a closed-form expression. 
Instead, one can find $\mathbf{y}$ by numerically solving the optimization problem~\eqref{eq:equilibrium_aggregation}, e.g., by gradient descent:
\begin{equation}\label{eq:gradient_descent}
    \mathbf{y}^{(t+1)} = \mathbf{y}^{t} - \alpha \nabla_{\mathbf{y}} E_{\theta}(X, \mathbf{y}^{(t)}), \quad \phi_{\theta}(X) = \mathbf{y}^{(T)}.
\end{equation}
Under certain conditions and with a large enough number of steps $T$, this procedure provides a sufficiently accurate solution that is itself well-defined and differentiable: either explicitly, through the unrolled gradient descent~\citep{andrychowicz2016learning, MAML}, or via the implicit function theorem applied to the optimality condition of~\eqref{eq:equilibrium_aggregation}~\citep{DEQs, blondel2021efficient}.
This allows to learn parameters of the potential $\theta$ and also to train the whole model involving the aggregation end-to-end.

In general, it is not guaranteed that gradient-based optimization will converge to the global minimum of~\eqref{eq:equilibrium_aggregation} when the potential is an arbitrarily structured neural network.
However, with a large enough regularization weight $\lambda$, it is possible to enforce convexity at least in the subspace of $\mathcal{Y}$~\citep{rajeswaran2019meta}.
When the gradient descent is initialized from a learnable starting point or, as in our implementation, from the zero vector, it becomes sufficient to find just a stationary point as long as the next layer in the network makes use of the aggregation result. 
Relaxing the need for convergence to the global minimum together with the use of flexible neural networks allows to implement a potentially complex and expressive aggregation mechanism. 
In our implementation we employ explicit differentiation through gradient descent and find that the network generally learns convergent dynamics~\eqref{eq:gradient_descent} automatically, even with a fairly small number of iterations such as $T=10$.

To additionally encourage convergence, we consider the following \emph{auxiliary loss} that penalizes the norm of the energy gradient at each step of optimization:
\begin{equation}\label{eq:aux_loss}
    \textstyle L_{\text{aux}}(X, \mathbf{y}, \theta) = \frac{1}{T} \sum_{t=1}^T || \nabla_{\mathbf{y}} E_{\theta}(X, \mathbf{y}^{(t)}) ||_2^2.
\end{equation}
We simply add the auxiliary loss to the main loss incurred by the task of interest and optimize the sum during the training. 
We further empirically assess convergence of the inner-loop optimization in Section~\ref{sec:molpcba}.

\begin{table}[]
    \centering
    \begin{tabular}{c|c|c}
    \toprule
    \bf Aggregation & $\phi(X)$ & $F(\mathbf{x}, \mathbf{y})$ \\
    \midrule
    Mean & ${1 \over N} \sum_{i=1}^N \mathbf{x}_i$ & $(\mathbf{x} - \mathbf{y})^2$ \\
    Median & $\mathbf{x}_{[N/2]}$ & $|\mathbf{x} - \mathbf{y}|$ \\
    Max & $\max\{\mathbf{x}_1, \ldots, \mathbf{x}_N\}$ & $\max(0, \mathbf{x} - \mathbf{y})$ \\
    Sum & \makecell{ $\sum_{i=1}^N \mathbf{x}_i$ or \\ \tiny  $  {\displaystyle \arg\min_{\mathbf{y}}} \left[ { \mathbf{y}^2 \over 2} + \sum_i F(\mathbf{x}_i, \mathbf{y}) \right] $} & $-\mathbf{x} \cdot \mathbf{y}$ \\
    \midrule
    \makecell{ \bf Equilibrium \\ \bf Aggregation} & $ \arg\min_{\mathbf{y}} E_{\theta}(X, \mathbf{y}) $ & \makecell{Neural network \\ $ \text{F}_{\theta}(\mathbf{x}, \mathbf{y}) $} \\
    \bottomrule
    \end{tabular}
    \caption{A comparison between Equilibrium Aggregation and pooling-based aggregation methods. Equations are given for the scalar case or can be applied coordinate-wise in higher dimensions.}
    \label{tab:aggregation}
\end{table}

\section{Universal Function Approximation on Sets}
According to the universal function approximation theorem for neural networks \citep{Hornik1989,Cybenko1989,Funahashi1989}, an infinitely large multi-layer perceptron can approximate any continuous function on compact domains in $\mathbb{R}$ with arbitrary accuracy. 
In machine learning, we typically do not know the function we aim to approximate. 
Hence, knowing that neural networks can in theory approximate anything is comforting. 
Equally, we seek to build inductive biases into the networks in order to facilitate learning, using more sophisticated architectures than multi-layered perceptrons. 
It is imperative to be aware whether and to what extent those modifications restrict the space of learnable functions.

Similar constructions to Equilibirum Aggregation, i.e. optimization-defined models defined as $\mathbf{y} = \arg \min_{\mathbf{y}} G(X, \mathbf{y})$, have previously been studied in the literature, especially in the context of permutation-sensitive (i.e. \textit{not} permutation invariant) functions~\citep{pineda1987generalization, MetaUniversality, DEQs} and various results with respect to universal function approximation were obtained.
It is not obvious, however, how these results translate to the important permutation-invariant case we consider in this paper. Introducing permutation invariance self-evidently restricts the space of functions that can be approximated.
In the next section we directly address the question of what set functions can be learned by Equilibrium Aggregations and establish a universality guarantee.

% \subsection{Relationship to DEQs}

% In order to discuss the representation power of the proposed approach, we first relate it to an existing architecture: Deep Equilibrium Models \citep{DEQs} (DEQs). As we will see, this relationship provides a hint that Equilibrium Aggregation does not impose restrictions on what functions can be learned and could indeed be universal.

% % Instead of differentiable optimsation, DEQs define the output of a layer as the fix point of a learned function conditioned on the input. This is the same 
% If we take our original model definition
% \begin{align}
% \label{eq:DEQ}
%  \arg \underset{\mathbf{y}}{\operatorname{min}} \,E_{\theta}(X,\mathbf{y})  
% \end{align}
% and assume $E_{\theta}$ to be convex for fixed $X$, then we can rewrite the search for the argmin as a dynamical system:
% \begin{align}
% % \mathbf{y}_{t+1} = - \gamma * \frac{d}{d\mathbf{y}} F(X,\mathbf{y}_{t})     
% \frac{d\mathbf{y}}{dt} = - \frac{d}{d\mathbf{y}} E_{\theta}(X,\mathbf{y})     
% \end{align}
% of which we search the fix point.\footnote{Note that all continuous functions can be written as a derivative of another function.}
% If we ignore the specific structure we proposed for $E_{\theta}$ (namely the sum over pairwise potentials), then this view is similar to DEQs \citep{DEQs}. From \citet{DEQs}, we know that such a fix point computation amounts to an infinitely deep neural network with weight tying between layers and residual connections. We also know that weight tying between layers does (at least in theory) not impose restrictions on what functions can be learned with a neural network (Theorem 3 in \cite{DEQs}). This hints at a universality guarantee of a network designed according to \eqref{eq:DEQ}.

% However, \eqref{eq:DEQ} is not permutation invariant with respect to shuffling the entries of $\mathbf{x}$. We can achieve permutation invariance by defining the energy in \eqref{eq:DEQ} as a sum of pairwise potentials as done in \eqref{eq:equilibrium_aggregation}. This evidently restricts the space of functions we can represent (as we can no longer learn any function of which the output depends on the ordering of the input). The question is: how much does this restrict the types of functions we can learn? Specifically, can we learn all permutation-invariant functions $\psi$?

\subsection{Universality of Equilibrium Aggregation}

In this section, we will see that Equilibrium Aggregation is indeed able to approximate all continuous permutation invariant functions $\psi$.
We start by stating a few assumptions: We assume a fixed input set size $N$ of scalar inputs\footnote{This is a common simplification in the literature on universal function approximation on sets. For a discussion on how to generalise from the scalar to the vector case, see \cite{Hutter}.} $x_i$ (note the dropping of the boldface to indicate that these are not vectors anymore) and a scalar output. We further assume that input space $\mathcal{X}$ is a compact subset of $\mathbb{R}^N$. For simplicity, without loss of generality (as we can always rescale the inputs), we choose this to be $[0,1]^N$:
\begin{align}
\psi: [0,1]^N \to \mathbb{R}.
\end{align}
As $\psi$ is permutation invariant, the vector valued inputs can be seen as (multi)sets. For a discussion on why considering uncountable domains (i.e. the real numbers) is important for continuity, see Section 3 of \cite{Wagstaff2019}. 

We consider a neural network architecture with Equilibrium Aggregation as a global pooling operation of the following form:
\begin{align}
\label{eq:equil_with_rho}
    \phi(X) = \rho(\arg \underset{\mathbf{y}}{\operatorname{min}}\sum_i F_{\theta}({x}_i,\mathbf{y})),
\end{align}

where $F_{\theta}$ (the potential function) and $\rho$ are modeled by neural networks, which are assumed to be universal function approximators. Note that, for simplicity of the proof, we implicitly set the regulariser to 0. We refer to the output of argmin $\sum_i F_{\theta}({x}_i,\mathbf{y})$ as the \textit{latent space}, analogous to the terminology used in \cite{Wagstaff2019} with respect to the Deep Sets architecture \citep{DeepSets}. We prove the following:

\begin{theorem}
\label{thm:main_theorem}
Let the latent space be of size $M=N$, i.e. $\mathbf{y}\in \mathbb{R}^N$. Then all permutation invariant continuous functions $\psi$ can be approximated with Equilibrium Aggregation as defined in \eqref{eq:equil_with_rho}.
\end{theorem}

\begin{proof}
For the purpose of this proof, we assume $F_{\theta}$ takes the form:
\begin{align}
\label{eq:injective_energy}
   F_{\theta}({{x}_i}, \mathbf{y}) = \sum_{k=1}^M (\frac{y_k}{N}-{x}_i^k)^2,
\end{align}

where $k$ serves both as an index for the vector $\mathbf{y}$ and as an exponent for ${x}_i$. There are two sums now, an inner one in the definition of $F_{\theta}$ and an outer one over the nodes in \eqref{eq:equil_with_rho}. Note that $F_{\theta}$ is continuous and can therefore be approximated by a neural network. Importantly, $F_{\theta}$ is also convex and can therefore assumed to be optimised with gradient descent to find $\arg {\operatorname{min}(\mathbf{y})}$. Note that all $M$ terms can be optimised independently as $X$ is fixed. It is a well-known fact that minimising the sum of squares yields the mean:
\begin{align}
        \arg \underset{z}{\operatorname{min}}\sum_{i=1}^N (z-x_i)^2 = \frac{1}{N} \sum_{i=1}^N x_i.
\end{align}
It follows that minimising the sum of energies defined in \eqref{eq:injective_energy} yields
\begin{align}
\label{power_of_sum_mapping}
    y_{k}^{min}= \sum_i x_i^k \quad \text{for} \,\, k \in \{1, \dots, N\}.
\end{align}
For inputs $(x_1, ..., x_M) \in [0,1]^M$, this mapping to $\mathbf{y}$ is evidently continuous and surjective with respect to its range $[0,M]^N$. We also know from Lemma 4 in \cite{DeepSets} that this mapping is injective and from Lemma 6 that it has a continuous inverse.\footnote{We refer to Appendix B.4 in \cite{Wagstaff2019} as to why the term $k=0$ in \eqref{power_of_sum_mapping} can be dropped for fixed set sizes.}
% for inputs [0,1] and from \cite{Wagstaff2019} that this is a homeomorphism. Since homeomorphisms have continuous inverses, 
$\psi$ is continuous by definition and, therefore,


\begin{align}
    \rho = \psi \circ \left( \arg \underset{\mathbf{y}}{\operatorname{min}}\sum_i F_{\theta}({x}_i,\mathbf{y}) \right)^{-1}
\end{align}

is continuous\footnote{The superscript $-1$ indicates the functional inverse w.r.t. $X$.} as long as the inputs $x_i$ are constrained to [0,1] and can therefore be approximated by a neural network. However, via a global re-scaling of the inputs, this proof can be used for any bounded input domain. Hence, any permutation invariant, continuous $\psi$ on a bounded domain can be appoximated via Equilibrium Aggregation for a latent space of size $M=N$.
% $\square$
\end{proof}

\subsection{Comparison to Deep Sets}

So far, we have only been able to prove that Equilibrium Aggregation scales at least as well as Deep Sets. By that, we mean that universal function approximation can be achieved with $N=M$, i.e. having as many latents as inputs is \textit{sufficient}. (For Deep Sets, we also know that $N=M$ is necessary \citep{Wagstaff2019}.) Even though we currently do not know whether it is possible to achieve universal function approximation with a smaller latent space, there is some indication that Equilibrium Aggregation may have more representational power, as we will lay out in the following:

Using one latent dimension, Deep Sets with max-pooling can obviously represent $\psi(X)=\max(X)$, but it cannot represent (or even approximate) the sum for set sizes larger than 1. Vice versa, sum-pooling can represent $\psi(X)=\text{sum}(X)$, but it cannot represent $\max(X)$ \citep{Wagstaff2019}. Equilibrium Aggregation can represent both sum and max pooling, each with just one latent dimension (i.e. $\mathbf{y}\in\mathbb{R}^1$) as shown in Table~\ref{tab:aggregation}.
% :
% \begin{align}
%     F_{\theta}({{{x}_i}, \mathbf{y}}) = (\frac{\mathbf{y}}{N}-{x}_i^k)^2
% \end{align}
% as well as max pooling:
% \begin{align}
%      F_{\theta}({{x}_i}, {\mathbf{y}}) = \max(0, {x}_i - \mathbf{y})
% \end{align}
% with just one latent dimension, i.e. $\mathbf{y}\in\mathbb{R}^1$.

\section{Related Work}
Equilibrium Aggregations sits at the intersection of two machine learning research areas: aggregation functions and implicit layers. In the following, we give an overview over the work closest related in each of the fields, respectively.

\subsection{Aggregation Functions}
Perhaps the most popular approach for obtaining a permutation invariant encoding of sets is Sum pooling. A particular instance of this is Deep Sets~\citep{DeepSets}, as described in~\eqref{eq:sum_aggregation}.
A central finding of \cite{Wagstaff2019} is that the latent space, i.e. the dimensionality of the result of $\sum_i f(x_i) \in \mathbb{R}^M$ needs to be at least as large as the number of inputs $N$, i.e. $M \geq N$ in order to guarantee universal function approximation.
This applies to many other aggregation methods as well and, to the best of our knowledge, there is currently no known pooling operation which does not introduce this scaling issue.

Principal Neighbourhood Aggregation (PNA)~\citep{Corso} addresses the limitations of each individual pooling operator such as Sum or Max by combining four different pooling operators and three different scaling strategies resulting into a simultaneous 12-way aggregation. 
Despite the more sophisticated aggregation procedure,~\citet{Corso} come to very similar conclusions as \cite{DeepSets} and \cite{Wagstaff2019}, namely that $N=M$ is both necessary and sufficient. 
They prove the necessity for any set of aggregators as well as the sufficiency for a specific set.
In our work, we further expand this line of thinking by allowing the model to ~\emph{learn} the desired aggregation operator which may include PNA or something drastically different.

Learnable Aggregation Functions (LAF)~\citep{LAF2020} provide a similar framework for learning an aggregation operator by expressing it as a combination of several weighted $L_p$ norms, where the weights and the $p$ parameters are trained jointly with the model. 
Even though LAFs are capable of expressing operators used in PNA and beyond, it is not clear how they can reproduce other aggregation methods such as attention. 
In contrast, our method can learn attention (see Supplementary Material for details) as well as even more expressive aggregation functions.

Further generalization of the functional form of the aggregation operator leads to the Karcher or Fr\'echet mean~\citep{Grove1973}, which are defined as a solution to the distance-generalization problem over a metric space $\mathcal{X}$ with a metric $d(\cdot, \cdot)$:
$$
    \bar{\mathbf{x}} = \arg \min_{\bar{\mathbf{x}} \in \mathcal{X}} \sum_{i=1}^N d^2(\bar{\mathbf{x}}, \mathbf{x}_i), \quad \mathbf{x}_i \in \mathcal{X}.
$$
While closely related to the Karcher or Fr\'echet mean, Equilibrium Aggregation differs in not restricting the aggregation result to the same space as $\mathcal{X}$ and allowing radically non-symmetrical potential functions, together with the regularizer. 

Finally, Janossy Pooling~\citep{Janossy} generalizes the idea of standard, coordinate-wise pooling to make use of higher-order interactions between set elements.
Despite the potential for practical effectiveness, it is unclear whether these developments guarantee better approximation results in the general case \citep{Wagstaff2021}. 
While Equilibrium Aggregation is also fully compatible with Janossy Pooling and may profit from even more expressive energy functions with pairwise or triplet interactions, this may not be necessary as such interactions can be emulated within the optimization process and ultimately come at a significant computational cost for larger set sizes.

In addition to formulating more expressive pooling operators, there is also a body of work concerned with multi-step parametric models for set encoding~\citep{vinyals2015order, lee2019set}. 
Inevitably, to achieve permutation invariance these models rely on some kind of a pooling as a building block, such as the ones outlined above.
Equilibrium Aggregation being a drop-in replacement for sum- or attention-pooling can be used in those models, too.

\subsection{Implicit and optimization-based models}

Gradient-based optimization has been utilized in a large number of applications~\citep{amos2019differentiable}: image denoising~\citep{putzky2017recurrent}, molecule generation~\citep{duvenaud2015convolutional, alquraishi2019alphafold}, planning~\citep{amos2018differentiable} and combinatorial search~\citep{hottung2020learning, bartunov2020continuous} to name a few. 
While there is a large body of work where gradient descent dynamics is decoupled from learning, (e.g.,~\cite{du2019implicit, song2019generative}, our work is particularly closely related to methods that seek to learn the underlying objective function end-to-end, such as~\cite{putzky2017recurrent, rubanova2021constraint}.

A closely-related family of methods involve the idea of defining computations inside a model \emph{implicitly}, i.e. via a set of conditions that a particular variable must obey instead of prescribing directly how the variable's value should be computed.
Deep Equilibrium Models (DEQs) formulate this via a fixed point of an update rule specified by the model~\citep{pineda1987generalization, liao2018reviving, DEQs} and Implicit Graph Neural Networks explore this idea in the context of graphs~\citep{gu2020implicit}. 
Neural ODEs~\citep{chen2018neural} allow to parametrize a derivative of a continuous-time function specifying the computation of interest.
iMAML~\citep{MetaImplicit} considers an implicit optimization procedure for the purpose of finding model parameters suitable for gradient-based meta-learning~\citep{MAML}.

Our work is similar in spirit but focuses specifically on the aggregation block for encoding sets, which can be seen as a small but generic building block that can be combined with arbitrary model architectures. 
Similarly to OptNet~\citep{amos2017optnet}, we propose a layer architecture that can be used inside another implicit or traditional multi-layer neural network. 

\subsection{Learning on distributions}

An important use-case for set encoding is machine learning models aiming at learning a distribution from a finite sample.
A recent example is Neural Processes~\citep{garnelo2018neural}, which builds a simple permutation-invariant representation of the training set via averaging of its encoded elements and a similar construction of \citet{edwards2016towards}.
Equilibrium Aggregation can be applied to building a more advanced variation on this idea that substitutes the average pooling with a maximum a posteriori (MAP) parameter estimation (see the Supplementary Material for details). 
It is also straight forward to replace the MAP formulation with the parametric variational inference approach, further bridging the gap between set encoding and distribution learning.

\section{Experiments}
In this section, we describe three experiments with the goal of analyzing the performance of Equilibrium Aggregation in different tasks and comparing it to existing aggregation methods.
Our intention is not to achieve state of the art results on any particular task.
Instead, we strive to consider archetypal scenarios and applications in which  performance significantly depends on the choice of aggregation method so it can be studied in isolation from other issues.

In all experiments we let the models to train for $10^7$ steps of Adam optimizer~\citep{kingma2014adam}.
Since maximizing performance is not the goal of our experiments, we do not perform an extensive hyperparameter search, only limiting it to a sweep over the learning rate (chosen from $\{10^{-4}, 3 \times 10^{-4}, 10^{-3}\}$) and the auxiliary loss weight (on MOLPCBA only). To that end, we use  a small subset of the training set reserved for validation (Omniglot and MOLPCBA benchmarks only).
We rely on a single GPU training regime using Nvidia P100s and V100s. 
All experimental code is written in Jax primitives~\citep{jax2018github} using Haiku~\citep{haiku2020github}.
Source code for the most crucial parts of our implementation can be found in the Supplementary Material.

% \input{05a_median}
\subsection{Median estimation}

\begin{figure}
    \centering
    \includegraphics[width=0.94\linewidth]{figures/median_avgandbest.pdf}
    \caption{Median estimation of a 100-number set with three different aggregation methods. The bold lines correspond to the average performance over 5 seeds, the faded lines show the best performing seed of the respective model. Mean square error is computed for varied set embedding sizes on $8 \times 10^5$ number of sets.}
    % \caption{Median estimation using a ResNet with three different pooling operations. A set of 100 scalars is fed into the network, the task is to predict/extract the median of this set. The bold lines correspond to the average performance over 5 seeds, the faded lines show the best performing seed of the respective architecture. The mean squared error is reported for different embedding sizes, i.e. the number of latent dimensions across which the pooling is performed. Note that, for multi-head attention, the smallest embedding size is 12 to allow for an equal division of channels across 4 heads.}
    \label{fig:exp_toy}
\end{figure}

% \begin{figure}
%     \centering
%     \begin{subfigure}{0.44\textwidth}
%         \centering
%         \includegraphics[width=\linewidth]{figures/median_avgseed.pdf}
%         \caption{Performance averaged over 5 seeds.}
%     \end{subfigure}
%     \quad\quad\quad
%     \begin{subfigure}{0.44\textwidth}
%         \centering
%         \vspace{5mm}
%         \includegraphics[width=\linewidth]{figures/median_minseed.pdf}
%         \caption{Performance of best seed.}
%     \end{subfigure}
%     \caption{Median estimation using a ResNet with three different pooling operations. A set of 100 scalars is fed into the network, the task is to predict/extract the median of this set. The mean squared error is reported for different embedding sizes, i.e. the number of latent dimensions across which the pooling is performed. Note that, for multi-head attention, the smallest embedding size is 12 to allow for an equal division of channels across 4 heads.}
%     \label{fig:exp_toy}
% \end{figure}

In this experiment, the neural network is tasked with predicting the median value of a  set of 100 randomly sampled numbers.
Each set is sampled from either a Uniform, Gamma or Normal distribution with fixed parameters, similarly to~\cite{Wagstaff2019}.
The basic architecture for pooling-based aggregation baselines consists of first embedding each number in the set with a fully connected ResNet~\citep{ResNet} with layer sizes $[256, 256, D]$, where $D$ is the set embedding size.
Then, the embeddings are pooled with the corresponding method into a $D$-dimensional vector and the median is predicted from it using another fully connected network with layer sizes $[D, 128, 1]$. 
A simple square loss is used to regress the median.

Equilibrium aggregation, in contrast, performs the input encoding and aggregation simultaneously by doing a 5-step gradient optimization of~\eqref{eq:equilibrium_aggregation} with the potential function implemented as a ResNet with layer sizes $[256, 256, 1]$ taking a $D+1$-dimensional input ($D$ for the implicit aggregation result and $1$ for the input number).
The result is then also transformed into the prediction using the same output network as in the baseline methods. 

We compare three models, Sum aggregation analogous to Deep Sets~\citep{DeepSets}, Multi-head attention with 4 heads, each operating with $D/4$ dimensional keys, values and learned query vectors, and Equilibrium Aggregation as described above.
For each of the models we vary the embedding size and assess the mean square error after $10^7$ training steps.
Empirical results are shown on Figure~\ref{fig:exp_toy}.

Equilibrium aggregation achieves one (for average across 5 seeds) or two (for the best out of 5 seeds) orders of magnitude better estimation error than the baseline pooling methods which confirms its higher representational power in this simple setting.
Importantly, in this experiment, there is no distinction between training and test distributions as the samples are continuously drawn and never repeated. Hence, we are primarily testing the representation power of the approaches as opposed to data efficiency in this particular example. However, it is worth noting that all architectures have roughly the same amount of trainable parameters. Presumably, the low error achieved by Equilibrium Aggregation suggests that it managed to discover or reasonably well approximate the analytical solution $F(\mathbf{x}, \mathbf{y}) = |\mathbf{x} - \mathbf{y}|$.


% Figure \ref{fig:exp_toy} shows that Equilibrium Aggregation significantly outperforms both sum pooling and Multi-Head Attention by about one order of magnitude. Interestingly, if we compare the best performing seed of each pooling type, we observe an even bigger performance increase from using Equilibrium Aggregation (roughly two orders of magnitude). This hints at Equilibrium Aggregation sometimes finding particularly strong solutions. More reliably converging to these particularly high-performing solutions and leveraging the full potential of equilibrium aggregation will be one of our focuses of future research.


\subsection{Omniglot class counting}

\begin{figure}
    \centering
    \begin{subfigure}[b]{0.99\linewidth}
    \includegraphics[width=\textwidth]{figures/omniglot_task_v2.pdf}
    \caption{Task setup.}
    \label{fig:omniglot_task}
    \end{subfigure}
    \begin{subfigure}[b]{0.99\linewidth}
    
    \includegraphics[width=0.95\textwidth]{figures/eqagg_omniglot.pdf}
    \caption{Train (dashed) and test (solid) accuracy for different aggregation methods.}
    \label{fig:omniglot_results}
    \end{subfigure}
    \caption{Omniglot class counting task.}
\end{figure}

We proceed to the more challenging task of counting the number of unique character classes in a set of 16 Omniglot images, which is inspired by~\cite{lee2019set}.
Omniglot~\citep{lake2015human} is a dataset of handwritten characters that are organized into alphabets and then into character classes for each of which only 20 instances are available. 
We randomly choose between 1 and 10 character classes and sample their images to form the input set.
The model then needs to aggregate those images and infer the number of unique character classes by outputting a vector of probabilities for each of the $1, 2, \ldots, 10$ possible number of classes (see Figure~\ref{fig:omniglot_task} for a visual illustration).

Original images are downsized to $32 \times 32$ and encoded using a convolutional ResNet with $[16, 32, 64]$ hidden channels in each of the three blocks correspondingly. 
Each block operates with $3 \times 3$ filters and a stride of 2 and hence reduces spatial sizes of the input tensor by half. 
The ResNet output is then flattened and linearly projected into a $256$-dimensional input embedding.
After the encoding step, as in the previous experiment, Sum, Multi-Head Attention with 4 heads and Equilibrium Aggregation perform set aggregation into $256$-dimensional set embedding and predicted the number of classes using a simple softmax distribution using a fully-connected ResNet with layer sizes of $[128, 10]$.
Equilibrium Aggregation also uses a ResNet potential with $[512, 512, 32]$ structure where the output of the last layer is squared and then summed to form a scalar potential value. 
We used 10 iterations of inner-loop optimization in this experiment.

Each model is trained on the characters from Omniglot train set for $10^7$ steps and with a batch size of $8$. Train and test accuracies are reported in Figure~\ref{fig:omniglot_results}.
One can see that, again, Equilibrium Aggregation outperforms both of the baselines, both in terms of train and test set accuracy.
This shows that, on the one hand, Equilibrium Aggregation has a significantly larger capacity and thus better fits the training data. 
On the other hand, this capacity results into better generalization and, presumably, a more robust aggregation strategy. 

\subsection{Global aggregation in Graph Neural Networks}\label{sec:molpcba}

\begin{figure}
    \centering
    \includegraphics[width=0.99\linewidth]{figures/aux_loss.pdf}
    \caption{Inner-loop optimization statistics on MOLPCBA with the GIN architecture. The pink curve shows the maximum value of the $L^1$ norm along any dimension of the gradient on the last (15th) iteration of the inner loop. A value of $10^{-2}$ indicates a small gradient update and therefore good convergence of the optimizer. The dark purple curve tracks the auxiliary loss, i.e. the $L^2$ norm of the gradient update averaged across all 15 optimization steps (see~\eqref{eq:aux_loss}). Overall, these curves indicate stable, convergent behaviour despite a modest number of inner-loop optimization steps.
    }
    \label{fig:aux_loss}
\end{figure}


\begin{table*}[]
    \centering
    \caption{Comparison between different aggregation methods on MOLPCBA.}
    \begin{tabular}{c|c|c|c}
        \toprule
         \bf Local Aggregation & \bf Global Aggregation & \bf Validation MAP & \bf Test MAP \\
         \midrule
         \makecell{Graph Convolutional Network \\ \small \citep{kipf2016semi}} & \makecell{Sum \\ Multi-Head Attention \\ Principal Neighbourhood Aggregation \\ \bf Equilibrium Aggregation} & \makecell{0.223 \\ 0.248 \\ 0.226 \\ \bf 0.269} & \makecell{0.203 \\ 0.229 \\ 0.209 \\ \bf 0.252} \\
         \midrule
         \makecell{Graph Isomorphism Network \\ \small \citep{xu2018powerful}} &  \makecell{Sum \\ Multi-Head Attention \\ Principal Neighbourhood Aggregation \\ \bf Equilibrium Aggregation} & \makecell{0.255 \\ 0.254 \\ 0.262 \\ \bf 0.263} & \makecell{0.232 \\ 0.234 \\ 0.244 \\ \bf 0.246} \\
         \midrule
         \bf Equilibrium Aggregation & \bf Equilibrium Aggregation & \bf 0.269 & \bf 0.258 \\
         \bottomrule
    \end{tabular}
    \label{tab:gnn}
\end{table*}

Finally, we study the effect of different aggregation methods in the global \textit{readout} layer of a graph neural network (GNN) on a well-established MOLPCBA benchmark~\citep{hu2020open}.
In this task, the model is required to predict 128 global binary properties of an input molecule.
This is traditionally implemented within the GNN framework by first applying several layers of message-passing on a graph and then aggregating the resulting 300-dimensional node embeddings into a single 300-dimensional graph representation from which the predictions are made.
Since there is more than one prediction task per molecule, mean average precision (MAP) is used as an evaluation metric.
The test MAP is reported for the best MAP attained on the validation set as the model is training.
The validation and test metrics are periodically evaluated from model snapshots taken approximately every $10^4$ training steps.

For this experiment, we choose two popular GNN architectures, namely a Graph Convolutional Network (GCN)~\citep{kipf2016semi} and a Graph Isomorphism Network (GIN)~\citep{xu2018powerful} that both use a simple Sum readout in their canonical implementations by \citet{hu2020open}.
We leave the architectures unchanged and only vary the global readout operation.
Our implementation uses the Jraph library~\citep{jraph2020github} and dynamic batch training with up to $8$ graphs and $1024$ nodes in a batch.

For the potential network we use an architecture similar to the previous experiment with layer sizes $[600, 300, 32]$, sum-of-the-squares output and $15$ iterations for energy minimization.

The results are provided in Table~\ref{tab:gnn}. 
Overall, the empirical findings on MOLPCBA are consistent with the previous experiments with Multi-Head Attention providing a noticeable performance improvement over the basic Sum aggregation and Equilibrium Aggregation performing even better.
In addition, we also evaluate Principal Neighbourhood Aggregation (PNA)~\citep{Corso}, which has been proposed to address limitations an each individual pooling method in the context of GNNs and combines 12 combinations of scaled pooling methods.
When combinining PNA with the GCN model, our experiments only show minor performance improvements over Sum pooling, in part because of increased overfitting.
However, when applied to the GIN architecture, it achieves performance levels almost on par with Equilibrium Aggregation.

These results confirm one of the central hypotheses of this research: namely that the global aggregation of node embeddings is a critical step in graph neural networks. 
Perhaps surprisingly, the GCN generally benefited more from more advanced aggregation methods which is probably due to smaller number of parameters and thus decreased risk of overfitting.
It is also worth noting that top performing GNN architectures achieve significantly higher test MAP on this task (see, e.g.,~\citet{yuan2020large, brossard2020graph}).

In addition, we test an architecture where both local (i.e. node-level) and the global aggregations are performed using Equilibrium Aggregation. 
This model yields even better performance, albeit only marginally.
While more careful architecture design that takes into account the specifics of Equilibrium Aggregation could potentially lead to larger performance improvements, it should be noted that the molecular graphs in this task are relatively small and aggregation on the local level may be not the most critical step for a typical GNN.

Besides the task performance we also investigate the behaviour of the inner-loop optimization. 
Figure~\ref{fig:aux_loss} plots two major statistics that quantify this: the max-norm of the final iterate of the optimization $\max_d |\nabla_{y_d} E(X, \mathbf{y}^{(T)})|$ and $L_{\text{aux}}$~\eqref{eq:aux_loss}.
One can see that both rapidly decrease during the training and that a good degree of convergence is achieved.
We observe similar behaviour with GCN and on other tasks we considered earlier.

\section{Discussion and Conclusion}

This work provides a novel optimization-based perspective on the widely encountered problem of aggregating sets that is provably universal.
Our proposed algorithm, Equilibrium Aggregation, allows learning a problem-specific aggregation mechanism which, as we show, is beneficial across different applications and neural network architectures.
The consistent empirical improvement brought by the use of Equilibrium Aggregation not only shows that many existing models are struggling from aggressive compression and inefficient representation of sets but also suggests a whole new class of set- or graph-oriented architectures that employ a composition of Equilibrium Aggregation operations.
Beyond GNNs, other classes of models, such as Transformers, may also profit from more expressive aggregation operations, specificially in modelling long-term memory -- a topic strongly connected to compression of sets~\citep{rae2019compressive, bartunov2019meta}, as well as potentially reduce the number of layers needed.

While there is a strong indication that using Equilibrium Aggregation as a building block is effective, the incurred computational cost may require more developments in differentiable optimization~\citep{ernoult2020equilibrium}, architecture~\citep{amos2017input} and hardware design~\citep{kendall2020training}, especially in order to compete with modern extra large models.

% \begin{contributions} % will be removed in pdf for initial submission,
%                       % so you can already fill it to test with the
%                       % ‘accepted’ class option
%     Briefly list author contributions.
%     This is a nice way of making clear who did what and to give proper credit.

%     H.~Q.~Bovik conceived the idea and wrote the paper.
%     Coauthor One created the code.
%     Coauthor Two created the figures.
% \end{contributions}

\begin{acknowledgements} % will be removed in pdf for initial submission,
                         % so you can already fill it to test with the
                         % ‘accepted’ class option
    We thank Peter Battaglia, Petar Veličković, Marcus Hutter, Yulia Rubanova and Marta Garnelo for their help with preparing the paper, insightful discussions and overall support during the course of the work. 
\end{acknowledgements}

\bibliography{bartunov_444}

\end{document}

