Keywords: test-time training, test-time optimization, context compression, transformers, memory mechanisms
TL;DR: GradMem learns to write a transformer’s input context into a small set of memory tokens using a few steps of test-time gradient descent, enabling memory-only task solving.
Abstract: Transformers typically process long contexts by storing a large per-layer KV-cache of past activations. A desirable alternative is compressive memory: read a context once, store it in a compact state, and answer many queries from that state. We introduce GradMem, which writes context into memory via per-sample test-time optimization. Given a context, GradMem performs a few steps of gradient descent on a small set of prefix memory tokens while keeping model weights frozen. GradMem explicitly optimizes a model-level self-supervised context reconstruction loss, resulting in a loss-driven write operation with iterative error correction, unlike forward-only methods. On associative key-value retrieval, GradMem outperforms forward-only memory writers with the same memory size, and additional gradient steps scale capacity much more effectively than repeated forward writes.
Submission Number: 17
Loading