Keywords: Transformers, Sets, Multisets, Explainable AI
Abstract: The advent of the set transformer (ST) brought about a new method of permutation equivariant modeling by leveraging cross-element interactions. However, ST is still subject to the fundamental challenge of transformers: scaling efficiently with large input sizes. Mini-batch consistent (MBC) methods were developed to address this problem by maintaining permutation equivariance while alleviating context fragmentation when processing partitioned sets. However, current MBC methods limit expressiveness and render the models incapable of producing element-wise contextualized representations and attention scores for prediction explainability. Therefore, the choice between ST or MBC methods results in a tradeoff between expressiveness and large set processing. To reconcile this tradeoff we propose the Universal Set Transformer (UST), a generalization of ST which is mini-batch consistent without sacrificing expressiveness. Additionally, we introduce multiset attention which leverages the MBC property to significantly reduce the computational cost of processing multisets while maintaining mathematical equivalence with standard multi-head attention. We show that UST is competitive with ST's performance while using less memory and outperforms other MBC methods in various benchmark tasks. Finally, we show that UST is capable of producing both whole-set and element-wise representations and demonstrate prediction explainability via attention scores.
Primary Area: unsupervised, self-supervised, semi-supervised, and supervised representation learning
Submission Number: 23658
Loading