Keywords: World Models, Deep Supervision, Linear Probes, Representation Learning
TL;DR: We improve world models in recurrent networks by adding a linear probe loss, which enhances prediction accuracy, reduces distribution drift, and enables smaller models to achieve better generalization.
Abstract: Developing effective world models is crucial for creating artificial agents that can reason about and navigate complex environments.
In this paper, we investigate a supervision technique for encouraging the development of a world model in a network trained end-to-end to predict the next observation. We explore the effect of adding a linear probe component to the network's loss function: a term that encourages the network to encode a subset of additional raw observations from the world and agent's position into its hidden state.
While deep supervision has been widely applied for task-specific learning, our focus is specifically on improving the world models.
We demonstrate our findings in an experimental environment based on the Flappy Bird game, where the agent receives only LIDAR measurements as observations. The supervision technique improves both training and test performance, reduced distribution drift, enhances training stability, and results in more easily decodable world features -- even for those features which were not included in the training. Including the probe in training roughly corresponded to doubling the model size, highlighting benefits in compute-limited settings or when aiming to achieve the best performance with smaller models. These findings contribute to our understanding of how to develop more robust and sophisticated world models in artificial agents.
Submission Number: 48
Loading