Our code is based on the [TinyLlama](https://github.com/jzhang38/TinyLlama) codebase.

### Setup
You can follow the [TinyLlama](https://github.com/jzhang38/TinyLlama) PRETRAIN.md to install the conda environment, download the Slimpajama dataset 
and preprocess the dataset.

### Pretrain
We provide the script to pre-train MDMs and ARMs on the Slimpajama dataset.
```angular2html
# pre-train MDM, e.g., pre-train MDM with 1028M non-embedding parameters and 1e20 FLOPs

echo 'n' | lightning run model \
    --node-rank=0  \
    --accelerator=cuda \
    --devices=8 \
    --num-nodes=1 \
    pretrain_scale/train_diffusion_slim_1_mc.py --model 1028 --flops 100

# pre-train MDM with random length, e.g., pre-train MDM with 170M non-embedding 
#parameters and 6e19 FLOPs, and set 1% training data to random length
echo 'n' | lightning run model \
    --node-rank=0  \
    --accelerator=cuda \
    --devices=8 \
    --num-nodes=1 \
    pretrain_scale/train_diffusion_slim_1_mc_ssl.py --model 170 --flops 60 --ssl_ratio 0.01

# pre-train ARM, e.g., pre-train ARM with 170M non-embedding parameters and 6e18 FLOPs
echo 'n' | lightning run model \
    --node-rank=0  \
    --accelerator=cuda \
    --devices=8 \
    --num-nodes=1 \
    pretrain_scale/train_ar_slim_1.py --model 170 --flops 6
```


### Finetune
#### ShareGPT
Please first download the [ShareGPT](https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered) dataset and put it in ./data
```angular2html
# Finetune MDM, e.g., finetune a pre-trained MDM with 1028M non-embedding parameters and 1e20 FLOPs
echo 'n' | lightning run model \
    --node-rank=0  \
    --accelerator=cuda \
    --devices=8 \
    --num-nodes=1 \
    sft/finetune_diff.py --model 1028 --flops 100 --bs 256

# Finetune ARM
echo 'n' | lightning run model \
    --node-rank=0  \
    --accelerator=cuda \
    --devices=8 \
    --num-nodes=1 \
    sft/finetune_ar.py --model 1028 --flops 100 --bs 256
```

#### Reverse curse
```angular2html
# Finetune MDM
echo 'n' | lightning run model \
    --node-rank=0  \
    --accelerator=cuda \
    --devices=8 \
    --num-nodes=1 \
    sft/finetune_diff_reverse.py --model 1028 --flops 1600 --bs 64
```


### Evaluation
#### MT-Bench
```angular2html
# Generative answer first

# MDM
python eval/gen_model_answer.py --model-path $path --answer-file "data/mt_bench/model_answer/diff.jsonl --model-id 1028 --model-type 'diff' 
--set-temperature 1. --steps 256 --cfg-scale 1.8 --alg "origin"

# ARM
python eval/gen_model_answer.py --model-path $path --answer-file "data/mt_bench/model_answer/ar.jsonl" --model-id 1028 --model-type 'ar' --set-temperature 1.

# Then evaluate the MT-Bench
export OPENAI_API_KEY=xxxxxxxxxxxxxx
python eval/gen_judgment.py --parallel 10 --judge-model "gpt-4o-2024-05-13"

# Show results
python eval/show_result.py --judge-model "gpt-4o-2024-05-13"

```

### Reasoning
```angular2html
# Chain Rule
python evaluate_diff.py --tasks hellaswag,openbookqa,arc_easy,boolq,piqa,social_iqa,race --model mdlm --batch_size 64 --model_args model_name=170,ckpt_path=$path,nll_type='ar_ftb',cfg=$cfg

# MC
python evaluate_diff.py --tasks hellaswag,openbookqa,arc_easy,boolq,piqa,social_iqa,race --model mdlm --batch_size 64 --model_args model_name=170,ckpt_path=$path,cfg=$cfg

# LAMBADA
python evaluate_diff.py --tasks lambada_standard --model mdlm --batch_size 64 --model_args model_name=170,ckpt_path=$path,nll_type='ar_ftb',greddy=3
```

#### Reversal Curse
```angular2html
python evaluate_reverse.py --model_type diff --qs_type dtn --model 1028 --ckpt-path $path --cfg 1.8
python evaluate_reverse.py --model_type diff --qs_type ntd --model 1028 --ckpt-path $path --cfg 1.8
```

#### Temporal quality degradation
```angular2html
python scripts/prepare_fineweb.py

python evaluate_fineweb.py --type diff --model 1028 --ckpt-path $path --fineweb "CC-MAIN-2024-18" --batch-size 32
python evaluate_fineweb.py --type diff --model 1028 --ckpt-path $path --fineweb "CC-MAIN-2024-10" --batch-size 32
```
