Accurate and Efficient World Modeling with Masked Latent Transformers

Published: 01 May 2025, Last Modified: 18 Jun 2025ICML 2025 posterEveryoneRevisionsBibTeXCC BY 4.0
TL;DR: We introduce EMERALD, a Transformer model-based reinforcement learning algorithm using a spatial latent state with MaskGIT predictions to generate accurate trajectories in latent space and improve the agent performance.
Abstract: The Dreamer algorithm has recently obtained remarkable performance across diverse environment domains by training powerful agents with simulated trajectories. However, the compressed nature of its world model's latent space can result in the loss of crucial information, negatively affecting the agent's performance. Recent approaches, such as $\Delta$-IRIS and DIAMOND, address this limitation by training more accurate world models. However, these methods require training agents directly from pixels, which reduces training efficiency and prevents the agent from benefiting from the inner representations learned by the world model. In this work, we propose an alternative approach to world modeling that is both accurate and efficient. We introduce EMERALD (Efficient MaskEd latent tRAnsformer worLD model), a world model using a spatial latent state with MaskGIT predictions to generate accurate trajectories in latent space and improve the agent performance. On the Crafter benchmark, EMERALD achieves new state-of-the-art performance, becoming the first method to surpass human experts performance within 10M environment steps. Our method also succeeds to unlock all 22 Crafter achievements at least once during evaluation.
Lay Summary: We introduce EMERALD, a new method in the field of world modeling that helps computers to simulate the world more accurately and efficiently compared to previous approaches. World modeling can improve sample efficiency and safety when training AI agents by generating imaginary training trajectories rather than interacting with the real world. Our proposed world model uses a spatial hidden state to carry more information and simulate the environment more accurately. This increase in precision improves the performance of the agent in complex visual environments like Crafter where details can be crucial. We also propose to use MaskGIT, an efficient prediction algorithm for image and video generation methods with spatial states. This makes EMERALD both accurate and efficient compared to previous approaches. We evaluate our method on the Crafter benchmark and demonstrate state-of-the-art performance. Our method also generalizes on Atari games that do not necessarily require the use of a spatial hidden state to perceive crucial details and achieve strong performance.
Primary Area: Reinforcement Learning
Keywords: model-based reinforcement learning, maskGIT, transformer network, crafter
Submission Number: 3144
Loading