DiffStitch: Boosting Offline Reinforcement Learning with Diffusion-based Trajectory Stitching

Published: 02 May 2024, Last Modified: 25 Jun 2024ICML 2024 PosterEveryoneRevisionsBibTeXCC BY 4.0
Abstract: In offline reinforcement learning (RL), the performance of the learned policy highly depends on the quality of offline datasets. However, the offline dataset contains very limited optimal trajectories in many cases. This poses a challenge for offline RL algorithms, as agents must acquire the ability to transit to high-reward regions. To address this issue, we introduce Diffusionbased Trajectory Stitching (DiffStitch), a novel diffusion-based data augmentation pipeline that systematically generates stitching transitions between trajectories. DiffStitch effectively connects low-reward trajectories with high-reward trajectories, forming globally optimal trajectories and thereby mitigating the challenges faced by offline RL algorithms in learning trajectory stitching. Empirical experiments conducted on D4RL datasets demonstrate the effectiveness of our pipeline across RL methodologies. Notably, DiffStitch demonstrates substantial enhancements in the performance of one-step methods(IQL), imitation learning methods(TD3+BC) and trajectory optimization methods(DT). Our code is publicly available at https://github.com/guangheli12/DiffStitch
Submission Number: 667
Loading