Improving Unsupervised Hierarchical Representation with Reinforcement Learning

Published: 19 Jun 2024, Last Modified: 30 Sept 2024CVPR 2024EveryoneCC BY 4.0
Abstract: Learning representations to capture the very fundamental understanding of the world is a key challenge in machine learning. The hierarchical structure of explanatory factors hidden in data is such a general representation and could be potentially achieved with a hierarchical VAE through learning a hierarchy of increasingly abstract latent representations. However, training a hierarchical VAE always suffers from the "$\textit{posterior collapse}$'' issue, where the information of the input data is hard to propagate to the higher-level latent variables, hence resulting in a bad hierarchical representation. To address this issue, we first analyze the shortcomings of existing methods for mitigating the $\textit{posterior collapse}$ from an information theory perspective, then highlight the necessity of regularization for explicitly propagating data information to higher-level latent variables while maintaining the dependency between different levels. This naturally leads to formulating the inference of the hierarchical latent representation as a sequential decision process, which could benefit from applying reinforcement learning (RL) methodologies. To align RL's objective with the regularization, we first propose to employ a $\textit{skip-generative path}$ to acquire a reward for evaluating the information content of an inferred latent representation, and then the developed Q-value function based on it could have a consistent optimization direction of the regularization. Finally, policy gradient, one of the typical RL methods, is employed to train a hierarchical VAE without introducing a gradient estimator. Experimental results firmly support our analysis and demonstrate that our proposed method effectively mitigates the $\textit{posterior collapse}$ issue, learns an informative hierarchy, acquires explainable latent representations, and significantly outperforms other hierarchical VAE-based methods in downstream tasks.
Loading