# MeSH: Memory-as-State-Highways for Recursive Transformers
This repository contains the implementation for the paper: "MeSH: Memory-as-State-Highways for Recursive Transformers (under review)."  Our implementation is built upon the **LLaMA-Factory** framework.

## Setup
```bash
git clone https://anonymous.4open.science/r/MeSH-5E25/
cd MeSH-5E25
conda create -n mesh python=3.10
conda activate mesh
pip install -e ".[torch,metrics,deepspeed]"
pip install wandb
conda install cuda=12.1.0 -c nvidia
pip install flash-attn==2.7.2.post1 --no-build-isolation
``````



## Running Experiments
All training scripts are located in the ```./scripts``` directory. They are configured for multi-GPU training using torchrun. Please adjust the *WORLD_SIZE* (commented out in the scripts) and other distributed training parameters in the DeepSpeed config file (examples/deepspeed/ds_z0_config.json) according to your hardware setup.

### Implementation
Our modifications are integrated into the LLaMA-Factory source code, primarily within the model loading and definition parts. The core logic for handling the prelude-recurrent-coda structure and the specific recurrence schemes (including MeSH) can be found in the files under ```./src/llamafactory/model/modeling/```, where the standard Transformer layers are defined and assembled.

### Quick Verification (Pythia-70M on MiniPile)
For rapid verification of the implementation, a smaller-scale experiment is provided using the Pythia-70M architecture on the MiniPile dataset. The **minipile** and **minipile_test** datasets are used. Ensure they are correctly configured in ```./data/dataset_info.json```.

#### To run the MiniPile demo experiments:
```bash
# The scripts (included in ./scripts/minipile_demo/) train the 70M-pythia on minipile for quick validation, e.g.
bash ./scripts/minipile_demo/pythia_70m_minipile_vanilla__6.sh
bash ./scripts/minipile_demo/pythia_70m_minipile_recursive_base__1_2R2_1.sh
bash ./scripts/minipile_demo/pythia_70m_minipile_recursive_mesh__1_2R2_1.sh
```

### Main Experiments (Pythia-1.4B on The Pile)
This set of experiments corresponds to the main results reported in the paper, using the Pythia-1.4B architecture trained on the Pile dataset. The experiments use the **Pile** dataset. Please ensure it is downloaded and prepared in a format accessible to LLaMA-Factory. You may need to create a custom dataset configuration in ```./data/dataset_info.json```.

#### To run the MeSH 1.4B experiment:
```bash
# The scripts (included in ./scripts/main_experiments/) contain all necessary arguments to reproduce the 1.4B-scale models. You may need to adjust paths and distributed training settings for your environment. For example:
bash ./scripts/main_experiments/pythia_1.4b_pile_vanilla__24.sh
bash ./scripts/main_experiments/pythia_1.4b_pile_recursive_base__4_8R2_4.sh
bash ./scripts/main_experiments/pythia_1.4b_pile_recursive_mesh__4_8R2_4.sh
```

## Evaluation
After training is complete, a checkpoint will be saved to the directory specified by ```--output_dir``` in the script. Downstream task performance can be measured using the **LM Evaluation Harness**.
```bash
# Example: Evaluating a trained checkpoint
lm_eval --model hf \
    --model_args pretrained=path/to/your/checkpoint,trust_remote_code=True \
    --tasks hellaswag,piqa,winogrande \
    --num_fewshot 5 \
    --batch_size auto
```