Keywords: Joint-Embedding Predictive Architecture, representation collapse
Abstract: JEPA-style predictive models are promising foundations for world models, yet they exhibit a surprising early-training pathology: the prediction loss drops rapidly while learned representations remain useless for downstream tasks. We call this the collapse phase. It arises because the exponential moving average (EMA) keeps the target encoder too close to the main encoder, making prediction trivial and allowing the model to minimize its objective without extracting meaningful structure. We derive an upper bound showing that collapse risk depends on the interplay between momentum dynamics and masking strategy, and that escape requires the encoder's updates to outpace the EMA's smoothing. Empirically, we show the collapse phase appears across images and time-series sensor data, lasting thousands of steps in each case. Our analysis provides a diagnostic metric for detecting collapse during training and explains why certain hyperparameter regimes prolong it. This reframes collapse not as an objective flaw but as a transient regime with predictable onset and recovery, offering practitioners a tool to monitor and understand early training dynamics in predictive world models.
Submission Number: 86
Loading