## Reproducing Training for `ImageDataYflows.py`

This guide shows exactly how to set up the environment, prepare data and ALAE weights, and run training/evaluation for the flow models defined in `ImageDataYflows.py`.

### 1) Project layout (expected)

Run from the project root so relative paths resolve as the script expects:

- `PROJECT_ROOT/`
  - `ALAE/`  ← original ALAE code and weights
  - `data/`  ← your .npy datasets (see below)
  - `ImageDataYflows.py`
  - `README_ImageDataYflows.md` (this file)

The script assumes:
- ALAE config: `ALAE/configs/ffhq.yaml`
- ALAE weights dir: `ALAE/training_artifacts/ffhq/`
- Data files: `data/latents.npy`, `data/gender.npy`, `data/age.npy`, `data/test_images.npy`

### 2) Get ALAE (code + weights)

- If `ALAE/` already exists in your project (as in this repo), you’re set.
- If not, clone the official repo and use it as `ALAE/` under your project root:

```bash
# From PROJECT_ROOT
git clone https://github.com/podgorskiy/ALAE.git ALAE
```

Download FFHQ pretrained weights (needed for encode/decode):

```bash
cd ALAE
python training_artifacts/download_all.py
cd ..
```

You should see files like `ALAE/training_artifacts/ffhq/model_submitted.pth` and a `last_checkpoint` file.

### 3) Environment

Python 3.8 works well. GPU is strongly recommended.

Example with conda:

```bash
conda create -n yflows python=3.8 -y
conda activate yflows
```

Install core deps:

```bash
# PyTorch (pick the right CUDA build; adjust index-url to your CUDA)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# Flow training dependencies
pip install numpy scipy matplotlib scikit-learn tqdm frechetdist torchdiffeq geomloss jupyter

# ALAE (inference) dependencies
pip install packaging imageio dlutils "bimpy>=0.1.1" "dareblopy>=0.0.5" yacs
```

Notes:
- If you only use ALAE for decode/encode with provided weights, `bimpy`/`dareblopy` are typically not exercised by `alae_ffhq_inference.py`, but installing them keeps ALAE imports happy.
- For `TYPE == "MFM"`, the code uses `torch.func.jvp`. Use PyTorch >= 2.0.

### 4) Data for `ImageDataYflows.py`

The script loads numpy arrays from `data/`:
- `data/latents.npy`: shape `(N, D)` with ALAE latents (D≈512 for FFHQ)
- `data/gender.npy`: shape `(N,)` with string labels like `"male"`/`"female"`
- `data/age.npy`: shape `(N,)` (optional)
- `data/test_images.npy`: optional image tensors for visual checks

You have two options:

- Option A: Use precomputed files (if you have them). Just place the `.npy` files in `data/`.
- Option B: Generate latents yourself by encoding a folder of aligned face images with ALAE.

Minimal encoder script (root-level paths):

```python
# save as encode_images_to_latents.py and run: python encode_images_to_latents.py
import os, glob
import numpy as np
import torch
from PIL import Image

import sys
sys.path.append("ALAE")
from alae_ffhq_inference import load_model, encode

# Inputs
images_dir = "my_faces_folder"  # folder of RGB images under PROJECT_ROOT
out_latents = "data/latents.npy"

# Load ALAE (uses FFHQ config and downloaded weights)
model = load_model("ALAE/configs/ffhq.yaml", training_artifacts_dir="ALAE/training_artifacts/ffhq/")
model.eval()

# Load and preprocess images -> torch tensor BCHW in [-1,1]
paths = sorted(glob.glob(os.path.join(images_dir, "*.png")) + glob.glob(os.path.join(images_dir, "*.jpg")))
imgs = []
for p in paths:
    im = Image.open(p).convert("RGB")
    arr = np.array(im, dtype=np.float32) / 255.0
    arr = (arr - 0.5) / 0.5  # to [-1,1]
    arr = np.transpose(arr, (2,0,1))  # HWC->CHW
    imgs.append(arr)

x = torch.from_numpy(np.stack(imgs, axis=0))

# Encode to ALAE latent representation expected by the training script
with torch.no_grad():
    Z = encode(model, x)  # wrapper repeats across layers

latents = Z[:,0,:].cpu().numpy()
os.makedirs("data", exist_ok=True)
np.save(out_latents, latents)
print("Saved latents:", latents.shape)
```

If you have labels, save them:

```python
import numpy as np, os
labels = [...]  # list/array of strings, length N (e.g., "male"/"female")
os.makedirs("data", exist_ok=True)
np.save("data/gender.npy", np.array(labels))
```

Quick dummy labels (for a smoke test only):

```python
# Creates random male/female labels aligned with latents.npy length
import numpy as np, os
N = np.load("data/latents.npy").shape[0]
labels = np.where(np.random.rand(N) < 0.5, "male", "female")
os.makedirs("data", exist_ok=True)
np.save("data/gender.npy", labels)
```

### 5) How to run `ImageDataYflows.py`

The file contains IPython magics (`%load_ext`). Prefer running in Jupyter or IPython. From the project root:

```bash
conda activate yflows
jupyter lab  # or jupyter notebook
```

Then open `ImageDataYflows.py` and run cells.

Alternatively, run with IPython so magics are understood:

```bash
ipython -i ImageDataYflows.py
```

If you want to run with plain `python`, remove or comment the `%load_ext`/`%autoreload` lines at the top.

### 6) Training knobs inside the script

Edit these in `ImageDataYflows.py`:
- `TYPE`: one of `"FM"`, `"MFM"`, `"YF"`
- `TRAIN_MODE`: one of `"GENERATION"`, `"TRASLATION"` (keep this spelling to match the code)
- Optimization: `n_iters`, `batch_size`, `lr`, `hidden`
- Y-flows: `STEPS`, `LAMBDA_SINKHORN`, `BETA`
- Branched FM: `k_centers`, `angle_gamma`, `angle_t_free`, `angle_tau`

Hardware:
- Set `device = "cuda" if torch.cuda.is_available() else "cpu"` (already in script). Ensure your GPU has enough memory for the chosen `batch_size` and `hidden`.

### 7) Outputs

The script writes under a timestamped directory:
- `SAVE_DIR = f"{TYPE}/{TRAIN_MODE}/{YYYY-mm-dd_HH:MM:SS}"`
- Periodically saves model weights: `model_iter_*.pth` and the latest `TYPE.pth`
- PCA trajectory plots: `traj_pca_lines_*.png`
- Generated or translated image grids: `ode_*.png`
- `args.txt` with training logs

### 8) Quick smoke test

After setup and data are in place:

```bash
# From project root
conda activate yflows
ipython -i ImageDataYflows.py
```

In the running session:
- Set `TYPE = "FM"`, `TRAIN_MODE = "TRASLATION"`
- Reduce `n_iters` to `2000` and `batch_size` to `64` for a quick run
- Let it run until it prints metrics and saves a few images into `FM/TRASLATION/<timestamp>/`

### 9) Troubleshooting

- Import errors referencing ALAE modules: ensure `ALAE/` exists at project root and you run from project root (the code adds `ALAE` to `sys.path`).
- CUDA/torch mismatch: install a torch build that matches your CUDA. See PyTorch get-started page for the right wheel.
- `torch.func.jvp` missing: upgrade to PyTorch >= 2.0 or switch `TYPE` to `"FM"`/`"YF"`.
- Empty label subsets in class-conditional sampling: ensure your `gender.npy` contains the labels that the script maps (it prints the label set on start).
