# AutoJudge

## 🚀 Running the code

Our approach introduces an algorithm for automatically identifying important token mismatches in model generations. We extract hidden states for these tokens, train a lightweight classifier to detect them, and employ it during inference.

To reproduce our results, follow these steps:

1. Run the dataset mining script
2. Calculate hidden states
3. Train the classifier
4. Run evaluations

**📦 Mined datasets will be published on Hugging Face shortly, so you can skip first two steps once released!**

### 🛠️ Getting started

Install packages from `requirements.txt`:
```bash
pip install -r requirements.txt
```

### ⛏️ Dataset mining 💎

Here we provide a small snippet of how to run dataset mining for GSM8K and LiveCodeBench, for the detailed instructions including multiple-gpu run please refer to the [`find_important_tokens_gsm8k.sh`](scripts/find_important_tokens_gsm8k.sh) and [`find_important_tokens_lcb.sh`](scripts/find_important_tokens_lcb.sh) scripts.

#### 📐 GSM8K 🔢


```bash
export MODEL0="meta-llama/Llama-3.2-1B-Instruct"
export MODEL1="meta-llama/Llama-3.1-8B-Instruct"
export TORCH_DTYPE=auto
export GSM8K_TRAIN=data/train_small.jsonl # replace by data/gsm8k_train.jsonl for full run
export RANDOM_SEED=42
export MAX_NEW_TOKENS=2048
export OUTPUT_FOLDER=output
export OUTPUT_FILE=important_tokens
export DUMP_FREQ=64

mkdir $OUTPUT_FOLDER

# one-gpu run

CUDA_VISIBLE_DEVICES=0 python3 src/find_important_tokens.py \
    --draft_model $MODEL0 \
    --target_model $MODEL1 \
    --torch_dtype $TORCH_DTYPE \
    --gsm8k_train_path $GSM8K_TRAIN \
    --random_seed $RANDOM_SEED \
    --max_new_tokens $MAX_NEW_TOKENS \
    --output_folder $OUTPUT_FOLDER \
    --output_file $OUTPUT_FILE \
    --dump_freq $DUMP_FREQ \
    --local_world_size 1

rm output/done*
```

#### 💻 LiveCodeBench 📄

```bash
export MODEL0="meta-llama/Llama-3.2-1B-Instruct"
export MODEL1="meta-llama/Llama-3.1-8B-Instruct"
export TORCH_DTYPE=auto
export GSM8K_TRAIN=data/train_small.jsonl # replace by data/train.jsonl for full run
export RANDOM_SEED=42
export MAX_NEW_TOKENS=2048
export OUTPUT_FOLDER=output
export OUTPUT_FILE=important_tokens_lcb
export DUMP_FREQ=64
export NUM_PROCESS_EVALUATE=64
export N_TASKS=2 # will use 2 tasks for short demo, set 880 for full lcb release_v5 dataset
export TOTAL_GPUS=1

mkdir $OUTPUT_FOLDER

# one-gpu run

CUDA_VISIBLE_DEVICES=0 python3 src/find_important_tokens_lcb.py \
    --draft_model $MODEL0 \
    --target_model $MODEL1 \
    --torch_dtype $TORCH_DTYPE \
    --random_seed $RANDOM_SEED \
    --max_new_tokens $MAX_NEW_TOKENS \
    --output_folder $OUTPUT_FOLDER \
    --output_file $OUTPUT_FILE \
    --dump_freq $DUMP_FREQ \
    --n_tasks $N_TASKS \
    --num_process_evaluate $NUM_PROCESS_EVALUATE \
    --total_gpus $TOTAL_GPUS \
    --local_world_size $TOTAL_GPUS
```    

### 🧮 Calculating hidden states ⚙️

For the full script including multiple-gpus run please refer to the [`calc_hiddens.sh`](scripts/calc_hiddens.sh) script.

```bash
export MODEL0="meta-llama/Llama-3.2-1B-Instruct"
export MODEL1="meta-llama/Llama-3.1-8B-Instruct"
export TORCH_DTYPE=auto
export BATCH_SIZE=8
export DATA_FILE=output/important_tokens.pt
export OUTPUT_PATH=output/important_tokens_with_hiddens
export SAVE_FREQ=128
export N_PROCESSES=1

# single gpu run
CUDA_VISIBLE_DEVICES=0 python src/calc_hiddens.py \
    --draft_model $MODEL0 \
    --target_model $MODEL1 \
    --torch_dtype $TORCH_DTYPE \
    --batch_size $BATCH_SIZE \
    --data_file $DATA_FILE \
    --output_path $OUTPUT_PATH \
    --save_freq $SAVE_FREQ \
    --n_processes $N_PROCESSES \
    --process_id 0 

```

### 🧠 Training a classifier 🎯
To train a classifier, we refer you to scripts [`src/train_head_gsm8k.py`](src/train_head_gsm8k.py) and [`src/train_head_lcb.py`](src/train_head_lcb.py).

You can run them using the following command:
```bash

# GSM8K
export MODEL0="meta-llama/Llama-3.2-1B-Instruct"
export MODEL1="meta-llama/Llama-3.1-8B-Instruct"
export DATA_PATH=output/important_tokens_with_hiddens.pt

python train_head_gsm8k.py --random_seed 52 --train_size 0.9 --data_path $DATA_PATH --checkpoint_path output/trained_head.pkl --target_model $MODEL1 --draft_model $MODEL0

# LiveCodeBench

python train_head_lcb.py --random_seed 52 --train_size 0.9 --data_path $DATA_PATH --checkpoint_path output/trained_head.pkl --target_model $MODEL1 --draft_model $MODEL0

```

To convert trained head to our vllm format you can run the following python script:

```python
import pandas as pd
import os
import pickle

checkpoint = pd.read_pickle('trained_head.pkl')

head = checkpoint['model']
scaler = checkpoint['scaler']

target_hidden_size = 4096 # 4096 for Llama-3.1-8B-Instruct and 8192 for Llama-3.1-70B-Instruct
head_dict = dict(
    mean=scaler.mean_[-target_hidden_size:],
    scale=scaler.scale_[-target_hidden_size:],
    weights=head.coef_[0][-target_hidden_size:],
    bias=head.intercept_[-target_hidden_size:],
    thr=0.25
)

vllm_checkpoint_path = 'vllm_compatible_head.pkl'

with open(vllm_checkpoint_path, 'wb') as f:
    dump_dict = head_dict
    pickle.dump(dump_dict, f)

```

### 📊 Evaluations 📝

## Accuracy vs Average Accepted Tokens, Sections 4.1 and 4.2

Here we provide evaluation example for GSM8K, similar scripts were used to obtain main results on LiveCodeBench. To run it, please refer to [`eval/run_inference_task_lcb_folds.py`](eval/run_inference_task_lcb_folds.py) and [`eval/run_inference_lcb_topk_baseline_task.py`](eval/run_inference_lcb_topk_baseline_task.py). There for each threshold(ours) and K(for baseline) values we also vary `FOLD_ID` since we use out-of-fold technique.

### AutoJudge Eval
```bash
export START_IDX=0
export END_IDX=10 # set 1319 for full eval
export THR_ID=0 # vary this from 0 to 25, thresholds for inference are selected automatically in train scripts
export MODEL0="meta-llama/Llama-3.2-1B-Instruct"
export MODEL1="meta-llama/Llama-3.1-8B-Instruct"
export NUM_SHOTS=0 # set 8 for 8 shot setup
export MAX_NEW_TOKENS=1024

# Running eval on 2 gpus
CUDA_VISIBLE_DEVICES=0,1 python3 eval/gpu_parallel.py --gpus_per_script 1 --start $START_IDX --end $END_IDX --use_queue --script eval/run_inference_task.py --extra_args "--save_folder output/eval_$THR_ID --gsm8k_test_path data/gsm8k_test.json --torch_dtype auto --window_size 64 --head_path output/trained_head.pkl --setup DD-DT --max_new_tokens $MAX_NEW_TOKENS --num_shots $NUM_SHOTS --head_threshold_idx $THR_ID --draft_model $MODEL0 --target_model $MODEL1"
```

### Top-K Baseline Eval
```bash
export START_IDX=0
export END_IDX=10 # set 1319 for full eval
export K=2048 # to be varied, we considered the following values [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 128256]
export MODEL0="meta-llama/Llama-3.2-1B-Instruct"
export MODEL1="meta-llama/Llama-3.1-8B-Instruct"
export NUM_SHOTS=0 # set 8 for 8 shot setup
export MAX_NEW_TOKENS=1024

# Running eval on 2 gpus
CUDA_VISIBLE_DEVICES=0,1 python3 eval/gpu_parallel.py --gpus_per_script 1 --start $START_IDX --end $END_IDX --use_queue --script eval/run_inference_topk_baseline_task.py --extra_args "--save_folder output/eval_baseline_$K --gsm8k_test_path data/gsm8k_test.json --torch_dtype auto --window_size 64 --head_path output/trained_head.pkl --setup DD-DT --max_new_tokens $MAX_NEW_TOKENS --num_shots $NUM_SHOTS --K $K --draft_model $MODEL0 --target_model $MODEL1"
```

To make a final report based on the evaluation outputs you can use the following snippet:
```python
import pandas as pd
import numpy as np
import os

def make_pareto_curve_df(data, group_by_col='thr'):
    df = pd.DataFrame(data)

    mean_accept = pd.DataFrame(
        df.groupby(group_by_col).apply(lambda x: np.concatenate(x['raw_accepts'].tolist()).mean()),
        columns=['mean_accept']
    ).reset_index()

    gsm_acc = pd.DataFrame(
        df.groupby(group_by_col).apply(lambda x: np.mean(x['tp'])),
        columns=['gsm8k_acc']
    ).reset_index()

    pareto_curve_df = pd.merge(left=mean_accept, right=gsm_acc, on=group_by_col).sort_values(by=[group_by_col])

    return pareto_curve_df

AJ_DIRS = ['output/eval_0', 'output/eval_1'] # output/eval_2, ... output/eval_25
aj_data = []
for DIR in AJ_DIRS:
    files = os.listdir(DIR)
    aj_data.extend([pd.read_pickle(os.path.join(DIR, f)) for f in files])

autojudge_df = make_pareto_curve_df(aj_data)
print(autojudge_df)

BASELINE_DIRS = ['output/eval_baseline_0', 'output/eval_baseline_1'] # output/eval_baseline_2, ... output/eval_baseline_17
baseline_data = []
for DIR in BASELINE_DIRS:
    files = os.listdir(DIR)
    baseline_data.extend([pd.read_pickle(os.path.join(DIR, f)) for f in files])

baseline_df = make_pareto_curve_df(baseline_data, group_by_col='k')
print(baseline_df)
```

## VLLM, Section 4.3
Clone `vllm` repository and checkout commit `a83a0f92b56b71855dc38e8e3d9809619e58bcd1`.
Copy the contents of `vllm` folder from our repository to the cloned `vllm` folder
to merge our patches with the original code.
Install vllm with `VLLM_USE_PRECOMPILED pip install -e path/to/vllm/folder`.


# Evaluating GSM8K
Run commands

```bash
python vllm_gsm8k.py -o $RESULTS_FILE --target_model meta-llama/Llama-3.1-8B-Instruct\
    --draft_model meta-llama/Llama-3.2-1B-Instruct --judge_path vllm_heads/head_${SHOTS}shot_8b.pkl --judge_threshold $THRESHOLD --shots $SHOTS
```
```bash
python vllm_gsm8k.py -o $RESULTS_FILE --target_model meta-llama/Llama-3.1-70B-Instruct\
    --draft_model meta-llama/Llama-3.1-8B-Instruct --judge_path vllm_heads/head_${SHOTS}shot_70b.pkl --judge_threshold $THRESHOLD --shots $SHOTS
```
to run evaluations on GSM8K dataset.
`SHOTS` can be either 0 or 8.

For example, to reproduce results for 0-shot 70B/8B model run
```bash
for threshold in 0.03719609313336198 0.07084856680433153 0.09208237305325259 0.13549077699786996 0.2209569576527778; do
python vllm_gsm8k.py -o results\
    --target_model meta-llama/Llama-3.1-70B-Instruct\
    --draft_model meta-llama/Llama-3.1-8B-Instruct\
    --judge_path vllm_heads/head_0shot_70b.pkl --judge_threshold $threshold --shots 0
done
```

# Evaluating LCB
Run commands

```bash
LCB_MAX_MEMORY_BYTES=34359738368 python vllm_lcb.py -o $RESULTS_FILE --target_model meta-llama/Llama-3.1-8B-Instruct\
     --draft_model meta-llama/Llama-3.2-1B-Instruct --judge_path head_lcb_8b.pkl --judge_threshold $THRESHOLD
```

```bash
LCB_MAX_MEMORY_BYTES=34359738368 python vllm_lcb.py -o $RESULTS_FILE --target_model meta-llama/Llama-3.1-70B-Instruct\
     --draft_model meta-llama/Llama-3.1-8B-Instruct --judge_path head_lcb_70b.pkl --judge_threshold $THRESHOLD
```
If you receive an error while performing multi-gpu 70B/8B evaluation, you can mitigate it by running one fold at a time:
```bash
LCB_MAX_MEMORY_BYTES=34359738368 python vllm_lcb.py -o $RESULTS_FILE --target_model meta-llama/Llama-3.1-70B-Instruct\
     --draft_model meta-llama/Llama-3.1-8B-Instruct --judge_path head_lcb_70b.pkl --judge_threshold $THRESHOLD --fold $fold
```

For example, to reproduce 70B/8B evaluation on LCB dataset, run
```bash
for threshold in 0.0 0.05 0.075 0.1 0.125; do
    for fold in 0 1 2 3 4; do
        LCB_MAX_MEMORY_BYTES=34359738368 python vllm_lcb.py -o results --target_model meta-llama/Llama-3.1-70B-Instruct\
             --draft_model meta-llama/Llama-3.1-8B-Instruct\
            --judge_path head_lcb_70b.pkl --judge_threshold $threshold --fold $fold
    done
done
```
