Accelerating Federated Learning Through Attention on Local Model UpdatesDownload PDF

Published: 21 Oct 2022, Last Modified: 05 May 2023FL-NeurIPS 2022 PosterReaders: Everyone
Keywords: Federated Learning, Clustering, Communication Reduction, Knowledge Extraction, Attention
TL;DR: A novel attention based method for knowledge extraction (clustering) from clients, and using this knowledge to reduce the communication cost.
Abstract: Federated learning is used widely for privacy-preserving training. It performs well if the client datasets are both balanced and IID. However, in real-world settings, client datasets are non-IID and imbalanced. They may also experience significant distribution shifts. These non-idealities can hinder the performance of federated learning. To address this challenge, the paper devises an attention-based mechanism that learns to attend to different clients in the context of a reference dataset. The reference dataset is a test dataset in the central server which is used to monitor the performance metric of the model under training. The innovation is that the attention mechanism captures the similarities and patterns of a batch of clients' model drifts (received by the central server in each communication round) in a low dimensional latent space, similar to the way it captures the mutual relation of a batch of words (a sentence). To learn this attention layer, we devise an autoencoder whose input/outputs are the model drifts and its bottleneck is the attention mechanism. The attention weights in the bottleneck are learned by utilizing the attention-based autoencoder as a network to reconstruct the model drift on reference dataset, from the batch of received model drifts from clients in each communication round. The learned attention weights effectively capture clusters and similarities amongst the clients’ datasets. The empirical studies with MNIST, FashionMNIST, and CIFAR10 under a non-IID federated learning setup show that our attention-based autoencoder can identify the cluster of similar clients. Then the central server can use the clustering results to devise a better policy for choosing participants clients in each communication round, thereby reducing the communication rounds by up to 75% on MNIST and FashionMNIST, and 45% on CIFAR10 compared to FedAvg.
Is Student: Yes
4 Replies