\documentclass{midl} 
\usepackage{amsmath}
\usepackage{colortbl}
\usepackage{multirow}
\usepackage{pifont}
\usepackage{floatrow}

% Adjust the spacing between image and caption
\floatsetup[figure]{captionskip=-6.5pt}
\floatsetup[table]{captionskip=0pt}
\jmlrvolume{-- 296}
\jmlryear{2024}
\jmlrworkshop{Full Paper -- MIDL 2024}
\editors{Accepted for publication on at MIDL 2024}
\newcommand{\cmark}{\ding{51}}%
\newcommand{\xmark}{\ding{55}}%
%\title[Short Title]{Debiasing with Clarity: Diffusion Counterfactual Explanations}

\title[DeCoDEx]{DeCoDEx: Confounder Detector Guidance for Improved Diffusion-based Counterfactual Explanations}

\midlauthor{\Name{Nima Fathi\midljointauthortext{Contributed equally}\nametag{$^{1,2}$}} \Email{nimafh@cim.mcgill.ca}
\AND
\Name{Amar Kumar\midlotherjointauthor\nametag{$^{1,2}$}} \Email{amarkr@cim.mcgill.ca}\AND
\Name{Brennan Nichyporuk\nametag{$^{2}$}} \Email{brennann@cim.mcgill.ca}\AND
\Name{Mohammad Havaei\nametag{$^{3}$}} \Email{mhavaei@google.com} \AND
\Name{Tal Arbel\nametag{$^{1,2}$}} \Email{arbel@cim.mcgill.ca}\\
\addr $^{1}$ Center for Intelligent Machines, McGill University, Montreal, Canada. \\
\addr $^{2}$ MILA (Quebec AI institute), Montreal, Canada. \\
\addr $^{3}$ Google Research, Montreal, Canada.
}
\begin{document}
\maketitle
\begin{abstract}
Deep learning classifiers are prone to latching onto dominant confounders present in a dataset rather than on the causal markers associated with the target class,  leading to poor generalization and biased predictions. Although explainability via counterfactual image generation has been successful at exposing the problem, bias mitigation strategies that permit accurate explainability in the presence of dominant and diverse artifacts remain unsolved. In this work, we propose the DeCoDEx framework and show how an external, pre-trained binary artifact detector can be leveraged during inference to guide a diffusion-based counterfactual image generator towards accurate explainability.  Experiments on the CheXpert dataset, using both synthetic artifacts and real visual artifacts (support devices), show that the proposed method successfully synthesizes the counterfactual images that change the causal pathology markers associated with Pleural Effusion while preserving or ignoring the visual artifacts. Augmentation of ERM and Group-DRO  classifiers with the DeCoDEx generated images substantially improves the results across underrepresented groups that are out of distribution for each class. The code is made publicly available at \url{https://github.com/NimaFathi/DeCoDEx}.
%Through extensive qualitative experimentation, we show diversity in the synthesized counterfactual images. 
\end{abstract}

\begin{keywords}
 Bias Mitigation, Causality, Counterfactual Image Synthesis, Diffusion, Explainability, Spurious Correlations
\end{keywords}

\section{Introduction}
Deep learning (DL) methods have shown tremendous success in a wide variety of medical image tasks, including disease classification, due to their ability to learn generalizable, discriminative features across subjects. However,  DL models are prone to learning shortcuts in order to obtain high overall accuracies, including any prevalent visual artifacts (e.g. marks in the image~\cite{degrave2021ai}) that are correlated with, but not causal of, the target outcome. Models that have not learned the relevant causal visual markers~\cite{jia2017adversarial, zech2018variable} are {\it right for the wrong reasons}~\cite{sun2023inherently, sun2023right}, and fail to generalize across out-of-distribution subgroups  ~\cite{geirhos2018imagenet, geirhos2020shortcut}. Explainable DL models that not only expose these biases but mitigate them are required in order to ensure their trustworthiness for safe clinical deployment. 

Counterfactual (CF) image generation methods (e.g. Gifsplanation~\cite{cohen2021gifsplanation} and Attri-Net~\cite{sun2023inherently, sun2023right}) have recently been successful at exposing when the classifier is latching onto spurious correlations in order to obtain high performance. These methods employ conditional generation of the counterfactual image when the classifier has the opposing target outcome. Differences between the factual and counterfactual images should reflect the predictive local markers indicative of the class label, but also expose the classifier's reliance on spurious correlations. These methods do not mitigate the biases nor address their poor generalization. A number of debiasing methods~\cite{sagawa2020distributionally, wang2020towards, sarhan2020fairness} have recently been successful in several medical imaging contexts. Recent work~\cite{kumar2023debiasing} combined  Cycle-GAN counterfactual image generation and a Group-DRO~\cite{sagawa2020distributionally} classifier to expose and mitigate the biases. The results showed improvement for minority subgroups, and classification based on disease-specific features. However, debiasing techniques have a number of known drawbacks, including improved fairness at the expense of a reduction in the performance in the majority subgroup. Integrating them into the counterfactual models requires sub-group labels for each class during training (which is often unavailable), and GANs require retraining the generative model with the pre-trained debiasing classifier in order to provide supervision for the counterfactual synthesis, which makes the process inflexible. Recently, an unconditional DDPM (Denoising Diffusion Probabilistic Model)~\cite{ho2020denoising}, DiME~\cite{jeanneret2022diffusion}, has been proposed to generate classifier-guided counterfactual explanations, however, with the classifier latching onto shortcuts present in the dataset. Overall, developing bias mitigation strategies that permit accurate explainability in the presence of dominant and diverse visual artifacts remain open research questions.

This paper introduces DeCoDEx, a diffusion-based (DDPM) counterfactual image generator for debiased classifier explainability in the presence of dominant and diverse visual artifacts. DeCoDEx overcomes a number of limitations of current approaches: Rather than requiring training specialized debiasing classifiers for known subgroups, and then retraining the counterfactual generator that uses them (e.g. GANs), the framework provides debiased explainability of the classifier in question {\it at inference time} by leveraging the flexibility and explicit inference procedure~\cite{dhariwal2021diffusion,wang2022diffusion,kazerouni2022diffusion} of DDPMs, is generalizable to any subgroups, and shows stable training as compared to GANs. The framework can make use of any pre-trained, binary detector trained to indicate the presence or absence of the visual artifact in question (for any classes).  During inference, the detector guides the diffusion-based counterfactual image generator towards accurate explainability, as the gradients from the detector counter the gradients from the classifier away from spurious correlations. 

Extensive experiments are performed on the publicly available CheXpert dataset~\cite{irvin2019chexpert}, using both synthetic artifacts and real visual artifacts (support devices). Qualitative results show that the proposed method successfully synthesizes the counterfactual images by making changes in the pathology associated with Pleural Effusion while preserving or ignoring the visual artifacts. The quality of the counterfactual images was measured via several metrics, such as L1 distance, Counterfactual Prediction Gain (CPG) and Spurious Correlation Latching Score~\cite{kumar2023debiasing}. The results indicated the strength of framework over a baseline (without a detector). Augmentation of the dataset with counterfactual images synthesized with DeCoDEx improves ERM and Group-DRO classification for the minority subgroups. 

\section{Methodology}
The DeCoDEX framework involves a training strategy for explainability via counterfactual image synthesis, while ensuring robustness to spurious correlations. The model consists of a classifier and a trained unconditional denoising diffusion probabilistic model (DDPM). The framework leverages a pre-trained visual artifact detector that guides a DDPM to synthesize counterfactuals while ignoring the artifact.  

\begin{figure}[h]
\centering
\floatconts
  {}
  {\caption{CF explanations for a subject with Pleural Effusion in the presence of an artifact: (a) Chest radiograph of a sick patient: \textcolor{red}{dot artifact},  \textcolor{cyan}{disease pathology}; (b) CF image from biased classifier using DDPM (i.e. DeCoDEx without detector) maintains the diseased area but modifies the dot; (c) DeCoDEx CF image modifies the Pleural Effusion area to look healthy as expected~\cite{huang2022deep,wang2017chestx} while ignoring the dot artifact.}
  \label{fig:intro}}
  {\includegraphics[width=0.75\linewidth]{images/intro_hq.png}}
\end{figure}

One of the advantages of our approach is the flexibility to use any pre-trained detector that can identify spurious correlations in the input sample. During the counterfactual image generation, when the biased classifier relies on spurious non-causal features in the dataset, the detector's gradient reversal signal readjusts the generation process, steering it back toward focusing on relevant pathological features (Fig.~\ref{fig:intro}). %Specifically, when the biased classifier is used, the counterfactual images show importance to the artifacts. But when the gradients to these artifacts are blocked, the classifier indicates features associated with the target class. 
It should be noted here that if there are several spurious correlations in the dataset that are difficult to detect, the detector may only block only some of them. Therefore, for some difficult cases, even after using a detector the counterfactual images may fail to make changes in the area correlated with the target class. An overview of our method is shown in Fig.~\ref{fig:workflow}.
\begin{figure}[h]
\centering
\floatconts
  {}
  {\caption{DeCoDEx Framework: Generating the counterfactuals (CFs) involves several inference steps. At each step, there are several components: (1) Denoising via unconditional DDPM, (2) pretrained classifier and detector loss, (3) gradient of the classifier, detector and perceptual loss, (4) counterfactual synthesis via sampling and backpropagating loss from black-box classifier and detector. The classifier, detector and unconditional DDPM are all pre-trained components. The resulting CF makes changes to the \textcolor{cyan}{disease markers} while disregarding \textcolor{red}{visual artifacts}.} \label{fig:workflow}}
  {\includegraphics[width=0.9\linewidth]{images/diagram.png}}
\end{figure}

\noindent{\bf Counterfactual Image Generation}:
%\noindent \underline{\textbf{}}\\
The counterfactual image generation is designed to adhere to a set of constraints~\cite{mothilal2020explaining,nemirovsky2020countergan} through the following loss functions: (i)\textit{Identity preservation loss, $\mathcal{L}_{perc}$}: counterfactual images preserves the identity of the factual image, $\mathcal{L}_{perc}$; (ii) \textit{Classifier consistency loss},  $\mathcal{L}_{class}$: counterfactual image belongs to the correct target class. An additional \textit{Detector loss} term,  $\mathcal{L}_{det}$, is introduced to help guide the generation away from the spurious correlations that the classifier would have otherwise latched onto. 

The generative model for synthesizing counterfactual images use a pre-trained DDPM with classifier guidance. DDPMs operate through a forward diffusion process (Eq.~\ref{eq:forwarddiff}), which incrementally corrupts the original image \( x \) (or \( z_0 \)) by adding Gaussian noise, culminating in a highly noised image \( z_T \) over \( T \) timesteps. 
\begin{equation}\label{eq:forwarddiff}
    z_t \sim \mathcal{N} \left(\sqrt{1-\beta_t} z_{t-1}, \beta_t \mathbb{I}\right),
\end{equation}
where \( \beta_t \) are predefined noise levels.
%,introducing a noise schedule that critically influences model performance \cite{sohl2015deep}. %Through the rest of the paper, we will refer to clean images with an \( x \), while noisy ones with a \( z_0 \).
The reverse diffusion process (Eq.~\ref{eq:reversediff}) aims to recover the original image from \( z_T \), using a neural network to predict and subtract the added noise iteratively.
\begin{equation}\label{eq:reversediff}
    x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left(x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right),
\end{equation}
where \( \alpha_t := \prod_{k=1}^{t} (1-\beta_k) \) and \( \epsilon_\theta(x_t, t) \) is the noise estimated by the network at step \( t \). Guided-diffusion sampling~\cite{dhariwal2021diffusion} is used to denoise the image at any time step $t$, given by $
z_{t-1} \sim \mathcal{N}\left( \mu(z_{t}) - \Sigma(z_{t}) \nabla_{z_{t}} \mathcal{L}(z_{t};y_{c}, y_{s}, x_{t}), \Sigma(z_{t})\right)$.
The complete loss function is given by: 
\begin{equation}\label{eq:completelossfunction}
     \mathcal{L}(x_t;y_c,y_s,x) = \lambda_c\mathcal{L}_{class}\left(C(y_c|x_t)\right) + \lambda_d\mathcal{L}_{det}\left(D(y_s|x_t)\right) + \lambda_p\mathcal{L}_{perc}(x_t,x),
 \end{equation}
 where $C$ and $D$ refer to our classifier and detector modules, $y_c$ and $y_s$ refer to the target labels of the class and the presence of spurious correlation and $\lambda_c,\lambda_d,\lambda_p$ are hyperparmeters. 
Finally, the gradient of the complete loss function can be expressed as follows:
\begin{equation}\label{eq:gradient}
\nabla_{z_t} \mathcal{L}(z_t;y_{c}, y_{s}, x_{t}) = \frac{1}{\sqrt{\alpha_{t}}} \nabla_{x_t}\mathcal{L}(x_t;y_{c}, y_{s}, x).
\end{equation}
\section{Experiments and Results}
\begin{figure}
\floatconts{}
  {\caption{Majority and Minority subgroup samples from the Dot dataset (top row) and Device dataset (bottom row). Red boxes show the location of the artifacts.}\label{fig:majmin}}
  {\includegraphics[width=0.75\linewidth]{images/fig3.png}}
\end{figure}
\subsection{Dataset and Implementation Details}
We perform experiments on the publicly available CheXpert dataset~\cite{irvin2019chexpert} that contains over 200,000 chest X-ray images, with binary labels for 14 diseases (e.g. Pleural Effusion, Cardiomegaly, Pneumonia), as well as binary labels indicating the presence of support devices (visual artifacts). We create two variants of the CheXpert dataset: (i) 
\textbf{\textit{Dot dataset}}: We introduce a synthetic artifact, a black dot of radius 9 pixels, in the center of the image for evaluating the quality of counterfactual images in the presence of spurious correlations. This also helps to compare the behaviour of counterfactual images synthesized from the baseline method: DeCoDEx without a detector and the proposed DeCoDEx. (ii) \textbf{\textit{Device dataset:}} A subset of the CheXpert data is used to demonstrate the performance of DeCoDEx in the presence of real artifacts (Support devices). In both datasets, the artifacts are present in the majority of images of subjects with Pleural Effusion and in the minority of images of healthy subjects. In contrast, the majority of the images of healthy subjects and the minority of the images of subjects with Pleural Effusion do not contain these artifacts. Fig.~\ref{fig:majmin} shows the sample images from both the datasets and all four subgroups. For both datasets, the ratio of the number of samples in majority to minority is 90:10 and the dataset is divided into training/validation/testing with a 70/15/15 random split. The details of number of samples in different split is included in Appendix~\ref{appendix:A}.

The DenseNet-121~\cite{huang2017densely} architecture is used to train the classifier and detector. The classifier is trained separately on Dot and Device datasets. We use the standard Empirical Risk Minimization (ERM)~\cite{sagawa2020distributionally} as the optimization method. %The dot detector is trained on the Dot dataset as it is simpler for a deep learning network to identify dot in the image. 
The binary detector indicating the presence/absence of support devices is pre-trained on the entire CheXpert dataset (except the held out test set).  DeCoDEx is capable of handling images with multiple support devices,  each varying in type, shape, size, location and intensities. An analysis showing the performance of the detector via counterfactual image generation is discussed in Appendix~\ref{appendix:explaindetector}.
\begin{figure}[t]
\floatconts
  {fig:dimevsrev}
  {\caption{Qualitative comparison of counterfactual images synthesized via Baseline (i.e. DeCoDEx without detector) and DeCoDEx. For the baseline, most of the changes were made to the spurious correlation but for DeCoDEx \textcolor{red}{visual artifacts} were ignored and changes pertained to \textcolor{cyan}{disease pathology}.}\label{fig:dimevsdecodex}}
  {\includegraphics[width=0.9\linewidth]{images/fig4.png}}
\end{figure}

\subsection{Metrics and Experiments to evaluate counterfactuals}
Several metrics are used to evaluate the quality of counterfactual images: (i) {\it Subject Identity Preservation: L1 Score}  as depicted by the L1 distance between the counterfactual and the factual (original) image (as in ~\cite{mothilal2020explaining,nemirovsky2020countergan}); (ii) {\it Counterfactual Prediction Gain (CPG)}~\cite{nemirovsky2020countergan} which measures the absolute value of the difference in the prediction of the classifier on the factual and counterfactual images (a higher score indicates a maximal change in the classifier decision boundary); (iii) {\it Spurious Correlation Latching Score (SCLS)})~\cite{kumar2023debiasing} assesses whether the spurious correlation was preserved in the counterfactual image (a lower SCLS score is desirable); (iv) {\it Classifier Flip Rate (CFR)} which represent the number of samples that flipped their class as per the classifier and (v) {\it Detector Robustness Rate (DRR)}  showing number of samples that were robust to the detector.

In order to show that the synthesized counterfactuals learn useful features associated with the disease, the training data for the original classifier was augmented with synthesized images. The ERM classifier is retrained with the augmented data\footnotetext{Only the training dataset is augmented while the validation and test split remains the same.}. An increase in performance indicates that synthesized images learned discriminative features generalizable to the subgroups. We augment the Dot dataset with 200 counterfactual images and the Device dataset with 600 counterfactual images synthesized using basline method and DeCoDEx. For completeness, the augmentation experiments are repeated for a debiasing Group-DRO classifier.  

\begin{table}[t]
\centering{%
\begin{tabular}{cccccc}
\hline
 & CFR↑    & DRR↑    & L1↓    & CPG↑   & SCLS↓   \\ \hline
\rowcolor{gray!20}
\multicolumn{6}{c}{Dot dataset}                                         \\ \hline
\multicolumn{1}{c|}{Baseline} & 0.8975 & 0.2575 & 0.038 & \textbf{0.592} & 0.7394 \\
\multicolumn{1}{c|}{DeCoDEx}  & \textbf{1}      & \textbf{0.9775} & \textbf{0.036} & 0.559 & \textbf{0.058}  \\ \hline
\rowcolor{gray!20}
\multicolumn{6}{c}{Device dataset}                                      \\ \hline
\multicolumn{1}{c|}{Baseline} & \textbf{0.97}   & 0.7625 & 0.040 & 0.377 & 0.201  \\
\multicolumn{1}{c|}{DeCoDEx}  & 0.89   & \textbf{0.99}   & \textbf{0.035} & \textbf{0.529} & \textbf{0.068 } \\ \hline
\end{tabular}%
}
\caption{Quantitative results comparing the scores for the counterfactuals generated by the Baseline and DeCoDEx on both datasets. Notice the high DRR and low SCLS values for DeCoDEx showing the spurious correlation were ignored in the counterfactual images.}
\label{tab:final_results}
\end{table}

\subsection{Results}
\textbf{Classifier and Detector evaluation}
The classifier's performance on both datasets is shown in Table~\ref{tab:exp2} on the row labeled `ERM'. Note that the performance on the minority subgroup samples is significantly lower than the majority for both datasets. The dot and device detectors perform very well with subgroups \textbf{{[}$majority_S$, $minority_S$, $minority_H$, $majority_H${]}} accuracies of [100, 99.9, 100, 99.8] and [91.8, 88.8, 79.2, 88.7]. Perfect accuracies for the dot detector can be expected given that the position and size of the dot remain fixed for all the subjects. However, the variability including size, position, intensity of support devices can be large, making detection much more challenging.

\noindent\textbf{Qualitative evaluation}
Pleural effusion (PE) is characterized by the rounding of the costophrenic angle, augmented lung opacity, and reduced clarity of the diaphragm and lung fissures~\cite{light2002pleural}. It is observed in the lower corner of the lungs~\cite{wang2017chestx,huang2022deep}. Qualitative results for counterfactual generation from the models can be seen in Fig.~\ref{fig:dimevsdecodex}. For the baseline method on the Dot dataset, the counterfactual for a healthy subject simply added dots to the image, indicating that the classifier has latched onto the dot artifact. However, DeCoDEx ignored the dot artifact (and maintained the original text artifact) and made changes reflective of PE disease. For the Device dataset, the baseline generated a counterfactual for a sick patient from the majority subgroup by simply removing the support device(s) while DeCoDEx makes (correct) changes in the area associated with the disease. Therefore, DeCoDEx indeed ignores the spurious correlation.

\noindent\textbf{Quantitative evaluation} 
In Table~\ref{tab:final_results}, CF images generated by the baseline and DeCoDEx show similar results for the metric L1 indicating CF images are similar to the factual image. CF synthesized by DeCoDEx have SCLS score close to zero indicating that the artifact was preserved in the counterfactual images. Table~\ref{tab:exp2} shows the results of the augmentation experiments. Augmented training significantly improves the performance over the minority groups for the Dot dataset, indicating that CF samples have learned discriminative features common to this subgroup. Among both ERM and Group-DRO based techniques, our method outperforms the Baseline along with augmented CFs. CF synthesized from DeCoDEx when augmented with ERM performs as well as Group-DRO for some minority classes which is a positive result, particularly when debiasing cannot be performed due to the lack of annotations.
\begin{table}[t]
\centering
\resizebox{\textwidth}{!}{%
\begin{tabular}{ccccc}
\hline
\rowcolor{gray!20}
\multicolumn{5}{c}{\textbf{\textit{Dot dataset}}}                                                                                                                                                                                        \\ \hline
\multicolumn{1}{c}{}                       & \multicolumn{2}{c|}{\textbf{Pleural Effusion}}                                                 & \multicolumn{2}{c}{\textbf{Healthy}}                                  \\ \cline{2-5} 
\multicolumn{1}{c}{}                       & \multicolumn{1}{c|}{\cellcolor{green!20}Dot}          & \multicolumn{1}{c|}{\cellcolor{red!20} No Dot}          & \multicolumn{1}{c|}{\cellcolor{red!20} Dot}          & \cellcolor{green!20} No Dot          \\ \hline
\multicolumn{1}{c|}{\textbf{ERM}}      & 97                              & 2.2                                  & 8                                & 100                       \\ 
\multicolumn{1}{c|}{\textbf{ERM augmented with Baseline CFs}} & \textbf{98.5}                                          & 0                                            & 16.6                                         & 100             \\ 
\multicolumn{1}{c|}{\textbf{ERM augmented with DeCoDEx CFs}} & 90.6                                          & \textbf{12.1}                                            & \textbf{53.1}                                         & 98.0             \\ 
\hline
\multicolumn{1}{c|}{\textbf{Group-DRO}}      & 90                                           & 61.8                                           & 56.0                                          & 88.0                       \\
\multicolumn{1}{c|}{\textbf{Group-DRO augmented with Baseline CFs}} & \textbf{97}                                & 8.0                                  & 33.0                                  & \textbf{98.9}             \\ 
\multicolumn{1}{c|}{\textbf{Group-DRO augmented with DeCoDEx CFs}} & 91.0                                & \textbf{70.2 }                                & \textbf{60.9}                                  & 80.3              \\ \hline
\rowcolor{gray!20}
\multicolumn{5}{c}{\textbf{\textit{Device dataset}}}                                                                                                                                                                                        \\ \hline
\multicolumn{1}{c}{}                       & \multicolumn{2}{c|}{\textbf{Pleural Effusion}}                                                 & \multicolumn{2}{c}{\textbf{Healthy}}                                  \\ \cline{2-5} 
\multicolumn{1}{c}{}                       & \multicolumn{1}{c|}{\cellcolor{green!20}Support Device} & \multicolumn{1}{c|}{\cellcolor{red!20} No Support Device} & \multicolumn{1}{c|}{\cellcolor{red!20}Support Device} & \cellcolor{green!20}No Support Device \\ \hline
\multicolumn{1}{c|}{\textbf{ERM}}      & \textbf{92.7}                                & 75.2                                   & 84.9                                  & 87.0                       \\
\multicolumn{1}{c|}{\textbf{ERM augmented with Baseline CFs}} & 92.5                                          & 74.8                                            & 83.4                                         & 87.0            \\ 
\multicolumn{1}{c|}{\textbf{ERM augmented with DeCoDEx CFs}} & 92.6                                          & \textbf{76.3}                                            & \textbf{85.9}                                         & 86.8             \\ 
\hline
\multicolumn{1}{c|}{\textbf{Group-DRO}}      & 92.7                                           & 83.4                                            & 77.3                                           & 88.4                       \\ 
\multicolumn{1}{c|}{\textbf{Group-DRO augmented with Baseline CFs}} & \textbf{97.8}                                & \textbf{85.4}                                   & 55.9                                  & 72.0   
\\
\multicolumn{1}{c|}{\textbf{Group-DRO augmented with DeCoDEx CFs}} & 93.2                                & 84.7                                   & \textbf{79.0}                                  & 88.4    
\\ \hline

\hline

\end{tabular}%
}
\caption{Augmented Classifier Accuracies: CFs generated by DeCoDEx and the Baseline are used to augment the imbalanced datasets. Both ERM and Group-DRO are retrained on these augmented datasets and the effects are examined. The accuracies (percentages) of all classifiers are shown on the held out test set. Green indicates majority subgroups (90\%) and red are the minority subgroups (10\%). The best results are in bold. Note the increase in performance for the minority classes when both the methods, ERM and Group-DRO are augmented with DeCoDEx counterfactual images, illustrating the power of our method on extracting disease pathology and generating better counterfactual explanations.}
\label{tab:exp2}
\end{table}

\section{Conclusions}
In medical image analysis, explainable models are needed to expose and mitigate the bias to improve the trustworthiness of complex models. This paper presents DeCoDEx, an explainability framework that leverages a pre-trained classifier and detector to guide diffusion-based counterfactual synthesis towards accurate disease markers while ignoring spurious correlations. Qualitative and quantitative analysis of our extensive experiments indicate that the proposed method outperforms the baseline model that does not use a detector. Furthermore, the flexibility of our method allows it to be used with any pre-trained detector, does not require retraining a debiasing classifier and associated generative architecture, and provides guidance during inference. One of the current limitations of the model is resulting minor changes throughout the generated CF images. Future work will explore conditional score-based generative models.

\midlacknowledgments{The authors are grateful for funding provided by the Natural Sciences and Engineering Research Council of Canada, the Canadian Institute for Advanced Research (CIFAR) Artificial Intelligence Chairs program, Mila - Quebec AI Institute, Google Research, Calcul Quebec, and the Digital Research Alliance of Canada. }

\bibliography{midl24_296}

\newpage
\appendix
\section{Detailed dataset description}
\label{appendix:A}
We elaborate on the datasets and their variants used in our experiments in this appendix, complemented by statistical data on participant distribution across groups as detailed in Table \ref{table:samples}. 
\begin{table}[!htb]
\centering
\resizebox{0.7\textwidth}{!}{%
\begin{tabular}{ccccccc}
\hline
\rowcolor{gray!20}
\rowcolor{gray!20}
\multicolumn{7}{c}{\textbf{Dot dataset }}                   \\ \hline
\multicolumn{1}{c|}{\multirow{2}{*}{\textbf{Spurious Correlation }}}    & \multicolumn{3}{c|}{\textbf{Pleural Effusion}} & \multicolumn{3}{c}{\textbf{Healthy}} \\
\multicolumn{1}{c|}{}                  & Train & Validation & \multicolumn{1}{c|}{Test} & Train & Validation & Test \\ \hline
\multicolumn{1}{c|}{Dot}               & 1359  & 191        & \multicolumn{1}{c|}{242}  & 56    & 5          & 49    \\ \hline
\multicolumn{1}{c|}{No Dot}            & 166   & 28         & \multicolumn{1}{c|}{242}   & 340   & 51         & 49   \\ \hline
\rowcolor{gray!20}
\multicolumn{7}{c}{\textbf{Device dataset}} \\ \hline
\multicolumn{1}{c|}{\multirow{2}{*}{\textbf{Spurious Correlation} }}    & \multicolumn{3}{c|}{\textbf{Pleural Effusion}}  & \multicolumn{3}{c}{\textbf{Healthy}} \\
\multicolumn{1}{c|}{}                  & Train & Validation & \multicolumn{1}{c|}{Test}                      & Train & Validation & Test \\ \hline 
\multicolumn{1}{c|}{Support Device(s)} & 6653  & 1432       & \multicolumn{1}{c|}{1389} & 665   & 143        & 138  \\ \hline
\multicolumn{1}{c|}{No Support Device(s)} & 665  & 143 & \multicolumn{1}{c|}{138} & 6653      & 1432     & 1389     \\ \hline
\end{tabular}%
}
\caption{
Summary of the number of samples for both dataset variants}
\label{table:samples}
\end{table}
\vspace{-3mm}
\section{Explainability: Providing insight into the detector result}\label{appendix:explaindetector}
%Since we show that the gradients from the detector can help avoid spurious correlations, it is crucial to understand the features important to the detector. 
We wish to provide some insights into the workings of the artifact/medical device detector. Fig.~\ref{fig:detectorXAI} shows two examples of explainability via counterfactual image generation illustrating the correct working of the detector. In both examples, the binary classifier was correctly focusing on the support devices. These are removed in the counterfactual images in order to flip the decision of the binary classifier.

\begin{figure}[h]
\floatconts
  {}
  {\caption{CF explanations for the detector: Removing medical devices from the original images while explaining the detector. Note the disease state is maintained in the counterfactual image. }\label{fig:detectorXAI}}
  {\includegraphics[width=0.46\linewidth]{images/appendix.png}}
\end{figure}
\newpage
\section{Extensive augmentation of the minority subgroup with synthesized CFs}
\begin{table}[h]
\centering
\resizebox{\textwidth}{!}{%
\begin{tabular}{ccccc}
\hline
\rowcolor{gray!20}
\multicolumn{5}{c}{\textbf{\textit{Dot dataset}}}              \\ \hline
\multicolumn{1}{c}{}                       & \multicolumn{2}{c|}{\textbf{Pleural Effusion}}                                                 & \multicolumn{2}{c}{\textbf{Healthy}}                                  \\ \cline{2-5} 
\multicolumn{1}{c}{}                       & \multicolumn{1}{c|}{\cellcolor{green!20}Dot}          & \multicolumn{1}{c|}{\cellcolor{red!20} No Dot}          & \multicolumn{1}{c|}{\cellcolor{red!20} Dot}          & \cellcolor{green!20} No Dot          \\ \hline

\multicolumn{1}{c|}{\textbf{ERM augmented with DeCoDEx CFs [200]}} & 90.6                                          & 12.1                                            & 53.1                                         & 98.0             \\ 
\multicolumn{1}{c|}{\cellcolor{cyan!10}\textbf{ERM augmented with DeCoDEx CFs [400]}} & \cellcolor{cyan!10}\textbf{93.8} & \cellcolor{cyan!10}\textbf{26.5}  & \cellcolor{cyan!10} \textbf{61.2}  & \cellcolor{cyan!10}98.3             \\ 
\hline


\multicolumn{1}{c|}{\textbf{Group-DRO augmented with DeCoDEx CFs [200]}} & 91.0                                & 70.2                                & 60.9                                  & 80.3              \\
\multicolumn{1}{c|}{\cellcolor{cyan!10}\textbf{Group-DRO augmented with DeCoDEx CFs [400]}} & \cellcolor{cyan!10}\textbf{93.1} & \cellcolor{cyan!10}\textbf{81.6}& \cellcolor{cyan!10}\textbf{65.3} & \cellcolor{cyan!10}85.9              \\ \hline
\rowcolor{gray!20}
\multicolumn{5}{c}{\textbf{\textit{Device dataset}}}                                                                                                                                                                                        \\ \hline
\multicolumn{1}{c}{}                       & \multicolumn{2}{c|}{\textbf{Pleural Effusion}}                                                 & \multicolumn{2}{c}{\textbf{Healthy}}                                  \\ \cline{2-5} 
\multicolumn{1}{c}{}                       & \multicolumn{1}{c|}{\cellcolor{green!20}Support Device} & \multicolumn{1}{c|}{\cellcolor{red!20} No Support Device} & \multicolumn{1}{c|}{\cellcolor{red!20}Support Device} & \cellcolor{green!20}No Support Device \\ \hline
\multicolumn{1}{c|}{\textbf{ERM augmented with DeCoDEx CFs [600]}} & 92.6                                          & 76.3                                            & 85.9                                         & 86.8             \\ 
\multicolumn{1}{c|}{\cellcolor{cyan!10}\textbf{ERM augmented with DeCoDEx CFs [1600]}} & \cellcolor{cyan!10}\textbf{92.8}  & \cellcolor{cyan!10}\textbf{76.8}  & \cellcolor{cyan!10}\textbf{86.8}  & \cellcolor{cyan!10} 82.3\\ 
\hline
\multicolumn{1}{c|}{\textbf{Group-DRO augmented with DeCoDEx CFs [600]}} & 93.2 & 84.7 & 79.0 & 88.4    
\\ 
\multicolumn{1}{c|}{\cellcolor{cyan!10}\textbf{Group-DRO augmented with DeCoDEx CFs [1600]}} & \cellcolor{cyan!10}\textbf{93.5} & \cellcolor{cyan!10}\textbf{86.5}& \cellcolor{cyan!10}\textbf{79.9} & \cellcolor{cyan!10}\textbf{90.1}    
\\ \hline

\hline

\end{tabular}%
}
\caption{Improved accuracy at subgroup level through extensive classifier augmentation: Using DeCoDEx, we expanded on counterfactual generation to demonstrate improvement in the subgroup accuracy. We synthesise 400 counterfactual samples for the Dot dataset and 1600 for the Device dataset (the number in the square bracket refers to the total number of augmentation samples added to both \textcolor{green}{majority} and \textcolor{red}{minority} subgroups). First row is the original results discussed in Table~\ref{tab:exp2} and the \textcolor{cyan}{second row} shows the result after extensive augmentation. Notice the improvement in the accuracy of minority subgroups across both Dot and Device datasets. Notably, 90\% of these counterfactuals represent minority subgroups, thereby achieving a more equalized distribution in the dataset. Our findings indicate improvement in the accuracy of minority groups across all scenarios.}
\label{tab:abl2}
\end{table}


\section{Validating the Preservation of Patient Sex in the Synthesized Counterfactual Images}
The identity preservation loss in Equation~\ref{eq:completelossfunction} does not guarantee that all the other attributes of the patients are maintained in the counterfactual images. A quick experiment was performed to validate that the sex of the patients is maintained in the counterfactual images generated by DeCoDEx. To this end, a sex classifier, $\mathcal{G}$, is trained on the real images, and then tested on real and synthesized counterfactual images. The sex classifier had an AUC-ROC of 0.98 on the real (factual) images. The differences in the sex classifier results based on the factual (F) and the counterfactual (CF) images, $|\mathcal{G}(\text{F}) - \mathcal{G}(\text{CF})|$, were 0.08 on average, indicating that the sex attribute was maintained in the counterfactual images.

\end{document}
