\PassOptionsToPackage{hidelinks}{hyperref}
%\documentclass{uai2022} % for initial submission
\documentclass[accepted]{uai2022} % after acceptance, for a revised
                                    % version; also before submission to
                                    % see how the non-anonymous paper
                                    % would look like
%% There is a class option to choose the math font
% \documentclass[mathfont=ptmx]{uai2022} % ptmx math instead of Computer
                                         % Modern (has noticable issues)
% \documentclass[mathfont=newtx]{uai2022} % newtx fonts (improves upon
                                          % ptmx; less tested, no support)
% NOTE: Only keep *one* line above as appropriate, as it will be replaced
%       automatically for papers to be published. Do not make any other
%       change above this note for an accepted version.

%% Choose your variant of English; be consistent
\usepackage[american]{babel}
%\usepackage[british]{babel}

%% Some suggested packages, as needed:
\usepackage{natbib} % has a nice set of citation styles and commands
    \bibliographystyle{plainnat}
    \renewcommand{\bibsection}{\subsubsection*{References}}
\usepackage{mathtools} % amsmath with fixes and additions
% \usepackage{siunitx} % for proper typesetting of numbers and units
\usepackage{booktabs} % commands to create good-looking tables
\usepackage{tikz} % nice language for creating drawings and diagrams

% my packages
\usepackage{color}
\usepackage{pstricks}
\usepackage{url}            % simple URL typesetting
\usepackage{amsfonts}       % blackboard math symbols
\usepackage{dsfont}
\usepackage{nicefrac}       % compact symbols for 1/2, etc.
\usepackage{lipsum}
\usepackage{amsmath} % math
\usepackage{amsthm}
\usepackage{color} % colored text
\usepackage{pstricks} % colored text
\usepackage{graphicx}
\usepackage{float}
\usepackage{caption}
\usepackage{subcaption}
\usepackage{tikz}
\usetikzlibrary{bayesnet}
\usetikzlibrary{arrows}
%\usepackage{todonotes}
\usepackage{arydshln}

%% Provided macros
% \smaller: Because the class footnote size is essentially LaTeX's \small,
%           redefining \footnotesize, we provide the original \footnotesize
%           using this macro.
%           (Use only sparingly, e.g., in drawings, as it is quite small.)

%% Self-defined macros
\newcommand{\Iobs}{I_{\text{obs}}}
\newcommand{\E}[1] {\mathbb{E}_{q(I, \theta|\Iobs, y)} \hspace{-0.1cm} \left[
   #1 \right]}
\newcommand{\q}{q(I, \theta|I_{\text{obs}}, y)}
\newcommand{\T}{T_{\theta}}
\newcommand{\R}{\mathds{R}}

\title{Probabilistic Spatial Transformer Networks}

% The standard author block has changed for UAI 2022 to provide
% more space for long author lists and allow for complex affiliations
%
% All author information is authomatically removed by the class for the
% anonymous submission version of your paper, so you can already add your
% information below.
%
% Add authors
\author[1]{\href{mailto:<posc@dtu.dk>?Subject=Your UAI 2022 paper}{Pola Schw\"obel}{}}
\author[1]{Frederik Warburg}
\author[2]{Martin J\o rgensen}
\author[1, 3]{Kristoffer H. Madsen}
\author[1]{S\o ren Hauberg}
% Add affiliations after the authors
\affil[1]{
Section for Cognitive Systems\\
DTU Compute\\
Technical University of Denmark, Copenhagen, Denmark
} 

\affil[2]{%
    Machine Learning Research Group\\
    Department of Engineering Science\\
    University of Oxford, Oxford, UK
}
\affil[3]{%
   Danish Research Centre for Magnetic Resonance\\
   Centre for Functional and Diagnostic Imaging and Research\\
   Copenhagen University Hospital Hvidovre, Hvidovre, Denmark
  }
  
\begin{document}
\maketitle

\begin{abstract}
Spatial Transformer Networks (STNs) estimate image transformations that can improve downstream tasks by `zooming in' on relevant regions in an image. However, STNs are hard to train and sensitive to mis-predictions of transformations. To circumvent these limitations, we propose a probabilistic extension that estimates a stochastic transformation rather than a deterministic one. Marginalizing transformations allows us to consider each image at multiple poses, which makes the localization task easier and the training more robust. As an additional benefit, the stochastic transformations act as a localized, learned data augmentation that improves the downstream tasks. We show across standard imaging benchmarks and on a challenging real-world dataset that these two properties lead to improved classification performance, robustness and model calibration. We further demonstrate that the approach generalizes to non-visual domains by improving model performance on time-series data.  
\end{abstract}

\section{Introduction} 
\label{sec:intro}

The \emph{Spatial Transformer Network (STN)} \citep{jaderberg2015spatial} predicts a \emph{transformation} on input data in order to simplify a downstream task. For example, a neural network might benefit from e.g.\@ `zooming in' on relevant parts of an image, remove unwarranted image rotations, or time-normalize sequence data before making predictions. In principle, this can improve robustness, interpretability and efficiency of the model. However, in practice, the situation is not as ideal. Both at training and test time, the STN is sensitive to small mis-predictions of transformations. For example, if the STN zooms in on the wrong part of an image, then the signal is lost for the downstream task, e.g.\@ see crop A and C in Fig.~\ref{fig:teaser}. The empirical impact is that STNs are difficult to train and often do not live up to their promise.

\begin{figure}[!tp]
    \centering
    \includegraphics[width=\linewidth]{images/teaser_pstn.pdf}
    \caption{The Probabilistic Spatial Transformer Network (P-STN) marginalizes over a distribution of possible input transformations. By `looking in multiple places'  we hope to stabilize the brittle nature of the regular spatial transformer: The P-STN loss landscape is significantly more smooth and with fewer local minima compared to the STN.}
    \label{fig:teaser}
\end{figure}

From a probabilistic perspective, this sensitivity has an obvious solution: we should estimate the posterior over the applied transformation and marginalize accordingly. This amounts to `trying many different transformations', and should improve robustness. It is exactly this approach we investigate. 

STNs consist of two parts. A localization network performs the transformation task, i.e.\@ it estimates the transformation parameters $\theta$ for a given image $I$ and applies the corresponding transformation $T_\theta(I)$. A standard neural network performs the downstream task on the transformed image, i.e.\@ computing  $p(y | T_\theta(I))$. Since we are concerned with classification tasks, we will refer to the latter as the classifier, but note that the approach generalizes to other tasks.

In our probabilistic STN (P-STN), we estimate a distribution over transformations that we marginalize: $p(y | I) = \int p(y | T_\theta(I)) \mathrm{d}\theta$. We approximate this intractable integral via Monte Carlo, i.e.\@ we sample transformations. These transformation samples produce different transformed versions of the input image, $\{T_\theta^s(I)\}_{s=1...S}$. The classifier makes predictions on all samples, and we aggregate the predictions. Figure~\ref{fig:architecture} shows the model architecture. 

We hypothesize that marginalizing image transformation has benefits for both parts of the model. For the \textit{localization} network, our model gets to `try many different transformations' through random sampling. This should improve the localization.
%
Secondly, the classifier now gets presented with different transformed versions of the input image through Monte Carlo samples $\{T_\theta^s(I)\}_{s=1...S}$. Interestingly, this corresponds to a type of data augmentation, which should improve classification.\looseness=-1

We verify these hypotheses by making the following contributions:
\begin{enumerate}
    \item We develop the Probabilistic Spatial Transformer; a hierarchical Bayesian model over image transformations.\looseness=-1
    \item We perform variational inference to fit the transformation model as well as downstream model end-to-end, using only label information. 
    \item We experimentally demonstrate that our model achieves better localization, increased classification accuracy (resulting from learned per-image data augmentation) and improved calibration. 
\end{enumerate}


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
\begin{figure*}
  \centering
  \includegraphics[width=\textwidth]{images/pstn_pipeline_new.pdf}
  \caption{The P-STN pipeline. From the observed image $I$, a distribution of transformations is estimated. Samples from this distribution are applied to the observed image to produce augmented samples, which are fed to a classifier that averages across samples. In the deterministic STN case, the localizer only computes one transformation $\theta(I)$, which can be thought of as the maximum likelihood solution. Instead of the multiple transformation samples, we obtain a single $T_\theta(I)$ in this case.}
  \label{fig:architecture}
\end{figure*}

\section{Related Work}
\label{sec:related_work}

\textbf{Spatial transformer networks} apply a spatial transformation to the input data as part of an end-to-end trained model \citep{jaderberg2015spatial}. The transformation parameters are estimated from each input separately through a neural network. Most commonly, STNs implement simple affine transformations, such that the network can learn to zoom in on relevant parts of an image before solving the task at hand. STNs have shown themselves to be useful for both generative and discriminative tasks, and have seen applications to different data modalities \citep{jaderberg2015spatial, detlefsen2019explicit, detlefsen18diffstn, NIPS2019_8884, Sonderby2015, LinL16a,kanazawa2016warpnet}.
We propose a probabilistic extension of this idea, replacing the usual likelihood maximization with marginalization over transformations.\looseness=-1

\textbf{Bayesian deep learning} aims to solve probabilistic computations in deep neural networks. Priors are put on weights and marginalized at training and test time, often yielding useful uncertainties in the posterior predictive. The required computations are in general intractable, and approaches differ mainly in the type of approximation to the weight posterior. 
\citet{gal2016dropout} propose to view dropout as a Bernoulli 
approximation to the weight posterior (i.e.\@ randomly switching each weight on or off). The Laplace approximation \citep{mackay1992bayesian, daxberger2021laplace} places a Gaussian posterior over a trained neural network's weights. Another generally successful way to obtain predictive uncertainties is to simply train an ensemble of models. Originally proposed as an alternative to Bayesian DL \citep{lakshminarayanan2017deepensembleUQ}, the approach can be interpreted in the  Bayesian framework by interpreting the weights of the trained ensemble members as samples from a weight posterior \citep{gustafsson2020evaluating}. Similar to our method, \citet{blundell2015weight} choose a variational approach with a simple Gaussian mean field posterior over weights. Our approach differs from standard Bayesian DL in that we are not reasoning about distributions over neural network weights $p(w)$, but instead a subnetwork's (i.e.\@ the localizer's) \textit{outputs} $p(\theta)$. 
Drawing from the posterior over image transformations, we effectively recover data augmentation.


\textbf{Data augmentation (DA)} is a useful way to increase the amount of available data \citep{lecun1995comparison, kr2012imagenet}. DA requires prior knowledge about the structure of the data: the target $y$ is assumed to be invariant to certain transformations of the observation $I$. Invariance assumptions are usually straight forward for natural images. Thus, DA is common for image data, where the transformation family is often chosen to be rotations, scalings, and similar \citep{goodfellow2009measuring, baird1992document, simard2003best, kr2012imagenet, loosli:lskm:2007}. The general trend is that, beyond `intuitive' data such as images, gathering an invariance prior is difficult, and DA is often hard to realize through manual tuning. 

\textbf{Learned data augmentation} provides a more principled approach to artificially extending datasets. \citet{hauberg2016dreaming} estimate an augmentation scheme from the training data via pre-aligning images in an unsupervised manner. The approach allows for significantly more complex transformations than the usual affine family, but the unsupervised nature and the implied two-step training process render the approach suboptimal. Similarly, \citet{cubuk2019autoaugment,cubuk2020randaugment} use reinforcement learning and grid search to learn data augmentation schemes, but rely on validation data rather than an end-to-end formulation. 

Learning data augmentation end-to-end requires a loss function suitable for model selection, as we are effectively trying to learn an inductive bias. Based on this realization, \citet{vanderwilk2018learning} learn DA end-to-end in Gaussian processes (GPs) via the marginal likelihood, a suitable loss for model selection and thus invariance learning \citep{mackay2002information}. The marginal likelihood is hard to compute for NNs, so \citet{schwoebel2021layer} extend this idea to NNs by considering a deep kernel model, i.e.\@ a neural network with a GP in the last layer. \citet{benton2020learning} instead use the standard, maximum likelihood loss and explicitly regularize towards non-zero augmentations. 
Our model differs from existing data augmentation approaches --- learned and non-learned --- in that we estimate local, i.e.\@ \emph{per-image} transformations instead of a global augmentation scheme. 


\section{Background}\label{sec:background}

The STN localiser module estimates a transformation $\theta(x)$ that transforms a coordinate grid and interpolates an image accordingly. The classifier module takes the transformed image and computes $p(y|T_\theta(x))$. Both the localizer and classifier are neural networks. The STN can be trained end-to-end with only label information as long as the image transformations are parameterized in a differentiable manner.\looseness=-1

\textbf{Affine transformations} are a simple class of transformations that can be differentiably parameterized. We limit ourselves to the subset of affine transformations containing rotation, isotropic scaling and translation in $x$ and $y$. In two dimensions (and the corresponding three-dimensional homogeneous coordinates), we thus learn $\theta= (r, s, t_x, t_y) $
which parameterizes the affine matrix
\begin{equation}\label{eq:affine}
A_{\theta} = 
\begin{bmatrix}
    s \cdot \cos r & -  s \cdot \sin r & t_x \\
    s  \cdot \sin r & s \cdot \cos r & t_y \\
    0  & 0 & 1
\end{bmatrix} \in
\mathds{R}^{3 \times 3}, \  s> 0.
\end{equation}

Since $\det(A_{\theta})=s^2$, the constraint $s>0$ ensures invertibility and can be implemented as seen in \citet{detlefsen18diffstn}. In practice, the STN estimates well-behaved, non-collapsing transformations without implementing the constraint explicitly. $\T(I)$ is applied by transforming a grid of the target image size by $A_\theta$ and interpolating the source image at the resulting coordinates (see \citet{jaderberg2015spatial} for details).

\textbf{Diffeomorphic transformations} (i.e.\@ transformations that are differentiable, invertible and possess a differentiable inverse) are more general than affine transformations, and are not limited to the spatial domain. 
\citet{freifeld2017transformations} construct diffeomorphisms from continuous piecewise-affine velocity fields as follows. The transformation domain $\Omega$ 
is divided into subsets and an affine matrix is defined on each cell $c$ of such a tessellation. Each affine matrix $A_{\theta_c}$ induces a vector field mapping each point $x \in c$ to a new position $v^{\theta_c}: x \mapsto A_{\theta_c} x$. These velocity fields are then integrated to form a trajectory for each image point $x$ $$
\phi^{\theta}(x;1) = x + \int_0^1 v^\theta(\phi(x; \tau))\mathrm{d}\tau.
$$
Given boundary and invertibility constraints \citep{freifeld2017transformations}, such a collection of affine matrices $\{ A_{\theta_c}\}_{c \subset \Omega}$ defines a diffeomorphic transformation $T^{\theta} : x \mapsto \phi^\theta(x,1)$. 

The libcpab library~\citep{detlefsen2018libcpab} provides an efficient implementation for this approach, specifically optimized for use in a deep learning context where fast gradient evaluations are crucial. The author successfully employs CPAB-transformations within a Spatial Transformer Network \citep{detlefsen18diffstn}.


\section{Probabilistic Spatial Transformer Network}

The P-STN is a probabilistic extension of the STN, where we replace the deterministic transformation $\theta(I)$ with a posterior over transformations $p(\theta | I)$. Figure~\ref{fig:architecture} illustrates the proposed pipeline. We assume observed data of the form $\mathcal{D}=\{y_i,I_i\}_{i=1}^N$, where $y$ is the target variable (e.g.\@ class label), and $I$ are observations of the covariates. For presentation purposes, we will consider the latter to be images, but the approach applies to any spatio-temporal data.
  
  \subsection{The Model}
    Recall that STNs are trained end-to-end for the downstream task using only label information. Thus, while we observe $y$, $\theta$ is a latent variable. We model it to be governed by a second latent variable $\lambda$. $\lambda$ is a precision parameter, effectively stopping the localization distribution (i.e.\@ the amount of `data augmentation' we introduce) from collapsing. The necessity for non-collapsing augmentation is discussed in \citet{benton2020learning, vanderwilk2018learning} and \citet{ schwoebel2021layer}.
    
    We wish to infer the latent variables in a Bayesian manner. This entails computing the (log-)marginal likelihood of the observed
    \begin{align} \label{ref:objective}
      \log p(I, y) &= \log \iint p(I, y, \theta, \lambda)  \mathrm{d}\theta \mathrm{d}\lambda .
    \end{align} 
    We let the joint distribution factorize as (see Fig.~\ref{fig:graphical_model})
    \begin{align} \label{eq:factorization}
      p(y, I, \theta, \lambda) &= p(y|I, \theta, \lambda) p(I, \theta, \lambda)\\
    &= p(y|I, \theta) p(\theta| \lambda, I)p(\lambda)p(I).
    \end{align}
    Notice $p(I)$ is unaffected by model parameters $\lambda$ and $\theta$, and in this sense can be specified without affecting the model. The distribution over $\theta$ depends on observed covariates in the following way
    \begin{align}
        p(\theta|\lambda, I)=\mathcal{N}(\theta| \mu(I), 1/\lambda),
    \end{align}
    where $\mu(I)$ is a function parametrised by a neural network, i.e.\@ $\mu(I): = \mu_\Phi(I)$ for model parameters $\Phi$. The prior over $\lambda$ is a Gamma distribution, i.e.
    \begin{align}
        p(\lambda_i)=\Gamma(\alpha_0, \beta_0).
    \end{align}
    We note here that there is one $\lambda_i$ associated to each observation, and they are assumed to factorize: $p(\lambda)~=~\prod_{i=1}^N~p(\lambda_i)$. This choice of conjugate priors for variance estimation is similar to \citep{stirn2020variational, takahashi2018student, detlefsen2019reliable}. Finally, we assume that, conditional on $I$ and $\theta$, we have marginal independence in $y$, i.e.\@ $p(y|I,\theta) = \prod_{i=1}^N p(y_i|I_i, \theta_i)$.
    
\begin{figure}
\centering
  \tikz{
% nodes
 \node[obs] (y) {$y$};%
 \node[latent, right=of y] (theta) {$\theta$};%
 \node[latent,above=of theta] (lambda) {$\lambda$}; %
 \node[obs, above=of y] (I) {$I$}; %
% plate
 \plate [inner sep=.25cm,yshift=0cm] {plate1} {(lambda)(I)(theta)(y)} {$N$}; %
% edges
 \edge {I} {y, theta}
  \edge {lambda} {theta}
 \edge {theta} {y}}
 \caption{A graphical representation of the model structure. Grey nodes are observables and white are latents.} \label{fig:graphical_model}
\end{figure}


\subsection{Variational Approximation} \label{sec:variational_approx}
The integral equation (\ref{ref:objective}) for the marginal likelihood is intractable and, thus, the posterior $p(\lambda, \theta| I, y)$ is too.  We derive a lower bound on the log marginal likelihood to utilize variational inference \citep{Blei2017VariationalIA}. We choose the variational approximation $q$ of the posterior $p(\theta, \lambda | I, y)$ as 
\begin{equation}
q(\theta, \lambda) := p(\theta|\lambda, I)q(\lambda).    
\end{equation}
Here $p(\theta|\lambda, I)$ is given as before and $q(\lambda):=\prod_{i=1}^N\Gamma\left(\alpha_i, \beta(I_i)\right)$. In our approximation, $\beta$ is a neural network: hence, we use amortized inference in a similar way to the VAE model \citep{kingm2013VAE}.

We derive our lower bound using Jensen's inequality
\begin{align}
    &\log p(y, I) = \log \iint p(y, I, \theta, \lambda) \mathrm{d}\theta \mathrm{d}\lambda  \\
    &\geq \iint \log   \left(\frac{p(y, I, \theta, \lambda)}{q(\theta, \lambda)}\right) q(\theta, \lambda) \mathrm{d}\theta \mathrm{d}\lambda\\ 
    &= \!\iint \!\log\!   \left(\frac{p(y|I, \theta) p(\lambda)p(I)}{q(\lambda)}\right) p(\theta|\lambda, I)q(\lambda) \mathrm{d}\theta \mathrm{d}\lambda\nonumber\\
    &= \underbrace{\mathbb{E}_{q(\theta, \lambda)} \log p(y|I, \theta)}_{\textbf{classification loss}}  + \log p(I)  - \text{KL}\!\left(q(\lambda) \| p(\lambda)\right).\label{ELBO}
\end{align}


Thus, our evidence lower bound (ELBO) objective function~\eqref{ELBO}, consists of two terms:
a classification loss and a KL-term controlling the distance of the approximate posterior to the prior. During inference, we can disregard  $\log p(I)$ as it does not depend on parameters of interest.

\subsection{Inference} \label{sec:inference}
The choice of variational posterior implies the following for the \textbf{classification loss}
\begin{align}
    &\mathbb{E}_{q(\theta, \lambda)} \log p(y|I, \theta)
    \\&= \iint \log p(y|I, \theta) q(\theta, \lambda) \mathrm{d}\theta \mathrm{d}\lambda \\
    &=  \iint \log p(y|I, \theta) p(\theta|\lambda, I)q(\lambda) \mathrm{d}\theta \mathrm{d}\lambda \\
    & = \int   \log p(y|I, \theta) \int \mathcal{N}(\theta|\mu(I), \lambda) \Gamma(\lambda | \alpha, \beta(I))  \mathrm{d}\lambda \mathrm{d}\theta\nonumber
    \\
    &= \int \log p(y|I, \theta) t_{2\alpha}(\theta | \mu (I)), \tfrac{\beta(I)}{\alpha}) \mathrm{d}\theta. \label{eq:variational_expectation}
\end{align}
Here $t$ denotes a scaled and location-shifted Student's $t$-distribution with mean $\mu (I)$, scaling $\beta$, and $\alpha$ degrees of freedom. For clarity, the marginalized $q(\theta)$ is $t$-distributed. Here $p(y|I,\theta)$ is what previously was referred to as $p(y|T_\theta(I))$, i.e.\@ the classifier conditioned the transformed $I$. 

We approximate Eq.~\ref{eq:variational_expectation} using an unbiased estimate 
\begin{align}
    & \mathbb{E}_{q(\theta, \lambda)} \log p(y_i|I_i, \theta_i) \approx \frac{1}{S} \sum_{s=1}^S \log p(y_i |I_i, \theta_{i,s}), \\
    & \text{with } \theta_{i,s} \sim  t_{2\alpha_i}(\cdot | \mu (I_i)), \tfrac{\beta(I_i)}{\alpha_i})
\end{align}
and backpropagate through neural networks $\mu(I)$ and $\beta(I)$ with the reparametrization trick. In all experiments $\alpha_i\!=\!1$.

Combining terms, the final ELBO we maximize becomes
\begin{align}
\begin{split}
    \mathcal{L}_{p, q}(I, y) & \approx  \sum_{i=1}^N \frac{1}{S}  \sum_{s=1}^S \log p(y_i | I_i, \theta_{i,s}) \\
        & \qquad - \text{KL}\left(q(\lambda)||p(\lambda)\right) + \text{const},
\end{split}
\end{align}
which is readily optimized using any gradient-based method. The KL-term is analytically tractable and differentiable between two gamma distributions.

In practice, following \citet{higgins2016beta} we introduce a weight parameter $w$ to the KL-term. This requires us to tune $w$ but in turn makes the model robust to the choice of prior. We perform a grid-search on a validation set to find the optimal $w$. Alternatively, we could have done a grid search over $\beta_0$; instead we chose $\alpha_0 = \beta_0=1$ for all experiments. Similar to \cite{kingm2013VAE}, we often find it sufficient to draw only $S=1$ samples during training. Note that our model naturally implies marginalization, and correspondingly data augmentation, at \textit{test-time} as well as the usual training time. At test time, we draw $S=10$ transformation samples.


\section{Experiments \& Results} \label{sec:exp}

Our model consists of two parts, the classifier $p(y|T_\theta(I))$ and the probabilistic localizer estimating the distribution over transformations. In the following experiments, we aim to disentangle our model's benefits for localization (Sec.~\ref{sec:exp_localiser}), classification (Sec.~\ref{sec:exp_classifier}) and calibration (Sec.~\ref{sec:exp_calibration}). 

The probabilistic localizer estimates $q(\theta) = t_2(\theta | \mu(I), \beta(I))$, i.e.\@ in practice we implement a mean and a variance network, $\mu(I)$ and $\beta(I)$, respectively (see Fig.~\ref{fig:architecture} for the architecture).
We employ a small convolutional network (\texttt{Conv2d, Maxpool2d, ReLU, Conv2d, Maxpool2d, ReLU}) followed by two fully connected layers for both the localizer and classifier unless stated otherwise. The P-STN localizer has two heads; one for the mean and one for the variance. The number of parameters is stated in each experimental subsection. Unless stated otherwise, we keep the number of parameters constant, i.e.\@ when adding a localization network we remove the extra parameters from the classifier for fair comparison. 

Our model is implemented in PyTorch and experiments are run on 12 GB Nvidia Titan X GPUs. The code is available at \url{https://github.com/FrederikWarburg/pSTN-baselines}.


\subsection{Marginalizing transformations improves localization accuracy} \label{sec:exp_localiser}
The appeal of STN models is that they are trained end-to-end, i.e.\@ based only on labels for the downstream task, and not the transformations. This same property, however, is what makes the STN hard to fit. The only signal we obtain is through the supervised downstream task (i.e.\@ the classification labels) and thus gradient information is sparse. We will now investigate whether estimating a posterior over transformations and marginalizing, i.e.\@ `getting to try multiple transformations', simplifies the task as suggested by Fig.~\ref{fig:teaser}. 

In order to disentangle the localization from the classification task, we construct the following experiments. We first train a CNN on a pose-normalized dataset (regular MNIST and Fashion MNIST). We then generate a new dataset by randomly sampling transformations $\theta_\text{true}$ and applying them to the MNIST images. Saving these transformations provides us with ground truth. We freeze the CNN weights and train STN and P-STN with this fixed classifier, effectively learning to recover and `undo' the true transformations. 


\subsubsection{Rotated MNIST}
From this data-generating process, we obtain a rotated version of the MNIST dataset (i.e.\@ regular MNIST with ground-truth transformations given by rotation angles, $\theta_\text{true}(I) = r_\text{true}(I)$). See Fig.~\ref{fig:exp_rotMNIST}, top right panel for example data. 

% RotMNIST
\begin{figure}% [h!]
    \centering
    \includegraphics[width=0.95\columnwidth]{images/calibration.pdf}
    \caption{Rotated MNIST experiment. \textit{Left panel:} Ground-truth transformation (rotation angles in radians) against recovered transformations (mean). \textit{Top right:} Example images from the data set and samples from the P-STN localizer. The localizer learns to pose-normalize. \textit{Bottom right:} Outputs of the variance network. When the transformation recovery is poor (the error $\varepsilon$ is above the median, in orange) the variances are slightly higher than when the localization works well (blue). }
    \label{fig:exp_rotMNIST}
\end{figure}


% Mapillary
\begin{figure*}
\begin{minipage}{\textwidth}
  \begin{minipage}[b]{0.65\textwidth}
    \centering
     \begin{subfigure}[b]{0.19\textwidth}
         \centering
         \includegraphics[width=\textwidth]{images/bouding_boxes/bounding_box_idx_0.png}
     \end{subfigure}
     \begin{subfigure}[b]{0.19\textwidth}
         \centering
         \includegraphics[width=\textwidth]{images/bouding_boxes/bounding_box_idx_5.png}
     \end{subfigure}
     \begin{subfigure}[b]{0.19\textwidth}
         \centering
         \includegraphics[width=\textwidth]{images/bouding_boxes/bounding_box_idx_11.png}
     \end{subfigure}
     \begin{subfigure}[b]{0.19\textwidth}
         \centering
         \includegraphics[width=\textwidth]{images/bouding_boxes/bounding_box_idx_8.png}
     \end{subfigure}
     \begin{subfigure}[b]{0.19\textwidth}
         \centering
         \includegraphics[width=\textwidth]{images/bouding_boxes/bounding_box_idx_15.png}
     \end{subfigure}
    \captionof{figure}{The P-STN learns to localize traffic signs in the challenging MTSD dataset. At test time, we sample $10$ transformations as shown with the various bounding boxes overlaid the images. These learned variations improve the final classification.}
     \label{fig:mtsd}
  \end{minipage}
  \hspace{0.5cm}
  \begin{minipage}[b]{0.20\textwidth}
    \centering
    \begin{tabular}{l|ll}
    \toprule
    & Acc. $\uparrow$ & NLL $\downarrow$ \\ \midrule
     CNN & $76.0$  & $0.49$ \\
     STN & $90.6$  & $0.31$ \\
     P-STN & $\mathbf{92.2}$  & $\mathbf{0.29}$ \\
     \bottomrule
    \end{tabular}
      \captionof{table}{Accuracy (Acc.) and negative log-likelihood (NLL) for CNN, STN and P-STN.}
      \label{tab:mtsd}
    \end{minipage}
  \end{minipage}
\end{figure*}


Our CNN classifier ($28$k weights) obtains $99.4\%$ test accuracy on MNIST and $41.2\%$ on rotated MNIST (frozen weights, no re-training). The STN and P-STN ($S\!=\!10$ training samples, $w\!=\!3\cdot 10^{-5}$, same CNN classifier as before $+ 72$k params in the localizer) both learn to pose-normalize, i.e.\@ to recover these transformations to a satisfactory degree. When training the localizers only (classifier weights remain frozen as described above), the STN test acc.\@ is $76.13 \%$, and $82.98\%$ for the P-STN.  We compute the expected average transformation error on the $N=10k$ rotated MNIST test images as
%
\begin{align}
\varepsilon = \frac{1}{N} \sum_{i=1}^N \| \theta_\text{true}(I_i) - \mu(I_i) \| \mod \pi.
\end{align}
We get $\varepsilon = 0.76$ for the STN and $\varepsilon = 0.59$ for the P-STN. The P-STN  outperforms the STN, i.e.\@ modeling uncertainty in the transformations helps in the localization task.


\paragraph{Uncertainty.}
The bottom right panel of Fig.~\ref{fig:exp_rotMNIST} shows a histogram of $\beta(I)$, i.e.\@ the localizer variance (or, correspondingly, the magnitude of augmentation) per image. In orange, we plot variances for images where pose-normalization is difficult (the transformation error $\varepsilon$ is larger than the median). In blue, we plot variances for images that are correctly pose-normalized (transformation error $\varepsilon$ smaller than the median). The poorly localized images are, on average, assigned $17\%$ larger variances $\beta(I)$. The localizer uncertainty and thus the amount of data augmentation applied is somewhat meaningful, corresponding to the difficulty of the task. 

\begin{table*}[!htp]
\centering
\begin{tabular}{l|lllll}
\toprule
                      & MNIST30          & MNIST100         & MNIST1000        & MNIST3000 & MNIST10000 \\ \midrule
\textit{CNN }        & $70.12 \pm 2.46$ & $87.29 \pm 0.58$ & $95.80 \pm  0.33$ &  $\textbf{97.48} \pm 0.21$ & $\textbf{97.82} \pm 0.34$ -  \\ \hline

\textit{affine STN}  & $69.26 \pm 4.53$ & $82.16 \pm 2.30$ & $92.05 \pm 0.58$ & $94.71 \pm 0.22$ & $96.96 \pm 0.20$  \\ \hline

\textit{affine P-STN}  & $\textbf{81.00}  \pm 3.92 $ & $\textbf{92.70} \pm 0.74$  &  $\textbf{96.62} \pm 0.58$  & $\textbf{97.33} \pm 0.17$ & $\textbf{97.63} \pm 0.23$  \\ \hline \hline

\textit{optimal} $w$ & $0.001$  & $0.0003$ & $0.0001$ & $0.00003$  & $0.00001$  \\ \hline
\end{tabular}%}
\caption{The performance of a CNN, STN and P-STN on differently sized MNIST datasets. Bold numbers indicate that a model is significantly better than the runner up under a two sample t-test at $p=0.05$.}
\label{tab:affine_vs_diffeo}
\end{table*}

\subsubsection{Random placement FashionMNIST}

We repeat a similar experiment on the slightly more challenging FashionMNIST dataset \citep{xiao2017fashion} . The CNN baseline accuracy is $90.63\%$ (same model as above with $28$k parameters). We then randomly sample an $x$ and $y$ coordinate and place the FashionMNIST accordingly on a black background, after downscaling it by $50\%$. No rotation is applied, i.e.\@ $\theta_\text{true} = [0, 0.5,  t^x_\text{true}, t^y_\text{true}]$.

% Random placement Fashion MNIST 
\begin{figure}[h!]
\centering
  \includegraphics[width=0.8\columnwidth]{images/UAI_FashionMNIST.pdf}
 \caption{Random Placement Fashion MNIST. Input images (left) and transformed samples $T_{\theta_s}(I)$ as learned by the P-STN. The P-STN learns to correctly pose-normalize and zoom into the relevant part of the image. The samples look like plausible candidates for a data augmentation scheme. We will explore this in Sec.~\ref{sec:exp_classifier}.}
\end{figure}

Like in the previous experiment, both localizers successfully recover $\theta_\text{true}$, with the P-STN ($S=10$ training samples, $w=3e-05$, same classifier as before $+193$k weights in the localizer) doing slightly better than its deterministic counterpart: test accuracies are $84.99\%$ and $84.41\%$, respectively. Inspecting the transformation posterior and the resulting samples $T_{\theta_s}(I_i)$, we find that these look visually pleasing, and, as hypothesized, might be promising candidates for a data augmentation scheme. We will explore this in Sec.~\ref{sec:exp_classifier}. %the following section.  

\subsubsection{Mapillary street signs}\label{sec:exp_street_signs}

Detection and classification of objects in images have many applications, e.g.\@ for autonomous vehicles, detecting traffic signs is crucial. We compare a top-performing classifier, an STN and our P-STN on the challenging Mapillary Traffic Sign Dataset (MTSD)~\citep{mtsd}. 

To focus this comparison, we select images that contain only one traffic sign. We obtain this subset by selecting all bounding boxes that do not intersect with other bounding boxes plus a margin of $150$ px to each side. We further select the ten most common classes from this subset. This gives us a training set of $4698$ images and a test set of $500$ images. Figure~\ref{fig:mtsd} shows example images from the chosen subset.

Our classifier is a ResNet18 pre-trained on ImageNet, where we replace the last fully connected layer. We use the same ResNet for the localizers in the STN and P-STN, where we similarly replace the last layer. As before, we wish to study the behavior of the localizers. Therefore, we again start by training a classifier on the ground-truth bounding boxes. We then initialize the classifier module of the STN and P-STN with this pre-trained classifier and freeze the weights of the classifier. We train the localizers of the STN and P-STN for $60$ epochs with learning rate $10^{-4}$ and kl weight $w=10^{-7}$. Figure \ref{fig:mtsd} shows that the P-STN learns to localize the traffic signs. At test time, we sample $10$ transformations illustrated by the multiple overlaying bounding boxes.


Table \ref{tab:mtsd} shows that both the STN and P-STN clearly outperform the baseline classifier when trained on the full images. Even though the STN and P-STN have exactly the same classifier, the P-STN achieves better performance because of the ensemble of classified transformations.

\subsection{Marginalizing transformations improves classification accuracy} \label{sec:exp_classifier}
We have argued that marginalizing transformations via samples corresponds to learned, localized data augmentations (the samples $T_{\theta_s}(I)$). We will now investigate whether these augmentations are indeed helpful in the downstream task, i.e.\@ whether they improve classification performance. 



\subsubsection{MNIST and subsets} \label{sec:exp_mnist_subsets}


We compare the performance of our P-STN against a standard convolutional neural network (CNN) and a regular STN on MNIST. The standard MNIST images are centered and pose-normalized, so the localization task is easy. Improved classifier performance can thus be viewed as an indicator for having learned a useful data augmentation scheme.

\begin{figure}%[!htp]
\centering
  \includegraphics[width=1\columnwidth]{images/MNIST_subsets.pdf}
 \caption{Performances of P-STN, STN and CNN on MNIST subsets (mean $\pm$ one STD across five folds).}
 \label{fig:MNIST_subsets}
\end{figure}

 Data augmentation is particularly important when training data is scarce, so we evaluate the models on small subsets of MNIST: MNIST30 contains $30$ images (i.e.\@ $3$ per class), MNIST100, MNIST1000, MNIST3000 and MNIST10000. STN and P-STN parameterize affine transformations, i.e.\@ the learned $\theta$ is interpreted as the full affine matrix as described in Sec.~\ref{sec:background}. All models have roughly $28$k parameters, architecture as described at the top of Sec.~\ref{sec:exp}. We use the Adam optimizer with weight decay $0.01$ and the default parameters of its PyTorch implementation. The images are color-normalized. We repeat the experiment $5$ times, each time with a different $k$-image subset of the MNIST dataset, and we report $\pm$ one standard deviation in tables and error bars. From Table~\ref{tab:affine_vs_diffeo} and Fig.~\ref{fig:MNIST_subsets}, we see that the P-STN outperforms both the STN and CNN on the small dataset sizes. For the larger datasets, the differences vanish. This supports our hypothesis: data augmentation is especially useful when data is a limited resource. This intuition is also supported by the optimal KL-weights (Table~\ref{tab:affine_vs_diffeo}, bottom row) that we determine via grid search on validation data. For smaller datasets, larger $w$ and thus more regularization towards the variance prior (away from $0$) are beneficial. 



The fact that the STN performs less well than the standard CNN on this data set might be explained by the fact that the images are already nearly perfectly pose-normalized, and wrong transformations can be detrimental. 

\subsubsection{UCR time-series dataset}\label{sec:exp_timeseries}

\begin{figure*}
\begin{minipage}{\textwidth}
    \begin{minipage}[t]{0.4\textwidth}
     \begin{subfigure}{\textwidth}
         \centering
    \includegraphics[width=0.9\columnwidth, height=3.2cm]{images/timeseries.pdf}%height=3.4cm
    \end{subfigure}    
        \captionof{figure}{Examples of augmentations for a time-series from the FaceAll dataset. The top plot shows the original time-series and the bottom plot shows three augmented versions of the time-series.}\label{fig:ex_timeseries}
    \end{minipage}
  \hspace{0.5cm}
  %\hfill
    \begin{minipage}[t]{0.55\textwidth}
        \centering
      \resizebox{\columnwidth}{!}{
        \begin{tabular}{l| l  l  l}
        \toprule
                & CNN   & STN    & P-STN \\ \midrule
         FaceAll & $80.83 \pm 0.62$              & $82.28 \pm 0.42$ & $\textbf{84.31} \pm 0.75$ \\
         TwoPatterns & $97.92 \pm 0.53$          & $99.79 \pm 0.04$ & $\textbf{99.96} \pm 0.04$ \\ 
         wafer &  $\textbf{99.63} \pm  0.05$              & $99.18 \pm 0.17$ - & $98.86 \pm 0.20$ \\ 
         % StarLightCurvesX& $0.9331 \pm 0.0372$      & $0.9621 \pm 0.0238$ - & - \\ 
         uWaveGestureLib.* & $74.15 \pm 1.27$    & $79.77 \pm 0.42$ - & $\textbf{81.13} \pm 0.46$ \\
         PhalangesOutlC.** & $79.88 \pm 1.32$    & $\textbf{82.26} \pm 0.98$ & $\textbf{81.66} \pm 0.59$ \\ 
        \midrule
        Mean & $86.48$ & $88.65$ & $\textbf{89.18}$ \\ 
        \bottomrule
        \end{tabular}}
         \captionof{table}{Accuracies on a subset of the UCR timeseries dataset (full dataset names are *uWaveGestureLibrary and **PhalangesOutlinesCorrect). $\pm1$ STD is reported after $5$ repetitions. Bold numbers indicate that a model is significantly better than the runner up under a two sample t-test at $p=0.05$.}
        \label{tab:timeseries}
        \end{minipage}
  \end{minipage}
\end{figure*}


For some data modalities, such as time-series, it is not trivial to craft a useful data augmentation scheme. In this experiment, we show that the P-STN can learn a useful, non-trivial data augmentation scheme that increases performance compared to a standard STN on time-series data. The UCR dataset \citep{UCRArchive2018} is composed of $108$ smaller datasets, where each dataset contains univariate time-series. The FordA dataset, for example, contains measurements of engine noise over time and the goal is to classify whether or not the car is faulty. We select $5$ of those subsets, each large enough to divide into training and validation sets ($75/25\%$), which we use to find the optimal $w$ via grid-search; those are $[0.0001, 1e-05, 0.001, 0.0, 0.0001]$. We draw $S=10$ training samples. The test-set is pre-defined by the dataset curators. Learning rate and optimizer are the same as in Sec.~\ref{sec:exp_mnist_subsets}, but we do not perform normalization. All models have approximately one million parameters. Table~\ref{tab:timeseries} shows that the P-STN achieves higher mean accuracy than both the STN and the CNN, indicating that we can automatically learn a useful data augmentation scheme for time-series. 

We verify this qualitatively in Fig.~\ref{fig:ex_timeseries}, which shows an example of the learned data augmentation. We see that the model does not simply apply a global transformation, but learns to augment the time-series more in some intervals, such as in $[60; 110]$, and augment the time-series less in other intervals, such as in $[0; 50]$. 

\begin{figure}
    \centering
    \includegraphics[width=1\columnwidth]{images/MNIST100_calibration.pdf}
    \caption{Calibration plots for CNN, STN and two P-STN models. One with KL-weight yielding optimal performance ($w=0.0003$) and one with  KL-weight yielding optimal calibration ($w=0.0001$). Both P-STN models are better calibrated than CNN and STN.}
    \label{fig:classifier_calibration}
\end{figure}


\subsection{Marginalizing transformations improves calibration} \label{sec:exp_calibration}
In Sec. \ref{sec:exp_localiser}, we have seen that harder images on average have larger transformation uncertainties. We now investigate whether those meaningful localization uncertainties translate into meaningful uncertainties downstream, i.e.\@ in the calibration of our classifier. 
At test-time, we evaluate 
\begin{align}
\hspace{-0.2cm} 
 p(y|I) &= \int p(y|I, \theta)q(\theta) \mathrm{d}\theta  
 \approx \frac{1}{S} \sum_{s=1}^S p(y | T_{\theta_{s}} (I)).
\end{align}
We will investigate how well the uncertainty in this distribution matches the quality of predictions. Fig.~ \ref{fig:classifier_calibration} shows a calibration plot for the MNIST100 subset classification task from Sec. \ref{sec:exp_mnist_subsets} for the CNN, STN and P-STN for two different $w$-parameters; $w=0.0003$ yields the best performance (reported in Table~\ref{tab:affine_vs_diffeo}) and $w=0.0001$ yields the best calibration. The expected calibration errors \citep{guo2017calibration, Kueppers_2020_CVPR_Workshops, Kueppers_2021_IV} are 
CNN:  $0.0743 \pm 0.0094$, STN:  $0.1160 \pm 0.0205$,  P-STN, $w=0.0003$ (optimal performance model): $0.0567 \pm 0.0065$,  P-STN, $w=0.0001$ (optimal calibration model): $\textbf{0.0271} \pm 0.0088$. We report the mean over 5 folds, $\pm$ one STD. The P-STN significantly improves calibration in the downstream classification task. 
\begin{figure*}
\centering
\begin{subfigure}{\textwidth}
 \includegraphics[width=\textwidth]{images/rebuttal_sizeplot.pdf}
\end{subfigure}
 \caption{\textit{Left:} Test accuracies for standard NN and (P-)STNs of different depths trained on rotated MNIST, as well as NN baseline on original MNIST (black). The STN (green) model does not usually recover the original images and thus behaves more like a standard NN (red) in most runs. P-STN (blue) un-transforms at least some of the rotations and is closer in accuracy to the NN on original MNIST (black). \textit{Right:} The variance of the learned transformations as a function of model depth. The STN learns the identity for deeper downstream models (this is consistent with the test accuracies we see on the left). P-STN learns to un-transform better, at least when the classifier is simple. For bigger classifiers it predicts the identity transform as well, but performs relatively well nonetheless (see left panel). We report medians $\pm 1$ median absolute deviation \@ over $5$ folds.} \label{fig:pstn_failure_mode_exp}
\end{figure*}

\subsection{A Typical Failure Mode in STNs} 
STNs are trained end-to-end, and with only label information available. Thus, the aim is to learn the optimal transformation for solving the downstream task. Depending on the complexity of the downstream task and the classification model, it might not be necessary to transform the input at all, i.e.\@ it might be possible to solve the downstream task on the original input image. Indeed, this is a failure mode we observe in practice --- often, the localizer simply learns the identity transform while the classifier learns to classify the non-transformed image.
Using more complex classifier architectures makes the STN more prone to this failure mode. This has been observed by other authors \citep{finnveden2021understanding}, and we investigate the problem in the experiment in Fig.~\ref{fig:pstn_failure_mode_exp}. We start by training differently-sized neural networks on MNIST (black, one layer on the $x$-axis is \texttt{[Linear, ReLU, Dropout]}). We compare the performance of this model with (P-)STN models trained on rotated MNIST, test accuracies are plotted in the left panel of the figure. If the localization task is performed perfectly, the (P-)STN models should be able to recover the accuracy on the original, non-rotated dataset. 
In the right panel, we plot the variance of the (mean) transformations learned by the (P-)STN models. Values close to $0$ indicate that the localizer does not transform the image, i.e.\@ it learns the identity transform. Larger values indicate that the localizer learns transformations. Median results are reported over $5$ runs, error bars correspond to one mean absolute deviation.  As hypothesized, for larger classifiers the localizers do not transform the images. Due to the increased capacity of the model, we nonetheless achieve decent classification accuracies (left panel). The P-STN learns to localize the rotated images somewhat successfully (large variance in the right panel, and high accuracy on the left) for smaller classifiers. The STN does not localize the images as well, most runs behave like the standard NN on rotMNIST (red), predicting identity transformations only. We conclude that, thanks to it `trying out multiple transformations', the P-STN avoids this failure mode to an extent. We also note that this property, while useful, is somewhat orthogonal to our interest in this work, and we have avoided the failure mode in the experiments of Sec.~\ref{sec:exp_localiser} by considering models with \textit{fixed}, pre-trained classifiers. 



\section{Conclusion}


We have introduced a probabilistic extension to the spatial transformer network (STN) \citep{jaderberg2015spatial}. Our work took motivation from the empirical observation that the STN is often brittle to train, as a poorly predicted transformation may prevent the model from getting any gradient signal, resulting in divergent optimization. Our probabilistic STN (P-STN) instead approximates the posterior distribution of transformations using amortized variational inference, and marginalizes accordingly. As is common, marginalization improves the robustness of the model.

Empirically, we note the following advantages of the probabilistic formulation over the deterministic. Firstly, the performance of the localization network is improved, since the Monte Carlo marginalization effectively amounts to trying many different transformations. Secondly, the probabilistic formulation improves the overall model performance, since the sampled transformations act as data augmentation both during training and during testing. The resulting ensemble of predictions is more accurate and better calibrated than common classifiers as well as the original spatial transformer. \looseness=-1


% \section{Back Matter}
% There are a some final, special sections that come at the back of the paper, in the following order:
% \begin{itemize}
%   \item Author Contributions
%   \item Acknowledgements
%   \item References
% \end{itemize}
% They all use an unnumbered \verb|\subsubsection|.

% For the first two special environments are provided.
% (These sections are automatically removed for the anonymous submission version of your paper.)
% The third is the ‘References’ section.
% (See below.)

% (This ‘Back Matter’ section itself should not be included in your paper.)

% \begin{contributions} % will be removed in pdf for initial submission,
%                       % so you can already fill it to test with the
%                       % ‘accepted’ class option
%     Briefly list author contributions.
%     This is a nice way of making clear who did what and to give proper credit.

%     Pola Schwöbel co-developed the model, wrote the paper and carried out most of the experiments. 
%     Frederik Warburg carried out the experiment 5.1.3 and wrote large parts of the code. 
%     Martin Jørgensen co-developed the model and some of the figures. 
% Kristoffer H. Madsen co-developed the idea. 
%     Søren Hauberg conceived the idea and advised during execution. 
% \end{contributions}

\begin{acknowledgements} % will be removed in pdf for initial submission,
MJ was supported by a research grant from the Carlsberg Foundation (CF20-0370). SH was supported by research grants (15334, 42062) from VILLUM FONDEN. This project has also received funding from the European
Research Council (ERC) under the European Union’s Horizon 2020 research and innovation programme (grant agreement No. 757360). This work was funded in part by the Novo Nordisk Foundation through the Center for Basic Machine Learning Research in Life Science (NNF20OC0062606).

\end{acknowledgements}

\bibliography{paper}

\appendix


\end{document}