\documentclass{midl} % Include author names
%\documentclass[anon]{midl} % Anonymized submission

% The following packages will be automatically loaded:
% jmlr, amsmath, amssymb, natbib, graphicx, url, algorithm2e
% ifoddpage, relsize and probably more
% make sure they are installed with your latex distribution

\usepackage{mwe} % to get dummy images
\usepackage{soul}
\usepackage{float}
\usepackage{booktabs}
\usepackage{graphicx}

%\jmlrvolume{-- Under Review}
\jmlryear{2020}
\jmlrworkshop{Full Paper -- MIDL 2020}
%\editors{Under Review for MIDL 2020}

\newcommand{\avd}[1]{\color{red}[AVD: {#1}]\normalcolor}

\title[Segmentation Auto-Encoder]{An Auto-Encoder Strategy for Adaptive Image Segmentation}

 % Use \Name{Author Name} to specify the name.
 % If the surname contains spaces, enclose the surname
 % in braces, e.g. \Name{John {Smith Jones}} similarly
 % if the name has a "von" part, e.g \Name{Jane {de Winter}}.
 % If the first letter in the forenames is a diacritic
 % enclose the diacritic in braces, e.g. \Name{{\'E}louise Smith}

 % Two authors with the same address
 % \midlauthor{\Name{Author Name1} \Email{abc@sample.edu}\and
 %  \Name{Author Name2} \Email{xyz@sample.edu}\\
 %  \addr Address}

 % Three or more authors with the same address:
 % \midlauthor{\Name{Author Name1} \Email{an1@sample.edu}\\
 %  \Name{Author Name2} \Email{an2@sample.edu}\\
 %  \Name{Author Name3} \Email{an3@sample.edu}\\
 %  \addr Address}


% Authors with different addresses:
% \midlauthor{\Name{Author Name1} \Email{abc@sample.edu}\\
% \addr Address 1
% \AND
% \Name{Author Name2} \Email{xyz@sample.edu}\\
% \addr Address 2
% }

%\footnotetext[1]{Contributed equally}

% More complicate cases, e.g. with dual affiliations and joint authorship
\midlauthor{\Name{Evan M. Yu\nametag{$^{1}$}} \Email{emy24@cornell.edu}
\AND
\Name{Juan Eugenio Iglesias \nametag{$^{2,3,4}$}} \Email{ e.iglesias@ucl.ac.uk}
\AND
\Name{Adrian V. Dalca \nametag{$^{2,3}$}} \Email{adalca@mit.edu}
\AND
\Name{Mert R. Sabuncu \nametag{$^{1,5}$}} \Email{msabuncu@cornell.edu}\\
\addr $^{1}$ Nancy E. and Peter C. Meinig School of Biomedical Engineering, Cornell University \\
\addr $^{2}$ Martinos Center for Biomedical Imaging, Massachusetts General Hospital, Harvard Medical School \\ 
\addr $^{3}$ Computer Science and Artificial Intelligence Laboratory (CSAIL), MIT \\
\addr $^{4}$ Centre for Medical Image Computing, University College London \\
\addr $^{5}$ School of Electrical and Computer Engineering, Cornell University
}
\begin{document}

\maketitle

\begin{abstract}
Deep neural networks are powerful tools for biomedical image segmentation. 
These models are often trained with heavy supervision, relying on pairs of images and corresponding voxel-level labels. 
However, obtaining segmentations of anatomical regions on a large number of cases can be prohibitively expensive.
%Furthermore, models trained with heavy supervision are often sensitive to shifts in image characteristics, for instance, due to a routine upgrade in scanner software.
Thus there is a strong need for deep learning-based segmentation tools that do not require heavy supervision and can continuously adapt.  
In this paper, we propose a novel perspective of segmentation as a discrete representation learning problem, and present a variational autoencoder segmentation strategy that is flexible and adaptive. 
Our method, called Segmentation Auto-Encoder (SAE), leverages all available unlabeled scans and merely requires a segmentation prior, which can be \textit{a single unpaired} segmentation image. 
In experiments, we apply SAE to brain MRI scans.
Our results show that SAE can produce good quality segmentations, particularly when the prior is good. We demonstrate that a Markov Random Field prior can yield significantly better results than a spatially independent prior. %\avd{comment on code availability?}
Our code is freely available at~\url{https://github.com/evanmy/sae}.
\end{abstract}

\begin{keywords}
Image Segmentation, Variational Auto-encoder
\end{keywords}

\section{Introduction}

Quantitative biomedical image analysis often builds on a segmentation of the anatomy into regions of interest (ROIs). 
Recently, deep learning techniques have been increasingly used in a range of segmentation applications~\cite{seg_rev,litjens2017survey,unet,kamnitsas2017efficient}.
These methods often rely on a large number of \textit{paired} scans and segmentations (voxel-level labels) to train a neural network. 
Training labels are either generated by human experts, which can be costly and/or hard to scale, or automatic software~\cite{dolz20183d}, which can constrain performance.
% An important drawback of the supervised approach is 
% Currently, the gold standard to obtain annotations is the manual delineation of the ROIs. 
% Unfortunately, this is a time consuming and costly process that requires expertise and thus does not scale easily.
% An alternative approach is to use weak supervision, for example, leveraging image segmentations obtained with a third-party automatic segmentation tool, as in~\cite{dolz20183d}.
% The main drawback of this approach is that the final performance will be constrained by the tool used to generate the training labels.
%Coupled with the the growing number of imaging and segmentation protocols, manual annotation of every new dataset becomes prohibitively expensive. 
Furthermore, supervised techniques typically yield tools that are sensitive to changes in image characteristics, for instance, due to a modification of the imaging protocol~\cite{amod}.
This is a significant obstacle for the widespread clinical adoption of these technologies. 
 
One approach to improve robustness and performance is to relax the dependency on paired training data and simply use unpaired examples of segmentations, sometimes called ``atlases.'' 
Building on unpaired atlases, a segmentation model can then be trained continuously on new sets of unlabeled images \cite{dalca1,dalca2,joyce}. For example, recently Dalca \textit{et al.} \cite{dalca1} proposed an approach where an autoencoder is pre-trained on \textit{thousands} of unpaired atlases.
For a new set of unlabeled images, the encoder is then re-trained via an unsupervised strategy. 
% Another related method was recently presented in~\cite{dalca2}, where a spatial transformer network registers a probabilistic atlas prior to an input image, which is then fed to a network that solves a conventional inverse inference problem.
% These prior methods, however, did not consider the scenario of very few unpaired atlases, which is a primary focus of this paper.
Another widely-used approach to improve generalizability is data augmentation on labeled training data~\cite{zhao,chaitanya2019semi}. 
For example, \citealt{zhao} demonstrated an adaptive approach that learns an augmentation model on a dataset of unlabeled images. This model was then applied to augment a single paired atlas to perform one-shot segmentation within a supervised learning framework. Another popular approach is to use registration to propagate atlas labels to a test image~\cite{sabuncu2010generative,lee2019few}. 
%While classical image registration methods require significant runtime, recent fast deep-learning based registration methods have been demonstrated in the context of brain segmentation~\cite{lee2019few}. 

In this paper, we present a novel perspective for \textit{minimally supervised} image segmentation.
Instead of viewing segmentation from the lens of supervised learning or inverse inference, we regard it as a discrete representation learning problem, which we solve with a variational autoencoder (VAE) like strategy~\cite{vae}. We call our framework Segmentation Auto-encoder, or SAE.
As we demonstrate below, SAE is flexible and can leverage \textit{all} available data, including unpaired atlases and unlabeled images. We show that we can train a good segmentation model using SAE with as little as a \textit{single unpaired} atlas.
In conventional representation learning, e.g., VAE~\cite{vae},
an encoder maps an input to a continuous latent representation, which often lacks interpretability. 
In contrast, in SAE, the encoder computes a discrete representation that is a segmentation image, which is guided by an atlas prior.
%We encourage segmentations to be close to a provided set of atlases via a prior on the representations.
%We present results with the classic Markov Random Field segmentation.
%Furthermore, we formulate a tractable loss function via a variational lower bound objective.
 Finally, we employ the Gumbel-softmax relaxation~\cite{gumbel} to train the SAE network. The Gumbel-softmax approximates the non-differentiable argmax (tresholding) operation with a softmax in order to make the function differentiable. It provides us with a simple and efficient way to perform the reparameterization trick for a categorical distribution, allowing the network to be trained via back-propagation.
 %similar to a recent paper that extracted motifs in video data~\cite{lemonade}.
In our experiments, we demonstrate that SAE produces high quality segmentation maps, even with a single unpaired atlas. We also quantify the boost in performance as we exploit richer prior models.
For example, a Markov Random Field model yields significantly better results than a spatially independent prior.
% We exploit 
% Working with discrete random variable is a difficult task since the backpropagation algorithm requires gradient computation. REINFORCE \cite{estimator1, estimator2} can be used as a gradient estimator, but it suffers from high variance \cite{estimator3}. Some work has focused on applying reparameterization trick to marginal distribution constructed from continuous functions which are parameterized by binary variables \cite{dvae1,dvae2,dvae3}. Alternatively, Jang \textit{et at.} \cite{gumbel} and Maddison \textit{et al.} \cite{concrete} proposed the Gumbel-Softmax to approximate categorical samples. The latter method has been used to train a VAE to extract motifs in video \cite{lemonade}. Using the Gumbel trick, we demonstrate that it is possible to obtain brain segmentation from one or few labeled example through a VAE with discrete representation. 
\section{Method}
We consider a dataset of $N$ observed images (e.g. MRI scans) $\{\boldsymbol{x}^{(i)}\}_{i=1}^{N}$, which we model as independent samples from the same distribution. 
%In the rest of the paper, we will restrict our attention to 3D brain MRIs, each with $V$ voxels. Yet, the proposed framework is applicable to any type of image.
Let $\boldsymbol{s}$  denote the (latent) segmentation, where
each voxel is assigned a unique discrete anatomical label. 
%In this work, we view the segmentation image as a discrete representation.
Using Bayes' rule:
%
\begin{equation}
    \log p(\boldsymbol{x}^{(i)}) = \log \sum_{\boldsymbol{s}} p( \boldsymbol{x}^{(i)}|\boldsymbol{s}) p(\boldsymbol{s}),
    \label{eq:logl1}
\end{equation}
where $p(\boldsymbol{s})$ denotes a prior distribution on the segmentation, $p( \boldsymbol{x}^{(i)}|\boldsymbol{s})$ is the posterior probability of the observed image conditioned on the latent segmentation, often called the image likelihood, and the sum is over all possible values of $\boldsymbol{s}$. 
%For simplicity, we have restricted our attention to the log-likelihood of a single sample $\boldsymbol{x}^{(i)}$ (i.e., the $i$'the image in our dataset), leveraging our independence assumption. 
We assume the prior $p(\boldsymbol{s})$ is provided and ``learning'' involves finding the parameters that describe the image likelihood $p( \boldsymbol{x}^{(i)}|\boldsymbol{s})$.
%Our approach is Bayesian and we thus marginalize over latent $\boldsymbol{s}$. 
Since Eq.~\eqref{eq:logl1} is computationally intractable for most practical scenarios, we follow the classical variational strategy and maximize the evidence lower bound objective (ELBO): 
\begin{equation}
\log p(\boldsymbol{x}^{(i)}) \geq  -\textup{KL}(q(\boldsymbol{s}| \boldsymbol{x}^{(i)}) || p (\boldsymbol{s})) +  \displaystyle \mathop{\mathbb{E}}_{\boldsymbol{s} \sim  q(\boldsymbol{s}| \boldsymbol{x}^{(i)})} \log p(\boldsymbol{x}^{(i)} | \boldsymbol{s}), \label{eq:ELBO}
\end{equation}
where $\textup{KL}(\cdot || \cdot)$ denotes the KL-divergence and $q(\boldsymbol{s}| \boldsymbol{x}^{(i)})$ is an efficient-to-manipulate distribution that approximates the true posterior $p(\boldsymbol{s}| \boldsymbol{x}^{(i)})$. 
%Once training is complete, we can use $q(\boldsymbol{s}| \boldsymbol{x}^*)$ to estimate the segmentation of a query image $\boldsymbol{x}^*$.

Following the VAE~\cite{vae} framework, we use two neural networks to compute the approximate posterior $q(\cdot|\cdot)$ and the image likelihood $p(\cdot|\cdot)$. A so-called encoder network computes the approximate posterior $q_{\phi}(\boldsymbol{s}| \boldsymbol{x})$, where $\phi$ denotes the parameters of the encoder. 
The image likelihood $p_{\theta}(\boldsymbol{x}| \boldsymbol{s})$ is computed by the a decoder network, parameterized by $\theta$.
In our formulation, the encoder can be viewed as a segmentation network.
%, in contrast to its common use of also compressing the input data. 
The decoder corresponds to a generative or ``reconstruction'' model that describes the process of creating an observed image from an underlying segmentation.

A natural choice for the approximate posterior is a voxel-wise independent model:
\begin{equation}
\begin{aligned}
 q_{\phi}(\boldsymbol{s}| \boldsymbol{x}^{(i)}) = \prod_{j=1}^{V}  \textup{Cat}(s_{j} | \boldsymbol{x}^{(i)}, \phi),
\end{aligned} \label{eq:enc}
\end{equation}
where $\textup{Cat}(s_{j} | \boldsymbol{x}^{(i)}, \phi)$ is a categorical distribution computed as the soft-max output of the encoder network at the $j^\text{th}$ voxel evaluated for label $s_{j}$.
Assuming an additive Gaussian noise likelihood model:
\begin{equation}
\begin{aligned}
p_{\theta}(\boldsymbol{x}| \boldsymbol{s}) = 
\prod_{j=1}^{V} 
\mathcal{N}(\boldsymbol{x};\hat{\boldsymbol{x}}_j(\boldsymbol{s}; \theta), \sigma^2),
\end{aligned}
\label{eq:dec}
\end{equation}
where $\hat{\boldsymbol{x}}(\boldsymbol{s}; \theta)$ is a ``reconstruction'' image computed by the decoder network, sub-script $j$ is the voxel index, and $\mathcal{N}(\cdot; \mu, \sigma^2)$ denotes a Gaussian with mean $\mu$ and variance $\sigma^2$. 

Putting together Eq.~\eqref{eq:ELBO} and \eqref{eq:dec} and relying on Monte Carlo sampling to approximate the expectation, we obtain the following loss function to be minimized over $\theta$ and $\phi$:
%
\begin{equation}
     \mathcal{L} =  \sum_{i=1}^N  \textup{KL}(q_{\phi}(\boldsymbol{s}| \boldsymbol{x}^{(i)}) || p (\boldsymbol{s})) + \frac{V}{2}\log \sigma^{2} 
      + \frac{1}{2 \sigma^{2} K}\sum_{k=1}^{K} ||\boldsymbol{x}^{(i)} - \boldsymbol{\hat{x}}(\boldsymbol{s}_{ik}; \theta)||^{2}_{2}, 
     \label{eq:loss}
\end{equation}
%
where $\boldsymbol{s}_{ik}$ is an independent sample segmentation image drawn from $q_{\phi}(\boldsymbol{s}| \boldsymbol{x}^{(i)})$. 
Following the convention in the field, in practice we set $K=1$, which yields an unbiased yet noisy estimate of the loss and its gradient.
% In addition, we have one or few $\{\boldsymbol{s}_{prior}^{(i)}\}_{i}^{M}$ set of segmentation with no corresponding MRI scan that serves as our prior. The MRI volume $\boldsymbol{x}$ is assume to be generated from the segmentation map $\boldsymbol{s}$. Our goal is to maximize the log-likelihood for all our observation $N$: \\ 
% \begin{equation}
%     \log p(\boldsymbol{x}^{(i)}) = \int p( \boldsymbol{x}^{(i)}|\boldsymbol{s}) p(\boldsymbol{s}) d\boldsymbol{s}
% \end{equation}
% We need to marginalize over $\boldsymbol{s}$ in order to compute this term, which is intractable. Instead, we can maximize the evidence lower bound (ELBO) shown in eq.(2). In the VAE \cite{vae} framework, we use an encoder to output approximate posterior $q_{\phi}(\boldsymbol{s}| \boldsymbol{x})$ that is parameterized by $\phi$. At the same time, we have a decoder $\log p_{\theta}(\boldsymbol{x}| \boldsymbol{s})$ with generative parameter $\theta$.
% \begin{equation}
% \begin{aligned}
% \log p(\boldsymbol{x}^{(i)}) &\geq  -\textup{KL}(q_{\phi}(\boldsymbol{s}| \boldsymbol{x}^{(i)}) || p (\boldsymbol{s})) + \displaystyle \mathop{\mathbb{E}}_{\boldsymbol{s} \sim  q_{\phi}(\boldsymbol{s}| \boldsymbol{x}^{(i)})} \left [ \log p_{\theta}(\boldsymbol{x}^{(i)} | \boldsymbol{s}) \right] 
% \end{aligned}
% \end{equation}
% We want the latent representation $\boldsymbol{s}$ to represent the segmentation of the brain. Therefore, a natural choice is to have a categorical distribution for each latent variable eq.(3). 
Eq.~\ref{eq:loss} does not explicitly require paired images and segmentations $\{\boldsymbol{x}^{(i)}, \boldsymbol{s}^{(i)}\}$. 
Instead, it merely needs a prior $p(\boldsymbol{s})$. 
There are many ways to define a prior, but in our experiments we use a classical construction: a probabilistic atlas that describes the probability of labels at each location, which can be coupled with a Markov random field component that encourages certain topological arrangements. 
%We describe this prior in detail below.

\begin{figure}
    \centering
    \includegraphics[width=0.77\columnwidth]{figures/architecture2.png}
    \caption{\textbf{Proposed architecture}. The encoder (blue) is a U-Net and decoder (green) is a simple CNN. (Conv) 3x3x3 convolution (Relu) rectified linear unit (Maxpool) 2x downsample (Up) 2x upsample (ST Gumbel) straight through Gumbel softmax (Sigm) sigmoid. The number of channels are displayed below each layer.}%\avd{It seems somehow incomplete to have q() but not p() in the figure?} }
    \label{fig:architecture}
\end{figure}
\subsection{Spatial Prior}
The first prior we consider is a probabilistic atlas that assigns an independent label probability vector at each voxel, $p_j$. We call this a spatial prior:
\begin{equation}
p_{spatial}(\boldsymbol{s}) = \prod_{j=1}^V p_j(s_{j}). \label{eq:spatial}
\end{equation}
There are many ways to construct this type of prior.
For example, we can aggregate segmentations of different subjects and compute the  frequency  of  anatomical labels at each voxel. 
If instead we only have a single segmentation image, we can apply a spatial blur to this segmentation in order to account for inter-subject variation. 
With the spatial prior, the first term in Eq.~\eqref{eq:loss} reduces to:
%
\begin{equation}
\textup{KL}(q_{\phi}(\boldsymbol{s}| \boldsymbol{x}^{(i)}) || p_{spatial} (\boldsymbol{s})) = \sum_{j=1}^{V} H({\textup{Cat}(s_{j} | \boldsymbol{x}^{(i)})}, p_j(s_{j})) - H(\textup{Cat}(s_{j} | \boldsymbol{x}^{(i)})) 
\label{kl1}
\end{equation}
where the first term denotes cross-entropy and second term is marginal entropy.


\subsection{Markov Random Field Prior}
The spatial prior can be modified using a Markov Random Field (MRF) to capture neighborhood relationships in a segmentation image. Following~\cite{zhang2001segmentation, fischl2002whole}, we define the MRF prior as:
%
\begin{equation}
    %p_{MRF}(\mathbf{s}) = \frac{1}{Z} \exp \left [ \sum_{j=0}^{V} \log p(s_{j}) + \sum_{j=0}^{V} \sum_{y \in \mathcal{N}_{j}} \log p(s_{y} | s_{j}) \right ]
    p_{MRF}(\mathbf{s}) = \frac{1}{Z} \exp \left [ \sum_{j=0}^{V} V_j(s_{j}) + \sum_{j=0}^{V} \sum_{k \in N_{j}} V(s_{k}, s_{j}) \right ]
\label{p_mrf}
\end{equation}
%
where $N_{j}$ is the $3 \times 3 \times 3$-neighborhood around voxel $j$, $V_j(\cdot)$ is the unitary potential at voxel $j$, $V(\cdot, \cdot)$ is the pairwise clique potential, and $Z$ is a normalization constant. Similar to~\cite{fischl2002whole}, we define these potential functions based on a provided probabilistic atlas. 
Specifically, $V_j$ is the voxelwise log frequency of each label: $\log p_j$; and $V(\cdot, \cdot)$ is the log normalized counts of label co-occurrences in neighboring voxels. E.g., $V(l_1, l_2)$ is computed as the logarithm of the count of neighboring voxel pairs with labels $l_1$ and $l_2$ divided by the count of voxels with label $l_2$. If the pairwise potential is set to zero, the MRF prior reduces to the spatial prior.
With the MRF prior, the first term in Eq.~\eqref{eq:loss} becomes:
%
\begin{equation}
\textup{KL}(q_{\phi}(\boldsymbol{s}| \boldsymbol{x}^{(i)}) || p_{MRF} (\boldsymbol{s})) = \textup{KL}(q_{\phi}(\boldsymbol{s}| \boldsymbol{x}^{(i)}) || p_{spatial}(\boldsymbol{s})) + \mathcal{L}_{MRF} + \textrm{const.},
%\begin{split}
%   D_{KL}(q_{\phi}(\mathbf{s}|\mathbf{x})||p_{MRF}(\mathbf{s})) = \mathop{\mathbb{E}}_{\mathbf{s} \sim q_{\phi}(\mathbf{s}|\mathbf{x}^{(i)})} \log q_{\phi}(\mathbf{s}|\mathbf{x}^{(i)})  - \mathop{\mathbb{E}}_{\mathbf{s} \sim q_{\phi}(\mathbf{s}|\mathbf{x}^{(i)})} \sum_{j=0}^{V} \log p_j(s_{j}) \\ - \mathop{\mathbb{E}}_{\mathbf{s} \sim q_{\phi}(\mathbf{s}|\mathbf{x}^{(i)})} \sum_{j=0}^{V} \sum_{y \in \mathcal{N}_{j}} V(s_{y} | s_{j}) + \textrm{const}.
%\end{split}
\label{kl_mrf}
\end{equation}
%
where the first term is from Eq.~\eqref{kl1}, and the second term can be expressed as:
%
\begin{equation}
\mathcal{L}_{MRF} = - \sum_{j=0}^{V} \left ( \sum_{l_{j}=0}^{L-1}   q_{j}(l_{j}|\mathbf{x}^{(i)}) \sum_{l_{k}=0}^{L-1} \sum_{k \in \mathcal{N}_{j}} q_{y}(l_{k}|\mathbf{x}^{(i)})  V(s_{k}= l_{k}, s_{j} = l_{j})  \right). 
\label{eq:MRF}
\end{equation}
% We are interested in the last term in Eq.\ref{kl_mrf}. We defined the spatial consistency loss as:
% \begin{equation}
% \begin{split}
%     \mathcal{L}_{mrf} &= - \sum_{j=0}^{V} \sum_{y \in \mathcal{N}_{j}} \mathop{\mathbb{E}}_{\mathbf{s} \sim q_{\phi}(\mathbf{s}|\mathbf{x}^{(i)})} \log p(s_{y} | s_{j}) + 
% \end{split}
% \end{equation}
% %&=  - \sum_{j=0}^{V} \left ( \sum_{l_{j}=0}^{L-1}  \sum_{l_{y}=0}^{L-1} q_{j}(l_{j}|\mathbf{x}^{(i)})  \sum_{y \in \mathcal{N}_{j}} q_{y}(l_{y}|\mathbf{x}^{(i)})  \log p(s_{y}= l_{y}| s_{j} = l_{j})  \right)
% We can consider the neighbor voxels $y$ around $j$ at a time. Each of them has approximate posterior probability of $q_{y}$ and $q_{j}$, respectively.
% \begin{equation}
%     \begin{split}
%         \mathcal{L}_{mrf} &= - \sum_{j=0}^{V} \sum_{y \in \mathcal{N}_{j}} \mathop{\mathbb{E}}_{\mathbf{s} \sim q_{j}q_{y}} \log p(s_{y} | s_{j}) \\&=  - \sum_{j=0}^{V} \sum_{y \in \mathcal{N}_{j}} \left ( \sum_{l_{j}=0}^{L-1}  \sum_{l_{y}=0}^{L-1} q_{j}(l_{j}|\mathbf{x}^{(i)})  q_{y}(l_{y}|\mathbf{x}^{(i)})  \log p(s_{y}= l_{y}| s_{j} = l_{j})  \right) \\ 
%     &=  - \sum_{j=0}^{V} \left ( \sum_{l_{j}=0}^{L-1}   q_{j}(l_{j}|\mathbf{x}^{(i)}) \sum_{l_{y}=0}^{L-1} \sum_{y \in \mathcal{N}_{j}} q_{y}(l_{y}|\mathbf{x}^{(i)})  \log p(s_{y}= l_{y}| s_{j} = l_{j})  \right)
%     \end{split}
% \end{equation}
% The final loss function for the model with MRF has the same form as in Eq.\ref{eq:loss} but with an additional term $\beta \mathcal{L}_{mrf}$. The constant $\beta$ weights the importance between the neighboring consistency, the reconstruction term and the closeness to the prior $p(\mathbf{s})$.
The MRF loss term quantifies the dissimilarity between the label topology of the prior and the approximate posterior $q(\cdot|\cdot)$.  



\subsection{Implementation Details}
Our SAE architecture is shown in Fig.~\ref{fig:architecture}. The encoder is a 3D U-Net~\cite{unet} and the decoder is a simple fully convolutional network. 
Training involves optimizing Eq.~\eqref{eq:loss} with back-propagation.
To implement the sampling layer, we employed the straight-through Gumbel-softmax relaxation scheme~\cite{gumbel, maddison2016concrete}, with the recommended setting for the temperature $\tau$ to $2/3$.
We estimated $\sigma^{2}$ by using the the global mean square error (MSE) between the reconstructed scan $\boldsymbol{\hat{x}}$ and the input scan $\boldsymbol{x}^{(i)}$. To initialize $\sigma^{2}$, we set the weight on the the reconstruction loss to be zero for the first 16 subjects (effectively setting $\sigma^{2}$ to infinity) so that the segmentation (encoder) network was trained only based on the prior.  
In subsequent batches, $\sigma^{2}$ was updated as the average MSE over the latest 16 subjects and rounded to the nearest power of 10 in order to reduce fluctuation. Our complete model is trained end-to-end with the ADAM optimizer \cite{adam}, with a learning rate of $10^{-4}$ and default parameter for its first and second moments. At test time, segmentation involves a computationally efficient single forward pass through the encoder and we output the \texttt{argmax} label at each voxel. Our code in PyThorch is available at~\url{https://github.com/evanmy/sae}.
%\begin{figure}
%    \centering
%    \includegraphics[width=0.75\columnwidth]{figures/placeholder.png}
%    \caption{(a) Probability $\log p(s_{y}|s_{j})$ for the MRF model.  Computed on the prior $p(\boldsymbol{s})$ using a neighbor of 3 voxels around $j$. (b) Performance measure of different models. Hausdorff distance was computed by measuring boundary distances between prediction and target of all brain regions. Then the 95th percentile of the largest distance was used for each subject in order to increase robustness to noise.}
%    \label{fig:summary}
%\end{figure}

% In the end, our loss function is given by:
% \begin{equation}
%      \mathcal{L}(\theta, \phi; \boldsymbol{x}^{(i)}) = \textup{KL}(q_{\phi}(\boldsymbol{s}| \boldsymbol{x}^{(i)}) || p (\boldsymbol{s})) + \frac{\mathcal{V}}{2}\log \sigma^{2} + \frac{1}{2\sigma^{2}}||\boldsymbol{x}^{(i)}- \boldsymbol{\hat{x}}||^{2}_{2}
% \end{equation}
% We tested our model in two different scenario and agaisnt a baseline:\\ \\




\section{Experiments}


\subsection{Dataset}
%\avd{Sorry, I've reread this a couple of times to try to sense why you need this here, But it really feels out of place, and I think Datasets, Benchmarks and SAE variants should be under "Experiments" or something}
We evaluated SAE on T1-weighted 3D brain MRI scans, which we preprocessed with FreeSurfer, including skull stripping, bias-field correction,
intensity normalization, affine registration to Talairach space, and resampling to
1 $mm^3$ isotropic resolution~\cite{fischl2012freesurfer}.
We focused on 12 brain regions (listed below) that were manually segmented and visually inspected for quality assurance. These manual segmentations were only used to quantify performance. The total number of subjects was 38: 30 subjects were used for training and 8 subjects for testing. 
Although we call our sets training and testing, we emphasize that SAE did not have access to the segmentation images during training, as we are proposing an unsupervised paradigm. 
We repeated the experiment 5 times with different random subject assignments to the train/test partitioning.


% \subsection{Atlases}
% We used two different atlases to construct priors in experiments.
% Atlas1 is the spatially blurred (Gaussian kernel with 3 mm isotropic standard deviation) FreeSurfer segmentation of a single average subject from the OASIS dataset~\cite{oasis}, which was visually inspected for quality. Atlas2 is a probabilistic atlas derived from 20 manually segmented subjects~\cite{puonti2016fast}.
% We implemented two priors for each atlas: the spatial prior of Eq~\ref{eq:spatial} and the MRF prior of Eq~\ref{p_mrf}.



\subsection{Variants of SAE}
We employed two atlases. The first one (Atlas1) was based on a single unpaired segmentation image that we obtained from ~\cite{oasis}, which was automatically segmented using FreeSurfer~\cite{fischl2012freesurfer}. 
We applied spatial blurring (Gaussian with 3 mm isotropic standard deviation) to the one-hot encoded segmentation image to obtain a probabilistic prior.
As a second prior (Atlas2), we used a publicly available probabilistic atlas~\cite{puonti2016fast}, which was computed based on 20 manually labeled subjects. Both priors and all input MRI scans were affine registered to Talairach space.
For both of these priors, we implemented two versions: including and excluding the MRF loss of Eq.~\eqref{eq:MRF}.
Specifically, SAE1 (w/o MRF) uses the spatial prior derived by smoothing the single OASIS segmentation. SAE1 (w/ MRF) adds the MRF term of Eq.~\eqref{eq:MRF}, where the pairwise potential function is computed based on the neighborhood statistics in the OASIS segmentation image.
Finally, SAE2 uses the probabilistic atlas prior~\cite{puonti2016fast}, instantiated with and without the MRF loss.

\subsection{Benchmark Methods}
As naive baselines, we used the most probable label at each voxel in the two priors.
\textbf{\textit{Baseline1}} corresponds to Atlas1 and  \textbf{\textit{Baseline2}} corresponds to Atlas2.
As a strong baseline, we used an implementation of a widely-used atlas-based brain MRI segmentation tool~\cite{van1999automated}, which uses Expectation-Maximization (EM)~\cite{dempster1977maximum} to invert a probabilistic generative model. This EM baseline was run with the two atlases, which we refer to as \textbf{\textit{EM1}} and \textbf{\textit{EM2}}. For each image, the EM baseline numerically solves an optimization problem and is thus relatively slow. Finally, all the data in the EM baseline has been pre-processed the exact way as we did for our model. 

As an effective upper bound on performance, we also implemented a supervised model, where a 3D U-Net~\cite{unet} with the same settings as our encoder was trained with the paired manual segmentations in the training data. Negative generalized (soft) Dice~\cite{sudre2017generalised} was used as the loss function and 6 of the 30 training subjects were reserved for validation. Training was terminated when validation loss stopped improving. As with our previous setup, we  repeated this experiment 5 times with different train (N=24), validation (N=6), and test (N=8) splits\footnote{Test subjects are always the same for all methods}. 

\subsection{Metrics}
All presented results are computed on the test images of each round.
For quantitative evaluation, we rely on two metrics: the Dice score that measures the volumetric overlap between the automatic segmentation and the ground truth manual segmentation; and 
the $95\%$-Hausdorff distance (HD) that quantifies the distance between the boundaries of the automatic and manual segmentations. 
When the two segmentation maps are exactly the same, Dice score will achieve its maximum value of 1 and HD will be equal to zero.

\begin{table}[]
\scalebox{0.89}{\begin{tabular}{lccllll}
\cline{1-3}
\multicolumn{1}{c}{} & \multicolumn{2}{c}{\textbf{Performance Measure}} &  &  &           &                                        \\ \cline{1-3}
Model                & Haussdorff (mm)       & Dice Overlap ($\%$)      &  &  &           &                                        \\ \cline{1-3}
Baseline1            & 4.11$\pm$0.07         & 62.82$\pm$0.53           &  &  &           &                                        \\ \cline{6-7} 
EM1 Baseline         & 4.25$\pm$0.09         & 71.24$\pm$0.71           &  &  & Model     & \multicolumn{1}{c}{Test Time (s)} \\ \cline{6-7} 
Baseline2            & 3.50$\pm$0.06         & 71.45$\pm$0.65           &  &  & EM        & \multicolumn{1}{c}{61.07}              \\
SAE1 (w/o MRF)       & 3.88$\pm$0.05         & 74.64$\pm$0.30           &  &  & SAE (CPU) & \multicolumn{1}{c}{6.58}               \\
SAE1 (w MRF)         & 3.81$\pm$0.05         & 75.36$\pm$0.32           &  &  & SAE (GPU) & \multicolumn{1}{c}{1.58}               \\ \cline{6-7} 
EM2 Baseline         & 2.65$\pm$0.05         & 79.70$\pm$0.54           &  &  &           &                                        \\
SAE2 (w/o MRF)       & 2.73$\pm$0.04         & 79.94$\pm$0.34           &  &  &           &                                        \\
SAE2 (w MRF)         & 2.68$\pm$0.05         & 80.54$\pm$0.36           &  &  &           &                                        \\
Supervised           & 2.23$\pm$0.07         & 84.60$\pm$0.26           &  &  &           &                                        \\ \cline{1-3}
\end{tabular}}
\caption{Mean performance of all methods with their standard errors and computational time per volume at testing.}  
\label{table:summary}
\end{table}

\subsection{Experimental Results}
 Table~\ref{table:summary} lists the global average Dice and HD values for the baselines and SAE variants.
Regional and subject-level results are also presented in  Fig.\ref{fig:boxplot}. 
We observe that in every single case and region, SAE produces segmentations that are better than the naive baselines. 
SAE Dice scores, overall, were 8-12 points higher than the naive atlas based baselines and slightly better than the strong EM baselines.
On a modern CPU, the EM baseline had a run-time of around 60 seconds, whereas SAE took less than 7 seconds per single volume at test time (less than 2 sec on a GPU). This represents more than a 10x speed-up over a popular brain MRI segmentation tool, with no discernible reduction in the quality of results. 
% \textbf{Persistence Model}: we 
% use the prior as the prediction for the test subjects. We have two spatial prior. \textbf{\textit{Baseline1}} consist of a brain segmentation of a single subject who was randomly picked from the OASIS dataset \cite{oasis}. The segmentation was automatically generated through FreeSurfer \cite{fischl2012freesurfer}. \textbf{\textit{Baseline2}} is a probabilistic atlas created by compiling segmentation of different subjects and computing the frequency at which each label appears on a given voxel.  Since the prior is probabilistic, we first took the \texttt{argmax} of the label before computing any measure.

% \textbf{Voxelwise Independent Prior}: for this experiment, we assume that we have access to a probabilistic atlas or prior with no corresponding MRI scan.  We are simulating a scenario in where the labeled prior could have come from another image modality than our unlabeled data.  \textbf{\textit{SAE1 (w/o MRF)}} uses the spatial prior from \textit{baseline1}. In order to create a probabilistic atlas from \textit{baseline1}, we applied a Gaussian blur with standard deviation of 3 voxels. On the other hand, \textbf{\textit{SAE2 (w/o MRF)}} uses the probabilistic atlas from \textit{baseline2} . %We emphasize that we only use MRI scan with \textit{no anatomical label} during training. The labels were only used during testing to evaluate the results.
% In order to initialize our network, we first pretrained the encoder by mapping every unlabeled brain MRI to our prior $p(\boldsymbol{s})$. Then, the network was trained end-to-end until the training loss converges. The experiment was repeated 5 times with another set of randomly chosen subjects and averaged the results.

% \textbf{Markov Random Field Prior}: the setup is the same as in the experiment with voxelwise independent prior. However, in this scenario, we investigated the usefulness of enforcing neighborhood consistency. This is achieved by including an additional term $\mathcal{L}_{mrf}$ in our loss. The neighborhood size $\mathcal{N}_{j}$ was chosen to be 3 voxel around $j$. Weighting this term is tricky since we are summing that we do not have labels, thereby there is no validation set. We performed a rough line search for $\beta$ using powers of 10. We chose the largest $\beta$, which was 0.01, that gave stable training dynamics and visually reasonable outputs. 

% \textbf{Supervised Baseline Model}: for this experiment, we wanted to see how well our proposed model fairs against a fully supervised setup. The model was a simple U-Net \cite{unet} with the same settings as our encoder. We split the dataset into 8 subjects for testing, 6 subjects for validation and 24 subjects for training. As with our previous setup, we also repeated this experiment 5 times with different subjects and averaged the results



For SAE, we observe that the adopted prior has a significant impact on the results. With a superior prior, SAE2 (derived from multiple subjects) yields substantially better results than SAE1. In addition, adding spatial consistency via the MRF loss improves the accuracy in all model variants (paired t-test $p<1e-6$, for both atlases).
This result highlights the importance of having a sophisticated prior.
The best unsupervised model, SAE2 (w/ MRF), yielded a Dice score that was about $4$ points below the fully supervised model, which is a strong upper bound in our experiment. 
%Finally, a detailed summary of the dice scores between different region is presented in the boxplot of Fig.\ref{fig:boxplot}.
\begin{figure}
    \centering
    \includegraphics[width=0.89\columnwidth]{figures/dice_boxplot7.pdf}
    %\includegraphics[width=1.0\columnwidth]{figures/boxplot_em.png}
    \includegraphics[width=0.9\columnwidth]{figures/hausdorff_boxplot7.pdf}
    \caption{Boxplot of dice and Hausdorff distance. Legend: (PAL) pallidum (AMY) amygdala (CAU) caudate (CT) cerebral cortex (HIP) hippocampus (THA) thalamus (PUT) putamen (WM) white matter (CCT) cerebellar cortex (LV) left ventricle (CMW) cerebral white matter (BS) brainstem. }
    \label{fig:boxplot}
\end{figure}



%\section{Results}
A qualitative visualization of SAE2 (w MRF) results is provided in Fig.\ref{fig:brain}. 
We can see that despite having a fixed prior $p(\mathbf{s})$, our model is able to capture inter-subject neuroanatomical variation. 
This is mainly due to the decoder, which enforces the latent representation to be useful for reconstruction. 

\section{Discussion}
We introduced SAE, a flexible deep learning framework that can be used to train image segmentation models with minimal supervision.
We applied SAE to segment brain MRI scans, relying on an unpaired atlas prior.
Importantly, SAE does not need manual segmentations paired with the images, which opens up to possibility to deploy it on new imaging techniques, e.g., with high resolution or different contrast.
Empirically, we presented the change in segmentation accuracy as we use different types of priors. 

Current implementation of SAE assumes that the input MRI is affine normalized with the prior by working in Talairach space. However, SAE can be implemented with very different types of priors, which we would like to explore in the future.
For example, in the present paper, we did not experiment with a spatial deformation model that would warp the atlas to better align with the input image.
We envision that we can integrate a ``spatial transformer'' type neural networks, such as VoxelMorph~\cite{dalca2018unsupervised}, to relax our assumption. By adding a deformation model to the prior, we believe that we can handle complications like moving organs. %Instead we relied on affine pre-registration.
Alternatively, we can implement more sophisticated priors, such as those that exploit an adversarial strategy, as in adversarial autoencoders~\cite{makhzani2015adversarial}.
%There are many ways in which we can improve our model. 
%Currently, we balance the KL-divergence and the reconstruction loss in Eq.\ref{eq:loss} through a simple running average of $\sigma^{2}$, which is a scalar. However, a pixel-wise or patch-wise weight can help improve the accuracy on smaller brain regions.

\section{Acknowledgement}
This research was funded by NIH grants 1R21AG050122,
R01LM012719, R01AG053949; and, NSF CAREER 1748377, and
NSF NeuroNex Grant1707312. JEI is supported by the European Research Council (ERC Starting Grant 677697, project BUNGEE-TOOLS). AVD is supported by NIH 1R56AG064027.

\begin{figure}
    \centering
    \includegraphics[width=0.75\columnwidth]{figures/brain.png}
    \caption{Representative segmentation results obtained with SAE2 (w/ MRF) on two subjects. Recon is the output of the decoder. GT scan and segmentation are the input MRI and manual segmentation, respectively.  Pred is the segmentation obtained through \texttt{argmax} of the one-hot encoding $q_{\phi}(\boldsymbol{s}| \boldsymbol{x}^{(i)})$. }
    \label{fig:brain}
\end{figure}

\bibliography{yu20}
\end{document}
