### Introduction
This is the basic implementation of CTC-drafter, including training and evaluation. The code is modified from [Medusa](https://github.com/FasterDecoding/Medusa). 
### Installaion
First, create a python>=3.9 environment. Then use pip to install the packages.
```bash
cd CTC-drafter
pip install medusa-llm
```
We use open-source Vicuna models, you can download the weights: [Vicuna-7b](https://huggingface.co/lmsys/vicuna-7b-v1.3) [Vicuna-13b](https://huggingface.co/lmsys/vicuna-13b-v1.3) [Vicuna-33b](https://huggingface.co/lmsys/vicuna-33b-v1.3)
### Training
For training, please install:
```bash
pip install -e ".[train]"
```
We take a public version of the ShareGPT dataset, which is a subset of the Vicuna training data. For other models, you can use the corresponding training dataset.
```bash
git clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered
```
We adopt knowledge distillation method:
```bash
cd medusa/train
export CUDA_VISIBLE_DEVICES= 0,1,2,3
python gen_distilled_data.py --gpu_index 0,1,2,3 --outdir /path/to/your/distilled/data
```
run the command to start training, you can find the remmended configs in CTC-drafter/medusa/train/config:
```bash
accelerate launch --mixed_precision=bf16 train.py \
--tmpdir /path/to/distilled/data \
--cpdir /path/where/you/save/the/weights \
--basepath /path/to/base/model \
--configpath /path/to/training/configs
```
### Evaluation
After training, you can evaluate the speedup performance on MT-bench and GSM8K. The model answers will be recorded in llm_judge/data/mt(GS)_bench :
```bash
cd llm_judge
python gen_answer_mt(GSM8K) \
--model-path /path/to/trained/weights \
--base-model-path /path/to/base/model \
--model-id  evaluation\
--tree-choices medusa \
--bench-name mt_bench(GS_bench) \
--use-safetensor-weight True 
```
you can also run the baseline without speculation method for comparison:
```bash
cd llm_judge
python gen_baseline_answer_mt(GSM8K) \
--model-path /path/to/trained/weights \
--base-model-path /path/to/base/model \
--model-id  evaluation\
--tree-choices medusa \
--bench-name mt_bench(GS_bench) \
--use-safetensor-weight True 
```


