% This is samplepaper.tex, a sample chapter demonstrating the
% LLNCS macro package for Springer Computer Science proceedings;
% Version 2.21 of 2022/01/12
%

\documentclass[runningheads]{llncs}
%
\usepackage{xcolor}
\usepackage[T1]{fontenc}
% T1 fonts will be used to generate the final print and online PDFs,
% so please use T1 fonts in your manuscript whenever possible.
% Other font encondings may result in incorrect characters.
%
\usepackage{graphicx}
% Used for displaying a sample figure. If possible, figure files should
% be included in EPS format.
%
% If you use the hyperref package, please uncomment the following two lines
\usepackage[colorlinks=true,allcolors=blue]{hyperref}
% to display URLs in blue roman font according to Springer's eBook style:
\usepackage{color}

\usepackage{orcidlink}

\renewcommand\UrlFont{\color{blue}\rmfamily}
\urlstyle{rm}
%
\begin{document}
%
\title{Explainable Few-Shot Learning for Multiple Sclerosis Detection in Low-Data Regime}
%
\titlerunning{Explainable FSL for MS detection in Low-Data Regime}
% If the paper title is too long for the running head, you can set
% an abbreviated paper title here
%
\author{
Montassar Ben Dhifallah\inst{1,6}\orcidlink{0009-0005-4637-4678}\thanks{ {corresponding author: \href{mailto:montassar.bendhifallah@issatso.u-sousse.tn}{montassar.bendhifallah@issatso.u-sousse.tn}}} \and
Dalel Kanzari\inst{1,2}\orcidlink{0000-0002-1206-5675} \and 
Selma Naija\inst{3,4}\orcidlink{0000-0001-7235-2169} \and
Sana Ben Amor\inst{3,4}\orcidlink{0000-0001-9936-8531} \and
Ahmed Zrig\inst{5,7}\orcidlink{0009-0000-1386-675X} \and 
Mezri Maatouk\inst{5,7}\orcidlink{0000-0002-9108-1509} \and 
Mabrouk Abdelaali\inst{5,7}\orcidlink{0000-0002-0870-0736} \and 
Jamel Saad\inst{5,7}\orcidlink{0000-0001-5313-4289} \and 
Asma Achour\inst{5,7}\orcidlink{0000-0002-5519-1743
} \and
Sofiane Gaied Chortane\inst{5,7} \and
Maher Hadhri \inst{6,7}\orcidlink{0000-0002-3034-140X} \and
Ahmed Dahmoul\inst{5,7} \and 
Azza Ben Ali\inst{5,7} \and
Sahar Selim\inst{8}\orcidlink{0000-0002-9886-1364} \and
Ahmed Nebli\inst{9}\orcidlink{0000-0003-4565-4502}
}
%index{Ben Dhifallah, Montassar}
%index{Kanzari, Dalel}
%index{Naija, Selma}
%index{Ben Amor, Sana}
%index{Zrig, Ahmed}
%index{Maatouk, Mezri}
%index{Abdelaali, Mabrouk}
%index{Saad, Jamel}
%index{Achour, Asma}
%index{Gaied Chortane, Sofiane}
%index{Hadhri, Maher}
%index{Dahmoul, Ahmed}
%index{Ben Ali, Azza}
%index{Selim, Sahar}
%index{Nebli, Ahmed}


\authorrunning{
M. Ben Dhifallah et al.
}
% First names are abbreviated in the running head.
% If there are more than two authors, 'et al.' is used.
%
\institute{
Higher Institute of Applied Sciences and
Technology, University of Sousse, Tunisia
\and
Operational Research, Decision and Process Control Laboratory (LARODEC), 41 Liberty Street, Bardo, 2000, Tunis, Tunisia
\and
Sahloul Hospital, Department of Neurology, Sousse, Tunisia
\and
University of Sousse, Faculty of Medicine of Sousse, Tunisia \and
Department of Radiology A, Fattouma Bourguiba Hospital, Monastir, Tunisia \and
Department of Neurosurgery, Fattouma Bourguiba Hospital, Monastir, Tunisia \and
Research Unity Interventional radiology LR18SP08, University of Monastir, Tunisia \and
School of Information Technology and Computer Science, Nile University, Giza, Egypt \and
Independent Researcher
}
%
\maketitle              % typeset the header of the contribution
%
\begin{abstract}



Diagnosing multiple sclerosis (MS) accurately is highly challenging due to symptom overlap with other demyelinating diseases. Here, we present DemyeliNeXt, an explainable few-shot learning framework designed to classify MS and other demyelinating diseases from MRI scans. This framework employs a prototypical network with a 3D DenseNet-121 backbone and uses Deep SHAP for feature importance visualization. We train our DemyeliNeXt on a dataset from African populations and we test it for different datasets including MICCAI MSSEG2 public dataset. Our findings demonstrate robust performance across diverse datasets highlighting the model's potential to enhance diagnosis accuracy and generalizability in various clinical settings.




\keywords{Few-Shot Learning \and  Explainable AI \and Multiple Sclerosis  \and 3D MRI  \and and Deep Learning}


\end{abstract}


\section{Introduction}


Multiple sclerosis (MS) is a complex neurological condition that is often misdiagnosed due to its symptom overlap with other conditions such as vasculitis and vascular leukoencephalopathy. Studies indicate that over half of the patients were misdiagnosed for a period exceeding three years \cite{gaitan2019multiple,solomon2016contemporary}. Moreover, 70\% of these patients had been administered disease-modifying therapies (DMTs), and 31\% suffered unnecessary morbidity due to the incorrect diagnosis and treatment \cite{gaitan2019multiple,solomon2016contemporary}. This diagnostic challenge results in a prolonged time to achieve a definitive diagnosis, often exceeding several months. Hence, accurate and timely diagnosis is crucial for effective management and treatment planning in MS patients. Advanced imaging techniques and biomarker analyses are increasingly important in differentiating MS from other similar presenting conditions, thereby reducing diagnostic errors and improving patient outcomes. Machine learning provides a robust approach for the analysis of medical images and the diagnosis of MS.

In this context, several studies have employed machine learning models for MS classification. For instance, Wang et al. \cite{wang_multiple_2018} employed a multi-layer convolutional neural network (CNN) with data augmentation techniques to classify MS. However, the model's explainability remains unexplored. To address this issue, Zhang et al. \cite{zhang_grad-cam_2021} proposed a classification model for MS subtypes based on VGG19 \cite{Simonyan_VGG_2014} with global average pooling and utilized Grad-CAM++ \cite{Chattopadhay_Grad-CAM++_2018} for model explanation. While effective in performance and interpretability, this approach did not account for the diversity of MS data, particularly by not comparing it with other similar demyelinating diseases such as vasculitis. To rectify this concern, Huang et al. \cite{huang_transformer-based_2022} leveraged a Transformer-based model with a Multiple Instance Learning (MIL) strategy to discriminate between MS and various demyelinating diseases. The authors used Grad-CAM to visualize feature extraction through activation heatmaps. Nevertheless, their study
did not incorporate data from low-income countries, such as datasets from the
African population. This omission underscores a critical gap, as regional genetic and environmental factors influence disease onset and progression \cite{waubant2019environmental}. These factors impact the timeliness and accuracy of MS diagnosis, thereby potentially
threatening the patient’s life.

Additionally, the collection of MS and other demyelinating diseases data is challenging due to the variability in disease presentation, limited patient availability, and the high cost of medical imaging. Therefore, the application of few-shot learning is essential to leverage limited data effectively. Furthermore, a key finding in MS identification is the presence of white matter lesions in the brain, detectable via Fluid Attenuated Inversion Recovery (FLAIR) sequence of MRI.

This study focuses on distinguishing MS from other demyelinating diseases. We introduce DemyeliNeXt, an explainable few-shot learning framework for the classification of MS and other demyelinating diseases. Our approach employs a prototypical network with a 3D DenseNet-121 backbone, which integrates spatial information from FLAIR MR (Magnetic Resonance) images to classify them as MS vs other demyelinating diseases (NON-MS). Additionally, the framework provides model interpretability through the Deep SHAP model for visualizing the most  important features leading to the classification of the input MRI. The primary contributions of our work are as follows:
\begin{enumerate}
    \item Application of Few-Shot Learning: We apply few-shot learning for the detection of multiple sclerosis (MS). 
    \item Emphasis on Explainability: Our method integrates explainability mechanisms to enhance interpretability, making it more suitable for clinical settings. 
    \item Utilization of African 3D MRI Data: We trained our model using 3D MRI data from African populations, which are often underrepresented in medical datasets. By benchmarking our model against MICCAI MS public dataset, we demonstrated its robust performance, thereby validating its generalizability across diverse populations.
\end{enumerate}


\section{Proposed Method}
In this section, we explain the key building blocks of our proposed DemyeliNeXt architecture for explainable MS identification from other demyelinating diseases.

\begin{figure}[h]
\centering
\includegraphics[width=\linewidth]{figures/fig1-v9.pdf}
\caption{\label{fig:method} \textit{DemyeliNeXt Pipeline.} (A) Preprocessing MRI scans: includes skull stripping, bias correction normalization, and FLAIR MRI smoothing. (B) Data splitting into support and query sets. (C) Training a prototypical network with 3D DenseNet-121 backbone. (D) Model testing on unseen MRIs with explanations provided using Deep SHAP.}

\end{figure}

\subsection{Architecture overview}

In this study, we introduce DemyeliNeXt, a four-stage pipeline designed for the classification of multiple sclerosis (MS) and other demyelinating diseases from MRI scans, while also providing model interpretability. Figure \ref{fig:method} illustrates the first stage (Section 2.2), which involves a preprocessing pipeline for FLAIR MRI scans. Here, raw FLAIR images are normalized, while noise and artifacts are reduced.  In the second stage, the MRI scans are divided into training, validation, and testing sets. Each set contains a support set ($S$) with labeled examples to update model parameters and a query set ($Q$) with unlabeled examples for performance evaluation. 

The third stage (Section 2.3) involves training a 3D DenseNet-based (DenseNet-121) \cite{Huang_CVPR_densenet_2017} prototypical network to classify the preprocessed MRIs. The training process utilizes $N^{tr}$ training tasks, each comprising $N_{shots}$ support examples for model weight updates and $N_{query}$ query examples for performance assessment. In the final stage, we employ Deep SHAP \cite{Lundberg_SHAP_2017} to approximate the model for interpretability. Deep SHAP, inspired by DeepLIFT \cite{Shrikumar_DeepLIFT_2017}, assigns importance scores to each input feature by propagating neuron contributions backward through the network. These scores are based on the difference from a reference input, known as the
"baseline" or "background" input, representing a typical or neutral state for the
input features. The importance scores are computed via the combination of the model’s weights, the actual input and the baseline input. After training the
explainer, we use the model and explainer to predict and interpret new examples
of MS and other demyelinating diseases during inference.

\subsection{Preprocessing Pipeline}

We begin our preprocessing pipeline by anonymizing DICOM MRI scans, converting them to NIfTI format. This process removes patient metadata and consolidates each volume into a single file. Next, we perform skull stripping using the ROBEX algorithm \cite{iglesias_robust_2011} to eliminate non-brain tissues. We then apply bias field correction using the N4ITK algorithm \cite{Tustison_N4ITK_2010} to remove low-frequency intensity non-uniformities. Following this, we normalize MRI intensities to a range of 0 to 1. We reduce the noise using a Gaussian filter. Finally, we reorient the images to the "IPL" (Inferior, Posterior, Left) orientation, resample them to isotropic voxels, and resize them to a standard format.

\subsection{Few shot learning}
\subsubsection{Prototypical network.}



Prototypical Networks (ProtoNet) \cite{Snell_protonet_2017} seek to find a metric space in which samples from the same class are close to one another. This approach makes the model particularly useful in settings with limited labeled data. Based on the prototype concept \cite{Snell_protonet_2017}, the model depicts each class using the mean of its embedded support set $S$. Prototypical Networks then determine query samples $Q$ based on their proximity to these prototypes.
To generate the image embeddings, we use a 3D DenseNet-121 \cite{Huang_CVPR_densenet_2017} as a backbone. We employed Euclidian distance for our ProtoNet to calculate the distance between the support samples and query samples. We create dataset episodes using a sampler that follows uniform distribution to load data from the dataset for each label.

\subsubsection{Loss function}
We use binary cross-entropy loss:
\begin{equation}
\mathcal{L} = -\left[ y \log(p) + (1 - y) \log(1 - p) \right]
\end{equation}
where $y$ and $p$ are the MS label and the predicted probability of MS from the model respectively. We use ADAM \cite{kingma2014adam} as an optimizer with step LR scheduler to decay the learning rate.

\subsection{Explainability with Deep SHAP}


Deep SHAP \cite{Lundberg_SHAP_2017} approximates explanations for deep neural network models using SHAP (SHapley Additive exPlanations) values to quantify feature importance. This method integrates concepts from a deep learning explanation technique called DeepLIFT \cite{Shrikumar_DeepLIFT_2017} that uses Shapley values \cite{shapley_values_1953}. We apply Deep SHAP to interpret our trained 3D DenseNet-based ProtoNet model using preprocessed MRI scans from the testing dataset. This approach creates a simplified explanation model, assessing the importance of each voxel in our testing MRIs, visualized through feature importance plots.




\subsection{Model inference and explanation}
After training and evaluating the model, we perform inference on unseen examples where we pass them to the explainer to check the used feature importance of the model on the classification of the new examples.


\section{Results and discussion}

In this section, we provide a quantitative evaluation of our model on three distinct datasets and we display the findings of the used Deep SHAP.

\subsection{Employed datasets} 
In this work, we utilized three labeled datasets, summarized in Table \ref{tab:dataset-desc}. We trained, validated, and tested using a set that comprises 182 FLAIR MRI scans from 121 patients with multiple sclerosis (MS) and other demyelinating diseases (NON-MS). The dataset was split randomly and patient-wise into three different sets as follows: 70\% for training, 15\% for validation and 15\% for testing. This dataset is sourced from the radiology department at CHU Fattouma Bourguiba Monastir (FBM), Tunisia. It includes 3D and axial scans: 91 scans from 52 MS patients and 91 scans from 69 patients with other demyelinating diseases such as vasculitis and vascular leukopathy.


We tested our model on a set containing 91 FLAIR MRI scans from 36 MS patients, obtained from the MRI center of CHU Sahloul Sousse (SS), Tunisia. Additionally, we used 80 3D FLAIR MRI scans from 40 patients in the MICCAI 2021 MS Segmentation Challenge (MSSEG-2) as a benchmark dataset. We randomly sampled data from each set to create episodes consisting of a support set and a query set. Prior to training, gamma correction was applied to all scans using $\gamma = 2.5$. No further data augmentation was performed.

\begin{table}[h!]
\centering
\caption{Datasets statistics}
\label{tab:dataset-desc}
\begin{tabular}{|l|l|l|l|l|}
\hline
Source       & Number of patients & Number of scans & Age   & Gender  \\ \hline
CHU FBM, Tunisia & MS: 52   & MS: 91    & 21-63 & MS: 22M/30F \\ & NON-MS: 69 & NON-MS: 91 && NON-MS: 19M/50F \\ \hline
CHU SS, Tunisia   & 36 MS  & 91   &  NA     & 4M/32F  \\ \hline
MSSEG-2   & 40 MS  & 80 & NA    & NA  \\ \hline
\end{tabular}
\end{table}


\subsection{Experimental settings.}

\subsubsection{Parameter settings} For model training, we used an ADAM optimizer \cite{kingma2014adam} with a learning rate of 0.001. We applied learning rate decay for every single step by 0.1 using a step scheduler. We employed dropout with 20\% rate. As for Deep SHAP explainer training, we adopted 90 background examples.
We trained our model and our explainer on the Nvidia RTX 3090 GPU.


\subsubsection{Hyperparameter Settings}

We conducted three distinct training experiments using 2-way ($K=2$) classification. Validation was performed with 100 episodes ($N^{val}=100$) every 500 training episodes. Testing was also conducted with 100 episodes. Each training lasted for 1000 episodes. Detailed hyperparameters for each experiment are listed below:

\begin{itemize}
    \item \textbf{Experiment A}: Trained with 5 examples in both support and query sets ($N_{shots}=5$, $N_{query}=5$).
    \item \textbf{Experiment B}: Trained with 3 examples in both support and query sets ($N_{shots}=3$, $N_{query}=3$).
    \item \textbf{Experiment C}: Trained with 1 example in both support and query sets ($N_{shots}=1$, $N_{query}=1$).
    \item \textbf{Test 1}: We used the saved model from Experiment A to test on 91 scans from CHU SS MS dataset and on 13 scans from CHU FBM NON-MS test set.
    \item \textbf{Test 2}: We used the saved model from Experiment A to test on 80 scans from MSSEG-2 and on 13 scans CHU FBM NON-MS test set.
\end{itemize}


\begin{table}[h!]
\centering
\caption{Experiments results}
\label{tab:results}
\begin{tabular}{|l|l|l|l|l|l|l|l|l}
\hline
Experiments/Tests & Accuracy & MS &NON-MS & Precision & Recall & Specificity & F1  \\
&& specific& specific &&&&score\\ && Accuracy& Accuracy &&&&\\\hline
A: 5 shots 5 queries&&&&&&&\\ (Dataset: CHU FBM)  & 78.8\%& - & - & 0.75 & 0.87 & 0.71   & 0.8  \\\hline
B: 3 shots 3 queries &&&&&&&\\(Dataset: CHU FBM) & 63.83\% & - & - & 0.62  & 0.72  &  0.56  & 0.67 \\ \hline
C: 1 shot 1 query &&&&&&&\\(Dataset: CHU FBM)   & 65.0\%& - & - & 0.64 & 0.68  & 0.62 & 0.66  \\ \hline 
Test 1:&&&&&&&\\ 5 shots 5 queries &&&&&&&\\ (Dataset: &&&&&&&\\CHU SS MS &&&&&&&\\ +  CHU FBM NON-MS) & 75.5\%& 68.6\% & 82.4\% & 0.8 & 0.69  & 0.82 & 0.74  \\ \hline
Test 2: &&&&&&&\\5 shots 5 queries &&&&&&&\\ (Dataset: &&&&&&&\\MSSEG-2 &&&&&&&\\ + CHU FBM NON MS) & \textbf{87.8\%}& \textbf{85\%}& \textbf{90.6\%}          & \textbf{0.9}  & \textbf{0.85}   &  \textbf{0.91} & \textbf{0.87} \\ \hline
\end{tabular}

\end{table}


\subsection{DemyeliNeXt evaluation}

Table \ref{tab:results} shows the classification accuracy, precision, recall, specificity, and F1 scores for the different experiments detailed in Section 3.2. Across all experiments, Test 2, which involved training on an African dataset and testing on a combination of African and European datasets, achieved the highest classification accuracy. This result may indicate that our model has the ability to generalize well across different populations despite the differences in socio-economic conditions between the subjects in each of the datasets.

In contrast, Experiment C and B, which utilized one, and three shots and queries, respectively, demonstrated the lowest performance. This indicates that reducing the number of shots below a certain threshold adversely affects model accuracy. These findings suggest that while reducing shots can decrease computational demands, maintaining an adequate number of shots is critical for reliable performance (see experiment A). In particular, one could generally recommend using the model trained in Experiment A as a guide for practitioners in balancing computational efficiency with diagnosis accuracy for MS.

\begin{figure}[h!]
\centering
\includegraphics[width=\linewidth]{figures/SHAP-explanation-v1.pdf}
\caption{\label{fig:NON_SHAP_MS} \textit{Deep SHAP Explanation of MS and NON-MS Examples}. \textbf{A:} Explanation of NON-MS example. \textbf{B:} Explanation of MS example. For each of the subfigures (A and B), the left panel displays an annotated MRI section of a patient with a NON-MS demyelinating disease (A) and a patient with MS disease (B). The center panel highlights the features identified by our model for classifying the case as NON-MS using Deep SHAP. The right panel shows the features identified for classification as MS using Deep SHAP. Lesions' locations are highlighted with orange rectangles across all panels. For the two right hand side panels,  blue indicates the features excluded by the model, while red shows the important features for each class}
\end{figure}

%%%%%%%%%%%%%%%%%%%%

Figure \ref{fig:NON_SHAP_MS} illustrates the explanation of our model backbone on unseen MS and NON-MS examples with lesion annotation. The plot highlights the features utilized by our trained ProtoNet model for classification that are explained by the Deep SHAP method. We evaluated the explainer results using the key diagnostic features outlined in the McDonald criteria \cite{thompson_diagnosis_2018}, which include lesion size, number of lesions, lesion location, lesion contrast, and lesion shape.
The Deep SHAP explainer seems to identify some of the key features for classification , specially the lesions in MS example (Fig.\ref{fig:NON_SHAP_MS} B). However, one should note that there is a risk that the included features in the explanation could be deemed irrelevant to clinicians.

\textbf{Limitations and future studies.} Despite the promising results, DemyeliNeXt has a few limitations that warrant further investigation. For instance, our approach currently utilizes only FLAIR MRI scans; incorporating other imaging modalities like T1-weighted and T2-weighted MRIs could potentially enhance diagnostic accuracy. While Deep SHAP provides some level of explainability, the clinical relevance of the highlighted features remains uncertain, indicating a need for further refinement. In future studies, we aim to benchmark against state-of-the-art methods. We will also focus on expanding the dataset to include diverse minority populations, integrating multimodal imaging techniques, as well as developing more clinically relevant explainability methods with their evaluation.

\section{Conclusion}

In this study, we introduced DemyeliNeXt, an explainable few-shot learning framework designed for the classification of multiple sclerosis (MS) and other demyelinating diseases in an African population. By incorporating the Deep SHAP model, we provided visual explanations for the model's decisions, enhancing its interpretability. Our findings, derived from MRI data of underrepresented African populations, demonstrate that this approach can generalize effectively to non-African datasets. Although the classification accuracy decreases with fewer shots, the method remains computationally efficient and can aid practitioners in improving diagnostic accuracy. In future work, we aim to extend our framework by including more minority populations and integrating additional neuroimaging modalities, thereby enhancing the generalizability and robustness of our model.

\begin{credits}
\subsubsection{\ackname} The data collection for this study was conducted under the agreement of the head of radiology department of CHU Fattouma Bourguiba Monastir, Tunisia, the head of neurology department of CHU Sahloul Sousse, Tunisia and the director of MRI center of Sahloul, Sousse, Tunisia.

\subsubsection{Code availability.}
We provide the code repository of our method on GitHub at this link: \url{https://github.com/Montassar-bdh/DemyeliNeXt}

\subsubsection{\discintname}
The authors have no competing interests to declare that are
relevant to the content of this article.
\end{credits}

\bibliographystyle{splncs04}
\bibliography{Paper-10}
\end{document}