\section{Methodology}

\begin{figure}[!t]
	\centering
    \includegraphics[width=0.95\columnwidth,height=0.293\columnwidth]{figs/framework.png}
	\caption{ Visual illustration of the proposed memory U-Net.
	}
	\label{fig:framework}
\vspace{-2ex}
\end{figure}

The proposed memory U-Net consists of three modules, a U-Net for feature extraction, a memory network for key-value pair storage of offsets, and a post-processing module for generating the final segmentation.
As shown in \figureref{fig:framework}, given a multi-modal input image, the U-Net outputs a matrix $\mathbf{Z} \in \mathbb{R}^{N\times C}$ ($N=H\times W \times D$ is the number of voxels, and $C$ is the number of channels) with each row representing the feature vector of a voxel.  
Taking the output matrix of the offset head as the query, the memory module and the addressing operator are working together to retrieve the most relevant key memory slots to output a normalized weight vector that can be used to reconstruct offsets from value memory slots.
Additionally, given the matrix $\mathbf{Z}$, the weight head outputs a voting weight for each voxel.
Three offsets along axial, coronal, and sagittal directions as well as the voting weight are the final output of the Memory U-Net.
During training, the U-Net, the offset head, the weight head and all key memory slots are simultaneously updated.
Given a testing sample, the model outputs four vectors $\mathbf{f}^a$, $\mathbf{f}^s$, $\mathbf{f}^c$, $\mathbf{f}^w$  representing offsets of three orthogonal directions to the individual lesion center and corresponding voting weights.
Given the four output vectors, the post-processing module generates an density map of potential lesion centers, followed by non-maximal suppression to pick up peak peak points. 
Then, these peak points and lesion voxel coordinates are delivered into a clustering method to generate the final segmentation.

\subsection{Dataset Materials}

The dataset we use was a cross-sectional study of MS patients who were participating in a clinical and imaging data repository for MS research. 
All studies were approved by an ethical standards committee on human experimentation, and written informed consent is obtained from all patients prior to their entry into the database.
It consisted of 150 MS patients with variant MS lesion loads.
Imaging was performed on a 3T Magnetom Skyra scanner (Siemens Medical Solutions USA, Malvern, PA) using a product twenty-channel head/neck coil. 
The scanning protocol consisted of standard 3D T1-weighted (T1-w) sequence for anatomical structure, 2D T2-weighted (T2-w) and 3D T2-w FLAIR sequences.
Golden lesion labels were segmented by an automated CNN model \cite{zhang2020efficient}, followed by manual editing and correction of two experts.
In contrast to the binary lesion mask, each voxel here was assigned with an id indicating which individual lesion it belongs to.

To prepare for the training samples, we computed the mass center of each lesion, followed by generation of voxel offsets to the lesion center by computing the difference between coordinates of voxels and corresponding lesion centers.
We then computed the size (the number of voxels) $S$ of each lesion, and assigned a weight value to each voxel by $\frac{100}{S}$ (See \figureref{fig:img_example} for a visual example). 
We observed from the statistics of the dataset that the offset value ranges from $-64$ to $64$, and the weight value ranges from $0$ to $6.66$.
Additionally, most of the offset values stay within $(-20,20)$, which presents a challenging long-tailed problem.

\subsection{Memory based Representation Learning}

Since the estimation of offsets is a long-tailed problem, it suffers from data imbalance issue if one takes directly a CNN model such as U-Net to predict offsets.
Thus, the proposed memory network is designed to resolve the long-tailed issue by learning prototypical feature representations for offset values in all range adaptively.
The proposed memory network has three memory modules for storing representative features for estimating offsets from axial, coronal, and sagittal directions respectively. 
Each of the modules consists of a memory addressing operator, a learnable matrix representing key memory slots and a fixed vector for value memory slots.

\subsubsection{Memory-based Offset Representation}

Unlike traditional memory networks, we use a key-value structure to associate offset values with prototypical feature vectors.
All the offset values (integers ranging from $-64$ to $64$) can form a vector $\mathbf{v} \in \mathbb{R}^{125}$, and turns into the value memory slots which are fixed during both training and testing processes.
The key memory is designed as a matrix $\mathbf{M} \in \mathbb{R}^{125 \times C}$ to store $125$ real-valued vectors with a fixed dimension $C$, where $C$ is the number of output channels at the offset head.
Let $\mathbf{Z}^{o} \in \mathbb{R}^{N\times C}$ denotes the output of the offset head, $\mathbf{z}^{o}_{i}$ be the $i_{th}$ row vector of $\mathbf{Z}^{o}$, $\mathbf{m}_{i}$ be the $i_{th}$ row vector of $\mathbf{M}$, and $v_{i}$ be the $i_{th}$ entry of $\mathbf{v}$.
Given a query voxel with a vector $\mathbf{z}^{o}_{i}$, the offset value $f_{i}$ can be obtained as follows:
\begin{equation}
    f_{i} = \sum_{j=1}^{125} w_j v_{j} = \mathbf{w}^{\top} \mathbf{v},
    \label{eq:memory_rep}
\end{equation}
where $\mathbf{w}$ is a weight vector with non-negative entries output from the memory addressing operator.
The weight vector is computed from the key-memory and the query vector.
Given the query vector, the memory addressing operator retrieves the most relevant memory slots from the key-memory and assign each memory slot with a weight representing the similarity between the slot and the query vector.
The process of memory addressing can be described by the following equation:
\begin{equation}
    w_j = \dfrac{\text{exp}(d(\mathbf{z}^{o}_{i},\mathbf{m}_{j}))}{\sum_{k=1}^{125}\text{exp}(d(\mathbf{z}^{o}_{i},\mathbf{m}_{k}))},
    \label{eq:memory_addr}
\end{equation}
where $d()$ is a function that performs similarity measurement.
In our work, we use cosine similarity, and $d(\mathbf{z},\mathbf{m})= \frac{\mathbf{z}\mathbf{m}}{||\mathbf{z}||\cdot||\mathbf{m}||}$.

The process of \equationref{eq:memory_rep} and \equationref{eq:memory_addr} are repeated for all row vectors in $\mathbf{Z}^{o}$ and for all three memory modules along different directions.
After this, we get three matrices $\mathbf{f}^{a}$, $\mathbf{f}^{s}$ and $\mathbf{f}^{c}$ which represent the offsets predicted by our memory U-Net along axial, sagittal and coronal direction respectively.

\subsubsection{Memory Update}

Since both \equationref{eq:memory_rep} and \equationref{eq:memory_addr} are differentiable operations, the key memory slots can be updated via back-propagation and stochastic gradient descent.
Nevertheless, the memory update is prone to be biased because the weight of the memory slots associated with offset values from the long-tail portion can be under estimated.
Thus, we develop an alternative memory updating mechanism to alleviate the issue.
Let $\mathbf{g} \in \mathbb{R}^{N}$ denote the ground-truth offsets, and $g_i$ be the $i_{th}$ entry of the vector $\mathbf{g}$, we propose to update memory slots as follows: 
\begin{equation}
    \mathbf{m}_i = \alpha \mathbf{m}_i + (1-\alpha) \frac{\sum_{j=1}^{N} \mathbf{z}^{o}_j \cdot  \mathbbm{1}(g_j,v_i)}{\sum_{j=1}^{N} \mathbbm{1}(g_j,v_i)},
    \label{eq:memory_update}
\end{equation}
where $\alpha \in [0,1]$ is the updating rate, and $\mathbbm{1}()$ is an indicator function ($\mathbbm{1}(x,y)=1$ if $x=y$; $\mathbbm{1}(x,y)=0$ if $x\neq y$).
Through \equationref{eq:memory_update}, the memory is forced to update memory slots for all possible offset values.

\subsection{Training Loss Functions}

We use L-1 loss to train our proposed memory U-Net.
Let $\mathbf{g}^{a}$, $\mathbf{g}^{s}$ and $\mathbf{g}^{c}$ be the ground-truth offset vectors from axial, sagittal, and coronal directions, $\mathbf{g}^{w}$ be the ground-truth weight vector, $\mathbf{f}^{w}$ be the output from the weight head, we can derive the L-1 loss as follows:
\begin{equation}
    \mathcal{L}_1 = \sum_{p \in \{a,s,c,w\}} |\mathbf{f}^{p} - \mathbf{g}^{p}|.
    \label{eq:l1}
\end{equation}
As we adopt \equationref{eq:memory_update} to update our key memory slots, we further develop a loss function to reduce the variance of the averaged vector as follows:
\begin{equation}
    \mathcal{L}_{var} = \frac{1}{125} \sum_{i=1}^{125} \frac{1}{\sum_{k=1}^{N} \mathbbm{1}(g_k,v_i)} \sum_{k=1}^{N} \mathbbm{1}(g_k,v_i)(||\mathbf{z}^{o}_k - \frac{\sum_{j=1}^{N} \mathbf{z}^{o}_j \cdot  \mathbbm{1}(g_j,v_i)}{\sum_{j=1}^{N} \mathbbm{1}(g_j,v_i)}|| - \delta)_{+},
    \label{eq:var}
\end{equation}
where $(x)_{+}=max(0,x)$ is the hinge function, and $\delta$ is the margin for the loss.
Putting \equationref{eq:l1} and \equationref{eq:var} together with equal weights, we can get our final loss function as follows:
\begin{equation}
    \mathcal{L} = \mathcal{L}_1 + \mathcal{L}_{var}.
\end{equation}