TWISTED: Enhancing Transformer World Models with Spatio-Temporal Encoding and Graph-Based Optimal Decoding
Keywords: model-based rl, vision-based rl, transformer world model
TL;DR: We propose TWISTED, a transformer world model that captures spatio-temporal characteristics of visual environments with 3D positional encoding and optimal transport-based decoding.
Abstract: Model-based reinforcement learning improves sample efficiency by using learned world models to simulate experiences for training agents.
Recent world models that leverage transformers demonstrate high quality simulations, leading to better agent performance.
However, transformer world models underutilize spatial relationships between visually adjacent tokens, which are critical when interacting in visual environments.
Additionally, current models rely on sampling methods for transformer decoding that do not leverage visual similarities among subsequent frames.
To address these limitations, we introduce TWISTED, a transformer world model with 3D spatio-temporal positional encoding and a graph-based optimal decoding strategy specific to visual environments.
Our experiments show state-of-the-art performance on the Craftax-classic, Craftax, and MinAtar benchmarks, challenging visual environments requiring long-horizon object recall and interaction.
The proposed method achieves a return of 72.5% and a score of 35.6% on Craftax-classic, significantly surpassing the previous best of 67.4% and 27.9%.
We plan to release our source code on GitHub upon acceptance.
Supplementary Material: zip
Primary Area: reinforcement learning
Submission Number: 17056
Loading