A Lightweight 3D Conditional Diffusion Model for Self-explainable Brain Age Prediction in Adults and Children
Abstract: Deep learning models that predict an individual’s biological brain age from structural MR images are widely used in neuroimaging to analyze brain development and aging. While standard discriminative models for this task are highly accurate, they usually suffer from poor explainability. This can be overcome by inherently self-explainable models that follow a generative paradigm that not only enables them to perform predictions but also to generate counterfactual images that visually explain their decision making process. Lately, denoising diffusion models have achieved incredible results in generative machine learning. However, their training on full-resolution 3D MR images, as required for many tasks, is computationally expensive. Here, we present a new lightweight wavelet-based diffusion model architecture that enables computationally efficient training of age-conditioned diffusion models on 3D brain MR images with a single consumer-grade GPU. We then show how those lightweight diffusion models can be used for self-explainable biological brain age prediction. The results of our proof-of-concept evaluation relying on imaging data of more than 7000 adults and more than 5000 images of children demonstrate the effectiveness of our method. This includes diverse, and realistic samples generated by the models, accurate brain age predictions (mean absolute error for adults: 3.93 yrs., children: 1.38 yrs.), and realistic counterfactual images.
Loading