# `nix`: JAX utilities

This repo contains several small utilities, I encountered to be useful in JAX. The main components are at the moment:
* Moving averages
* Natural gradient preconditioning
* Optax optimizer
* Tree utilities
* pmap utilities
* Linear algebra utilities
