# Memory Aware Routing in GPT2-MoE

This project implements a **Mixture-of-Experts (MoE)** language model based on the GPT architecture. It introduces the **Memory Aware Routing (MAR)** algorithm to improve the stability of expert routing and enhance the model's expressive power. The training script supports single-GPU training, distributed DDP training, mixed precision, checkpointing, and detailed logging.

-----

## Key Features

  * **GPT-2 Architecture**: Supports multi-layer, multi-head self-attention.
  * **MoE Layer**: Replaces the MLP in each Transformer block with an **MoE** layer, supporting multiple experts and a load balancing loss.
  * **Memory Aware Routing**: Each expert maintains a replay buffer that dynamically calculates a preference vector to guide the router in a more intelligent allocation of tokens.
  * **Efficient Training**: Supports distributed DDP, mixed precision, gradient accumulation, and automatic logging and checkpointing.
  * **Detailed Logging**: All training information is written to a log file and printed to the console, with support for **Weights & Biases (wandb)** tracking.

-----

## Introduction to the Memory Aware Routing Algorithm

We propose **Memory-Aware Routing (MAR)**, a novel mechanism that augments load balancing with memory-guided routing. **MAR** introduces memory buffers to explicitly capture the long-term preferences of each expert, guiding the model to avoid assigning similar information to different experts and effectively mitigating knowledge overlap. In addition, we define an **expert–token matching score**, which quantifies the similarity between an input token and an expert’s preference vector. This score promotes the consistent routing of tokens to semantically aligned experts, fostering the emergence and consolidation of expert specialization. By maintaining a global balance while transitioning from uniform allocation to differentiated routing, **MAR** mitigates the pseudo-balance problem and encourages both functional diversity and stable specialization across experts.

-----

## Dataset Preparation

This project uses the **OpenWebText** dataset from HuggingFace for training. The data preprocessing script, `prepare.py`, automatically downloads, splits, tokenizes, and saves the data into binary files for efficient training.

-----

### Preprocessing Workflow

1.  **Download Data**
    The script automatically downloads the OpenWebText dataset using the HuggingFace `datasets` library.

2.  **Split into Training/Validation Sets**
    By default, 0.05% of the data is used for the validation set, with the rest allocated for the training set.

3.  **Tokenization**
    The `tiktoken` GPT-2 BPE tokenizer is used to convert text into token IDs. An `EOT` (end of text) token is appended to the end of each text sample.

4.  **Save as Binary Files**
    All token IDs are concatenated into a single large array and saved as `train.bin` and `val.bin`, which can be efficiently loaded directly with `np.memmap` for training.

-----

### Running the Script

To prepare the dataset, navigate to the `openweb` directory and run the following command:

```bash
python prepare.py
```

After running, the following files will be generated:

  * `train.bin`: All training token IDs, approximately 17 GB.
  * `val.bin`: All validation token IDs, approximately 8.5 MB.

-----

## Training Methods

  * **Data Format**: `.bin` files containing token ID sequences, which support large-scale data streaming.
  * **Loss Function**: `total_loss = lm_loss + load_balancing_weight * lb_loss`
      * `lm_loss`: Language model cross-entropy loss
      * `lb_loss`: Load balancing loss
  * The validation set loss only includes `lm_loss`, as it better reflects the model's generalization ability.
  * Supports checkpointing, automatic saving of checkpoints, detailed logging, and **wandb** tracking.

-----

## Examples

Single-GPU training:

```bash
python train_wandb.py --batch_size=32 --compile=False
```

Distributed DDP training (4 GPUs):

```bash
torchrun --standalone --nproc_per_node=4 train_wandb.py
```

-----

## File Structure

  * `model.py`: Model definition, including GPT, MoE Layer, and Memory Aware Routing logic
  * `train_wandb.py`: Main training script, supporting distributed training, mixed precision, logging, and checkpointing

-----

## Parameters

  * `num_experts`: Number of MoE experts, default 8
  * `buffer_size`: Size of each expert's replay buffer, default 128
  * `alpha`: Weighting coefficient for interest distribution, default 0.5
  * `load_balancing_weight`: Weight for the load balancing loss, default 0.4

## Inference


To use a trained model for text generation, use the `inference.py` script. This script loads a model from a checkpoint and generates text based on a user-provided prompt.

### Usage Instructions

1.  **Prepare a Checkpoint**: Before running, ensure you have a trained model checkpoint (`.pt` file) from the training process.

2.  **Run the Script**: Execute the `inference.py` script from the command line with the following arguments:

      * `--ckpt <path_to_checkpoint>`: The path to your saved model checkpoint file. **(Required)**
      * `--prompt "<your_text>"`: The text you want the model to continue. Enclose the prompt in quotes. **(Required)**
      * `--max_new_tokens <number>`: The maximum number of new tokens to generate. The default is 100.
      * `--device <device>`: The device to run inference on, e.g., `cuda:0` or `cpu`. The default is `cuda:0`.

-----

### Example

```bash
python inference.py --ckpt ckpt.pt --prompt "Hello, my name is" --max_new_tokens 200
```

The script will print the complete generated text to the console.



