<h1 align="center">FastCLIP: A Suite of Optimization Techniques to <br> Accelerate CLIP Training with Limited Resources</h1>

## Getting Started

### Environment Setup

To set up the environment for training, please run the following steps in this directory.
```bash
conda create -n fastclip python=3.11
conda activate fastclip
pip install -r requirements-training.txt
```

### Training

We present sample slurm scripts to run OpenCLIP and FastCLIP-v0 to v3. For non-slurm instructions, please refer to the end of this subsection. To train on your own data, you need to modify the following options

- `--train-data`: the path to the training data, currently only **webdataset** format is supported.
- `--train-num-samples`: this many samples will be seen for one epoch, we recommend to set it to the actual size of the dataset.
- `--data_size`: the original size of the dataset, this may take a value different from `--train-num-samples`. In the case of CC3M, its metadata contains 3318333 image-URL/caption pairs, but we were only able to down 2723840 of them. So we set `--data_size` to 3318333 and set `--train-num-samples` to 2723848.
- `--epochs`: for this many epochs the model will be trained.
- `--gamma_decay_epochs`: for this many epochs $\gamma$ will decrease from 1.0 to `--gamma`. We recommend to set it to half of `--epochs`.

<details open>
    <summary>Sample script to run <b>FastCLIP-v3</b> on CC3M using 8 GPUs (2 nodes and 4 GPUs per node)</summary>

```bash
#!/bin/bash
#SBATCH --time=2-00:00:00
#SBATCH --mem=120G
#SBATCH --nodes=2
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=6
#SBATCH --wait-all-nodes=1
#SBATCH --job-name=fastclipv3
#SBATCH --partition=gpu
#SBATCH --output=%x_%j.log

source ~/.bashrc
conda activate fastclip

master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
export MASTER_PORT=12805

export CUDA_VISIBLE_DEVICES='0,1,2,3'
export PYTHONPATH="$PYTHONPATH:$PWD/src"
export HUGGINGFACE_HUB_CACHE='./checkpoints/huggingface'

srun python -u src/training/main.py \
    --save-frequency 1 \
    --train-data './datasets/cc3m_webdataset/cc3m_train/{00000..00331}.tar' \
    --train-num-samples 2723840 --data_size 3318333 \
    --warmup 10000 \
    --batch-size 128 \
    --epochs 37 \
    --workers 6 \
    --model RN50 \
    --name medium_fastclipv3 \
    --seed 2024 \
    --profile \
    --wd 0.1 \
    --local-loss \
    --fastclip --multiply_tau --temperature_scheme global_learnable \
    --lr 1e-3 --lr_tau 2e-4 --lr_tau_scheduler step_thresh --rho 6.5 \
    --gamma 0.2 --gamma_schedule cosine --gamma_decay_epochs 18
```

</details>

<details>
    <summary>Sample script to run <b>OpenCLIP</b> on CC3M using 8 GPUs (click to expand):</summary>

Replace the `srun python -u src/training/main.py` command in the FastCLIP-v3 script with
```bash
srun python -u src/training/main.py \
    --save-frequency 1 \
    --train-data './datasets/cc3m_webdataset/cc3m_train/{00000..00331}.tar' \
    --train-num-samples 2723840 --data_size 3318333 \
    --warmup 10000 \
    --batch-size 128 \
    --epochs 37 \
    --workers 6 \
    --model RN50 \
    --name medium_openclip \
    --seed 2024 \
    --profile \
    --wd 0.1 \
    --local-loss \
    --gather-with-grad \
    --lr 1e-3
```

</details>

<details>
    <summary>Sample script to run <b>FastCLIP-v0</b> on CC3M using 8 GPUs (click to expand):</summary>

Replace the `srun python -u src/training/main.py` command in the FastCLIP-v3 script with
```bash
srun python -u src/training/main.py \
    --save-frequency 1 \
    --train-data './datasets/cc3m_webdataset/cc3m_train/{00000..00331}.tar' \
    --train-num-samples 2723840 --data_size 3318333 \
    --warmup 10000 \
    --batch-size 128 \
    --epochs 37 \
    --workers 6 \
    --model RN50 \
    --name medium_fastclipv0 \
    --seed 2024 \
    --profile \
    --wd 0.1 \
    --local-loss \
    --fastclip --temperature_scheme global_learnable \
    --lr 1e-3 \
    --gamma 0.2 --gamma_schedule cosine --gamma_decay_epochs 18
```

</details>

<details>
    <summary>Sample script to run <b>FastCLIP-v1</b> on CC3M using 8 GPUs (click to expand):</summary>

Replace the `srun python -u src/training/main.py` command in the FastCLIP-v3 script with
```bash
srun python -u src/training/main.py \
    --save-frequency 1 \
    --train-data './datasets/cc3m_webdataset/cc3m_train/{00000..00331}.tar' \
    --train-num-samples 2723840 --data_size 3318333 \
    --warmup 10000 \
    --batch-size 128 \
    --epochs 37 \
    --workers 6 \
    --model RN50 \
    --name medium_fastclipv1 \
    --seed 2024 \
    --profile \
    --wd 0.1 \
    --local-loss \
    --fastclip --temperature_scheme global_constant \
    --lr 1e-3 \
    --gamma 0.2 --gamma_schedule cosine --gamma_decay_epochs 18
```

</details>

<details>
    <summary>Sample script to run <b>FastCLIP-v2</b> on CC3M using 8 GPUs (click to expand):</summary>

Replace the `srun python -u src/training/main.py` command in the FastCLIP-v3 script with
```bash
srun python -u src/training/main.py \
    --save-frequency 1 \
    --train-data './datasets/cc3m_webdataset/cc3m_train/{00000..00331}.tar' \
    --train-num-samples 2723840 --data_size 3318333 \
    --warmup 10000 \
    --batch-size 128 \
    --epochs 37 \
    --workers 6 \
    --model RN50 \
    --name medium_fastclipv2 \
    --seed 2024 \
    --profile \
    --wd 0.1 \
    --local-loss \
    --fastclip --temperature_scheme individual_learnable \
    --lr 1e-3 --lr_tau 0.0133 --lr_tau_scheduler const --temperature 0.03 --rho 7.0 \
    --gamma 0.2 --gamma_schedule cosine --gamma_decay_epochs 18
```

</details>

<details>
    <summary>Sample script to run <b>FastCLIP-v3</b> on <b>CC12M</b> using 8 GPUs (click to expand):</summary>

Replace the `srun python -u src/training/main.py` command in the CC3M script with
```bash
srun python -u src/training/main.py \
    --save-frequency 1 \
    --train-data './datasets/cc12m_webdataset/cc12m/{00000..01242}.tar' \
    --train-num-samples 9187328 --data_size 12423374 \
    --warmup 10000 \
    --batch-size 256 \
    --epochs 33 \
    --workers 6 \
    --model ViT-B-32 \
    --name large_fastclipv3 \
    --seed 2024 \
    --profile \
    --wd 0.1 \
    --local-loss \
    --fastclip --multiply_tau --temperature_scheme global_learnable \
    --lr 4e-4 --lr_tau 1e-4 --lr_tau_scheduler step_thresh --rho 8.5 \
    --gamma 0.2 --gamma_schedule cosine --gamma_decay_epochs 16
```

</details>

<details>
    <summary>Sample script to run <b>FastCLIP-v3</b> on <b>LAION400M</b> using 16 GPUs (click to expand):</summary>

Many slurm systems have a limit on the running time of a job (e.g., 2 days), which is far less than the time required for 30 epochs on LAION400M. We recommend to leverage the `--stop_epochs` and `--resume` options. By setting `--stop_epochs`, the training will stop after the specified epoch. And by setting `--resume`, the training will resume at the specified checkpoint. Then the user can construct an array of jobs such that the `--resume` of a job points to the checkpoint where its preceding job stops. Below is a sample script that stops after the third epoch.
```bash
#!/bin/bash
#SBATCH --time=2-00:00:00
#SBATCH --mem=120G
#SBATCH --nodes=4
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=6
#SBATCH --wait-all-nodes=1
#SBATCH --job-name=fastclipv3
#SBATCH --partition=gpu
#SBATCH --output=%x_%j.log

source ~/.bashrc
conda activate fastclip

master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
export MASTER_PORT=12805

export CUDA_VISIBLE_DEVICES='0,1,2,3'
export PYTHONPATH="$PYTHONPATH:$PWD/src"
export HUGGINGFACE_HUB_CACHE='./checkpoints/huggingface'

srun python -u src/training/main.py \
    --save-frequency 1 \
    --train-data './datasets/laion400m/laion400m-data/{00000..41407}.tar' \
    --train-num-samples 315658316 --data_size 414080000 \
    --warmup 13200 \
    --batch-size 320 \
    --epochs 30 \
    --workers 6 \
    --model ViT-B-16 \
    --name xlarge_fastclipv3 \
    --seed 2024 \
    --profile \
    --wd 0.2 \
    --local-loss \
    --fastclip --multiply_tau --temperature_scheme global_learnable \
    --lr 2e-4 --lr_min 4e-5 --lr_tau 5e-5 --lr_tau_scheduler step_thresh --rho 16.0 \
    --gamma 0.8 --gamma_schedule cosine --gamma_decay_epochs 10 \
    --fastclip_eps 1e-6 \
    --stop_epochs 3
```

</details>

**Non-slurm Training**: For non-slurm training, please set `master_addr` manually (e.g., `127.0.0.1`), change `srun python -u src/training/main.py` to `cd src && torchrun --nproc_per_node=4 --rdzv_endpoint=$master_addr -m training.main`, and run the above script with `/bin/bash`.

### Evaluation

<details open>
    <summary>Sample slurm script to evaluate a trained CLIP model (specified by <code>--resume</code>) on <b>ImageNet-1k</b>:</summary>

```bash
#!/bin/bash
#SBATCH --time=01:00:00
#SBATCH --mem=20G
#SBATCH --nodes=1
#SBATCH --gres=gpu:1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=6
#SBATCH --job-name=eval_fastclip
#SBATCH --partition=gpu
#SBATCH --output=%x_%j.log

source ~/.bashrc
conda activate fastclip

master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
export MASTER_PORT=12802

export PYTHONPATH="$PYTHONPATH:$PWD/src"
export HUGGINGFACE_HUB_CACHE='./checkpoints/huggingface'

srun python -u src/training/main.py \
    --resume ./logs/medium_fastclipv3/checkpoints/epoch_37.pt \
    --zeroshot-frequency 1 \
    --imagenet-val ./datasets/imagenet/val \
    --batch-size 512 \
    --epochs 0 \
    --workers 6 \
    --model RN50 \
    --name eval_medium_fastclipv3_epoch_37 \
    --seed 2024
```

</details>

**Datacomp**: For evaluation on the Datacomp benchmark, please refer to the "Evaluation" section in the [Datacomp repository](https://github.com/mlfoundations/datacomp?tab=readme-ov-file#evaluation).

**Non-slurm Training**: For non-slurm training, please set `master_addr` manually (e.g., `127.0.0.1`), change `srun python -u src/training/main.py` to `python src/training/main.py`, and run the above script with `/bin/bash`.
