# RMAAT: Astrocyte-Inspired Memory Compression and Replay for Efficient Long-Context Transformers

[![Python 3.10](https://img.shields.io/badge/Python-3.10-blue.svg)](https://www.python.org/downloads/release/python-3100/)
[![PyTorch 1.13](https://img.shields.io/badge/PyTorch-1.13-ee4c2c.svg)](https://pytorch.org/)
[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE)
[![ICLR 2026](https://img.shields.io/badge/ICLR-2026-blueviolet.svg)](https://iclr.cc/)

Official implementation of **RMAAT** (Recurrent Memory Augmented Astromorphic Transformer), accepted at **ICLR 2026**.

Paper links: [arXiv](https://arxiv.org/abs/2601.00426) | [OpenReview](https://openreview.net/forum?id=sTkJdbVxsI)

> The quadratic complexity of self-attention limits Transformers on long sequences. RMAAT integrates computational principles derived from astrocytes — glial cells critical for biological memory and synaptic modulation — into a recurrent Transformer framework. It uses segment-based processing with persistent memory tokens, an adaptive compression mechanism governed by a retention factor derived from simulated astrocyte long-term plasticity (LTP), linear-complexity attention inspired by astrocyte short-term plasticity (STP), and the memory-efficient Astrocytic Memory Replay Backpropagation (AMRB) training algorithm.

## Architecture Overview

RMAAT processes long sequences by dividing them into fixed-length segments processed recurrently. Key components:

- **Segmented Processing with Memory Tokens** — Persistent memory tokens propagate compressed context across segments, enabling long-range dependency modeling without attending over the full sequence.
- **Astromorphic Attention** — A linear-complexity O(N) attention mechanism inspired by astrocyte short-term plasticity (STP), replacing standard O(N^2) self-attention within each segment.
- **Memory Retention Factor** — Derived from a macro model of neuron-astrocyte LTP dynamics, this factor adaptively compresses memory tokens across segments, implementing biologically motivated context decay.
- **AMRB Training** — Astrocytic Memory Replay Backpropagation replays segments from stored compressed memory states during backpropagation, reducing peak GPU memory by **up to 4.41x** versus standard BPTT (Retrieval: `3.4 -> 15.0` GB, Text: `5.1 -> 22.0` GB, i.e., `4.31x`) while maintaining equivalent accuracy.

## Results

All table values below are copied directly from the camera-ready paper tables.

### LRA Benchmark Accuracy (%)

| Model | ListOps (2K) | Text (4K) | Retrieval (8K) | Image (1K) | Pathfinder (1K) | **Average** |
|:------|:---:|:---:|:---:|:---:|:---:|:---:|
| Transformer | 36.4 | 64.3 | 57.5 | 42.4 | 71.4 | 54.4 |
| Nystromformer | 37.2 | 65.5 | 79.6 | 41.6 | 70.9 | 59.0 |
| Luna-256 | 37.3 | 64.6 | 79.3 | 47.4 | 77.7 | 61.3 |
| RMT | 37.4 | 65.0 | 79.3 | 54.6 | 81.5 | 63.6 |
| **RMAAT (Ours)** | **38.9** | **65.9** | **83.2** | **64.8** | **87.1** | **68.0** |

### Peak GPU Memory (GB)

| Model | ListOps | Text | Retrieval | Image | Pathfinder |
|:------|:---:|:---:|:---:|:---:|:---:|
| Transformer | 4.7 | 6.7 | 5.2 | 7.8 | 5.4 |
| RMT | 20.4 | 24 | 18.3 | 22.7 | 12.7 |
| **RMAAT (Ours)** | **5.2** | **5.1** | **3.4** | **5.3** | **4.7** |

### Training Speed (relative to RMT)

| Model | ListOps | Text | Retrieval | Image | Pathfinder |
|:------|:---:|:---:|:---:|:---:|:---:|
| RMT | 1.0x | 1.0x | 1.0x | 1.0x | 1.0x |
| **RMAAT (Ours)** | **1.5x** | **1.5x** | **1.73x** | **1.3x** | **0.95x** |

## Repository Structure

```
RMAAT/
├── configs/
│   └── config_v10.yaml       # Main configuration file
├── model.py                  # RMAAT architecture (AstroAttention, memory tokens, encoder)
├── train.py                  # Training and evaluation loops with AMRB
├── run_train.py              # Entry point — loads config and launches training
├── dataloader.py             # Data loading and preprocessing
├── lra_datasets.py           # Dataset classes for LRA benchmarks
├── lra_config.py             # Tokenizers and configs for LRA tasks
├── utils.py                  # Utility functions (positional encoding, custom embeddings)
├── get_lra_data.sh           # Script to download LRA datasets
├── environment.yml           # Conda environment specification
├── CITATION.bib              # BibTeX citation
├── LICENSE                   # MIT License
└── README.md
```

## Setup

### 1. Clone the repository

```bash
git clone https://github.com/NeuroCompLab-psu/RMAAT.git
cd RMAAT
```

### 2. Create the Conda environment

```bash
conda env create -f environment.yml
conda activate rmaat
```

### 3. Download datasets

The script downloads CIFAR-10, Long Range Arena (LRA), and IMDb datasets into a `datasets/` directory:

```bash
bash get_lra_data.sh
```

> **Note:** The original LRA dataset hosted on Google Cloud Storage may return a 403 error (the upstream repo was archived in Feb 2025). If the download fails, the script will print instructions for obtaining the data from alternative sources such as the [e-lra fork](https://github.com/lucaslingle/e-lra). Place the extracted `lra_release/` directory inside `datasets/`.

## Usage

### Training

Run training from the repository root:

```bash
CUDA_VISIBLE_DEVICES=0 python run_train.py --config configs/config_v10.yaml
```

### Selecting a task

Edit `configs/config_v10.yaml` to change the dataset and relevant hyperparameters:

| Parameter | Description | Options |
|:----------|:------------|:--------|
| `dataset` | Dataset/task to run | `imdb`, `imdb_long`, `imdb_lra`, `listops`, `cifar10`, `pathfinder32`, `aan` |
| `max_seq_len` | Maximum sequence length | Typical: 4096 (text), 2048 (listops), 1024 (cifar10/pathfinder32), 8192 (aan/retrieval-style setup) |
| `num_segments` | Number of recurrent segments | 2, 4, 8, 16 |
| `num_memory_tokens` | Memory tokens per segment | 1 (default) |
| `attention_type` | Attention mechanism | `astro` (astromorphic), `softmax` (standard) |
| `memory_replay_backprop` | Enable AMRB training | `True` / `False` |
| `astro_mem` | Enable astrocytic memory retention | `True` / `False` |

### Logging

To enable [Weights & Biases](https://wandb.ai) logging, set `wandb: True` and configure `wandb_run_name` in the config file.

## Citation

If you find this work useful, please cite:

```bibtex
@article{mia2026rmaat,
  title         = {{RMAAT}: Astrocyte-Inspired Memory Compression and Replay for Efficient Long-Context Transformers},
  author        = {Mia, Md Zesun Ahmed and Bal, Malyaban and Sengupta, Abhronil},
  journal       = {arXiv preprint arXiv:2601.00426},
  year          = {2026},
  archivePrefix = {arXiv},
  eprint        = {2601.00426},
  primaryClass  = {cs.NE},
  doi           = {10.48550/arXiv.2601.00426},
  url           = {https://arxiv.org/abs/2601.00426}
}
```

## License

This project is licensed under the MIT License. See [LICENSE](LICENSE) for details.
