# ICML Training Code - File List and Description

## Overview

This directory contains the **complete training codebase** for SimVQ method.

**Total Files**:
- 30 Python files
- 2 YAML configuration files
- 1 README
- 1 requirements.txt

---

## Core Files (Must Read)

### 1. ⭐ taming/modules/vqvae/simvq.py (UPDATED)
**Status**: ✅ Replaced with icml_submission_code version

**Contains**:
- `SimVQ` class: Proposed method with projection layer
- `VQ` class: Standard VQ baseline
- `IBQ` class: Index-based quantization baseline
- `compute_codebook_regularization_loss()`: **Unified regularization function**

**Key Sections**:
- Lines 30-146: Unified regularization loss function
  - `codebook_regularization`: Orthogonal regularization (Gram matrix)
  - `barlow_twins_codebook`: Barlow Twins on codebook (covariance)
  - `barlow_twins_zq`: Barlow Twins on z_q
- Lines 170-346: SimVQ implementation
- Lines 353-466: VQ implementation
- Lines 473-625: IBQ implementation

**What Changed**:
- ✅ OLD: Multiple separate loss functions
- ✅ NEW: Single unified `compute_codebook_regularization_loss()`
- ✅ OLD: `compute_orthogonal_loss_codebook()`
- ✅ NEW: `loss_type="codebook_regularization"`
- ✅ Added: `barlow_twins_codebook` and `barlow_twins_zq` options
- ✅ Removed: Multi-group functionality (simplified)

---

## Training Files

### 2. main.py
**Purpose**: Main training script

**Features**:
- Lightning-based training
- Distributed training support
- Automatic checkpointing
- WandB logging support

**Usage**:
```bash
python main.py --config configs/imagenet_simvq_128_B.yaml
```

### 3. evaluation.py
**Purpose**: Model evaluation script

**Metrics**:
- Reconstruction: PSNR, SSIM, LPIPS
- Generation: FID, Inception Score
- Codebook: Utilization, Effective Rank

**Usage**:
```bash
python evaluation.py \
    --config_file <config.yaml> \
    --ckpt_path <checkpoint.ckpt>
```

---

## Configuration Files

### 4. configs/imagenet_simvq_128_B.yaml
**Purpose**: SimVQ training configuration for ImageNet 128×128

**Key Settings**:
```yaml
# Quantizer
n_e: 16384
e_dim: 128
disentangle_loss_type: "codebook_orth"
disentangle_loss_weight: 0.001

# Training
learning_rate: 4.5e-6
batch_size: 256
max_epochs: 100
```

### 5. configs/imagenet_vq_128_B.yaml
**Purpose**: VQ baseline configuration

**Difference from SimVQ**:
- Uses `VQ` class instead of `SimVQ`
- No projection layer
- Same regularization options available

---

## Model Architecture Files

### 6. taming/models/vq.py
**Purpose**: VQ-VAE model wrapper

**Components**:
- Integrates encoder, decoder, quantizer
- Handles EMA (Exponential Moving Average)
- Manages training/validation loops
- Computes codebook statistics

**Key Methods**:
- `encode()`: Encodes images to quantized latents
- `decode()`: Decodes quantized latents to images
- `forward()`: Full encode-quantize-decode pipeline
- `training_step()`: Training iteration
- `validation_step()`: Validation iteration

### 7. taming/modules/diffusionmodules/improved_model.py
**Purpose**: Encoder and Decoder architectures

**Architecture**:
- Based on VQGAN encoder/decoder
- Residual blocks with attention
- Multi-scale design

**Classes**:
- `Encoder`: Downsamples images to latent space
- `Decoder`: Upsamples latents to images

---

## Loss Functions

### 8. taming/modules/losses/vqperceptual.py
**Purpose**: Combined perceptual and adversarial loss

**Components**:
- Perceptual loss (LPIPS)
- Adversarial loss (PatchGAN discriminator)
- Adaptive weighting

**Formula**:
```
L_total = L_recon + λ_perceptual * L_LPIPS + λ_adv * L_GAN + L_quant
```

### 9. taming/modules/losses/lpips.py
**Purpose**: LPIPS perceptual loss implementation

---

## Discriminator

### 10. taming/modules/discriminator/model.py
**Purpose**: PatchGAN discriminator

**Architecture**:
- Multi-scale discriminator
- Convolutional layers
- Patch-based discrimination

---

## Data Loading

### 11. taming/data/imagenet.py
**Purpose**: ImageNet dataset loader

**Features**:
- Train/val split
- Data augmentation
- Multi-resolution support

### 12. taming/data/base.py
**Purpose**: Base dataset class

---

## Utilities

### 13. taming/modules/ema.py
**Purpose**: Exponential Moving Average for model weights

**Usage**: Stabilizes training, improves evaluation results

### 14. taming/modules/util.py
**Purpose**: Utility functions

### 15. taming/modules/scheduler/lr_scheduler.py
**Purpose**: Learning rate schedulers

**Schedulers**:
- Linear warmup
- Cosine decay
- Linear warmup + cosine decay

---

## Evaluation Metrics

### 16. metrics/inception.py
**Purpose**: InceptionV3 for FID computation

### 17. metrics/fid.py
**Purpose**: Fréchet Inception Distance computation

---

## File Change Summary

### Modified Files (from original)
- ✅ **taming/modules/vqvae/simvq.py**: Replaced with updated version
  - Unified loss function
  - Added Barlow Twins
  - Removed multi-group

### Unmodified Files (from original)
All other files remain unchanged from the original training codebase:
- Model architecture files
- Data loading files
- Loss function files (except simvq.py)
- Training scripts
- Evaluation scripts

---

## Usage Examples

### Example 1: Train SimVQ with Codebook Regularization
```bash
python main.py --config configs/imagenet_simvq_128_B.yaml
```

Config setting:
```yaml
disentangle_loss_type: "codebook_orth"
disentangle_loss_weight: 0.001
```

### Example 2: Train SimVQ with Barlow Twins
```bash
# First, edit configs/imagenet_simvq_128_B.yaml:
# Change:
#   disentangle_loss_type: "barlow_twins_codebook"
#   disentangle_loss_weight: 0.005

python main.py --config configs/imagenet_simvq_128_B.yaml
```

### Example 3: Train VQ Baseline
```bash
python main.py --config configs/imagenet_vq_128_B.yaml
```

### Example 4: Evaluate Trained Model
```bash
python evaluation.py \
    --config_file results/simvq/<exp_name>/config.yaml \
    --ckpt_path results/simvq/<exp_name>/last.ckpt \
    --image_size 128 \
    --batch_size 64
```

---

## Verification

### Check simvq.py was updated correctly:
```bash
# Check if unified loss function exists
grep -n "compute_codebook_regularization_loss" taming/modules/vqvae/simvq.py

# Should see:
# Line 30: def compute_codebook_regularization_loss(...)

# Check for Barlow Twins support
grep -n "barlow_twins" taming/modules/vqvae/simvq.py

# Should see multiple matches
```

### Verify file structure:
```bash
tree -L 2 -I '__pycache__|*.pyc'
```

---

## Troubleshooting

### Issue: Import errors
**Solution**: Make sure all `__init__.py` files exist
```bash
find . -name __init__.py
```

### Issue: Config file errors
**Solution**: Check paths in YAML files match your setup

### Issue: simvq.py not found
**Solution**: Verify file is at `taming/modules/vqvae/simvq.py`

---

## Next Steps

1. ✅ Code is ready in `icml/` directory
2. ✅ simvq.py has been updated with unified loss
3. ⏭️ Update data paths in config files
4. ⏭️ Run training
5. ⏭️ Evaluate results

---

**Directory**: `/mnt/petrelfs/zhangfang/SimVQ/icml/`  
**Status**: ✅ Ready for training  
**Key Update**: simvq.py with unified regularization loss

