\documentclass{midl} % Include author names
% \documentclass[anon]{midl} % Anonymized submission

% The following packages will be automatically loaded:
% jmlr, amsmath, amssymb, natbib, graphicx, url, algorithm2e
% ifoddpage, relsize and probably more
% make sure they are installed with your latex distribution

% \usepackage{mwe} % to get dummy images
\newcommand{\ourmod}{Triplet Training}

\usepackage{amsfonts}
\usepackage{amsmath}
\usepackage{booktabs}
\usepackage{enumitem}
\usepackage{gensymb}
\usepackage{multirow}
\usepackage{pifont}
\usepackage{graphicx}
% \usepackage{subcaption}
% \usepackage{todonotes}
\usepackage{amssymb}
% \usepackage{subfig}
% \usepackage[pagebackref,breaklinks,colorlinks]{hyperref}


\newcommand{\cmark}{\checkmark}
\newcommand{\interval}[1]{\tiny{$\pm$ #1}}

\usepackage{xcolor}
\definecolor{mpurple}{rgb}{0.671, 0.204, 0.922}
\definecolor{mdarkblue}{rgb}{0.192, 0.212, 0.584}
\definecolor{mred}{rgb}{0.843, 0.188, 0.153}
\definecolor{mdarkgrey}{rgb}{0.249, 0.249, 0.249}
\definecolor{mlightblue}{rgb}{0.455, 0.678, 0.820}
\definecolor{morange}{rgb}{0.992, 0.682, 0.380}
\definecolor{mlightgrey}{rgb}{0.671, 0.671, 0.671}



% \jmlrvolume{-- Under Review}
% \jmlryear{2024}
% \jmlrworkshop{Full Paper -- MIDL 2024 submission}
% \editors{Under Review for MIDL 2024}


\jmlryear{2024}
\jmlrworkshop{Full Paper -- MIDL 2024}
\jmlrvolume{-- nnn}
\editors{Accepted for publication at MIDL 2024}

\title[From Barlow Twins to \ourmod]{From Barlow Twins to \ourmod:\\
Differentiating Dementia with Limited Data}


% More complicate cases, e.g. with dual affiliations and joint authorship
% \midlauthor{\Name{Yitong Li\midljointauthortext{Contributed equally}\nametag{$^{1,2}$}} \Email{yi\_tong.li@tum.de}\\
% \Name{Tom Nuno Wolf\midlotherjointauthor\nametag{$^{1,2}$}} \Email{tom\_nuno.wolf@tum.de}\\
% \Name{Sebastian P{\"o}lsterl\nametag{$^{1}$}} \Email{sebastian.poelsterl@med.uni-muenchen.de}\\
% \Name{Igor Yakushev\nametag{$^{3}$}} \Email{igor.yakushev@tum.de}\\
% \Name{Dennis M. Hedderich\nametag{$^{4}$}} \Email{dennis.hedderich@tum.de}\\
% \Name{Christian Wachinger\nametag{$^{1,2}$}} \Email{christian.wachinger@tum.de} \\
% \addr $^{1}$ Laboratory for Artificial Intelligence in Medical Imaging, Department of Radiology, Technical University of Munich (TUM), Germany \\
% \addr $^{2}$ Munich Center for Machine Learning (MCML), Germany \\
% \addr $^{3}$ Department of Nuclear Medicine, Klinikum rechts der Isar, TUM, Germany \\
% \addr $^{4}$ Department of Neuroradiology, Klinikum rechts der Isar, TUM, Germany  \\
% }


\midlauthor{\Name{Yitong Li\midljointauthortext{Contributed equally}\nametag{$^{1,2}$}},
\Name{Tom Nuno Wolf\midlotherjointauthor\nametag{$^{1,2}$}},
\Name{Sebastian P{\"o}lsterl\nametag{$^{1}$}},
\Name{Igor Yakushev\nametag{$^{3}$}}, \\
\Name{Dennis M. Hedderich\nametag{$^{4}$}},
\Name{Christian Wachinger\nametag{$^{1,2}$}} 
\Email{yi\_tong.li@tum.de} \\
\addr $^{1}$ Laboratory for Artificial Intelligence in Medical Imaging, Department of Radiology, Technical University of Munich (TUM), Germany \\
\addr $^{2}$ Munich Center for Machine Learning (MCML), Germany \\
\addr $^{3}$ Department of Nuclear Medicine, Klinikum rechts der Isar, TUM, Germany \\
\addr $^{4}$ Department of Neuroradiology, Klinikum rechts der Isar, TUM, Germany  \\
}



\begin{document}
% \footnotetext[*]{Contributed equally}
\maketitle

\begin{abstract}
Differential diagnosis of dementia is challenging due to overlapping symptoms, with structural magnetic resonance imaging (MRI) being the primary method for diagnosis.
Despite the clinical value of computer-aided differential diagnosis, research has been limited, mainly due to the absence of public datasets that contain diverse types of dementia. 
This leaves researchers with small in-house datasets that are insufficient for training deep neural networks (DNNs). 
Self-supervised learning shows promise for utilizing unlabeled MRI scans in training, but small batch sizes for volumetric brain scans make its application challenging.
To address these issues, we propose \emph{\ourmod} for differential diagnosis with limited target data. It consists of three key stages: (i) self-supervised pre-training on unlabeled data with Barlow Twins, (ii) self-distillation on task-related data, and (iii) fine-tuning on the target dataset. 
Our approach significantly outperforms traditional training strategies, achieving a balanced accuracy of 75.6\%.
We further provide insights into the training process by visualizing changes in the latent space after each step. 
Finally, we validate the robustness of \ourmod\ in terms of its individual components in a comprehensive ablation study. Our code is available at \url{https://github.com/ai-med/TripletTraining}.
\end{abstract}

\begin{keywords}
differential diagnosis, dementia, transfer learning, limited data.
\end{keywords}

\section{Introduction}

The number of patients suffering from dementia is expected to increase to 152.8 million by 2050~\cite{nichols2022},
with Alzheimer's Disease (AD) accounting for 60-80\% of affected patients.
Frontotemporal dementia (FTD) is the second most common type of dementia in the younger-elderly population (aged $<$ 65 years)~\cite{Young2018FrontotemporalDL}.
Accurately diagnosing different dementia types is challenging as symptoms overlap, but is crucial for patient management, therapy, and prognosis.
In the clinical routine, differential diagnosis incorporates structural magnetic resonance imaging (sMRI) to evaluate distinct atrophy patterns.
Despite the clinical importance of differential diagnosis, there is limited research in computer-aided diagnosis for this task compared to classifying AD and cognitively normal (CN) subjects, largely rooted in the lack of related public MRI datasets. 
Accessing in-house data from hospitals is an alternative; however, even if available, such data is typically too small to train DNNs successfully. 

At the same time, public datasets exist that focus on single types of dementia. 
For AD, the Alzheimer's disease neuroimaging initiative (ADNI, adni.loni.usc.edu) provides an extensive resource~\cite{adni}.
Similarly, the initiative on Neuroimaging in Frontotemporal Dementia (NIFD,  4rtni-ftldni.ini.usc.edu) collected data for FTD. 
As a result, previous research on the differential diagnosis of AD and FTD combined the two datasets \cite{ma20,hu21,ngu22}. 
An inherent limitation of such a combination is the confounding of dataset and diagnosis, potentially yielding shortcut learning that differentiates datasets instead of diagnosis \cite{geirhos2020shortcut}. 
While the evaluation of such a merged dataset easily leads to inflated estimates of classification accuracy, it can instead provide a valuable resource in the training process. 

Population imaging studies, e.g., UK Biobank \cite{ukbio}, establish an even larger resource of MRI data for training, but they do not contain task-related labels.
Recent advances in self-supervised learning (SSL) can provide means to benefit from such data in an unsupervised fashion, which have not yet been incorporated for differential diagnosis. 
A challenge for applying common SSL methods like SimCLR~\cite{simclr} or SwAV~\cite{caron2020unsupervised} to 3D brain MRI data is the need for large batch sizes and hence GPU memory, as they rely on hard negative samples to avoid collapse.
Barlow Twins~\cite{barlow} is an alternative that eliminates the need for negative samples and naturally avoids collapse by redundancy reduction.
As a result, it demonstrates better robustness to small batch sizes, which makes it well-suited for SSL in neuroimaging. 


\begin{figure}[t]
\vspace{-0.5cm}
\floatconts
    {fig:problem}
    {\caption{\ourmod\ for differential diagnosis of dementia: 1) task un-related data is invoked with self-supervision, 2) self-distillation on task-related data, 3) the network is fine-tuned on the training part of the target dataset and evaluated on the test part.}}
    {\includegraphics[width=\textwidth]{graphs/probfig.png}}
    \vspace{-0.5cm}
\end{figure}



We introduce \emph{\ourmod} for differential diagnosis with limited target data. 
\ourmod, see \figureref{fig:problem}, combines three learning strategies to include all relevant MRI data in training.  
%that takes advantage of the available MRI resources. 
%Fig. \ref{fig:problem} illustrates the three steps:
First, self-supervision trains the network on task un-related data without target labels (UK Biobank). 
Second, we apply self-distillation on a task-related dataset that is created by merging data from ADNI and NIFD. 
Third, we fine-tune the model on a training set of the small in-house clinical data. 
Our results demonstrate that \ourmod\ outperforms competing methods while being robust to a variety of properties.


To summarize, our key contributions are: 
\begin{itemize} [topsep=0pt,label=$\bullet$]
    \item \ourmod\  for learning DNNs with limited target data.
    \vspace{-0.3cm}
    \item Adapting Barlow Twins as an efficient SSL algorithm on volumetric brain MRI data.
    
    \vspace{-0.3cm}
    %for DD (self-sup for dementia?)
    \item Self-distillation to distill knowledge from the SSL-trained teacher network in combination with task-related labels. % fuse the result of SSL with task-related data.
    \vspace{-0.3cm}
    \item Reporting of test accuracy for differential diagnosis of AD and FTD on a well-characterized single-site clinical dataset.
\end{itemize}

\subsection{Related Work}\label{sec:relatedwork}

\paragraph{Differential Diagnosis of AD and FTD with DNNs.}
One line of research for differential diagnosis performs brain segmentation~\cite{ma20,ngu22} and uses volume and thickness measurements for the classification. 
Such an approach may restrict learning general dementia-specific features across the entire brain. Motivated by the success of using a 3D-ResNet50 encoder-decoder on MRI~\cite{hu21} to extract latent representations for classification, we selected a 3D-ResNet as the backbone for our work.

As no public dataset exists comprising both AD and FTD patients, these methods combined ADNI and NIFD.
The fundamental problem of such an approach is that datasets coincide with diagnosis; hence, it cannot be determined whether the network inadvertently learns to differentiate datasets instead of pathology \cite{geirhos2020shortcut}.
Thus, we incorporate ADNI and NIFD in \ourmod~for pretraining and evaluate on the in-house single-site dataset to allow for a reliable performance assessment. 



\paragraph{Self-Supervised Learning and Self-Distillation in Medical Image Analysis.}

A variety of research~\cite{Azizi2021,chaitanya,LChen2019,Taleb2020,Zhou2020ComparingTL,Hosseinzadeh-Taher2021,Li2021,Tran2022,ZZhou2019} concluded that self-supervised pre-training on domain-related datasets (i.e., unlabeled (3D) medical images) improves performance on medical downstream tasks.
\citet{dira} added restorative and adversarial branches to the SSL pipeline for medical downstream tasks.
Additionally, \citet{clfinetune} and \citet{desd} showed how SSL trained on task-unrelated medical images improves generalization on low-data regimes.
This problem has also been tackled with self-distillation in~\citet{Distil1}, \citet{Distil2}, and \citet{DistilMICCAI}.


In summary, self-supervised pre-training and self-distillation on medical images improve the performance of the downstream task, with domain-related datasets adding additional benefits.
Such approaches have not yet been explored for differential diagnosis and have not yet been extended to \ourmod.  
Moreover, research on Barlow Twins has been limited despite its attractive properties for volumetric medical images. 



\section{Methods}

In this section, we present the details of \ourmod\ to tackle the limited data availability for the target task. 
We utilize SSL with Barlow Twins to integrate task un-related data in the initial step. 
In the second step, we propose to include task-related data via self-distillation.
Self-distillation fully utilizes the previous SSL step by aligning the distribution of latent features extracted by the student network with those learned from SSL, using the Kullback-Leibler (KL) divergence. This method not only builds on prior learning but also reduces the risk of overfitting on the task-related dataset.
Finally, we fine-tune the model on the target dataset.
Before going into technical details, we introduce notation and datasets. 


\subsection{Preliminaries and Datasets}\label{sec:datasets}

We define a 3D image as $\mathcal{I} \in \mathbb{R}^{H \times W \times D}$, with $H$, $W$, $D$ as height, width and depth, respectively. A dataset consists of $N$ 3D images $\mathcal{I}_i$, $i=1,\dots,N$, and class labels $y_i$ if available.
Our model consists of a feature extractor $f: \mathbb{R}^{H \times W \times D} \rightarrow \mathbb{R}^{Z}$, with $Z$ the latent space dimension, and a projection head $g: \mathbb{R}^{Z} \rightarrow \mathbb{R}^{C}$, which maps the latent vectors to outputs of dimension $C$.
We select a 3D-ResNet backbone for the feature extractor $f$ and a two-layer MLP for the projection head $g$ (implementation details in ~\sectionref{sec:architecture}).


We utilize three datasets:
\vspace{-0.3cm}
\begin{enumerate} %[label=\roman*.]     
    \item The unlabeled dataset $\mathcal{U}$ comprises $N=39,560$ samples $X^\mathcal{U}_{i} = (\mathcal{I}_{i}^\mathcal{U})$ extracted from the UK Biobank~\cite{ukbio}.
    \vspace{-0.2cm}
    
    \item The labeled, task-related dataset $\mathcal{D}$ consists of $N=1,305$ samples $X_{i}^\mathcal{D} = (\mathcal{I}_{i}^\mathcal{D} , y^\mathcal{D} _i), y^\mathcal{D} _i \in \{CN, AD, FTD\}$ from ADNI and NIFD. 
    \vspace{-0.2cm}
    
    \item The labeled target in-house dataset $\mathcal{T}$ consists of $N=329$ samples $X_{i}^\mathcal{T} = (\mathcal{I}_{i}^\mathcal{T}, y_i^\mathcal{T})$, $y^\mathcal{T}_i \in \{CN, AD, FTD\}$ from hospital Klinikum rechts der Isar, Munich, Germany.
    \vspace{-0.2cm}
\end{enumerate}
~\tableref{tab:dataset_statistics} reports demographic statistics for all three datasets.
 
\begin{table}[t]
\centering
\scriptsize{\setlength{\tabcolsep}{0.5em}
\floatconts
    {tab:dataset_statistics}
    {\caption{Statistics for unlabeled $\mathcal{U}$, task-related $\mathcal{D}$, and target $\mathcal{T}$ datasets. MMSE denotes the Mini Mental State Examination score.}}
    {\vspace{-0.5cm}}
    {\begin{tabular}{lccccc}
    \toprule
        Dataset & Diagnosis & \ \# Samples & \ \% Female & Age & MMSE \\
        \midrule
        $\mathcal{U} = $ UK Biobank
        & N/A & 39,560 & 52.6 & 63.6 \interval{7.5} & N/A \\
        \midrule
        \multirow{3}{*}{$\mathcal{D} = $ ADNI+NIFD}
        & CN & 766 & 56.9 & 71.9 \interval{7.1}  & 29.0 \interval{1.2}\\   
        & AD & 489 & 44.2 & 74.4 \interval{7.7} & 22.0 \interval{4.1} \\
        & FTD & 50 & 28.0 & 60.8 \interval{6.3} & 24.1 \interval{5.8} \\
        \midrule
        \multirow{3}{*}{$\mathcal{T} = $ In-House}
        & CN & 143 & 46.9 & 64.2 \interval{9.9} & N/A \\     
        & AD & 110 & 50.0 & 67.3 \interval{8.4} & N/A \\
        & FTD & 76 & 50.0 & 64.6 \interval{9.4} & N/A \\
    \bottomrule
    \end{tabular}}
    }
    \vspace{-0.5cm}
\end{table}





\subsection{\ourmod}

\textbf{1. Self-Supervised Learning.}
The self-supervision task proposed in Barlow Twins (BT) de-correlates features in latent space and has shown to be relatively robust with respect to the batch size~\cite{barlow}. 
This benefits training with 3D medical images because their large size limits batch sizes. 
Hence, BT presents a promising approach for the initial step of \ourmod.


To pre-train the feature extractor $f^\theta$ with trainable parameters $\theta$ on the unlabeled dataset $\mathcal{U}$, two different augmentations $A$ and $B$ of an input image $\mathcal{I}_{i}^\mathcal{U}$ are required.
These augmented images $A(\mathcal{I}_{i}^\mathcal{U})$ and  $B(\mathcal{I}_{i}^\mathcal{U})$ are fed into a neural network consisting of the feature extractor $f^\theta$ and a projection head $g^\theta$, yielding two output latent vectors $z^A_i = g^\theta(f^\theta(A(\mathcal{I}_{i}^\mathcal{U})))$ and $z^B_i = g^\theta(f^\theta(B(\mathcal{I}_{i}^\mathcal{U}))),  z^A_i, z^B_i \in \mathbb{R}^C$.
The model is optimized by maximizing the cross-correlation between corresponding features of different augmentations $\mathcal{C}_{cc}$ and minimizing the cross-correlation between the remaining components $\mathcal{C}_{cj}$:
\begin{equation*}
	\mathcal{L}_\text{BT} = \sum_c (1 - \mathcal{C}_{cc})^2 + \lambda_1 \sum_c \sum_{j \neq c} {\mathcal{C}_{cj}}^2 \textrm{,  \ with \ }
	\mathcal{C}_{cj} = \frac
	{\sum_i z^A_{i,c} z^B_{i,j}}
	{\sqrt{\sum_i (z^A_{i,c})^2} \sqrt{\sum_i (z^B_{i,j})^2}}
\end{equation*}
with $c = 1,\dots,C$ indices across the latent space dimension $C$, $i$ the index of a sample within the dataset $\mathcal{U}$, and $\lambda_1$ a constant hyper-parameter.
This loss makes embeddings invariant to distortions while also reducing redundant information.
We denote the resulting weights after this self-supervised pre-training step as $\theta^\prime$.


\begin{figure}[t]
\vspace{-0.5cm}
\centering
\floatconts
    {fig:pipeline}
    {\caption{Overview of the three stages of \ourmod.} }
    {\includegraphics[width=\textwidth]{graphs/methodsfigure5.png}}
\vspace{-0.5cm}
\end{figure}


\noindent
\textbf{2. Self-Distillation.}
This step requires the feature extractor $f^\theta$, with pre-trained weights $\theta = \theta^\prime$ from the previous step, as a teacher.
We freeze the teacher network $f^{\theta}$ during training to reduce the risk of over-fitting towards the task-related dataset $\mathcal{D}$.
We randomly initialize a student network $f^\psi$ with trainable parameters $\psi$ of the same architecture as the teacher, and an additional projection head $g^\psi$.
Inspired by \citet{self-distillation}, the student is trained on the task-related dataset $\mathcal{D}$ by minimizing the KL divergence $\mathcal{L}_\text{KL}$ between the outputs of the feature extractors $f^\theta(\mathcal{I}_{i}^\mathcal{D})$ and $f^\psi(\mathcal{I}_{i}^\mathcal{D})$, and minimizing the cross-entropy $\mathcal{L}_\text{CE}$ between the predictions of the student $g^\psi(f^\psi(\mathcal{I}_{i}^\mathcal{D}))$ and the related class labels $y_i^\mathcal{D}$:
\begin{equation*}
	\mathcal{L}_\text{SD} =
	\lambda_2 \mathcal{L}_\text{KL} (\mathcal{Z}^\psi \sim f^\psi(\mathcal{I}_{i}^\mathcal{D}), 	\mathcal{Z}^\theta \sim f^\theta(\mathcal{I}_{i}^\mathcal{D}))
	+
	( 1 - \lambda_2) \sum_i  \mathcal{L}_\text{CE} (g^\psi(f^\psi(\mathcal{I}_{i}^\mathcal{D})), y_i^\mathcal{D}),
\vspace{-0.2cm}
\end{equation*}
with $\mathcal{Z}^\theta$ and $\mathcal{Z}^\psi$ random variables sampled via forward passes of the samples from the dataset $\mathcal{D}$, and $\lambda_2$ a constant hyper-parameter trading off the importance of the first and second terms of $\mathcal{L}_\text{SD}$. 
The resulting weights of the student network are denoted as $\psi'$.

\noindent
\textbf{3. Fine-Tuning.}
In the final step, we optimize the student network $f^\psi$, $g^\psi$ initialized with pre-trained weights $\psi = \psi'$ from the previous step, by fine-tuning it on the in-house dataset $\mathcal{T}$ for the target task using cross-entropy loss:
\begin{equation*}
	\mathcal{L}_\text{FT} = \sum_i \mathcal{L}_\text{CE} (g^{\psi}(f^{\psi}(\mathcal{I}^{\mathcal{T}}_{i})), y^{\mathcal{T}}_i).
\end{equation*}

\section{Experiments}

\noindent
\textbf{{Pre-processing and Data Augmentation:}}\label{sec:preprocessing}
Each T1-weighted MRI scan is pre-processed using SPM\footnote{https://www.fil.ion.ucl.ac.uk/spm/software/spm12} and the VBM pipeline of CAT12~\cite{cat12}. The results are gray-matter density volumes (samples with a quality control score lower than B-- are discarded), which are min-max rescaled, center-cropped, and resampled % each scan %of ???1mm??? isotropic resolution 
to a spatial dimension of $55 \times 55 \times 55$ (for training convenience without sacrificing model performance). \sectionref{sec:data_augmentation} reports details about the data augmentation strategy.

\noindent
\textbf{Evaluation:}
As the target dataset $\mathcal{T}$ is relatively small, we perform 5-fold cross-validation with ratios of 65\%, 15\%, and 20\% for train, validation, and test sets, respectively, stratified by age, gender, and diagnostic labels to prevent biased results~\cite{confound}.
Additionally, we split a balanced 20\%-portion of the task-related dataset $\mathcal{D}$ to perform further evaluations for the task at hand.


\noindent
\textbf{Miscellaneous:}
Hyper-parameters for the individual training steps and search spaces of baseline methods are reported in~\sectionref{sec:hyperparams}.
We implement models with PyTorch~\cite{paszke2019pytorch} and train on one NVIDIA GeForce 3090 with 24 GByte memory. We train the model for 29,300 self-supervised iterations (24 hours), followed by 600 self-distillation iterations (2.5 hours) and 150 fine-tuning iterations with early stopping (40 minutes).

\section{Results}
As a baseline, we implement a non-deep learning approach for the differential diagnosis on $\mathcal{T}$, 
by extracting FreeSurfer~\cite{freesurfer} volume and thickness features from MRI scans to train an XGBoost classifier, which achieves a balanced accuracy (BAcc) of 66.46 $\pm$ 3.45\%.

\begin{table}[t]
    \caption{Mean, standard deviation, and pairwise p-values of the balanced accuracy (BAcc), true positive rate per class (TPR), and macro-F1 score (F1) across splits for 3-class differential diagnosis.}
    \vspace{-0.2cm}
    \label{tab:results}
\scriptsize
    \centering
    {\setlength{\tabcolsep}{0.15em}
    %\begin{tabular}{l|ccc|cc|ccc|c}
    \begin{tabular}{lccccccccc|c}
    \toprule
        Training Strategy  & $\mathcal{U}$ & $\mathcal{D}$ & $\mathcal{T}$ &  $\textrm{BAcc}_\mathcal{T}$ & $\emph{p}$-value  & $\textrm{TPR}_\textrm{CN}$ & $\textrm{TPR}_\textrm{AD}$ & $\textrm{TPR}_\textrm{FTD}$ & $\textrm{F1}_\mathcal{T}$ & $\textrm{BAcc}_\mathcal{D}$ \\
        \midrule
        Supervised & & & \cmark & 67.15 \interval{5.36} & 0.011  & 69.9 & 65.5 & 65.8 & 66.94 \interval{5.52}  & - \\
        % &&&&&&&&&&\\
        Supervised & & \cmark & \cmark & 68.44 \interval{4.63} & 0.016  & 79.7 & 66.4 & 59.2 & 69.78 \interval{4.26} & 78.2 \\
        % &&&&&&&&&&\\
        Self-Supervised \scriptsize{(SimCLR)} & \cmark & & \cmark & 63.47 \interval{4.38} & 0.001  & \textbf{86.0} & 50.0 & 54.0 & 64.44 \interval{4.13} & - \\
        \cite{simclr} &&&&&&&&&&\\
        Self-Supervised \scriptsize{(VICReg)} & \cmark & & \cmark & 68.94 \interval{3.42} & 0.012  & 72.7 & 70.0 & 64.5 & 69.22 \interval{2.78} & - \\
        \cite{vicreg} &&&&&&&&&&\\
        Self-Supervised \scriptsize{(DiRA)} & \cmark & & \cmark & 66.78 \interval{0.89} & 0.001 & 80.4 & 60.9 & 59.2 & 67.21 \interval{2.03} & - \\
        \cite{dira} &&&&&&&&&&\\
        Self-Supervised \scriptsize{(BT)} & \cmark & & \cmark & 71.36 \interval{4.18} & 0.072  & 79.7 & 68.2 & 65.8 & 72.24 \interval{3.78} & - \\
        \cite{barlow} &&&&&&&&&&\\
        \midrule
        \ourmod\ (Ours) & \cmark & \cmark & \cmark & \textbf{75.57} \interval{3.62} & - & 81.8 & \textbf{71.8} & \textbf{73.7} & \textbf{75.32} \interval{4.51} & \textbf{85.6} \\
    \bottomrule
    \end{tabular}
    }
    \vspace{-0.5cm}
\end{table}

As seen in \tableref{tab:results}, training a DNN on the target dataset $\mathcal{T}$ alone results in a BAcc of 67.15 $\pm$ 4.78\%, which is likely due to the overfitting on the small task-specific data.
Pre-training the model on the task-related dataset $\mathcal{D}$ improves the performance only marginally by 1.29\%. 
Pre-training with unlabeled $\mathcal{U}$ with established SSL methods (SimCLR~\cite{simclr}, VICReg~\cite{vicreg}, DiRA~\cite{dira}, and Barlow Twins~\cite{barlow}) and then fine-tuning on $\mathcal{T}$ outperforms supervised pre-training on $\mathcal{D}$ by 2.92\% (with Barlow Twins). 
\ourmod, which adds a self-distillation step on $\mathcal{D}$ after self-supervised pre-training, significantly outperforms all competing approaches on the target dataset, achieving a BAcc of 75.57 $\pm$ 3.62\% with the highest true positive rates for both types of dementia (see \tableref{tab:results}). 

Additionally, we evaluate \ourmod\ on the hold-out test set of $\mathcal{D}$ after self-distillation on $\mathcal{D}$, which clearly outperforms (+7.4\%) supervised training on $\mathcal{D}$ alone (denoted as $\textrm{BAcc}_\mathcal{D}$ in Table~\ref{tab:results}).
This indicates that \ourmod\ potentially mitigates overfitting when training with limited data, thus, extracts features that generalize well.

%\section{Discussion}
\begin{figure}[t]
    \vspace{-0.6cm}
    \centering
    \subfigure[Pre-training on $\mathcal{U}$]{
      \includegraphics[width=40mm]{graphs/BT_colorblind.png}
      \label{fig:sub4}
    }
    \subfigure[Self-distillation on $\mathcal{D}$]{
      \includegraphics[width=40mm]{graphs/SD_colorblind.png}
      \label{fig:sub5}
    }
    \subfigure[Fine-tuning on $\mathcal{T}$]{
      \includegraphics[width=40mm]{graphs/legend_TT_colorblind.png}
      \label{fig:sub6}
    }
    \subfigure[Pre-training on $\mathcal{U}$ (only $\mathcal{U}$ colored)]{
      \includegraphics[width=40mm]{graphs/BT_colorblind_dataonly.png}
      \label{fig:sub1}
    }
    \subfigure[Self-distillation on $\mathcal{D}$ (only $\mathcal{D}$ colored)]{
      \includegraphics[width=40mm]{graphs/SD_colorblind_dataonly.png}
      \label{fig:sub2}
    }
    \subfigure[Fine-tuning on $\mathcal{T}$ (only $\mathcal{T}$ colored)]{
      \includegraphics[width=40mm]{graphs/TT_colorblind_dataonly.png}
      \label{fig:sub3}
    }
    \caption{Changes in latent space of all datasets (first row) and the step-wise target dataset (second row) after each step in Triplet Training with UMAP. $\mathcal{U}$: No label (\textcolor{mpurple}{purple}, representative fraction of samples to improve readability); Task-related $\mathcal{D}$: CN (\textcolor{mdarkblue}{dark blue}), AD (\textcolor{mred}{red}), FTD (\textcolor{mdarkgrey}{dark grey}); In-house $\mathcal{T}$: CN (\textcolor{mlightblue}{light blue}), AD (\textcolor{morange}{orange}), FTD (\textcolor{mlightgrey}{light grey}).}
    \label{fig:latent_space}
    \vspace{-0.6cm}
\end{figure}
\paragraph{Visualization of the latent space.}
We argue that the high accuracy of \ourmod\ is rooted in decision boundaries of the classifier that are less population dependent.
Therefore, we plot the evolution of the latent features of all three datasets $\mathcal{U}$, $\mathcal{D}$ and $\mathcal{T}$ after each step in \ourmod\ with UMAP~\cite{mcinnes2018umap}, visualized in~\figureref{fig:latent_space}.
After self-supervised pre-training on $\mathcal{U}$ only, all samples of different classes from the three datasets are mixed together. 
After self-distillation on $\mathcal{D}$, there is a trend of separation between CN, AD, and FTD samples from all datasets. 
The unlabeled samples drawn from $\mathcal{U}$ display considerable overlap with the CN samples, which aligns with expectations as the majority of the UK Biobank samples consist of healthy individuals.
Furthermore, the final features extracted after full \ourmod\ are well separated for each class without dataset dependence, with a particularly clean cluster of FTD samples from $\mathcal{D}$ and $\mathcal{T}$.
Moreover, CN and AD samples of $\mathcal{D}$ maintain a clear separation, indicating that the network did not unlearn the previous knowledge while fitting on the new domain.
This property is crucial in continual learning and domain adaptation, showing that \ourmod\ generalizes well even with limited data available for the target task.



\paragraph{Ablation Study 1: Hyper-parameters.}

As shown in the original work \cite{barlow}, Barlow Twins is relatively robust to the batch size.
However, the evaluated batch sizes up to 4,096 are infeasible when working with volumetric images.
Thus, we examine the robustness of \ourmod\ w.r.t. batch sizes typically used in DNNs for medical image analysis.
As seen in~\figureref{fig:batchsize}, \ourmod\ consistently surpasses both supervised training on $\mathcal{T}$ and pre-training 
on $\mathcal{D}$, $\mathcal{T}$ across all batch sizes, with 128 (used for all experiments) achieving the highest performance marginally over the other batch sizes. 
Evidently, \ourmod\ benefits from a moderate increase in batch size and surpasses all competing methods regardless of the batch size, demonstrating considerable robustness to the batch size variation.
~\figureref{fig:lambdasize} 
shows that Triplet Training outperforms the baseline methods for a wide range of $\lambda_2$, a constant hyper-parameter used during self-distillation. 


\begin{figure}[t]
\vspace{-0.5cm}
\centering
\vspace{-0.3cm}
    \subfigure[BAcc for different batch sizes.]{
    \centering
    \includegraphics[width=0.4\textwidth]{graphs/batch_size_ablation.pdf}
    \label{fig:batchsize}
    }
    \subfigure[BAcc for different values of $\lambda_2$.]{
    \centering
    \includegraphics[width=0.4\textwidth]{graphs/lambda_ablation.pdf}
    \label{fig:lambdasize}
    }
    \vspace{-0.3cm}
\caption{Ablation studies of hyper-parameters in Triplet Training.}
\vspace{-0.5cm}
\end{figure}


\paragraph{Ablation Study 2: Benchmark Self-Supervised Approaches.}
We replace the SSL algorithm (Barlow Twins) in the initial step of \ourmod\ with three SOTA algorithms.
\tableref{tab:ablation_selfsupervised} reports that \ourmod\ showcases high and consistent accuracy across all SSL methods, highlighting its robustness and generalizability.
Among them, Barlow Twins and SimCLR demonstrate the best performance, and introduce few additional hyper-parameters compared to the other methods.
We argue that Barlow Twins is the optimal choice, as it has shown to be robust in terms of the batch sizes.
\vspace{-0.3cm}
\begin{table}[h!]
    \caption{Mean and standard deviation of the balanced accuracy (BAcc), true positive rate (TPR), and macro-F1 score (F1) for different SSL approaches in the initial step of the \ourmod. We propose to use Barlow Twins (BT) in \ourmod.}
    \label{tab:ablation_selfsupervised}
\scriptsize
    \centering
    \vspace{-0.3cm}
     {\setlength{\tabcolsep}{0.4em}
    %\begin{tabular}{l|cc|ccc|c}
    \begin{tabular}{lccccc|c}
    \toprule
        SSL in Triplet Training & $\textrm{BAcc}_\mathcal{T}$  & $\textrm{TPR}_\textrm{CN}$ & $\textrm{TPR}_\textrm{AD}$ & $\textrm{TPR}_\textrm{FTD}$ & $\textrm{F1}_\mathcal{T}$ &  $\textrm{BAcc}_\mathcal{D}$\\
        \midrule
        SimCLR~\cite{simclr} & 75.22 \interval{2.80} & \textbf{86.7} & 69.1 & 69.7 & \textbf{75.64} \interval{2.74} & \textbf{86.0}  \\
        VicReg~\cite{vicreg} & 73.44 \interval{4.92} & 83.9 & 69.1 & 67.1 & 74.15 \interval{4.91} & 85.5 \\
        DiRA~\cite{dira} & 74.49 \interval{4.14} & 86.7 & 65.5 & 71.1 &  74.85 \interval{4.03} & 85.4 \\ 
        BT~\cite{barlow} & \textbf{75.57} \interval{3.62} & 81.8 & \textbf{71.8} & \textbf{73.7} & 75.32 \interval{4.51} & 85.6 \\ 
    \bottomrule
    \end{tabular}
     }
    \vspace{-0.8cm}
\end{table}


\section{Conclusion}
We introduced \ourmod\ for differential diagnosis of dementia, which enhances predictive performance for tasks with limited data availability.
\ourmod\ consists of three steps that fully utilize large-scale unlabeled data, task-related data, and limited amounts of target data, achieving a BAcc of 75.6\% on a well-characterized clinical dataset while showing strong generalizability.
Ablation studies confirmed \ourmod's robustness against varying hyper-parameters and method selection in the initial step.


\clearpage


% Acknowledgments---Will not appear in anonymized version
\midlacknowledgments{This research was supported by the Federal Ministry of Education and Research in the call for Computational Life Sciences (DeepMentia, 031L0200A) and the DFG. This research was conducted using the UK Biobank Resource. 
The authors gratefully acknowledge the Leibniz Supercomputing Centre for funding this project by providing computing time on its Linux-Cluster.}

\bibliography{midl24_22}

\clearpage
\appendix

\section{Architecture}\label{sec:architecture}

\begin{figure}[h]
    \centering
    \includegraphics[width=0.99\textwidth]{graphs/backbone.png}
    \caption{We select a 3D ResNet as the feature extractor $f$ for all models. It consists of six residual blocks, each consisting of two convolutional layers followed by batch normalization and ReLU non-linearity. The five last residual blocks each start with a convolutional layer with stride two.}
    \label{fig:backbone}
\end{figure}

\begin{figure}[h!]
    \centering
    \subfigure[]{
        \includegraphics[width=30mm]{graphs/ph2.png}
        \label{fig:ssl-ph}
    }
    \subfigure[]{
        \includegraphics[width=30mm]{graphs/ph1.png}
        \label{fig:sd-ft-ph}
    }
    \caption{Projection head $g$ for: (a) self-supervision (Barlow Twins); (b) self-distillation and fine-tuning.}
\end{figure}

\section{Training Details}\label{sec:training_details}
\subsection{Hyper-parameters}\label{sec:hyperparams}

\begin{table}[h]
    \caption{Hyper-parameters of the different training strategies. The number of iterations for each step is based on the convergence of the validation set. If available, we use the hyper-parameters proposed in the original work.}
    \label{tab:hyperparams}
    \centering
    {\setlength{\tabcolsep}{1em}
    \begin{tabular}{llc}
    \toprule
    Training Strategy & Hyper-Parameter & Value \\\midrule
    \multirow{4}{*}{Supervised Training ($\mathcal{T}$)} & Learning rate & 0.01 \\
    & Weight decay & 0.00001 \\
    & Batch size & 64 \\
    & Training iterations & 150 \\
    \midrule
    \multirow{4}{*}{Supervised Pre-Training ($\mathcal{D}$)} & Learning rate & 0.01 \\
    & Weight decay & 0.0000015 \\
    & Batch size & 128 \\
    & Training iterations & 600 \\
    \midrule
    \multirow{5}{*}{Triplet Training (Self-Supervision)} & Learning rate & 0.5 \\
    & Weight decay & 0.0000015 \\
    & Batch size & 128 \\
    & Training iterations & 29,300 \\ % 200 epochs
    & $\lambda_1$ & 0.005 \\
    \midrule
    \multirow{5}{*}{Triplet Training (Self-Distillation)} & Learning rate & 0.01 \\
    & Weight decay & 0.0000015 \\
    & Batch size & 128 \\
    & Training iterations & 600 \\ % 400 epochs
    & $\lambda_2$ & 0.001 \\
    \midrule
    \multirow{4}{*}{Triplet Training (Fine-Tuning)} & Learning rate & 0.0005 \\
    & Weight decay & 0.00001 \\
    & Batch size & 64 \\
    & Training iterations & 150 \\ % 200 epochs
    \bottomrule   
    \end{tabular}
    }
\end{table}


\subsection{Data Augmentation}\label{sec:data_augmentation}
\begin{table}[h]
    \caption{Data Augmentations used in the \ourmod. }
    \label{tab:data_augmentation}
    \centering
    {\setlength{\tabcolsep}{0.5em}
    \begin{tabular}{lcc}
    \toprule
    Training Strategy & Augmentation & Values \\
    \midrule
     \multirow{9}{*}{Self-Supervision} &Rescale Intensity & intensity range = (0, 1) \\
    \cmidrule{2-3}
    &\multirow{3}{*}{Random Cropping with Resizing} & crop scale = (0.5, 1.0) \\
    && output size = (55, 55, 55) \\
    && random center = True \\
    \cmidrule{2-3}
    &\multirow{2}{*}{Random Flipping} & axes = (0, 1, 2) \\
    && probability = 0.5 \\
    \cmidrule{2-3}
    &\multirow{3}{*}{Random Affine Transformation} & rotation range = ($-90$\degree, $+90$\degree) \\
    && translation range = ($-8$ pixel, $+8$ pixel) \\
    && probability = 0.5 \\\midrule
    \multirow{4}{*}{Self-Distillation} &Rescale Intensity & intensity range = (0, 1) \\
    \cmidrule{2-3}
    &\multirow{3}{*}{Random Affine Transformation} & rotation range = ($-8$\degree, $+8$\degree) \\
    && translation range = ($-8$ pixel, $+8$ pixel) \\
    && probability = 0.5 \\
    \midrule
    \multirow{4}{*}{Fine-Tuning} &Rescale Intensity & intensity range = (0, 1) \\
    \cmidrule{2-3}
    &\multirow{3}{*}{Random Affine Transformation} & rotation range = ($-8$\degree, $+8$\degree) \\
    && translation range = ($-8$ pixel, $+8$ pixel) \\
    && probability = 0.5 \\
    \bottomrule
    \end{tabular}}
    \vspace{-0.5cm}
\end{table}

\end{document}
