# SimVQ - ICML 2026 Complete Training Code

This directory contains the complete training codebase for SimVQ, including all model definitions, training scripts, and configuration files.

## Directory Structure

```
icml/
├── main.py                          # Main training script
├── evaluation.py                    # Evaluation script
├── configs/                         # Configuration files
│   ├── imagenet_simvq_128_B.yaml   # SimVQ config for ImageNet 128x128
│   └── imagenet_vq_128_B.yaml      # VQ baseline config
├── taming/                          # Model implementation
│   ├── models/
│   │   └── vq.py                   # VQ-VAE model wrapper
│   ├── modules/
│   │   ├── vqvae/
│   │   │   ├── simvq.py           # ⭐ SimVQ quantizer (UPDATED)
│   │   │   ├── expvq.py           # ExpVQ baseline
│   │   │   └── gsq.py             # GSQ baseline
│   │   ├── diffusionmodules/
│   │   │   └── improved_model.py  # Encoder & Decoder
│   │   ├── discriminator/
│   │   │   └── model.py           # PatchGAN discriminator
│   │   ├── losses/
│   │   │   ├── vqperceptual.py    # Perceptual loss
│   │   │   └── lpips.py           # LPIPS loss
│   │   ├── scheduler/
│   │   │   └── lr_scheduler.py    # Learning rate schedulers
│   │   ├── ema.py                 # Exponential Moving Average
│   │   └── util.py                # Utility functions
│   └── data/
│       ├── imagenet.py            # ImageNet dataloader
│       ├── base.py                # Base dataset class
│       └── utils.py               # Data utilities
└── metrics/                         # Evaluation metrics
    ├── inception.py               # Inception network for FID
    ├── fid.py                     # FID computation
    └── ...

```

## Key File: simvq.py (UPDATED)

**Location**: `taming/modules/vqvae/simvq.py`

This file has been **replaced** with the updated version from `icml_submission_code/`, which includes:

### New Features:
1. **Unified Regularization Loss Function**
   ```python
   compute_codebook_regularization_loss(
       codebook=None, 
       z_q=None, 
       loss_type="codebook_regularization"
   )
   ```

2. **Three Regularization Methods**:
   - `codebook_regularization`: Orthogonal regularization (Gram matrix)
   - `barlow_twins_codebook`: Barlow Twins on codebook (covariance matrix)
   - `barlow_twins_zq`: Barlow Twins on quantized vectors

3. **Simplified Interface**: No multi-group complexity

### Changes from Original:
- ✅ Replaced old `compute_orthogonal_loss_codebook()` with unified function
- ✅ Added Barlow Twins loss variants
- ✅ Removed multi-group functionality
- ✅ Cleaner, more maintainable code

## Usage

### Training

```bash
# Train SimVQ on ImageNet 128x128
python main.py --config configs/imagenet_simvq_128_B.yaml

# Train VQ baseline
python main.py --config configs/imagenet_vq_128_B.yaml
```

### Evaluation

```bash
# Evaluate trained model
python evaluation.py \
    --config_file results/simvq/<exp_name>/config.yaml \
    --ckpt_path results/simvq/<exp_name>/last.ckpt
```

## Configuration

### SimVQ Configuration (imagenet_simvq_128_B.yaml)

Key parameters in the config file:

```yaml
model:
  class_path: taming.models.vq.VQModel
  init_args:
    quantconfig:
      class_path: taming.modules.vqvae.simvq.SimVQ
      init_args:
        n_e: 16384                              # Codebook size
        e_dim: 128                              # Embedding dimension
        beta: 0.25                              # Commitment loss weight
        embedding_init: "gaussian"              # Initialization
        l2_norm: false                          # L2 normalization
        disentangle_loss_type: "codebook_orth"  # Regularization type
        disentangle_loss_weight: 0.001          # Regularization weight (λ)
```

### Available Loss Types

Change `disentangle_loss_type` in config:

1. **`codebook_orth`** (Recommended)
   - Uses `codebook_regularization` loss
   - Weight: 0.001

2. **`barlow_twins_codebook`** (For severe collapse)
   - Uses Barlow Twins on codebook
   - Weight: 0.005

3. **`barlow_twins_zq`** (For usage diversity)
   - Uses Barlow Twins on quantized vectors
   - Weight: 0.005

## Dependencies

Install required packages:

```bash
pip install torch torchvision lightning
pip install einops omegaconf
pip install lpips scikit-image scipy
```

## Integration Points

### 1. Quantizer Module
**File**: `taming/modules/vqvae/simvq.py`
- Contains `SimVQ`, `VQ`, `IBQ` classes
- Unified loss function: `compute_codebook_regularization_loss()`

### 2. Model Wrapper
**File**: `taming/models/vq.py`
- VQ-VAE model that uses quantizer
- Handles training loop, EMA, logging

### 3. Training Script
**File**: `main.py`
- Lightning-based training
- Supports distributed training
- Automatic checkpointing

### 4. Loss Function
**File**: `taming/modules/losses/vqperceptual.py`
- Combines reconstruction loss + GAN loss + quantization loss
- Handles discriminator updates

## Differences from icml_submission_code/

This directory (`icml/`) contains the **complete training codebase**, while `icml_submission_code/` contains **only the core quantizer implementation** for paper submission.

| Aspect | icml_submission_code/ | icml/ |
|--------|----------------------|-------|
| **Purpose** | Paper supplementary code | Complete training code |
| **Files** | 10 files (~100 KB) | 30+ files (~several MB) |
| **Content** | Core quantizers only | Full training pipeline |
| **Dependencies** | Minimal (torch, einops) | Full (lightning, data, metrics) |
| **Runnable** | Examples only | Full training |

## Quick Start

### 1. Prepare Data

```bash
# ImageNet should be organized as:
# /path/to/imagenet/
#   ├── train/
#   │   ├── n01440764/
#   │   └── ...
#   └── val/
#       ├── n01440764/
#       └── ...
```

### 2. Update Config

Edit `configs/imagenet_simvq_128_B.yaml`:
- Set data path
- Adjust batch size
- Set output directory

### 3. Train

```bash
python main.py --config configs/imagenet_simvq_128_B.yaml
```

### 4. Evaluate

```bash
python evaluation.py \
    --config_file results/simvq/<exp_name>/config.yaml \
    --ckpt_path results/simvq/<exp_name>/last.ckpt
```

## Key Files Explanation

### Core Implementation
- **`taming/modules/vqvae/simvq.py`** ⭐
  - SimVQ quantizer with unified loss function
  - **UPDATED** from icml_submission_code version
  - Lines 30-146: Unified regularization loss
  - Lines 170-346: SimVQ class

### Model Architecture
- **`taming/modules/diffusionmodules/improved_model.py`**
  - Encoder and Decoder architectures
  - Based on VQGAN architecture

### Training
- **`main.py`**
  - Lightning-based training script
  - Handles distributed training, checkpointing, logging

- **`taming/models/vq.py`**
  - VQ-VAE model wrapper
  - Integrates encoder, decoder, quantizer, and loss

### Loss Functions
- **`taming/modules/losses/vqperceptual.py`**
  - Combines perceptual loss + adversarial loss
  - Manages generator and discriminator updates

### Evaluation
- **`evaluation.py`**
  - Computes reconstruction metrics (PSNR, SSIM, LPIPS, FID)
  - Saves results

## Hyperparameters

### Recommended Settings (ImageNet 128x128)

```yaml
# Quantizer
n_e: 16384
e_dim: 128
beta: 0.25
disentangle_loss_type: "codebook_orth"
disentangle_loss_weight: 0.001

# Training
learning_rate: 4.5e-6
batch_size: 256
max_epochs: 100
warmup_epochs: 5
scheduler_type: "linear-warmup_cosine-decay"

# Architecture
z_channels: 128
resolution: 128
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 1, 2, 2, 4]
num_res_blocks: 2
```

## Notes

1. **Updated simvq.py**: This version includes the latest improvements with unified loss function
2. **Complete Pipeline**: All components needed for training and evaluation
3. **Configuration-based**: Easy to experiment with different settings
4. **Distributed Training**: Supports multi-GPU training via Lightning

## Citation

If you use this code, please cite:

```bibtex
@inproceedings{simvq2026,
  title={SimVQ: Simplified Vector Quantization with Learnable Projection},
  author={[Your Name]},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2026}
}
```

---

**Last Updated**: 2026-01-29  
**Version**: 2.0 (with unified loss function)  
**Key Update**: taming/modules/vqvae/simvq.py replaced with improved version

