Task-Aware Virtual Training: Enhancing Generalization in Meta-Reinforcement Learning for Out-of-Distribution Tasks

Published: 01 May 2025, Last Modified: 18 Jun 2025ICML 2025 posterEveryoneRevisionsBibTeXCC BY 4.0
TL;DR: To improve generalization for out-of-distribution tasks in meta-RL, our Task-Aware Virtual Training algorithm proposes a metric-based task latents learning method, task-preserving learning method of virtual task, and a state regularization method.
Abstract: Meta reinforcement learning aims to develop policies that generalize to unseen tasks sampled from a task distribution. While context-based meta-RL methods improve task representation using task latents, they often struggle with out-of-distribution (OOD) tasks. To address this, we propose Task-Aware Virtual Training (TAVT), a novel algorithm that accurately captures task characteristics for both training and OOD scenarios using metric-based representation learning. Our method successfully preserves task characteristics in virtual tasks and employs a state regularization technique to mitigate overestimation errors in state-varying environments. Numerical results demonstrate that TAVT significantly enhances generalization to OOD tasks across various MuJoCo and MetaWorld environments. Our code is available at https://github.com/JM-Kim-94/tavt.git.
Lay Summary: Can an agent that has learned only the forward movement move in a different direction? Agents that be applied in real world shuld be able to adapt to changing environments or various tasks. In particular, it is very challenging for agent to adapt to unseen tasks. To overcome this, the TAVT algorithm proposes two main methods. First, an agent learns the training task set based on the newly defined task measurement (Task-Aware structure). Second, it creates virtual tasks based on the learned task set and uses them for agent learning (Virtual Training structure). Due to this virtual training scheme, the agent trained by the TAVT can adapt to unseen tasks that have not actually been experienced.
Link To Code: https://github.com/JM-Kim-94/tavt.git
Primary Area: Reinforcement Learning->Deep RL
Keywords: Reinforcement Learning, Deep Reinforcement Learning, Meta Reinforcement Learning, Representation Learning, Generalization, Machine Learning, Out-Of-Distribution Tasks
Submission Number: 2088
Loading