Scalable Multitask Learning Using Gradient-based Estimation of Task Affinity

Published: 01 Jan 2024, Last Modified: 13 May 2025KDD 2024EveryoneRevisionsBibTeXCC BY-SA 4.0
Abstract: Multitask learning is a widely used paradigm for training models on diverse tasks, with applications ranging from graph neural networks to language model fine-tuning. Since tasks may interfere with each other, a key notion for modeling their relationships is task affinity. This includes pairwise task affinity, computed among pairs of tasks, and higher-order affinity, computed among subsets of tasks. Naively computing either of them requires repeatedly training on data pooled from various task combinations, which is computationally intensive. We present a new algorithm Grad-TAG that can estimate task affinities without this repeated training.The key idea of Grad-TAG is to train a "base" model for all tasks and then use a linearization technique to estimate the loss of any other model with a specific task combination. The linearization works by computing a gradient-based first-order approximation of the loss, using low-dimensional projections of gradients as features in a logistic regression trained to predict labels for the specific task combination. We show theoretically that the linearized model can provably approximate the loss when the gradient-based approximation is accurate, and also empirically verify that on several large models. Then, given the estimated task affinity matrix, we design a semi-definite program for clustering to group similar tasks that maximize the average density of clusters.We evaluate Grad-TAG's performance across seven datasets, including multi-label classification on graphs, and instruction fine-tuning of language models. Our results show that our task affinity estimates are within 2.7% distance of the true affinities while needing only 3% of FLOPs compared to full training. On our largest graph with 21M edges and 500 labeling tasks, our algorithm delivers an estimate accurate to within 5% of the true affinities, while using only 112.3 GPU hours. Our results show that Grad-TAG achieves excellent performance and runtime tradeoffs compared to existing approaches.
Loading