Closing the Gap between TD Learning and Supervised Learning -- A Generalisation Point of View.

Published: 07 Nov 2023, Last Modified: 21 Nov 2023FMDM@NeurIPS2023EveryoneRevisionsBibTeX
Keywords: reinforcement learning decision transformers, stitching, data augmentation
TL;DR: We show that the stitching property in RL is a form of generalization that SL-based methods can not have, however we show how a new type of data augmentation facilitates this type of generalization.
Abstract: Recent years have seen a drastic shift of focus on large models trained with simple self-supervised objectives on diverse datasets. These foundational models have become ubiquitous in NLP and vision because they are generally applicable to many downstream tasks. These success stories have sparked various attempts to train similar models for RL problems. Decision Transformers (DT) is one such popular approach that treats the RL problem as a sequence modeling problem, and uses a transformer model for predicting actions. This algorithmic choice, though simple, can have certain limitations when compared to traditional RL algorithms. In this paper, we study one such limitation -- the capability of recombining together pieces of previously seen experience to solve a task never seen before during training. This paper studies this question in the setting of goal-reaching problems. We formalize this desirable property as a form of \emph{stitching} generalization: after training on a distribution of (state, goal) pairs, one would like to evaluate on (state, goal) pairs not seen \emph{together} in the training data. Our analysis shows that this sort of generalization is different from \emph{i.i.d.} generalization. This connection between stitching and generalization reveals why we should not expect existing DT like methods to perform stitching, even in the limit of large datasets and models. We experimentally validate this result on carefully constructed datasets. This connection also suggests a simple remedy, the same remedy for improving generalization in supervised learning: data augmentation. We propose a naive \emph{temporal} data augmentation approach and demonstrate that adding it to RL methods based on SL enables them to stitch together experience so that they succeed in navigating between states and goals unseen together during training.
Submission Number: 68