Object Representations as Fixed Points: Training Iterative Inference Algorithms with Implicit Differentiation
Keywords: implicit differentiation, object-centric learning, iterative amortized inference, symmetric generative models
TL;DR: Our primary contribution is to propose implicit differentiation for training the iterative amortized inference procedures of symmetric generative models, such as those used for learning object representations.
Abstract: Deep generative models, particularly those that aim to factorize the observations into discrete entities (such as objects), must often use iterative inference procedures that break symmetries among equally plausible explanations for the data. Such inference procedures include variants of the expectation-maximization algorithm and structurally resemble clustering algorithms in a latent space. However, combining such methods with deep neural networks necessitates differentiating through the inference process, which can make optimization exceptionally challenging. In this work, we observe that such iterative inference methods can be made differentiable by means of the implicit function theorem, and develop an implicit differentiation approach that improves the stability and tractability of training such models by decoupling the forward and backward passes. This connection enables us to apply recent advances in optimizing implicit layers to not only improve the stability and optimization of the slot attention module in SLATE, a state-of-the-art method for learning entity representations, but do so with constant space and time complexity in backpropagation and only one additional line of code.
Community Implementations: [![CatalyzeX](/images/catalyzex_icon.svg) 3 code implementations](https://www.catalyzex.com/paper/object-representations-as-fixed-points/code)
4 Replies
Loading