Plug, Play, and Generalize: Length Extrapolation with Pointer-Augmented Neural Memory

Published: 04 Oct 2024, Last Modified: 04 Oct 2024Accepted by TMLREveryoneRevisionsBibTeXCC BY 4.0
Abstract: We introduce Pointer-Augmented Neural Memory (PANM), a versatile module designed to enhance neural networks' ability to process symbols and extend their capabilities to longer data sequences. PANM integrates an external neural memory utilizing novel physical addresses and pointer manipulation techniques, emulating human and computer-like symbol processing abilities. PANM facilitates operations like pointer assignment, dereferencing, and arithmetic by explicitly employing physical pointers for memory access. This module can be trained end-to-end on sequence data, empowering various sequential models, from simple recurrent networks to large language models (LLMs). Our experiments showcase PANM's exceptional length extrapolation capabilities and its enhancement of recurrent neural networks in symbol processing tasks, including algorithmic reasoning and Dyck language recognition. PANM enables Transformers to achieve up to 100% generalization accuracy in compositional learning tasks and significantly improves performance in mathematical reasoning, question answering, and machine translation. Notably, the generalization effectiveness scales with stronger backbone models, as evidenced by substantial performance gains when we test LLMs finetuned with PANM for tasks up to 10-100 times longer than the training data.
Submission Length: Regular submission (no more than 12 pages of main content)
Changes Since Last Submission: - Polish writing - Add publication information
Code: https://github.com/thaihungle/PANM
Assigned Action Editor: ~Swarat_Chaudhuri1
Submission Number: 2895
Loading