# SSDD: Single-Step Diffusion Decoder for Efficient Image Tokenization

Training and evaluation code for our paper *SSDD: Single-Step Diffusion Decoder for Efficient Image Tokenization*, under review for ICLR 2026.

## Installation

You need a python environment with `torch` and `torchvision` installed.
If you do not have, you can start by creating a conda environment (`conda env create -f environment.yml ; conda activate ssdd_iclr`).
Then install with:

```bash
pip install -e .
```

## Training & evaluation


### Train an encoder

```bash
accelerate launch dae/main.py \
    dataset.im_size=128 dataset.augs.rand_resize_scale=2 training.optimizers.main.lr=1e-4
```

* `dataset.augs.rand_resize_scale=2` adds the multi-scale data augmentation, before cropping to 128x128

### Shared model pre-training

```bash
accelerate launch dae/main.py \
    dataset.im_size=128 dataset.augs.rand_resize_scale=2 \
    +ae.encoder.init.freeze=true training.losses.kl=0 \
    +ae.encoder.init.checkpoint=<shared_encoder_path>/checkpoints/best/model_ae_ema.safetensors
```

* To use a shared encoder, we load its weights (`+ae.encoder.init.checkpoint`), and freeze them (`+ae.encoder.init.freeze=true`, `training.losses.kl=0`)



### Finetune model at target resolution

```bash
accelerate launch dae/main.py \
    training.epochs=200 dataset.im_size=128 training.optimizers.main.lr=1e-4 \
    +ae.encoder.init.freeze=true training.losses.kl=0 \
    +ae.model_init.checkpoint=<pretrained_model_path>/checkpoints/best/model_ae_ema.safetensors \
    +aux_losses.model_init.checkpoint=<pretrained_model_path>/checkpoints/best/model_aux_losses.safetensors
```

* We load directly a full pre-trained model with `+ae.model_init.checkpoint`, and still freeze the encoder
* We keep the REPA MLP for the losses using `+aux_losses.model_init.checkpoint`
* To fine-tune for 256x256, use `dataset.im_size=246` and `training.optimizers.main.lr=3e-4`

### Distill model

```bash
accelerate launch dae/main.py \
    training.epochs=10 training.eval_freq=1 dataset.im_size=128 training.optimizers.main.lr=1e-4 \
    +ae.encoder.init.freeze=true training.losses.kl=0 \
    +ae.model_init.checkpoint=<finetuned_model_path>/checkpoints/best/model_ae_ema.safetensors \
    +aux_losses.model_init.checkpoint=<finetuned_model_path>/checkpoints/best/model_aux_losses.safetensors \
    ae.fm_sampler.steps=7 '+teacher=${ae}'
```


### Eval model

```bash
accelerate launch dae/main.py \
    task=ae.eval dataset.im_size=128 \
    +ae.model_init.checkpoint=<model_path>/checkpoints/best/model_ae_ema.safetensors \
    ae.fm_sampler.steps=8
```

* We change the task to `task=ae.eval` (from the default `ae.train`)
* Change the number of steps by setting `ae.fm_sampler.steps=?`. Use `ae.fm_sampler.steps=1` with a distilled model.


### Other parameters

* Adding a GAN loss: `+models=gan`
* Changing the encoder compression: `ae.encoder.patch_size=?` and `ae.encoder.z_dim=?` (for a f16c32: `ae.encoder.patch_size=16`, `ae.encoder.z_dim=32`)
* Changing the decoder size: `ae.decoder.size=?` (available values are: S, B, M, L, XL, H)
* Changing the number of epochs: `training.epochs=?` *(useful for testing)*
* Changing the evaluation frequency: `training.eval_freq=?` *(useful for testing)*
* Limiting the number of train / test samples: `dataset.limit=?`, `test_dataset.limit=?` *(useful for testing)*
* Limiting the batch size: `training.batch_size=?` *(useful for testing)*

Remark: Paths should be given as absolute to commands
