Keywords: library, visualization, model surgery, interventions, JAX
TL;DR: I present Penzai, a JAX library for modifying pretrained models by representing them as modular data structures, and Treescope, an interactive IPython pretty printer designed to visualize them.
Abstract: Much of today's machine learning research involves interpreting, modifying or visualizing models after they are trained.
I present *Penzai*, a neural network library designed to simplify model manipulation by representing models as simple data structures, and *Treescope*, an interactive pretty-printer and array visualizer that can visualize both model inputs/outputs and the models themselves.
Penzai models are directly structured as compositions of modular operations, and expose the model forward pass in the structure of the model object itself, while also using named axes to ensure each operation is semantically meaningful. Users can insert new logic and extract intermediate values by directly transforming the model object using Penzai's tree-editing selector system, and get immediate feedback by visualizing the modified model with Treescope.
I describe the motivation and main features of Penzai and Treescope, and discuss how treating the model as data enables a variety of analyses and interventions to be implemented as data-structure transformations, without requiring model designers to add explicit hooks.
Submission Number: 23
Loading