# Self-Distillation through time code

## Summary
* The entrypoint to our code is located in `src/fast_discrete_diff/main.py`. It can be called to train, sample and evaluate our models. The mode (train, sample, eval) is selected via the `mode` flag in the config.
* We use hydra for configuration management, and all configuration files are in `src/fast_discrete_diff`. 
* The implementation of the distillation algorithm is in `src/fast_discrete_diff/core/distill/mdlm_double_dt_correct.py`. It contains the code to compute the loss and the training loop. We use Pytorch Lightning to organize our code cleanly.


## Example commands
* Train a model
```bash
python src/fast_discrete_diff/main.py \
    mode=train \
    data_preprocess.legacy_start_end_bos=True \
    parameterization=distill-mdlm \
    parameterization.num_distill_steps=2 \
    model=dit-orig-small \
    time_conditioning=False \
    loader.global_batch_size=128 \
    loader.batch_size=32 \
    trainer.max_steps=80000 \
    hydra.run.dir="./outputs/distill_2_steps" \
    loader.num_workers=16 \
    compile=False \
    trainer.num_nodes=1 \
    trainer.devices=1 \
    trainer.val_check_interval=5000 \
    trainer.precision="bf16-mixed" \
    data=openwebtext-split
```
* Sample from a trained model
```bash
# Unconditional
python src/fast_discrete_diff/main.py \
    mode=sample \
    data_preprocess.legacy_start_end_bos=True \
    parameterization=distill-mdlm \
    parameterization.num_distill_steps=2 \
    model=dit-orig-small \
    time_conditioning=False \
    loader.global_batch_size=128 \
    loader.batch_size=32 \
    trainer.max_steps=80000 \
    hydra.run.dir="./outputs/distill_2_steps" \
    loader.num_workers=16 \
    compile=False \
    \
    trainer.num_nodes=1 \
    trainer.devices=1 \
    trainer.val_check_interval=5000 \
    trainer.precision="bf16-mixed" \
    \
    data=openwebtext-split
    \
    parameterization.sampling.uncond.run=True \
    parameterization.sampling.uncond.num_steps=64 \
    checkpointing.resume_ckpt_path=./student_checkpoints/70000.ckpt

# Conditional generation with OpenAI's webtext
python src/fast_discrete_diff/main.py \
    mode=sample \
    data_preprocess.legacy_start_end_bos=True \
    parameterization=distill-mdlm \
    parameterization.num_distill_steps=2 \
    model=dit-orig-small \
    time_conditioning=False \
    loader.global_batch_size=128 \
    loader.batch_size=32 \
    trainer.max_steps=80000 \
    hydra.run.dir="./outputs/distill_2_steps" \
    loader.num_workers=16 \
    compile=False \
    \
    trainer.num_nodes=1 \
    trainer.devices=1 \
    trainer.val_check_interval=5000 \
    trainer.precision="bf16-mixed" \
    \
    data=openwebtext-split
    \
    parameterization.sampling.cond_prefix.run=True \
    parameterization.sampling.cond_prefix.add_bos=True \
    parameterization.sampling.cond_prefix.dataset=webtext \
    parameterization.sampling.cond_prefix.num_steps=64 \
    checkpointing.resume_ckpt_path=./student_checkpoints/70000.ckpt
```
* Evaluate generative perplexity
```bash
# Will evaluate all unconditional samples generated from the directory
python src/fast_discrete_diff/main.py \
    mode=eval \
    data_preprocess.legacy_start_end_bos=True \
    parameterization=distill-mdlm \
    parameterization.num_distill_steps=2 \
    model=dit-orig-small \
    time_conditioning=False \
    loader.global_batch_size=128 \
    loader.batch_size=32 \
    trainer.max_steps=80000 \
    hydra.run.dir="./outputs/distill_2_steps" \
    loader.num_workers=16 \
    compile=False \
    \
    trainer.num_nodes=1 \
    trainer.devices=1 \
    trainer.val_check_interval=5000 \
    trainer.precision="bf16-mixed" \
    \
    data=openwebtext-split
    \
    eval.ppl_with_ar.run=True
```

* Evaluate MAUVE
```bash
# Will evaluate all conditional samples generated in the directory
python src/fast_discrete_diff/main.py \
    mode=eval \
    data_preprocess.legacy_start_end_bos=True \
    parameterization=distill-mdlm \
    parameterization.num_distill_steps=2 \
    model=dit-orig-small \
    time_conditioning=False \
    loader.global_batch_size=128 \
    loader.batch_size=32 \
    trainer.max_steps=80000 \
    hydra.run.dir="./outputs/distill_2_steps" \
    loader.num_workers=16 \
    compile=False \
    \
    trainer.num_nodes=1 \
    trainer.devices=1 \
    trainer.val_check_interval=5000 \
    trainer.precision="bf16-mixed" \
    \
    data=openwebtext-split
    \
    eval.ppl_with_ar.run=False \
    eval.mauve.run=True \
```

* Evaluate on LAMBADA
```bash
python src/fast_discrete_diff/main.py \
    mode=eval \
    data_preprocess.legacy_start_end_bos=True \
    parameterization=distill-mdlm \
    parameterization.num_distill_steps=2 \
    model=dit-orig-small \
    time_conditioning=False \
    loader.global_batch_size=128 \
    loader.batch_size=32 \
    trainer.max_steps=80000 \
    hydra.run.dir="./outputs/distill_2_steps" \
    loader.num_workers=16 \
    compile=False \
    \
    trainer.num_nodes=1 \
    trainer.devices=1 \
    trainer.val_check_interval=5000 \
    trainer.precision="bf16-mixed" \
    \
    data=openwebtext-split
    \
    eval.ppl_with_ar.run=False \
    eval.lambada_openai.run=True \
    eval.lambada_openai.num_samples=64 
```
