# RL-Eagle: Tree-Structured Drafting for Efficient Large Language Model Inference

## Project Overview
RL-Eagle is a research and experimental framework for maximizing inference efficiency of large language models (LLMs) using tree-structured drafting techniques. It provides a full pipeline for LLM experiments—including data preprocessing, training, evaluation, calibration, batch evaluation, and token frequency analysis—supporting various Llama-family models (Llama 2, 3, DeepSeek, etc).

---

## Installation & Environment Setup

```bash
mamba create -n rl-eagle -y python=3.10
mamba install -n rl-eagle -y uv c-compiler=1.10.0 cxx-compiler cmake ninja 'kernel-headers_linux-64>=4' pudb transformers==4.49.0 trl==0.14.0 tqdm datasets seaborn notebook pytorch-cuda=12.4 cuda-toolkit cuda-version=12.4 cudnn tensorboard evaluate -c pytorch -c nvidia
conda activate rl-eagle
pip install torch torchvision torchaudio accelerate
pip install 'flash-attn<2.8.0'
uv pip install -e models/token/ --no-build-isolation --reinstall
```

---

## Main Folders & File Descriptions

- `preprocess.py`: Tokenizes and preprocesses datasets into tree structures. Supports various Llama-family models and experiment parameters.
- `train.py`: Fine-tunes LLMs with various losses (CE, distill, chained, tree, etc.) and custom trainers.
- `evaluate.py`: Evaluates trained models on various benchmarks (e.g., GSM8K, HumanEval).
- `generate.py`: Generates samples from models and saves results.
- `calibration.py`: Analyzes and visualizes model confidence calibration (e.g., Brier score, ECE).
- `batch_eval.py`: Automates large-scale evaluation across multiple models/datasets.
- `token_freq.py`: Analyzes and visualizes token frequency.
- `models/`: Custom model implementations for Llama and Draft architectures.
- `criterion/`: Various loss and metric function implementations.
- `dataloader/`: Custom data collator and Dataset classes.
- `utils/`: Timer, utility functions, etc.

---

## Example Full Pipeline

### 1. Data Preprocessing
```bash
accelerate launch --num_processes 2 preprocess.py --batch_size 1 --model_type r1-distill-llama-8b --output_dir /big-data0/preprocessed_r1-distill-llama-8b_sharegpt_2048_eagle1 --dataset_id sharegpt --max_length 2048 --depth 3 --top_draft 4 --top_node 128
```

### 2. Training
```bash
TOKENIZERS_PARALLELISM=false accelerate launch --num_processes 2 train.py --output_dir paper/r1-distill-llama-8b-ce-3 --overwrite_output_dir --fp16 --num_train_epochs 10 --warmup_epochs 1 --eval_epochs 2 --save_epochs 1 --loss_type distill_cross_entropy --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 16 --learning_rate 3e-4 --top_k 0 --data_dir /big-data0/preprocessed_r1-distill-llama-8b_sharegpt_2048_eagle1 --model_name deepseek-ai/DeepSeek-R1-Distill-Llama-8B --logging_steps 50 --max_length 2048
```

### 3. Evaluation & Calibration
```bash
python evaluate.py --ea-model-path paper/r1-distill-llama-8b-ce-3/final --model_id deepseek-ai/DeepSeek-R1-Distill-Llama-8B --data_dir ...
python calibration.py --data_dir ... --output_dir ...
```

### 4. Batch Evaluation & Token Analysis
```bash
python batch_eval.py --config ...
python token_freq.py --data_dir ...
```

---

## Key Code Structure & Custom Modules

- `models/llama_modeling.py`: Custom implementations for Llama, Draft, Tree attention/decoder layers.
- `criterion/losses.py`: Various loss and metric functions (Distill, Chained, Tree, RL, etc).
- `dataloader/dataloader.py`: Collator and Dataset for tree-structured data.
- `models/token/kv_cache.py`: Optimized KV cache and FastKVCache support.

---

## Example Experiments (by Model)

### Llama2-7b
```bash
# Preprocessing
accelerate launch --num_processes 2 preprocess.py --batch_size 1 --model_type llama2-7b --output_dir ... --dataset_id sharegpt --max_length 2048 --depth 3 --top_draft 4 --top_node 128
# Training
TOKENIZERS_PARALLELISM=false accelerate launch --num_processes 2 train.py --output_dir paper/llama2-7b-ce-3 ... --model_name meta-llama/Llama-2-7b-chat-hf ...
```

### Llama3_1-8b
```bash
# Preprocessing
accelerate launch --num_processes 2 preprocess.py --batch_size 1 --model_type llama3_1-8b --output_dir ... --dataset_id sharegpt --max_length 2048 --depth 3 --top_draft 4 --top_node 128
# Training
TOKENIZERS_PARALLELISM=false accelerate launch --num_processes 2 train.py --output_dir paper/llama3_1-8b-ce-3 ... --model_name meta-llama/Llama-3.1-8B-Instruct ...
```

### DeepSeek R1-Distill-Llama-8b
```bash
# Preprocessing
accelerate launch --num_processes 2 preprocess.py --batch_size 1 --model_type r1-distill-llama-8b --output_dir ... --dataset_id sharegpt --max_length 2048 --depth 3 --top_draft 4 --top_node 128
# Training
TOKENIZERS_PARALLELISM=false accelerate launch --num_processes 2 train.py --output_dir paper/r1-distill-llama-8b-ce-3 ... --model_name deepseek-ai/DeepSeek-R1-Distill-Llama-8B ...
```

---

## References & Citation

- Leverages the latest LLM research ecosystem: HuggingFace Transformers, Accelerate, TRL, FlashAttention, etc.
- When citing, please reference both this project and related papers.

---




llama2-7b
    * preprocess
        - `accelerate launch --num_processes 2 preprocess.py --batch_size 1 --model_type llama2-7b --output_dir /big-data0/preprocessed_llama2-7b_sharegpt_2048_eagle1 --dataset_id sharegpt --max_length 2048 --depth 3 --top_draft 4 --top_node 128`
    * train
        - `TOKENIZERS_PARALLELISM=false accelerate launch --num_processes 2 train.py --output_dir paper/llama2-7b-ce-3 --overwrite_output_dir --fp16 --num_train_epochs 10 --warmup_epochs 1 --eval_epochs 2 --save_epochs 1 --loss_type distill_cross_entropy --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 16 --learning_rate 3e-4 --top_k 0 --data_dir /big-data0/preprocessed_llama2-7b_sharegpt_2048_eagle1 --model_name meta-llama/Llama-2-7b-chat-hf --logging_steps 50 --max_length 2048`
        - `TOKENIZERS_PARALLELISM=false accelerate launch --num_processes 2 train.py --output_dir paper/llama2-7b-chained-3 --overwrite_output_dir --fp16 --num_train_epochs 3 --warmup_epochs 0 --eval_epochs 1 --save_epochs 1 --loss_type chained --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 16 --learning_rate 1e-4 --top_k 0 --data_dir /big-data0/preprocessed_llama2-7b_sharegpt_2048_eagle1 --model_name meta-llama/Llama-2-7b-chat-hf --logging_steps 50 --max_length 2048 --depth 3 --pretrained_model paper/llama2-7b-ce-3/final`
        - `TOKENIZERS_PARALLELISM=false accelerate launch --num_processes 2 train.py --output_dir paper/llama2-7b-chained-tree-3 --overwrite_output_dir --fp16 --num_train_epochs 3 --warmup_epochs 0 --eval_epochs 1 --save_epochs 0.1 --loss_type tree --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 16 --learning_rate 1e-4 --top_k 0 --data_dir /big-data0/preprocessed_llama2-7b_sharegpt_2048_eagle1 --model_name meta-llama/Llama-2-7b-chat-hf --logging_steps 50 --max_length 2048 --pretrained_model paper/llama2-7b-chained-3/final`
            -

llama3_1-8b
    * preprocess
        - `accelerate launch --num_processes 2 preprocess.py --batch_size 1 --model_type llama3_1-8b --output_dir /big-data0/preprocessed_llama3_1-8b_sharegpt_2048_eagle1 --dataset_id sharegpt --max_length 2048 --depth 3 --top_draft 4 --top_node 128`
    * train
        - `TOKENIZERS_PARALLELISM=false accelerate launch --num_processes 2 train.py --output_dir paper/llama3_1-8b-ce-3 --overwrite_output_dir --fp16 --num_train_epochs 10 --warmup_epochs 1 --eval_epochs 2 --save_epochs 1 --loss_type distill_cross_entropy --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 16 --learning_rate 3e-4 --top_k 0 --data_dir /big-data0/preprocessed_llama3_1-8b_sharegpt_2048_eagle1 --model_name meta-llama/Llama-3.1-8B-Instruct --logging_steps 50 --max_length 2048`
        - `TOKENIZERS_PARALLELISM=false accelerate launch --num_processes 2 train.py --output_dir paper/llama3_1-8b-chained-3 --overwrite_output_dir --fp16 --num_train_epochs 3 --warmup_epochs 0 --eval_epochs 1 --save_epochs 1 --loss_type chained --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 16 --learning_rate 1e-4 --top_k 0 --data_dir /big-data0/preprocessed_llama3_1-8b_sharegpt_2048_eagle1 --model_name meta-llama/Llama-3.1-8B-Instruct --logging_steps 50 --max_length 2048 --pretrained_model paper/llama3_1-8b-ce-3/final`
        - `TOKENIZERS_PARALLELISM=false accelerate launch --num_processes 2 train.py --output_dir paper/llama3_1-8b-chained-tree-3 --overwrite_output_dir --fp16 --num_train_epochs 3 --warmup_epochs 0 --eval_epochs 1 --save_epochs 0.1 --loss_type tree --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 16 --learning_rate 1e-4 --top_k 0 --data_dir /big-data0/preprocessed_llama3_1-8b_sharegpt_2048_eagle1 --model_name meta-llama/Llama-3.1-8B-Instruct --logging_steps 50 --max_length 2048 --pretrained_model paper/llama3_1-8b-chained-3/final`


r1-distill-llama-8b
    * preprocess
        - `accelerate launch --num_processes 2 preprocess.py --batch_size 1 --model_type r1-distill-llama-8b --output_dir /big-data0/preprocessed_r1-distill-llama-8b_sharegpt_2048_eagle1 --dataset_id sharegpt --max_length 2048 --depth 3 --top_draft 4 --top_node 128`
    * train
        - `TOKENIZERS_PARALLELISM=false accelerate launch --num_processes 2 train.py --output_dir paper/r1-distill-llama-8b-ce-3 --overwrite_output_dir --fp16 --num_train_epochs 10 --warmup_epochs 1 --eval_epochs 2 --save_epochs 1 --loss_type distill_cross_entropy --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 16 --learning_rate 3e-4 --top_k 0 --data_dir /big-data0/preprocessed_r1-distill-llama-8b_sharegpt_2048_eagle1 --model_name deepseek-ai/DeepSeek-R1-Distill-Llama-8B --logging_steps 50 --max_length 2048`
        - `TOKENIZERS_PARALLELISM=false accelerate launch --num_processes 2 train.py --output_dir paper/r1-distill-llama-8b-chained-3 --overwrite_output_dir --fp16 --num_train_epochs 3 --warmup_epochs 0 --eval_epochs 1 --save_epochs 1 --loss_type chained --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 16 --learning_rate 1e-4 --top_k 0 --data_dir /big-data0/preprocessed_r1-distill-llama-8b_sharegpt_2048_eagle1 --model_name deepseek-ai/DeepSeek-R1-Distill-Llama-8B --logging_steps 50 --max_length 2048 --pretrained_model paper/r1-distill-llama-8b-ce-3/final`
        - `TOKENIZERS_PARALLELISM=false accelerate launch --num_processes 2 train.py --output_dir paper/r1-distill-llama-8b-chained-tree-3 --overwrite_output_dir --fp16 --num_train_epochs 3 --warmup_epochs 0 --eval_epochs 1 --save_epochs 0.1 --loss_type tree --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 16 --learning_rate 1e-4 --top_k 0 --data_dir /big-data0/preprocessed_r1-distill-llama-8b_sharegpt_2048_eagle1 --model_name deepseek-ai/DeepSeek-R1-Distill-Llama-8B --logging_steps 50 --max_length 2048 --pretrained_model paper/r1-distill-llama-8b-chained-3/final`
