# Differentiable, Stable and Efficient Floating-Point Quantization
- Target [microsoft/microxcaling](https://github.com/microsoft/microxcaling)-compliant parameters
- Square block-quantize to guarantee transpose-commutativity of quantization
- **Differentiable** automatic mixed precision via pseudo-quantization training (PQT)

## Outline
- `nanoGPT` - PoC of GPT training from [karpathy/nanoGPT](https://github.com/karpathy/nanoGPT)
- `torchtitan` - PoC of LLaMA training from [pytorch/torchtitan](https://github.com/pytorch/torchtitan)
- `Adam-mini` - memory-efficient optimizer from [zyushun/Adam-mini](https://github.com/zyushun/Adam-mini)
- `Apollo` - memory-efficient optimizer, from [MadsToftrup/Apollo-dev](https://github.com/MadsToftrup/Apollo-dev)
- `mx-amp-kernel` - Performant PQT Triton kernels with validation and benchmark

## in common
MX datatype should be set manually in `mx-amp-kernel/linear.py:MxQ`.

## To reproduce GPT2-124M results
Use `nvcr.io/nvidia/pytorch:24.10-py3` container as a training environment.
Tested on 3090, 4090 and A100 GPUs with R565 driver.

```bash
#!/bin/bash
cd nanoGPT

# install dependency
python3 -m pip install -r requirements.txt

# prepare the OpenWebText dataset
python3 data/openwebtext/prepare.py

# to enable remote logging
wandb login

# edit training configuration
vi config/train_gpt2.py

# start training from scratch
# it can take up to 2 weeks on 3090 8GPU
# for baseline BF16
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py --compile=False --model_type=vanilla

# for PQT
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py --compile=False --model_type=quant

# for MX-quantized training without PQN
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py --compile=False --model_type=quant-no-noise
```

Experimental results can be reproduced by changing the options inside the configuration file `config/train_gpt2.py`.
The options of interest are:
- `opt_type`
  - optimizer choice
- `is_diffq`
  - valid only if `model_type` is "quant"
  - If `is_diffq` is true, `U(-0.5, 0.5)` is used
  - approximated `round(N(0, 1) * 0.5)` is used otherwise.


## To reproduce Llama2 results
Use `ghcr.io/pytorch/pytorch-nightly:2.7.0.dev20250107-cuda12.4-cudnn9-devel` container as a training environment.
Tested on 3090, 4090 and A100 GPUs with R565 driver.

```bash
#!/bin/bash
cd torchtitan

# install dependency
python3 -m pip install -r requirements.txt

# prepare sentencepiece tokenizer for Llama2
# for reference, the hash values of `tokenizer.model` are:
# md5 : `eeec4125e9c7560836b4873b6f8e3025`
# sha256 : `9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347`
python3 torchtitan/datasets/download_tokenizer.py\
  --repo_id meta-llama/Llama-2-7b\
  --hf_token HF_TOKEN

# dataset
# - will be streamed via huggingface `allenai/c4` during training
# - we tested with locally downloaded dataset, as follows:
# /mnt/cephfs/dataset/c4$ du -sh *
# 41K     README.md
# 305G    en
# 516K    en.noblocklist
# 3.6M    en.noclean
# 30M     multilingual
# 257K    realnewslike

# to enable remote logging
wandb login

# edit training configuration for Llama2-134M
vi train_configs/llama2_134m.toml

# start training Llama2-134M from scratch
# it can take up to 2 weeks on 3090 8GPU
CONFIG_FILE="./train_configs/llama2_134m.toml" ./run_llama_train.sh

# edit training configuration for Llama2-1B
vi train_configs/llama2_1b.toml

# start training Llama2-1B from scratch
# requires GPUs with >40GiB GMEM and it can take up to 3 weeks on A100 8GPU
CONFIG_FILE="./train_configs/llama2_1b.toml" ./run_llama_train.sh

```

Experimental results can be reproduced by changing the options inside the configuration files `train_configs/llama2_134m.toml` and `train_configs/llama2_1b.toml`.
The options of interest are:
- `dataset_path`
- `name` of optimizer
  - can be "AdamW" or "Adam-mini".
- `enable_gws`
  - enables weight sampling (PQT).
- `is_diffq`
  - valid only if `enable_gws` is true.
  - If `is_diffq` is true, `U(-0.5, 0.5)` is used
  - approximated `round(N(0, 1) * 0.5)` is used otherwise.
- `quant_no_noise`
  - quantize parameters to MX without PQN
