ScaLA: Speeding-Up Fine-tuning of Pre-trained Transformer Networks via Efficient and Scalable Adversarial Perturbation
Keywords: Efficient Training Methods, Large Batch Optimization, Transformer Networks, BERT
Abstract: The size of transformer networks is growing at an unprecedented rate and has increased by three orders of magnitude in recent years, approaching trillion-level parameters. To train models of increasing sizes, researchers and practitioners have employed large-batch optimization to leverage massive distributed deep learning systems and resources. However, increasing the batch size changes the training dynamics, often leading to generalization gap and training instability issues that require extensive hyperparameter turning to maintain the same level of accuracy. In this paper, we explore the steepness of the loss landscape of large-batch optimization and find that it tends to be highly complex and irregular, posing challenges to generalization. To address this challenge, we propose ScaLA, a scalable and robust method for large-batch optimization of transformer networks via adversarial perturbation. In particular, we take a sequential game-theoretic approach to make large-batch optimization robust to adversarial perturbation, which helps smooth the loss landscape and improve generalization. Moreover, we perform several optimizations to reduce the computational cost from adversarial perturbation, improving its performance and scalability in the distributed training environment.
We provide a theoretical convergence rate analysis for ScaLA using techniques for analyzing non-convex saddle-point problems. Finally, we perform an extensive evaluation of our method using BERT and RoBERTa on GLUE datasets. Our results show that our method attains up to 18 $\times$ fine-tuning speedups on 2 DGX-2 nodes, while achieving comparable and sometimes higher accuracy than the state-of-the-art large-batch optimization methods. When using the same number of hardware resources, ScaLA is 2.7--9.8$\times$ faster than the baselines.
One-sentence Summary: ScaLA is a scalable and efficient method for large-batch optimization of pre-trained transformer networks via adversarial perturbation.
11 Replies
Loading