# Aligning Latent Spaces with Flow Priors

This repository contains the implementation of our Aligning Latent Spaces with Flow Priors. The approach consists of three main components: toy examples for demonstration, autoencoder training, and MAR (Masked Autoregressive) model training.

The codebase is anonymized with our best effort.

## Toy Examples

### Quick Start

Train a flow model on synthetic 2D distributions and visualize the alignment process:

```bash
python continuous_tokenizer/toy_examples/flow_train_then_guide.py --distribution grid --model_path models/flow_model_grid_of_gaussians/flow_model.pt
python continuous_tokenizer/toy_examples/visualize.py
```

### Supported Distributions

The toy example supports various 2D synthetic distributions:
- **Grid of Gaussians**: 5-component Mixture of Gaussians arranged on a circle
- **Spiral**: Spiral-shaped data distribution
- **Moons**: Two interleaving half-moon shapes
- **Concentric Rings**: Nested circular distributions
- **Swiss Roll**: 3D manifold embedded in 2D

## Autoencoder Training

### Prerequisites

Install the required dependencies:

```bash
pip install -r continuous_tokenizer/requirements.txt
```

### Training Pipeline

#### 1. Train the Flow Head

Train the flow-based head component:

```bash
bash continuous_tokenizer/scripts/train_flowhead/flowhead-qwen-0.5B-proj.sh
```

#### 2. Train the Complete Autoencoder

Train the full autoencoder:

```bash
bash continuous_tokenizer/scripts/train/flowae-qwen-proj-head-w-1e2-l-64-50ep.sh
```

## MAR Training

### Setup

Download the required dependencies from the original MAR repository:

```bash
# Note: Removed src/torch_fidelity and fid_stats folders due to size constraints
# Download them from: https://github.com/LTH14/mar
```

### Training Pipeline

#### 1. Cache the Latents

Pre-compute and cache the latents:

```bash
bash mar/scripts/cache/flowae_qwen_proj_head_w_1e2.sh
```

#### 2. Estimate VAE Scaling Factor

Compute the normalization scaling factor for the latents:

```bash
python mar/estimate_vae_scaling.py path/to/cached/latents --num_samples 10000
```

#### 3. Train MAR Model

Train MAR:

```bash
bash mar/scripts/imagenet/flowae_qwen_proj_head_w_1e2.sh
```

## File Structure

```
├── flow_train_then_guide.py          # Toy example training script
├── visualize.py                      # Visualization utilities
├── continuous_tokenizer/             # Autoencoder implementation
│   ├── requirements.txt              # Dependencies
│   ├── toy_examples/                 # Toy examples
│   │   ├── flow_train_then_guide.py  # Flow model training script
│   │   └── visualize.py              # Visualization utilities
│   └── scripts/                      # Training scripts
│       ├── train_flowhead/           # Flow head training
│       └── train_tokenizer/          # Full autoencoder training
└── mar/                              # MAR implementation
    ├── estimate_vae_scaling.py       # Scaling factor estimation
    └── scripts/                      # MAR training scripts
        ├── cache/                    # Latent caching scripts
        └── imagenet/                 # ImageNet training scripts
```
