## Contents

- [Installation](#installation)
- [Dataset Generation](#dataset_generation)
  - [Truncated Dataset](#truncated_dataset)
  - [Distillation Dataset](#distillation_dataset)
- [Training Special tokens](#special_tokens)
- [Inference](#inference)
  - [MT Bench](#mt_bench)
  - [Alpaca Eval](#alpaca_eval)
  - [HumanEval](#humaneval)
  - [GSM8K](#gsm8k)

## Installation

```bash
git clone <repo-url>
cd prompt-decoding
pip install -e .
```

## Dataset Generation

### Truncated Dataset

With a given dataset, a random truncation is performed to reduce the contextual bias of the training of sepcial tokens. Then, a distillation dataset is generated from the truncated dataset.

The truncated datasets need to be generated first. Here is how a dataset for 3 special tokens can be generated for the ShareGPT dataset.

```
python generate_dataset.py --dataset_type finetune --num_special_tokens 3 --data_path ShareGPT_V4.3_unfiltered_cleaned_split.json --model_max_length 2048
```

### Distillation Dataset

Then, we can generate the distillation dataset from the truncated dataset. `--data_path` is the path to the previously generated truncated dataset and `--model_name_or_path` is the model the distribution of which we want to obtain.

```
python generate_dataset.py --dataset_type distillation  --num_special_tokens 3 --data_path ShareGPT_training_dataset_3_finetune_2048.pt --model_max_length 2048 --model_name_or_path lmsys/vicuna-7b-v1.3
```

## Training Special tokens

Example script to train Vicuna-7b with distillation dataset named "ShareGPT_training_dataset_2_distillation.pt".

```
accelerate launch --num_processes 4 prompt/train/train.py --model_name_or_path lmsys/vicuna-7b-v1.3 \
    --dataset_path "./ShareGPT_training_dataset_2_distillation.pt" \
    --output_dir test/ \
    --num_train_epochs 1 \
    --save_steps 500 \
    --model_max_length 2048 \
    --num_special_tokens 3 \
    --virtual_tokens_per_special_token 1 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy "no" \
    --learning_rate 1e-2 \
    --weight_decay 0.0 \
    --warmup_ratio 0.0 \
    --lr_scheduler_type "cosine" \
    --logging_steps 10 \
    --load_in_4bit \
    --vt_attention_type "ensemble" \
    --trainer_type "distillation_trainer"
```

You need to change the `--dataset_path` to the the location of the distillation dataset and specify `--trainer_type` as "distillation_trainer" to train with knowledge distillation. `--num_special_tokens` specifies the number of special tokens for training. `--virtual_tokens_per_special_token` is the number of virtual tokens used for 1 special token, which should be set to 1 to achieve the lowest latency results.

## Inference

We employ a dynamically extended tree attention and top K candidates for inference. The supported evaluation datasets currently include [Alpaca Eval](https://huggingface.co/datasets/tatsu-lab/alpaca_eval/blob/0cd24d711fe90d0c1aae5bde03fe98ee48ae52f8/alpaca_eval.json), [MT Bench](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge), and [HumanEval](https://github.com/openai/human-eval).

Refer to this [README.md](application/README.md) on how to install libraries for these datasets.

### MT Bench

- To obtain the latency of the baseline model (without the use of special tokens), run

```
python3 gen_model_answer_baseline.py
  --model-path <model-path>
  --model-id <model-id>
  --answer-file <output-file>
  --bench-name mt_bench
  --temperature <temperature>
```

- To obtain the latency of the model with special tokens, run

```
python3 gen_model_answer_prompt_decoding.py
  --model-path <model-path>
  --model-id <model-id>
  --answer-file <output-file>
  --tree-length 105
  --bench-name mt_bench
  --temperature <temperature>
```

`--model-path` is the path to the trained special tokens and `--tree-length` is the length of the sparse tree used.

- To view the latency results of a generated `.jsonl` file, run

```
python get_throughput_results.py data/mt_bench/experiments/vicuna-7b-faster1.jsonl --n 3
```

`--n` specifies the number of experiment runs to get the average of.

### Alpaca Eval

We use Alpaca Eval dataset as the evaluation dataset. The latency results can be obtained using the same script as MT Bench and adding `--bench-name alpaca_eval`.

- To compare the latencies and accept lengths between sparse trees with different sizes, run:

```
python accept_length.py \
  --dir-path <output-dir> \
  --file-name <file-name> \
  --model-name <model-path> \
  --eval-file-name gen_model_answer_prompt_decoding.py \
  --n 1 \
  --max-length 120 \
  --min-length 60 \
  --length-interval 9 \
  --choices "[75, 105, 135, 165, 195, 225, 255, 285]" \

python3 tree_latency.py \
  --model-path <model-path> \
  --model-id <model-id> \
  --answer-file <output-file> \
  --bench-name alpaca_eval \
  --min-tree-length 60 \
  --max-tree-length 120 \
  --length-interval 3 \
  --max-new-token 1024
```

This [script](script/latency/optimal-sparse-tree.sh) runs the latency tests on a range of sparse trees.

### HumanEval

The latency results of HumanEval can be obtained using the same script as MT Bench and adding `--bench-name humaneval`.

### GSM8K

The latency results of GSM8K can be obtained using the same script as MT Bench and adding `--bench-name gsm8k`.
