- TL;DR: When training a world model, you can encode useful assumptions into the loss by using training-time counterfactuals.
- Abstract: In sequential tasks, planning-based agents have a number of advantages over model-free agents, including sample efficiency and interpretability. Recurrent action-conditional latent dynamics models trained from pixel-level observations have been shown to predict future observations conditioned on agent actions accurately enough for planning in some pixel-based control tasks. Typically, models of this type are trained to reconstruct sequences of ground-truth observations, given ground-truth actions. However, an action-conditional model can take input actions and states other than the ground truth, to generate predictions of unobserved counterfactual states. Because counterfactual state predictions are generated by differentiable networks, relationships among counterfactual states can be included in a training objective. We explore the possibilities of counterfactual regularization terms applicable during training of action-conditional sequence models. We evaluate their effect on pixel-level prediction accuracy and model-based agent performance, and we show that counterfactual regularization improves the performance of model-based agents in test-time environments that differ from training.
- Code: https://drive.google.com/file/d/1aD-5x28ZuDo62i8cfAD7GnliK_fDXhfD/view
- Keywords: Counterfactual, Model-Based Reinforcement Learning