Efficient Multi-Scale Deformable Attention on GPUs

25 Apr 2026 (modified: 08 May 2026)Under review for TMLREveryoneRevisionsBibTeXCC BY 4.0
Abstract: Multi-scale deformable attention (MSDA) is a core operator in DETR-family vision transformers whose scattered bilinear sampling pattern defeats the tile-based strategies on which FlashAttention-style kernels depend. We present a diagnostic study of GPU kernel optimization for MSDA on NVIDIA A100 (SM 8.0) and H100 (SM 9.0), identifying two failure modes of conventional heuristics and a root cause that is both hardware- and compiler-gated. Dispatch-order reordering does not pay: seven query orders (linear, Morton Z-order, random, scanline, Hilbert, centroid, and a clustering-and-packing analogue) produce within-±2% forward latency at K=4, L=4 because L2 locality is tile-set by the query-block kernel rather than by the dispatch order. Throughput proxies mislead: an 85%-occupancy point-parallel tiling delivers only 5.1% of A100 peak bandwidth, while a 17%-occupancy query-block tiling delivers 36% and runs 7.4× faster. The backward-pass bottleneck is scattered-gradient atomic contention: at BF16, the backward kernel attains 2.4% of A100 peak bandwidth versus 21.3% on H100. The gap is hardware- and compiler-gated: Ampere has no native BF16 atomic instruction (forcing a 32-bit compare-and-swap emulation), and on H100 the standard CUDA atomic still lowers to that emulation while a relaxed-ordering variant reaches Hopper's native reduction primitive; an FP32-accumulator variant closes the A100 gap entirely. The resulting backends deliver 2.4–14× forward speedup and up to 88% peak VRAM reduction over the reference implementation at numerical parity.
Submission Type: Regular submission (no more than 12 pages of main content)
Assigned Action Editor: ~Liang-Chieh_Chen1
Submission Number: 8620
Loading