Confounder-Free Continual Learning via Recursive Feature Normalization

Published: 01 May 2025, Last Modified: 18 Jun 2025ICML 2025 posterEveryoneRevisionsBibTeXCC BY 4.0
TL;DR: We introduce the Recursive Metadata Normalization (R-MDN) layer to learn confounder-invariant feature representations under changing distributions of the data during continual learning.
Abstract: Confounders are extraneous variables that affect both the input and the target, resulting in spurious correlations and biased predictions. There are recent advances in dealing with or removing confounders in traditional models, such as metadata normalization (MDN), where the distribution of the learned features is adjusted based on the study confounders. However, in the context of continual learning, where a model learns continuously from new data over time without forgetting, learning feature representations that are invariant to confounders remains a significant challenge. To remove their influence from intermediate feature representations, we introduce the Recursive MDN (R-MDN) layer, which can be integrated into any deep learning architecture, including vision transformers, and at any model stage. R-MDN performs statistical regression via the recursive least squares algorithm to maintain and continually update an internal model state with respect to changing distributions of data and confounding variables. Our experiments demonstrate that R-MDN promotes equitable predictions across population groups, both within static learning and across different stages of continual learning, by reducing catastrophic forgetting caused by confounder effects changing over time.
Lay Summary: When we train an AI model, it can often be influenced by certain factors, which we call confounders, that affect both the input to the model as well as the outcome that the model predicts. This can cause the model to learn false associations from the data and make biased predictions. For example, a model trained to diagnose neurodegenerative disorders from brain MRI scans acquired from two different clinical sites, where healthy patients come from one site and diseased patients come from the other, can start learning site-specific patterns for prediction as opposed to actual signs of the disease from the collected scans. While recently proposed techniques such as metadata normalization (MDN) have been developed to help fix this issue, the setting assumed is that when we have all data collected prior to training the model. However, in most settings, data is collected over long periods of time--often months or years--and we would like to keep training the model when new data comes in as opposed to starting from scratch. This setting is called continual learning. We propose a new method called Recursive MDN to combat biased learning in continual learning settings by using a technique from statistics called recursive least squares. Our model continually tracks changes in the distribution of confounders and learns in real time to avoid making misleading predictions. Through our tests, we find that our model makes equitable predictions across various population groups.
Primary Area: Applications->Health / Medicine
Keywords: deep neural networks, confounders, continual learning, invariant representations, statistical regression
Link To Code: https://github.com/stanfordtailab/RMDN.git
Submission Number: 14097
Loading