A Correlation Analysis Approach to Finding Interpretable Latent Representations via Conditional Generative Models
Abstract: Supervised disentanglement, that is, learning interpretable nonlinear latent representations of a target data view informed by an auxiliary data view, is a central challenge in interpretable machine learning. We formulate this problem as a partially linear invertible canonical correlation analysis (PLiCCA). Specifically, given two data views, (i) complex data lying near a potentially high-dimensional manifold, and (ii) auxiliary high-dimensional multivariate data, PLiCCA learns latent variables for the complex view that are maximally correlated with sparse linear combinations of the auxiliary variables. In contrast to regression-based approaches to supervised disentanglement, the proposed method yields a latent embedding whose coordinates are explicitly ordered by their interpretability with respect to the auxiliary variables. We formalize the population PLiCCA problem and establish existence results. We then show a close theoretical connection between PLiCCA and conditional latent variable models, in particular conditional variational autoencoders and conditional normalizing flows, which enables practical estimation. We demonstrate our approach on brain imaging data, where PLiCCA is used to learn embeddings informed by auxiliary demographic, psychometric, and behavioral variables.
Submission Number: 2149
Loading