Stable-Drift: A Patient-Aware Latent Drift Replay Method for Stabilizing Representations in Continual Learning
Abstract: When deep learning models are sequentially trained on
new data, they tend to abruptly lose performance on previously
learned tasks, a critical failure known as catastrophic
forgetting. This challenge severely limits the deployment
of AI in medical imaging, where models must continually
adapt to data from new hospitals without compromising established
diagnostic knowledge. To address this, we introduce
a latent drift-guided replay method that identifies
and replays samples with high representational instability.
Specifically, our method quantifies this instability via ”latent
drift”, the change in a sample’s internal feature representation
after naive domain adaptation. To ensure diversity
and clinical relevance, we aggregate drift at the patient
level; our memory buffer stores the per patient slices exhibiting
the greatest multi-layer representation shift. Evaluated
on a cross-hospital COVID-19 CT classification task
using state-of-the-art CNN and Vision Transformer backbones,
our method substantially reduces forgetting compared
to naive fine-tuning and random replay. This work
highlights latent drift as a practical and interpretable replay
signal for advancing robust continual learning in realworld
medical settings.
Loading