Jaxpruner: A Concise Library for Sparsity Research

Published: 20 Nov 2023, Last Modified: 05 Dec 2023CPAL 2024 (Proceedings Track) OralEveryoneRevisionsBibTeX
Keywords: jax, sparsity, pruning, quantization, sparse training, efficiency, library, software
TL;DR: his paper introduces JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research.
Abstract: This paper introduces JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. JaxPruner aims to accelerate research on sparse neural networks by providing concise implementations of popular pruning and sparse training algorithms with minimal memory and latency overhead. Algorithms implemented in JaxPruner use a common API and work seamlessly with the popular optimization library Optax, which, in turn, enables easy integration with existing JAX based libraries. We demonstrate this ease of integration by providing examples in four different codebases: Scenic, t5x, Dopamine and FedJAX and provide baseline experiments on popular benchmarks. Jaxpruner is hosted at github.com/google-research/jaxpruner
Track Confirmation: Yes, I am submitting to the proceeding track.
Submission Number: 14
Loading