Keywords: Hopfield Network, Lyapunov function, Hierarchical Associative Memory, JAX, Software framework
TL;DR: HAMUX: A novel abstraction and corresponding JAX framework for modularly constructing Hierarchical Hopfield Networks, built around energy-based memory retrieval dynamics.
Abstract: Conceptualized as Associative Memory, Hopfield Networks (HNs) are powerful models which describe neural network dynamics converging to a local minimum of an energy function. HNs are conventionally described by a neural network with two layers connected by a matrix of synaptic weights. However, it is not well known that the Hopfield framework generalizes to systems in which many neuron layers and synapses work together as a unified Hierarchical Associative Memory (HAM) model: a single network described by memory retrieval dynamics (convergence to a fixed point) and governed by a global energy function. In this work we introduce a universal abstraction for HAMs using the building blocks of neuron layers (nodes) and synapses (edges) connected within a hypergraph. We implement this abstraction as a software framework, written in JAX, whose autograd feature removes the need to derive update rules for the complicated energy-based dynamics. Our framework, called HAMUX (HAM User eXperience), enables anyone to build and train hierarchical HNs using familiar operations like convolutions and attention alongside activation functions like Softmaxes, ReLUs, and LayerNorms. HAMUX is a powerful tool to study HNs at scale, something that has never been possible before. We believe that HAMUX lays the groundwork for a new type of AI framework built around dynamical systems and energy-based associative memories.