Abstract: The versatility of self-attention mechanism earned transformers great success in almost all data modalities, with limitations on the quadratic complexity and difficulty of training. Efficient transformers, on the other hand, often rely on clever data-modality-dependent construction to get over the quadratic complexity of transformers. This greatly hinders their applications on different data modalities, which is one of the pillars of contemporary foundational modeling. In this paper, we lay the groundwork for efficient foundational modeling by proposing \textbf{SAMSA} - SAMpling-Self-Attention, a context-aware linear complexity self-attention mechanism that works well on multiple data modalities. Our mechanism is based on a differentiable sampling without replacement method we discovered. This enables the self-attention module to attend to the most important token set, where the importance is defined by data. Moreover, as differentiability is not needed in inference, the sparse formulation of our method costs little time overhead, further lowering computational costs. In short, SAMSA achieved competitive or even SOTA results on many benchmarks, while being faster in inference, compared to other very specialized models. Against full self-attention, real inference time significantly decreases while performance ranges from negligible degradation to outperformance. We release our source code in supplementary materials.
Submission Length: Regular submission (no more than 12 pages of main content)
Previous TMLR Submission Url: https://openreview.net/forum?id=m4eD6HDGGX
Changes Since Last Submission: **Compared to the previous submission, our work includes several significant changes:**
- We have streamlined the claims and methods, concentrating exclusively on the sampling algorithm and graph modeling. This means that the claims on optimization were removed as we focus solely on sampling side of contribution. This refinement should enhance the readability of our manuscript.
- We believe we have a better model than vanilla transformer on the tasks we experimented, in both efficiency and performance. For specific data modalities, we even outperformed many SOTA graph models on two tasks while being much faster. Also, this improvement is not from hyperparameter tuning but from better gradient estimation and integration mechanism.
**The flow of our work is as follows:**
1. **Differentiable Sampling in Previous Literature**: We review how differentiable sampling has been approached in prior research.
2. **Our Construction of Differentiable Sampling**: We introduce our own method for constructing differentiable sampling.
3. **Applying Sampling in Transformers**: We demonstrate how our sampling approach can be integrated into transformer architectures.
4. **Graph Data Modality Add-On**: We provide a small extension to cover graph data modality by including an embedding layer that enables any self-attention network to understand graph structures.
Assigned Action Editor: ~Steven_Stenberg_Hansen1
Submission Number: 3181
Loading