\documentclass{midl} % Include author names

% The following packages will be automatically loaded:
% jmlr, amsmath, amssymb, natbib, graphicx, url, algorithm2e
% ifoddpage, relsize and probably more
% make sure they are installed with your latex distribution

\usepackage{mwe} % to get dummy images
\usepackage{tikz}
\usetikzlibrary{positioning} % Include the positioning library
\usepackage{array}
\usepackage{tabularx}
\usepackage{makecell}
\usepackage{subcaption}
\usepackage{overpic}
\usepackage{amsmath}
\usepackage{fancyhdr}

\jmlryear{2024}
\jmlrworkshop{Full Paper -- MIDL 2024}
\jmlrvolume{-- 302}
\editors{Accepted for publication at MIDL 2024}

\title[Segmentation Masks for Image Classification]{Improving Identically Distributed and Out-of-Distribution Medical Image Classification with Segmentation-Guided Attention in Small Dataset Scenarios}

 % Use \Name{Author Name} to specify the name.
 % If the surname contains spaces, enclose the surname
 % in braces, e.g. \Name{John {Smith Jones}} similarly
 % if the name has a "von" part, e.g \Name{Jane {de Winter}}.
 % If the first letter in the forenames is a diacritic
 % enclose the diacritic in braces, e.g. \Name{{\'E}louise Smith}

 % Two authors with the same address
 % \midlauthor{\Name{Author Name1} \Email{abc@sample.edu}\and
 %  \Name{Author Name2} \Email{xyz@sample.edu}\\
 %  \addr Address}

 % Three or more authors with the same address:
 % \midlauthor{\Name{Author Name1} \Email{an1@sample.edu}\\
 %  \Name{Author Name2} \Email{an2@sample.edu}\\
 %  \Name{Author Name3} \Email{an3@sample.edu}\\
 %  \addr Address}


% Authors with different addresses:
% \midlauthor{\Name{Author Name1} \Email{abc@sample.edu}\\
% \addr Address 1
% \AND
% \Name{Author Name2} \Email{xyz@sample.edu}\\
% \addr Address 2
% }

%\footnotetext[1]{Contributed equally}

% More complicate cases, e.g. with dual affiliations and joint authorship
% \midlauthor{
% \Name{Mariia Rizhko\nametag{$^{1,2}$}} \Email{abc@sample.edu}\\
% \addr $^{1}$ Address 1 \\
% \addr $^{2}$ Address 2 
% \Name{Author Name2\midlotherjointauthor\nametag{$^{1}$}} \Email{xyz@sample.edu}\\
% \Name{Author Name3\nametag{$^{2}$}} \Email{alphabeta@example.edu}\\
% \Name{Author Name4\midljointauthortext{Contributed equally}\nametag{$^{3}$}} \Email{uvw@foo.ac.uk}\\
% \addr $^{3}$ Address 3 \AND
% \Name{Author Name5\midlotherjointauthor\nametag{$^{4}$}} \Email{fgh@bar.com}\\
% \addr $^{4}$ Address 4
% }

\midlauthor{\\
\Name{Mariia Rizhko\nametag{$^{1}$}} \Email{mariia.rizhko@mail.utoronto.ca} \\
\addr $^{1}$ Department of Computer Science, University of Toronto \AND
\Name{Lauren Erdman\nametag{$^{1,2}$}} \Email{larunerdman1@gmail.com} \\
\addr $^{2}$ Center for Computational Medicine, Hospital for Sick Children \AND
\Name{Mandy Rickard\nametag{$^{3}$}} \Email{mandy.rickard@sickkids.ca} \\
\Name{Armando J. Lorenzo\nametag{$^{3,4}$}} \Email{armando.lorenzo@sickkids.ca} \\
\addr $^{3}$ Division of Urology, Hospital for Sick Children \\
\addr $^{4}$ Department of Surgery, University of Toronto \AND
\Name{Kunj Sheth\nametag{$^{5}$}} \Email{kunj.sheth@gmail.com} \\
\Name{Daniel Alvarez\nametag{$^{5}$}} \Email{alvarezd93@gmail.com} \\
\Name{Kyla N Velaer\nametag{$^{5}$}} \Email{kyla.velaer@gmail.com} \\
\addr $^{5}$ Stanford Children’s Health, Lucile Packard Children’s Hospital, Stanford University \AND
\Name{Megan A. Bonnett\nametag{$^{6}$}} \Email{bonnme01@luther.edu} \\
\Name{Christopher S. Cooper\nametag{$^{6}$}} \Email{christopher-cooper@uiowa.edu} \\
\addr $^{6}$ Department of Urology, University of Iowa \AND
\Name{Gregory E. Tasian\nametag{$^{7,8}$}} \Email{TasianG@chop.edu} \\
\Name{John Weaver\nametag{$^{7,8}$}} \Email{jweave2925@gmail.com } \\
\Name{Alice Xiang\nametag{$^{7,8}$}} \Email{Alice.Xiang@jefferson.edu} \\
\addr $^{7}$ Department of Surgery, University of Pennsylvania \\
\addr $^{8}$ Division of Urology, Children’s Hospital of Philadelphia \AND
\Name{Anna Goldenberg\nametag{$^{1,9}$}} \Email{nyulik@gmail.com} \\
\addr $^{9}$ Genetics and Genome Biology, Hospital for Sick Children \\
}

\begin{document}

\maketitle

\begin{abstract}
We propose a new approach for training medical image classification models using segmentation masks, particularly effective in small dataset scenarios. By guiding the model's attention with segmentation masks toward relevant features, we significantly improve accuracy for diagnosing Hydronephrosis. Evaluation of our model on identically distributed data showed either the same or better performance with improvement up to 0.28 in AUROC and up to 0.33 in AUPRC. Our method showed better generalization ability than baselines, improving from 0.02 to 0.75 in AUROC and from 0.09 to 0.47 in AUPRC for four different out-of-distribution datasets. The results show that models trained on smaller datasets using our approach can achieve comparable results to those trained on datasets 25 times larger. The source code is available at \url{github.com/MeriDK/segmentation-guided-attention}.
\end{abstract}

\begin{keywords}
Medical Image Classification, Domain Generalization, Hydronephrosis.
\end{keywords}

\section{Introduction}

Can a model trained to predict medical diagnoses for young children maintain the same level of accuracy for older children? If a model is trained with data from one hospital, will it perform just as efficiently with data from another? How precise will a model be when analyzing images produced by different machines? Machine Learning (ML) models struggle with these scenarios since they are trained on \textit{identically distributed} (i.i.d.) data. However, their accuracy can vary significantly when these models are tested on \textit{out-of-distribution} (OOD) data. The problem is known as \textit{domain shift} between i.i.d \textit{source} and OOD \textit{target} domains. It occurs when models are not trained to deal with the domain shift in mind. This issue is significant for the medical field, where labeled data is limited, and training different models for each scenario is impractical.

Domain shift is a challenge that extends beyond healthcare. The task of addressing domain shift is known as \textit{Domain Generalization} (DG), a problem that exists in almost every application of ML \cite{zhou2022domain}. For example, in the semantic segmentation task in autonomous driving, a model trained on urban data may fail in rural settings \cite{hoffman2018cycada, ros2016synthia}, potentially leading to accidents. In personal identification systems, the model trained on well-illuminated images may not recognize a person in dim lighting \cite{sun2019learning, li2020scalable}, potentially preventing access to their home if the lights are broken. Even with seemingly simple tasks like recognizing handwritten digits, ML models can underperform due to minor variations like ink color \cite{ganin2015unsupervised}. Similarly, in the medical domain, a model trained on images collected with one protocol might be ineffective for images collected through another \cite{liu2020shape}. These examples underline the significance of the problem across different domains.

An excellent survey \cite{zhou2022domain} categorizes various DG methodologies. The \textit{Domain Alignment} \cite{li2018domain, li2018deep} methods focus on learning a mapping function between the domain and target datasets. \textit{Meta-learning} \cite{li2018learning, balaji2018metareg} approaches divide data into meta-train and meta-test sets, where a model is trained on the meta-train set and evaluated on the meta-test set. The methods separate domain-specific and domain-agnostic features within datasets in the \textit{Learning Disentangled Representations} \cite{li2017deeper, ilse2020diva} category. While Domain Alignment, Meta-learning, and Learning Disentangled Representations offer promising approaches, they require a labeled target dataset during training on a domain dataset. The target dataset is usually unavailable during training in the medical domain, so other approaches should be used.

The DG survey \cite{zhou2022domain} also covers the methodologies that do not require a target dataset during the training. \textit{Data augmentations} \cite{volpi2018generalizing, volpi2019addressing, xu2020robust} simulate a domain shift by changing images. \textit{Ensemble learning} \cite{xu2014exploiting, cha2021swad} trains the same model with a different random seed for weight initialization or data split. \textit{Self-supervised learning} \cite{carlucci2019domain, bucci2021self} lets a model learn generic features of your data first and then fine-tune the model for a downstream task. \textit{Regularization Strategies} \cite{wang2019learning, huang2020self} learn generalized features by focusing on global structure instead of local patterns or by masking out over dominant features. All of these approaches are generally considered to make more robust models. However, when trained on small datasets, which is usually the case for the medical domain, their performance might suffer on i.i.d and OOD data.

\begin{figure}[ht]
    \centering
    \begin{tikzpicture}
        \node (img1) {\includegraphics[width=0.15\textwidth]{images/Screenshot 2023-12-15 at 11.01.17 AM.png}};
        \node[right=0.5cm of img1] (img2) {\includegraphics[width=0.15\textwidth]{images/Screenshot 2023-12-15 at 11.01.27 AM.png}};
        \node[right=of img2] (img3) 
        {\includegraphics[width=0.15\textwidth]{images/Screenshot 2023-12-15 at 11.20.35 AM.png}};
        \node[right=0.5cm of img3] (img4) {\includegraphics[width=0.15\textwidth]{images/Screenshot 2023-12-15 at 11.20.48 AM.png}};

        \draw[->, double, line width=0.5pt, double distance=0.5pt] (img1.east) -- (img2.west);
        \draw[->, double, line width=0.5pt, double distance=0.5pt] (img3.east) -- (img4.west);

        \node[above=0.02cm of img1] {before};
        \node[above=0.02cm of img2] {after};
        \node[above=0.02cm of img3] {before};
        \node[above=0.02cm of img4] {after};
    \end{tikzpicture}
    \caption{Illustration of our approach. Before: focus on both organ and background noise. After: targeted focus on the critical organ.} 
    \label{fig:gradcams}
\end{figure}

To address this limitation we utilize Gradient-weighted Class Activation Mapping (Grad-CAM) \cite{selvaraju2017grad} to create attention maps. GradCAM is a visual explanation method in computer vision that highlights the regions in an input image that influence the model's outcome the most. Prior work showed that this attention mechanism could be learned, resulting in a better performance for image segmentation \cite{li2018tell} and classification tasks \cite{fukui2019attention}. This idea has been adapted for the medical domain: it showed improved accuracy for thyroid nodules \cite{lu2022gan}, for chest X-ray abnormality localization and diagnosis \cite{ouyang2020learning}, for diagnosis of COVID-19 \cite{ouyang2020dual} and dementia \cite{lian2020attention,lian2019end}. However, these prior studies used large datasets, ranging from 2,000 MRI scans to as many as 1.2 million images from ImageNet, and they focused only on i.i.d. data. In contrast, we apply the idea to DG tasks and demonstrate its effectiveness on small datasets with less than 100 images.

Our study addresses the typical scenario in the medical field where models are trained on small datasets. Typically, these models learn specific "useful" noise patterns, leading to high performance on similar (i.i.d) test datasets. However, their accuracy declines when applied to new images without these noise patterns. Beyond basic classification tasks, models should also be trained to disregard features known a priori to be irrelevant. For instance, in kidney ultrasound classification (see Figure \ref{fig:gradcams}), the model should focus only on the kidney, ignoring background noise. Often, in medical imaging, additional information like segmentation masks is available. We adapt gradient-based techniques to utilize segmentation masks for medical imaging with small dataset scenarios. This adaptation allows us to effectively train models on small datasets and improve their performance when tested on i.i.d and OOD data without having target datasets during training.


\section{Method}

% \begin{figure}[htbp]
%  % Caption and label go in the first argument and the figure contents
%  % go in the second argument
% \floatconts
%   {fig:arch}
%   {\caption{Attention-Guided Binary Classification Loss}}
%   {\includegraphics[width=0.7\linewidth]{images/arch.drawio.png}}
% \end{figure}

Let \(\mathcal{X}\) be the input image space and \(\mathcal{Y}\) the label space. A \textit{domain} is defined as a joint distribution \(\mathcal{D}\) = (\(\mathcal{X}, \mathcal{Y}\)), which contains image-label pairs \( \ \{(x^{(n)}, y^{(n)})\}_{n=1}^N\), where N is the number of samples. Our goal is to learn a classification model $F_{\theta} : \mathcal{X} \to \mathcal{Y}$ using the source domain \(\mathcal{D}\) for generalization across unseen target domains \{$\mathcal{D}_{tg}^1$, $\mathcal{D}_{tg}^2$, \ldots, $\mathcal{D}_{tg}^K$\} set of K target domains. In a source domain for input images \(\mathcal{X}\), we have corresponding segmentation masks \(\mathcal{M}\) that will be utilized for our method. Note that there are no requirements for masks in target domains. The core idea of our method is to force the model to learn two things simultaneously: the attention mechanism learning task and the classification task itself. 

\textbf{Attention Map Calculation.} Given an input image, a classification model processes it up to a target layer. Let \( A^k \) represent the activation of the \( k \)-th feature map at this layer. The gradient of the score for class \( c \), denoted \( y^c \), with respect to the activations \( A^k \) of the feature map is computed. This gradient is represented as \( \frac{\partial y^c}{\partial A^k} \). To obtain the neuron importance weights \( \alpha^c_k \) we apply Global Average Pooling (GAP) to these gradients. This is given by:
\begin{equation}
\alpha^c_k = {GAP}\left(\frac{\partial y^c}{\partial A^k}\right)
\end{equation}

The Class Activation Map (CAM) for class \( c \), denoted as \( L^c_{Grad-CAM} \), is a weighted sum of the feature maps, weighted by \( \alpha^c_k \), and passed through a ReLU function:
\begin{equation}
L^c_{Grad-CAM} = {ReLU}\left(\sum_k \alpha^c_k A^k\right)
\end{equation}

The final Attention Map $\mathcal{A}$ is achieved by resizing \( L^c_{Grad-CAM} \) to the dimensions of the input image.

\textbf{Attention Loss.} It is a custom loss function, denoted as \( \mathcal{L}_{Attention} \), that incorporates the difference between the Grad-CAM Attention Map \( \mathcal{A} \) and a given ground truth attention mask \( \mathcal{M} \) by calculating the mean squared error (MSE) between \( \mathcal{A} \) and \( \mathcal{M} \):
\begin{equation}
\mathcal{L}_{Attention} = \frac{1}{N} \sum_{i=1}^{N} \left( \mathcal{A}_i - \mathcal{M}_i \right)^2
\end{equation}

where \( N \) is the total number of pixels in the image,  and $i$ indexes these pixels.

This loss function measures the alignment between the regions highlighted by the Grad-CAM and those indicated by the attention mask. The objective of the training is to minimize this loss, thereby encouraging the model to focus more on areas marked as important by the mask.

\textbf{Overall Loss.} The Binary Cross-Entropy Loss $\mathcal{L}_{BCE}$, given the predicted outputs \( y^{pred} \) and the true labels \( y^{true} \), is defined as:
\begin{equation}
\mathcal{L}_{BCE} = -\frac{1}{N} \sum_{i=1}^{N} \left[ y^{true}_{i} \log(y^{pred}_{i}) + (1 - y^{true}_{i}) \log(1 - y^{pred}_{i}) \right]
\end{equation}

where \( N \) is the number of samples and \( i \) indexes these samples.

The overall loss is a weighted combination of the Binary Cross-Entropy Loss $\mathcal{L}_{BCE}$ and the Attention Loss $\mathcal{L}_{Attention}$.
\begin{equation}
\mathcal{L} = \alpha \mathcal{L}_{BCE} + \beta \mathcal{L}_{Attention}
\label{eq:loss}
\end{equation}

where \( \alpha \) and \( \beta \) are weighting coefficients that balance the two components of the loss. In all our experiments \( \alpha \) = \( \beta \) = 1.

By combining these two losses, the model not only focuses on minimizing the prediction error but also emphasizes the alignment of the attention maps with the important regions as marked by the attention masks.

\section{Experiments}

Hydronephrosis (HN) is a medical condition characterized by the swelling of one or both kidneys due to a urine buildup. It can affect people of any age and is spotted in up to 5\% of babies during routine pregnancy ultrasound scans. However, surgical intervention is only required in 20\% of these cases \cite{dos2015new}. To determine which cases need intervention, patients receive repeated invasive scans and ultrasounds to monitor whether the HN is causing functional damage or resolving without the need for surgery. Previously, deep learning models have used postnatal ultrasound images to predict surgical intervention in HN from the first ultrasound \cite{erdman2020predicting}, further investigated predicting HN grades \cite{smail2020using} and risk scores \cite{tabrizi2021pediatric}. While prior models \cite{erdman2020predicting, smail2020using, tabrizi2021pediatric} worked well for i.i.d data, they showed lower performance on smaller and OOD datasets. 

\textbf{Datasets.} We use five datasets from four pediatric hospitals in North America containing ultrasounds of kidneys. The variation of the data comes not only from its collection across various hospitals but also from differences in patient demographics and imaging equipment, attributing to its OOD characteristics. For example, the average patient age in the Hospital for Sick Children (SickKids) is 53 weeks, while in the Children's Hospital of Philadelphia (CHOP), it is 313 weeks. These variances provide a robust setting for evaluating models for DG.

The source domain dataset \(\mathcal{D}\) from SickKids has 2542 ultrasounds. 20\% of \(\mathcal{D}\) is held out to create i.i.d. test dataset \(\mathcal{D}_{test}\). The rest of the 2048 images are used for training and validation. Only 83 out of 2048 images have corresponding kidney segmentation masks. We will utilize these 83 images for training baseline models. The same 83 images with their segmentation masks will also be used to train our model. To assess the robustness of our model, we will further use the complete set of 2048 images, which is bigger by 25 times, to train additional baseline models. We split \(\mathcal{D}\) into train \(\mathcal{D}_{train}\) and validation \(\mathcal{D}_{val}\) sets, ensuring each patient's images appear in only one set. \(\mathcal{D}_{train}\) has only 66 images with kidney segmentation masks, creating \(\mathcal{D}^{seg}_{train}\) with 51 non-surgical and 15 surgical cases. Similarly, the validation set \(\mathcal{D}^{seg}_{val}\) is a subset of \(\mathcal{D}_{val}\) and has 10 non-surgical and 7 surgical cases.

We evaluate models on four distinct target domain datasets. The first, \(T_{SickKids}\), includes data from 202 patients at SickKids, having 711 images, of which 75 are positive. Despite being collected at the same hospital as the training dataset, patient demographics and imaging equipment variations make this dataset OOD. The second dataset, \(T_{Stanford}\), is from the Stanford Children's Hospital (Stanford) and includes data from 103 patients, with 551 images (27 positive). The third, \(T_{UIowa}\), is from the University of Iowa Children's Hospital (UIowa) with 91 patients and 97 images (56 positives). Lastly, \(T_{CHOP}\) comes from CHOP with 89 patients and 89 images, 60 of which are positive. The datasets summary is shown in Table \ref{tab:dataset_overview}.
  
\begin{table}[!h]
\centering
\caption{Datasets Summary}
\label{tab:dataset_overview}
% \begin{tabular}{|l|l|l|l|c|c|c|c|c|}
\begin{tabularx}{\textwidth}{
>{\hsize=.115\hsize}X 
>{\hsize=.135\hsize}X 
>{\hsize=.115\hsize}X 
>{\hsize=.155\hsize}X 
>{\hsize=.1\hsize}X 
>{\hsize=.11\hsize}X 
>{\hsize=.11\hsize}X 
>{\hsize=.08\hsize}X 
>{\hsize=.08\hsize}X}
\hline
\textbf{Name}                & \textbf{Hospital} & \textbf{Domain} & \textbf{Used for}         & \textbf{Masks} & \textbf{Patients} & \textbf{Images} & \textbf{Pos} & \textbf{Neg} \\ \hline
\(D_{train}\)       & SickKids & source & training  & \(\times\)    & 266     & 1549     & 185      & 1364     \\
\(D^{seg}_{train}\) & SickKids & source & training  & \(\checkmark\) & 35      & 66       & 15       & 51       \\
\(D_{val}\)         & SickKids & source & validation    & \(\times\)    & 89      & 499      & 67       & 432      \\
\(D^{seg}_{val}\)   & SickKids & source & validation    & \(\checkmark\) & 7       & 17       & 7        & 10       \\
\(D_{test}\)        & SickKids & source & i.i.d. test      & \(\times\)    & 89      & 494      & 71       & 423      \\
\(T_{SickKids}\)      & SickKids & target & OOD test         & \(\times\)    & 202     & 711      & 75       & 636      \\
\(T_{Stanford}\)      & Stanford & target & OOD test         & \(\times\)    & 103     & 551      & 27       & 524      \\
\(T_{UIowa}\)      & UIowa       & target & OOD test         & \(\times\)    & 91      & 97       & 56       & 41       \\
\(T_{CHOP}\)      & CHOP     & target & OOD test         & \(\times\)    & 89      & 89       & 60       & 29       \\ \hline
\end{tabularx}
% \end{tabular}
\end{table}


\textbf{Baselines.} ResNet-18, ResNet-50 \cite{he2016deep}, ViT-Tiny, and ViT-Base \cite{dosovitskiy2020image} were  trained on \(\mathcal{D}^{seg}_{train}\) and validated  on \(\mathcal{D}^{seg}_{val}\), utilizing Binary Cross Entropy Loss $\mathcal{L}_{BCE}$ only. We tested two weights initialization methods: Kaiming uniform initialization (random) \cite{he2015delving} and using weights pre-trained on ImageNet \cite{deng2009imagenet}. Hyperparameters were tuned via Bayesian optimization \cite{snoek2012practical} to minimize the loss on \(\mathcal{D}^{seg}_{val}\), more details in Appendix \ref{appendix:hyperparameters_search}.Consistent image transformations are applied across all experiments. Rotation, cropping, horizontal flipping, and normalization are used for training, while resizing and normalization are used for validation. We train all models for 30 epochs with early stopping based on validation AUROC. One NVIDIA RTX 2080 Ti was used for all experiments. For further analysis, we also trained additional baseline models on the larger datasets \(\mathcal{D}_{train}\) and \(\mathcal{D}_{val}\); all experimental setups were the same.

\textbf{Our model.} We trained the ResNet-18 model on \(\mathcal{D}^{seg}_{train}\) and validated it on \(\mathcal{D}^{seg}_{val}\), utilizing our proposed loss function as described in Equation \ref{eq:loss}. The model's weights were initialized using pre-trained ImageNet weights. All other experimental setups, including hyperparameters search, image transformations, and training duration, were consistent with those used in the baseline models.

\subsection{Baselines vs. Our Model trained on the Small Datasets.} 
 
The results in Table \ref{tab:model_performance_test_iid_small} and Table \ref{tab:model_performance_ood_small} show Area Under the Receiver Operating Characteristic (AUROC) and Area Under the Precision-Recall Curve (AUPRC) of the baseline models and our model, all trained on \(\mathcal{D}^{seg}_{train}\) and validated on \(\mathcal{D}^{seg}_{val}\).

\textbf{I.i.d. Comparison.} 

Table \ref{tab:model_performance_test_iid_small} presents a comparative analysis of the models performance on i.i.d. test dataset \(\mathcal{D}_{test}\). Note, \(\mathcal{D}_{test}\) is a held-out test dataset from the whole dataset \(\mathcal{D}\) and has 494 images, while the models are trained on the small subsets  \(\mathcal{D}^{seg}_{train}\) and \(\mathcal{D}^{seg}_{val}\) with 66 and 17 images respectively. This comparison reflects each model's ability to generalize to new data with a similar distribution to the training set. Interestingly, only three models, including our own, were able to effectively generalize to \(\mathcal{D}_{test}\), showing 0.81-0.83 AUROC and 0.48 AUPRC.
 
\begin{table}[!h]
\centering
\caption{Comparison of models trained on the small dataset \(\mathcal{D}^{seg}_{train}\) for performance on held-out i.i.d. test dataset \(\mathcal{D}_{test}\)}
\label{tab:model_performance_test_iid_small}
% \begin{tabular}{l l l c c}
\begin{tabularx}{\textwidth}{>{\hsize=.35\hsize}X >{\hsize=.2\hsize}X >{\hsize=.2\hsize}X >{\centering\arraybackslash\hsize=.125\hsize}X >{\centering\arraybackslash\hsize=.125\hsize}X}
% \Xhline{1pt}
\hline
\textbf{Model Name} & \textbf{Backbone} & \textbf{Weights Init.} & \textbf{AUROC} & \textbf{AUPRC} \\
\hline
R18-random-small & ResNet-18 & Random & 0.79 & 0.43 \\
R18-imagenet-small & ResNet-18 & ImageNet & 0.70 & 0.27 \\
R50-random-small & ResNet-50 & Random & 0.68 & 0.30 \\
R50-imagenet-small & ResNet-50 & ImageNet & 0.71 & 0.28 \\
ViT-T-random-small & ViT-Tiny & Random & 0.55 & 0.20 \\
ViT-T-imagenet-small & ViT-Tiny & ImageNet & \textbf{0.81} & \textbf{0.48} \\
ViT-B-random-small & ViT-Base & Random & 0.54 & 0.15 \\
ViT-B-imagenet-small & ViT-Base & ImageNet & \textbf{0.83} & \textbf{0.48} \\
\hline
R18-attention (Ours) & ResNet-18 & ImageNet & \textbf{0.82} & \textbf{0.48} \\
% \Xhline{1pt}
\hline
\end{tabularx}
% \end{tabular}
\end{table}

\textbf{OOD Comparison.} 

Table \ref{tab:model_performance_ood_small} presents the performance of the models across four different OOD datasets \(T_{SickKids}\), \(T_{Stanford}\), \(T_{UIowa}\), and \(T_{CHOP}\). Notably, our model consistently outperformed all baselines across all OOD datasets.

\begin{table}[!h]
\centering
\caption{Comparison of models trained on the small dataset \(\mathcal{D}^{seg}_{train}\) for performance on four different OOD datasets \(T_{SickKids}\), \(T_{Stanford}\), \(T_{UIowa}\), and \(T_{CHOP}\)}
\label{tab:model_performance_ood_small}
\begin{tabularx}{\textwidth}{
>{\hsize=.32\hsize}X 
>{\centering\arraybackslash\hsize=.095\hsize}X 
>{\centering\arraybackslash\hsize=.095\hsize}X 
>{\centering\arraybackslash\hsize=.07\hsize}X 
>{\centering\arraybackslash\hsize=.07\hsize}X 
>{\centering\arraybackslash\hsize=.095\hsize}X 
>{\centering\arraybackslash\hsize=.095\hsize}X 
>{\centering\arraybackslash\hsize=.07\hsize}X 
>{\centering\arraybackslash\hsize=.07\hsize}X}
\hline
& \multicolumn{4}{c}{\textbf{AUROC}} & \multicolumn{4}{c}{\textbf{AUPRC}} \\
\textbf{Model Name} & \small{\(T_{SickKids}\)} & \small{\(T_{Stanford}\)} & \small{\(T_{UIowa}\)} & \small{\(T_{CHOP}\)} & \small{\(T_{SickKids}\)} & \small{\(T_{Stanford}\)} & \small{\(T_{UIowa}\)} & \small{\(T_{CHOP}\)} \\
\hline
R18-random-small & 0.47 & 0.19 & 0.33 & 0.34 & 0.09 & 0.04 & 0.48 & 0.59 \\
R18-imagenet-small & 0.52 & 0.35 & 0.72 & 0.54 & 0.10 & 0.04 & 0.77 & 0.73 \\
R50-random-small & 0.52 & 0.23 & 0.39 & 0.27 & 0.20 & 0.04 & 0.51 & 0.58 \\
R50-imagenet-small & 0.37 & 0.13 & 0.18 & 0.21 & 0.08 & 0.03 & 0.43 & 0.52 \\
ViT-T-random-small & 0.52 & 0.29 & 0.76 & 0.56 & 0.10 & 0.03 & 0.73 & 0.72 \\
ViT-T-imagenet-small & 0.80 & 0.72 & 0.80 & 0.69 & 0.35 & 0.22 & 0.85 & 0.82 \\
ViT-B-random-small & 0.46 & 0.31 & 0.71 & 0.50 & 0.10 & 0.03 & 0.74 & 0.68 \\
ViT-B-imagenet-small & 0.84 & 0.84 & 0.72 & 0.68 & 0.46 & 0.33 & 0.79 & 0.81 \\
\hline
R18-attention (Ours) & \textbf{0.86} & \textbf{0.88} & \textbf{0.90} & \textbf{0.81} & \textbf{0.53} & \textbf{0.42} & \textbf{0.90} & \textbf{0.92} \\
\hline
\end{tabularx}
\end{table}

\subsection{Baselines trained on the Big Datasets vs. Our Model trained on the Small Datasets.} 

To further analyze our model, we trained additional baselines on the whole train dataset \(\mathcal{D}_{train}\) and validation dataset \(\mathcal{D}_{val}\) with a total of 2078 images. We compared the baselines to our model trained on \(\mathcal{D}^{seg}_{train}\) and validated on \(\mathcal{D}^{seg}_{val}\) with a total of 83 images.

\textbf{I.i.d. Comparison.} 

Table \ref{tab:model_performance_test_iid} shows the overall performance of baselines and our model on the held-out i.i.d. test dataset \(\mathcal{D}_{test}\). All models, including our trained only \textbf{on 4\% of the data}, have comparable AUROC (0.82 - 0.87) and AUPRC (0.47 - 0.54), which means all models generalize well to unseen images from the same i.i.d. distribution.

\begin{table}[!h]
\centering
\caption{Comparison of baselines trained on the big dataset \(\mathcal{D}_{train}\) and our model trained on the small dataset \(\mathcal{D}^{seg}_{train}\) for performance on held-out i.i.d. test dataset \(\mathcal{D}_{test}\)}
\label{tab:model_performance_test_iid}
\begin{tabularx}{\textwidth}{>{\hsize=.3\hsize}X >{\hsize=.17\hsize}X >{\hsize=.2\hsize}X >{\centering\arraybackslash\hsize=.1\hsize}X >{\centering\arraybackslash\hsize=.12\hsize}X >{\centering\arraybackslash\hsize=.12\hsize}X}
% \Xhline{1pt}
\hline
\textbf{Model Name} & \textbf{Backbone} & \textbf{Weights Init.} & \textbf{Images} & \textbf{AUROC} & \textbf{AUPRC} \\
\hline
R18-random & ResNet-18 & Random & 2078 & 0.85 & 0.50 \\
R18-imagenet & ResNet-18 & ImageNet & 2078 & 0.87 & 0.52 \\
R50-random & ResNet-50 & Random & 2078 & 0.82 & 0.47 \\
R50-imagenet & ResNet-50 & ImageNet & 2078 & 0.83 & 0.52 \\
ViT-T-random & ViT-Tiny & Random & 2078 & 0.84 & 0.50 \\
ViT-T-imagenet & ViT-Tiny & ImageNet & 2078 & 0.86 & 0.54 \\
ViT-B-random & ViT-Base & Random & 2078 & 0.84 & 0.49 \\
ViT-B-imagenet & ViT-Base & ImageNet & 2078 & 0.85 & 0.52 \\
\hline
R18-attention (Ours) & ResNet-18 & ImageNet & \textbf{83} & 0.82 & 0.48 \\
% \cellcolor{acidgreen}\textcolor{custompink}{R18-attention (Our)} & \cellcolor{acidgreen}\textcolor{custompink}{ResNet-18} & \cellcolor{acidgreen}\textcolor{custompink}{ImageNet} & \cellcolor{acidgreen}\textcolor{custompink}{\textbf{83}} & \cellcolor{acidgreen}\textcolor{custompink}{0.82} & \cellcolor{acidgreen}\textcolor{custompink}{0.48} \\
\hline
\end{tabularx}
% \end{tabular}
\end{table}

\textbf{OOD Comparison.}

Table \ref{tab:model_performance_ood} shows models' performance on OOD datasets \(T_{SickKids}\), \(T_{Stanford}\), \(T_{CHOP}\), and \(T_{UIowa}\). Out of 9 models that perform well on i.i.d. data, only three models, including ours, transfer well to all OOD datasets. It demonstrates the effectiveness of using our approach, considering that our model trained on 25 times less data could generalize well to OOD data. 


% \cline{2-9}
% Model Name & \(T_{SickKids}\) & \(T_{Stanford}\) & \(T_{UIowa}\) & \(T_{CHOP}\) & \(T_{SickKids}\) & \(T_{Stanford}\) & \(T_{UIowa}\) & \(T_{CHOP}\) \\
\begin{table}[!h]
\centering
\caption{Comparison of baselines trained on the big dataset \(\mathcal{D}_{train}\) and our model trained on the small dataset \(\mathcal{D}^{seg}_{train}\) for performance on four OOD datasets \(T_{SickKids}\), \(T_{Stanford}\), \(T_{UIowa}\), and \(T_{CHOP}\)}
\label{tab:model_performance_ood}
\begin{tabularx}{\textwidth}{
>{\hsize=.37\hsize}X 
>{\centering\arraybackslash\hsize=.095\hsize}X 
>{\centering\arraybackslash\hsize=.095\hsize}X 
>{\centering\arraybackslash\hsize=.07\hsize}X 
>{\centering\arraybackslash\hsize=.07\hsize}X 
>{\centering\arraybackslash\hsize=.095\hsize}X 
>{\centering\arraybackslash\hsize=.095\hsize}X 
>{\centering\arraybackslash\hsize=.07\hsize}X 
>{\centering\arraybackslash\hsize=.07\hsize}X}
\hline
& \multicolumn{4}{c}{\textbf{AUROC}} & \multicolumn{4}{c}{\textbf{AUPRC}} \\
% \hline
% \cline{2-9}
\textbf{Model Name} & \small{\(T_{SickKids}\)} & \small{\(T_{Stanford}\)} & \small{\(T_{UIowa}\)} & \small{\(T_{CHOP}\)} & \small{\(T_{SickKids}\)} & \small{\(T_{Stanford}\)} & \small{\(T_{UIowa}\)} & \small{\(T_{CHOP}\)} \\
\hline
R18-random & 0.59 & 0.35 & 0.43 & 0.36 & 0.19 & 0.04 & 0.52 & 0.59 \\
\textbf{R18-imagenet} & \textbf{0.88} & \textbf{0.88} & \textbf{0.82} & \textbf{0.85} & \textbf{0.55} & \textbf{0.40} & \textbf{0.88} & \textbf{0.94} \\
R50-random & 0.61 & 0.49 & 0.66 & 0.55 & 0.21 & 0.06 & 0.67 & 0.71 \\
R50-imagenet & 0.74 & 0.66 & \textit{0.80} & \textit{0.78} & 0.23 & 0.07 & \textit{0.84} & 0.84 \\
ViT-T-random & 0.53 & 0.17 & 0.23 & 0.27 & 0.17 & 0.03 & 0.43 & 0.55 \\
ViT-T-imagenet & 0.77 & 0.62 & 0.66 & 0.62 & 0.35 & 0.12 & 0.72 & 0.72 \\
ViT-B-random & 0.57 & 0.22 & 0.23 & 0.24 & 0.24 & 0.04 & 0.44 & 0.54 \\
\textbf{ViT-B-imagenet} & \textbf{0.89} & \textbf{0.91} & \textbf{0.88} & \textbf{0.85} & \textbf{0.55} & \textbf{0.48} & \textbf{0.88} & \textbf{0.93} \\
\hline
\textbf{R18-attention (Ours)} & \textbf{0.86} & \textbf{0.88} & \textbf{0.90} & \textbf{0.81} & \textbf{0.53} & \textbf{0.42} & \textbf{0.90} & \textbf{0.92} \\
\hline
% \end{tabular}
\end{tabularx}
\end{table}

% \subsection{GradCAM}

% We provide visual representations of attention maps (GradCAM) for a selected sample across various datasets: i.i.d. test dataset (\(\mathcal{D}_{\text{test}}\)) and OOD datasets \(\mathcal{D}^{1}_{\text{tg}}\), \(\mathcal{D}^{2}_{\text{tg}}\), \(\mathcal{D}^{3}_{\text{tg}}\), and \(\mathcal{D}^{4}_{\text{tg}}\) (see Figure \ref{fig:grad}). Each group in the figure consists of three images: the original ultrasound, the kidney's segmentation mask, and the GradCAM attention map generated by our model for that particular image. These visualizations demonstrate our model's ability to consistently focus on the kidney, the critical region of interest. This ensures the model bases its predictions on the most relevant anatomical features, avoiding distraction by irrelevant details or noise. We quantify attention maps in Appendix \ref{appendix:attention_score} for further analysis.

% \vspace{10pt} 
% \begin{figure}[ht]
%     \centering
%     \begin{overpic}[width=0.95\textwidth]{images/gradcams_wide.drawio.png}
%     \put(25,37.5){\makebox(0,0){\textbf{$D_{test}$}}}
%     \put(75,37.5){\makebox(0,0){\textbf{\(T_{Stanford}\)}}}
%     \put(25,18){\makebox(0,0){\textbf{\(T_{SickKids}\)}}}
%     \put(75,25){\makebox(0,0){\textbf{\(T_{UIowa}\)}}}
%     \put(75,12.5){\makebox(0,0){\textbf{\(T_{CHOP}\)}}}
%     \end{overpic}
%     \caption{GradCAMs Visualization}
%     \label{fig:grad}
% \end{figure}

\section{Conclusion}

This paper presented a new method for improving medical image classification models using segmentation masks, especially effective in small dataset scenarios (less than 100 images). By utilizing a specialized loss function, our model demonstrated remarkable performance on both i.i.d. and OOD datasets despite limited training data. It matched or exceeded the performance of other models trained on similar-sized datasets in i.i.d. scenarios and consistently outperformed all baselines in OOD settings. Notably, our model, trained on just 4\% of the data, showed the same or even better performance as baselines trained on significantly larger datasets in i.i.d. and OOD settings. The implications of these results are promising. Creating segmentation masks, which our method relies on, could be more feasible than gathering extensive data on rare diseases. Additionally, our model's ability to transfer across different hospitals could reduce the need for unique models for each medical setting.

\bibliography{midl24_302}

\newpage

\appendix

\section{Hyperparameterms Search}
\label{appendix:hyperparameters_search}

In our study, we used Bayesian optimization to systematically explore and identify optimal hyperparameters for training all models. We focused on tuning batch size, gamma, learning rate, and weight decay, aiming to minimize the validation loss. We tested batch sizes of 16, 32, 64, and 128; gamma values ranging from 0.99 to 0.85 in decrements of 0.02; learning rates of 0.1, 0.01, 0.001, 0.0001, 1e-05, and 1e-06; and weight decay parameters of 0.3, 0.1, 0.03, 0.01, 0.003, and 0.001. The best set of hyperparameters for each model is reported in the Table \ref{tab:hyperparameter_selection}.

\begin{table}[!h]
\centering
\caption{Hyperparameter Selection for Models}
\label{tab:hyperparameter_selection}
\begin{tabularx}{\textwidth}
{>{\hsize=0.3\hsize}X 
>{\hsize=0.155\hsize}X 
>{\hsize=0.105\hsize}X 
>{\hsize=0.215\hsize}X 
>{\hsize=0.225\hsize}X}
\hline
\textbf{Model Name} & \textbf{Batch Size} & \textbf{Gamma} & \textbf{Learning Rate} & \textbf{Weight Decay} \\
\hline
R18-random-small & 32 & 0.91 & 0.001 & 0.1 \\
R18-imagenet-small & 16 & 0.91 & 0.001 & 0.1 \\
R50-random-small & 32 & 0.85 & 0.01 & 0.01 \\
R50-imagenet-small & 64 & 0.85 & 0.01 & 0.01 \\
ViT-T-random-small & 64 & 0.85 & 0.01 & 0.003 \\
ViT-T-imagenet-small & 128 & 0.95 & 0.0001 & 0.01 \\
ViT-B-random-small & 16 & 0.85 & 0.000001 & 0.3 \\
ViT-B-imagenet-small & 128 & 0.91 & 0.00001 & 0.001 \\
R18-random & 64 & 0.89 & 0.001 & 0.001 \\
R18-imagenet & 64 & 0.93 & 0.00001 & 0.01 \\
R50-random & 64 & 0.87 & 0.001 & 0.001 \\
R50-imagenet & 16 & 0.93 & 0.0001 & 0.001 \\
ViT-T-random & 32 & 0.87 & 0.0001 & 0.001 \\
ViT-T-imagenet & 16 & 0.99 & 0.00001 & 0.01 \\
ViT-B-random & 16 & 0.91 & 0.00001 & 0.001 \\
ViT-B-imagenet & 16 & 0.87 & 0.000001 & 0.03 \\
R18-attention (Ours) & 128 & 0.85 & 0.001 & 0.1 \\
\hline
\end{tabularx}
\end{table}


\section{Attention Score}
\label{appendix:attention_score}

To quantify how much different models actually pay attention to the region of interest, we create a new metric \textit{Attention Score}, which has two components Overlap Score and Coverage Score. \textit{Overlap Score (OS)} measures the proportion of the important areas, as defined by the ground truth mask $\mathcal{M}$, that is successfully captured by the attention map $\mathcal{A}$:
\begin{equation}
    {OS}(\mathcal{A}, \mathcal{M}) = \frac{\sum_{i=1}^{N} \min(\mathcal{A}_i, \mathcal{M}_i)}{\sum_{i=1}^{N} \mathcal{M}_i}
\end{equation}

where $N$ is the total number of pixels, and $i$ indexes these pixels. \textit{Coverage Score (CS)} assesses the concentration and specificity of the model’s attention, evaluating how much of the attention map's activation $\mathcal{A}$ is meaningfully focused on the target areas $\mathcal{M}$:
\begin{equation}
    CS(\mathcal{A}, \mathcal{M}) = \frac{\sum_{i=1}^{N} \min(\mathcal{A}_i, \mathcal{M}_i)}{\sum_{i=1}^{N} \mathcal{A}_i}   
\end{equation}

where $N$ is the total number of pixels, and $i$ indexes these pixels. You can think about the Overlap Score as a Recall metric and the Coverage Score as Precision but for Attention maps instead of classification labels. The final Attention Score is computed as the harmonic mean of the Overlap Score and Coverage Score, providing a balanced measure of both overlap and coverage:
\begin{equation}
    AttentionScore(\mathcal{A}, \mathcal{M}) = 2 \times \frac{
    {OS}(\mathcal{A}, \mathcal{M}) \times {CS}(\mathcal{A}, \mathcal{M})}{{OS}(\mathcal{A}, \mathcal{M}) + {CS}(\mathcal{A}, \mathcal{M})}
\end{equation}


\begin{table}[!h]
\small
\centering
\caption{Attention Scores on i.i.d. \(\mathcal{D}_{test}\) and OOD datasets \(T_{SickKids}\), \(T_{Stanford}\), \(T_{UIowa}\), and \(T_{CHOP}\)}
\label{tab:attention_scores}
\setlength{\extrarowheight}{2pt} % Increase the distance between rows
\begin{tabularx}{0.8\textwidth}{
>{\raggedright\arraybackslash\hsize=0.4\hsize}X
>{\centering\arraybackslash\hsize=0.12\hsize}X 
>{\centering\arraybackslash\hsize=0.12\hsize}X 
>{\centering\arraybackslash\hsize=0.12\hsize}X 
>{\centering\arraybackslash\hsize=0.12\hsize}X 
>{\centering\arraybackslash\hsize=0.12\hsize}X}
\hline
Model Name & \(\mathcal{D}_{test}\) & \(T_{SickKids}\) & \(T_{Stanford}\) & \(T_{UIowa}\) & \(T_{CHOP}\) \\
\hline
R18-random & 0.38 & 0.51 & 0.34 & 0.46 & 0.44 \\
R18-imagenet & 0.43 & 0.56 & 0.52 & 0.57 & 0.52 \\
R50-random & 0.40 & 0.49 & 0.40 & 0.40 & 0.51 \\
R50-imagenet & 0.41 & 0.48 & 0.43 & 0.45 & 0.45 \\
ViT-T-random & 0.44 & 0.57 & 0.10 & 0.23 & 0.09 \\
ViT-T-imagenet & 0.26 & 0.41 & 0.32 & 0.43 & 0.42 \\
ViT-B-random & 0.45 & 0.59 & 0.50 & \textbf{0.67} & 0.58 \\
ViT-B-imagenet & 0.27 & 0.36 & 0.38 & 0.39 & 0.39 \\
\hline
R18-attention (Ours) & \textbf{0.57} & \textbf{0.60} & \textbf{0.61} & 0.59 & \textbf{0.62} \\
\hline
\end{tabularx}
\end{table}

In Table \ref{tab:attention_scores}, we show that our model generally outperforms other models in terms of Attention Score. The interesting exception is the Attention Score for the ViT-B-random model on \(T_{UIowa}\) dataset, where it shows a higher score than our model. Considering the low performance of the ViT-B-random model in terms of AUROC and AUPRC on that dataset, we conclude that the attention score, even though it is a useful indicator of the model's performance, is only a part of the evaluation and should be considered in combination with other metrics.

\section{Datasets Information}
\label{appendix:datasets}

\textbf{Training Dataset}
\\
\textit{Sex distribution:} 2027 M, 515 F.  \\
\textit{Kidney side distribution:} 1289 Left, 1253 Right. \\
\textit{Ultrasound machine distribution:} philips-medical-systems: 992, toshiba-mec: 497, NA: 376, ToshibaST: 258, PhilipsST: 112, SamsungST: 97, ge-medical-systems: 45, samsung-medison-co-ltd: 36, OutsideST: 26, acuson: 25, atl: 22, toshiba-mec-us: 20, TreeST: 17, GEST: 13, siemens: 4, ge-healthcare: 2.  \\
\textit{The age} varies from 0.14 weeks to 720 weeks, with an average of 53 weeks.
\newline

\textbf{OOD dataset \(T_{SickKids}\)}
\\
\textit{Sex distribution:} 599 M, 112 F.  \\
\textit{Kidney side distribution:} 475 Left, 236 Right. \\
\textit{Ultrasound machine distribution:} ToshibaST: 294, PhilipsST: 247, SamsungST: 158, OutsideST: 12.  \\
\textit{The age} varies from 0.29 weeks to 92 weeks, with an average of 17 weeks. 
\newline

\textbf{OOD dataset \(T_{Stanford}\)}
\\
\textit{Sex distribution:} 413 M, 138 F.  \\
\textit{Kidney side distribution:} 275 Left, 276 Right. \\ 
\textit{Ultrasound machine distribution:} Stanford: 551.  \\
\textit{The age} varies from 104.0 weeks to 988 weeks, with an average of 190 weeks.
\newline

\textbf{OOD dataset \(T_{UIowa}\)}

\textit{Sex distribution:} 80 M, 17 F.  \\
\textit{Kidney side distribution:} 59 Left, 38 Right.\\ 
\textit{Ultrasound machine distribution:} UIowa: 97.  \\
\textit{The age} varies from 0.14 weeks to 266 weeks, with an average of 28 weeks.
\newline

\textbf{OOD dataset \(T_{CHOP}\)}

\textit{Sex distribution:} 55 M, 34 F.  \\
\textit{Kidney side distribution:} 56 Left, 33 Right.\\ 
\textit{Ultrasound machine distribution:} Philips: 51, GE: 16, Phillips: 7, HDI 5000: 3, Siemens: 2, Acuson: 2, General electric: 1, MRI abd w/wo, RBUS 7/19/2010: 1, Cineloop: 1, Mindray: 1, Toshiba: 1.  \\
\textit{The age} varies from 1.43 weeks to 1001 weeks, with an average of 313 weeks.

\section{Limitations and Future Work}
\label{appendix:future_work}

\textbf{Limitations.} Even though our model performed well on multiple out-of-distribution datasets, it's worth noting that all the data came from hospitals in the USA and Canada. In real-world scenarios, particularly in areas with substantially different demographics or medical equipment, our model might show diminished performance.
\newline

\textbf{Future Work.} We plan to conduct a series of comprehensive ablation studies to precisely quantify the impact of attention loss on the performance of each model separately. Additionally, we aim to broaden the applicability and robustness of our model by collecting and incorporating data from hospitals outside North America. This effort will test the model's ability to generalize across diverse demographics, addressing potential biases and enhancing its global applicability. Furthermore, we intend to explore the potential of our approach in other clinical settings, such as the diagnosis of pneumonia in lung ultrasound images. By extending our domain generalization efforts to various medical imaging tasks, we hope to contribute further to the advancement of AI in healthcare, ensuring models are both effective and equitable across different populations and conditions.

\end{document}

