# Code for *On the Scaling Theory of Multi-Layer Transformers*

## Initialization
```commandline
conda create -n scaling_theory python==3.12
conda activate scaling_theory

pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126

pip install -r requirements.txt
```

## Start Training
```commandline
python -m torch.distributed.launch --nproc_per_node=8 train.py
```

If you want to compute the FLOPs of ViT, run
```commandline
python get_flops.py
```

## Visualization
Run the following command to draw Figure 2
```commandline
python draw_fig1.py
```

Run the following command to draw Figure 3
```commandline
python draw_fig2.py
```

Run the following command to draw Figure 5
```commandline
python draw_fig3.py
```

## Supplementary Experimental Details
* hf_Model: google/vit-base-patch16-224
* datasets: MNIST, Cifar-10, Cifar-100
* devices: 8 RTX 2080 Ti GPUs (12 GiB)
* batch_size: 64
* optimizer: AdamW
* learning_rate: 2e-5
* weight_decay: 0.01
* learning_rate_scheduler: get_linear_schedule_with_warmup
* warmup_steps: 0.1 * total_steps
* num_epochs: 30
* platforms: PyTorch and Hugging Face