# FlowLet: Conditional 3D Brain MRI Synthesis using Wavelet Flow Matching

## Overview

FlowLet is a conditional generative framework that synthesizes high-fidelity, age-conditioned 3D brain Magnetic Resonance Images (MRIs). It is designed to address key challenges in neuroimaging, particularly for applications like **Brain Age Prediction (BAP)**, which require large, diverse, and age-balanced datasets.

Existing generative methods often struggle with the high dimensionality of MRI data, leading to slow inference, compression artifacts, or insufficient conditioning. FlowLet mitigates the "generative modeling trilemma" (the trade-off between sample quality, diversity, and speed) by integrating **Flow Matching (FM)** within an invertible **3D Haar wavelet domain**. This approach preserves fine anatomical details and supports diversity in age-specific synthesis without the artifacts of learned compression, while allowing the model to generate samples in very few steps.

This repository contains the official PyTorch implementation for the paper: *"FlowLet: Conditional 3D Brain MRI Synthesis using Wavelet Flow Matching"*.

<p align="center">
  <!-- architecture diagram -->
  <img src="assets/FlowLet_Architecture.png" width="700" alt="FlowLet Architecture Diagram">
</p>

## Key Features

*   **Wavelet-Based Flow Matching:** Implements multiple Flow Matching formulations (RFM, CFM, VP, Trigonometric) in the 3D Haar wavelet domain for efficient and stable generative modeling.
*   **Advanced Conditional Synthesis:** Generates 3D brain MRIs conditioned on specified variables (e.g., Age) using a dual mechanism:
    *   **FiLM (Feature-wise Linear Modulation)** in residual blocks for global, feature-wise control.
    *   **Spatial Conditioning** in transformer blocks for spatially-aware, fine-grained anatomical conditioning with Cross-Attention.
*   **High-Fidelity 3D Output:** Designed to generate high-resolution volumetric NIfTI (`.nii.gz`) brain images that preserve anatomical coherence.
*   **Efficient and Modular U-Net:** Employs a robust 3D U-Net architecture built with modular blocks, with optional support for `xformers` for memory-efficient attention.
*   **Comprehensive & Reproducible Workflow:** Includes scripts for the entire pipeline: data preparation, training, generation, and quantitative evaluation (FID, MMD, MS-SSIM).

## Flow Matching Formulations

FlowLet supports several flow matching strategies, allowing for a systematic evaluation of how trajectory curvature impacts training stability and synthesis quality. Each formulation defines a different path and target velocity field between noise and data.

*   **Rectified Flow Matching (RFM):** Performs a simple linear interpolation between noise and data. The straight-line path and constant velocity field promote stable training and produce high-quality, coherent anatomical structures.
*   **Conditional Flow Matching (CFM):** Also uses a linear path, but defines a time-dependent target velocity that points from the current state `x_t` to the data `x_1`.
*   **Variance-Preserving (VP) Diffusion Matching:** Defines a non-linear, curved path inspired by Denoising Diffusion Probabilistic Models (DDPMs), governed by a variance schedule.
*   **Trigonometric Flow:** Uses a circular interpolation path on a unit half-circle, introducing smooth, curved trajectories with a constant norm.

### How to Use
You can select the desired formulation during training by using the `--flow_type` command-line argument. The available choices are `rectified`, `cfm`, `vp_diffusion`, and `trigonometric`. See the `training_ablation.sh` script for an example of how to launch training for all variants.

```bash
# Example of training with Conditional Flow Matching (CFM)
python scripts/train.py \
    --flow_type cfm \
    ... # other arguments
```

## Installation

### 1. Data Availability

Due to patient privacy regulations and strict data use agreements, we cannot directly redistribute the 3D MRI scans. To use the same training data, researchers must apply for access to the original datasets, which are openly available to the research community upon request.

The datasets used are:
*   **OpenBHB:** [https://baobablab.github.io/bhb/dataset](https://baobablab.github.io/bhb/dataset)
*   **ADNI:** [https://adni.loni.usc.edu/](https://adni.loni.usc.edu/)
*   **OASIS-3:** [https://sites.wustl.edu/oasisbrains/](https://sites.wustl.edu/oasisbrains/)


To ensure precise replication, we provide the exact list of subjects and scans used for training in the file: `Dataset_preparation/metadata/main_dataset_catalog.csv`. After obtaining the data, you can use this catalog to construct the identical dataset cohort.

2.  **Prerequisites:**
    The framework was developed using Python 3.11 and CUDA 12.0.

3.  **Create a Conda Environment:**
    ```bash
    conda create -n flowlet_env python=3.11
    conda activate flowlet_env
    ```

4.  **Install Dependencies:**
    ```bash
    pip install -r requirements.txt
    ```

5.  **Install xFormers (Optional, Recommended for efficiency):**
    Follow the official instructions at the [xFormers GitHub repository](https://github.com/facebookresearch/xformers#installing-xformers).
    ```bash
    # This command often works, but may vary based on your setup
    pip install -U xformers
    ```

## Data Preparation

 ##### IMPORTANT: Raw T1-weighted MRI volumes must be preprocessed using the standardized pipeline provided in the `MRI_preprocessing` folder before running any experiments or evaluations.

The model can be trained using two methods for data loading. The metadata CSV approach is recommended as it is more robust and was used for the paper's experiments.

### Method 1: Metadata CSV (Recommended)

1.  **Structure:** Prepare a single CSV file containing metadata for your entire preprocessed dataset. This CSV must include a column with the **absolute path** to each NIfTI file and columns for all conditions you wish to use (e.g., `Age`, `Condition`).
2.  **Create the CSV:** You can use the provided script to generate this file from a directory of NIfTI files that have been pre-named to include age information.
    ```bash
    # Example from Dataset_preparation/create_metadata_csv.sh
    PYTHONPATH=. python3 Dataset_preparation/create_metadata_csv.py \
        --input_dirs /path/to/your/nifti/data \
        --output_csv ./metadata/main_dataset_catalog.csv \
        --condition_label CN
    ```
3.  **Training with CSV:** During training, point to this file using the `--metadata_csv` argument. You can also filter the dataset using `--csv_filter_col` and `--csv_filter_value`.

### Method 2: Filename Parsing

As a simpler alternative, you can point the trainer directly to a folder of NIfTI files.
1.  **Dataset Structure:** Place all `.nii.gz` files in a single directory.
2.  **Filename Convention:** The script extracts conditions from filenames. Ensure your filenames include the condition variables, like `_AGE_`. The default regex looks for `[_-]AGE[_-]([0-9.]+)`.
    *   **Examples:** `subject001_AGE_65.3.nii.gz`, `sub-002-age-22.0.nii.gz`
3.  **Training with Folder:** Use the `--data_folder` argument in `scripts/train.py`.

## Training

The main training script is `scripts/train.py`.

**Example Command (using Metadata CSV):**
```bash
PYTHONPATH=. nohup python3 -u scripts/train.py \
    --metadata_csv ./metadata/main_dataset_catalog.csv \
    --run_name "FlowLet_RFM_Training" \
    --flow_type rfm \
    --condition_vars Age \
    --csv_filter_col Condition \
    --csv_filter_value CN \
    --epochs 200 \
    --batch_size 4 \
    --lr 3e-6 \
    --model_input_size 112 112 112 \
    --save_size 91 109 91 \
    --unet_model_channels 128 \
    --unet_channel_mult "1,2,4,8" \
    --unet_attention_res "4,8" \
    --wandb \
    --wandb_project "FlowLet_Project" > logs/training_rfm.log 2>&1 &
```

### Key Training Arguments
*   `--metadata_csv`: Path to your metadata CSV file (recommended).
*   `--data_folder`: Path to your NIfTI dataset folder (alternative).
*   `--run_name`: A unique name for this experiment. Checkpoints and logs will be saved in `checkpoints_flowlet/<run_name>`.
*   `--flow_type`: The Flow Matching formulation to use. Choices: `rectified`, `cfm`, `vp_diffusion`, `trigonometric`.
*   `--condition_vars`: List of conditions to use (must be columns in the CSV or parseable from filenames).
*   `--model_input_size`: The spatial size images are padded before the DWT. Must be divisible by `2^(num_downsampling_layers)`.
*   `--save_size`: The final spatial size to crop generated images to after the IDWT.
*   `--unet_*`: Arguments to configure the U-Net architecture (channels, attention resolutions, etc.).
*   `--wandb`: Enable Weights & Biases logging.

## Generating Samples

### 1. Generating with Linearly Interpolated Ages
Use `scripts/generate_linear.py` to generate a sequence of samples where age is varied smoothly. This is ideal for visualization and for creating large, age-balanced datasets for downstream tasks.

```bash
PYTHONPATH=. python3 -u scripts/generate_linear.py \
    --checkpoint_path checkpoints_flowlet/<run_name>/fmw_best.pth \
    --output_dir ./generated_samples/<run_name>/linear_age \
    --condition_ranges_path checkpoints_flowlet/<run_name>/condition_ranges.json \
    --min_age 5.9 \
    --max_age 95.5 \
    --num_total_samples 3000 \
    --num_flow_steps 10 \
    --save_size 91 109 91
```

### 2. Generating for Specific Conditions
Use `scripts/generate.py` to create a fixed number of samples for specific, discrete conditions.

```bash
PYTHONPATH=. python3 -u scripts/generate.py \
    --checkpoint_path checkpoints_flowlet/<run_name>/fmw_best.pth \
    --output_dir ./generated_samples/<run_name>/specific \
    --condition_ranges_path checkpoints_flowlet/<run_name>/condition_ranges.json \
    --generation_conditions "Age=45" "Age=70.5" \
    --num_synthetic 50 \
    --save_size 91 109 91
```

### Advanced: Generating Ablation Samples
The `scripts/generate_linear.py` script contains a `generation_modes` dictionary that allows you to generate samples with specific conditioning mechanisms turned off such as FiLM only, Spatial Condition only (cross-attention) or complete unconditional. This is the exact tool used to produce the samples for the ablation study in the paper.

---
# Full Evaluation Experiments
-----

---
## Reproducing the Quantitative Evaluation

This section details how to perform the quantitative evaluation of generated samples, as reported in the paper. The evaluation is designed to assess image fidelity, distributional similarity, and sample diversity.

### 1. The Evaluation Script

The primary script for this task is `Evaluation_metrics/Evaluation_FID_MMD_MSSIM.py`. This script handles the calculation of all key metrics, including age-stratified analysis and statistical significance testing.

### 2. How it Works

The script compares one or more directories of generated 3D NIfTI files against a directory of real NIfTI files. It calculates three main metrics:

1.  **Fréchet Inception Distance (FID):** Measures the similarity between the distributions of real and generated images in a deep feature space. A lower FID indicates that the generated samples are more realistic and diverse.
2.  **Maximum Mean Discrepancy (MMD):** An alternative metric to FID for comparing distributions in a feature space, using a Gaussian kernel. A lower MMD indicates better similarity.
3.  **Multi-Scale Structural Similarity (MS-SSIM):** Measures the structural similarity between pairs of images within the *generated set*. A lower MS-SSIM score indicates higher diversity among the generated samples (i.e., the model is not collapsing to a single mode).

The evaluation pipeline follows these key steps:

1.  **Feature Extractor:** A pre-trained **3D Medical ResNet-50** is loaded. This network acts as a feature extractor, converting each 3D MRI into a high-dimensional feature vector (an "activation").
2.  **Data Normalization:** A global intensity normalization is calculated by sampling a subset of the real data. This ensures that both real and generated images are processed with a consistent intensity range.
3.  **Activation Calculation:** The script processes all real images and all generated images (from each provided directory) through the feature extractor to get their corresponding sets of feature vectors.
4.  **Metric Calculation:**
    *   **FID/MMD:** The script compares the set of real activations against the set of generated activations. To ensure stable and reliable results, it uses a **bootstrapping** procedure: it repeatedly takes random subsamples from both activation sets, calculates FID and MMD on these subsamples, and then reports the **mean** and **standard deviation** of these scores over all iterations.
    *   **MS-SSIM:** The script calculates the pairwise MS-SSIM between a large number of randomly sampled pairs from the generated dataset to assess intra-set diversity.
5.  **Age-Stratified Analysis:** If the filenames contain `_AGE_` tags, the script automatically groups both real and generated samples into predefined age bands (e.g., 15-30, 40-55, 65-80) and repeats the metric calculations for each band. This provides a fine-grained understanding of model performance across different demographics.
6.  **Statistical Testing:** If more than one directory of generated samples is provided, the script automatically treats the first directory as the baseline. It performs a **Wilcoxon rank-sum test** on the distributions of bootstrapped FID/MMD scores to determine if the differences between models are statistically significant, applying a Bonferroni correction to account for multiple comparisons.

### 3. How to Run the Evaluation

You can run the full evaluation pipeline using the `Evaluation_metrics/Evaluation_FID_MMD_MSSSIM.sh` script.

**Steps:**

1.  Open the `Evaluation_metrics/Evaluation_FID_MMD_MSSSIM.sh` script.
2.  Set the `--real_dir` argument to the path of your directory containing the real NIfTI dataset.
3.  Set the `--gen_dirs` argument to a space-separated list of paths to the directories containing your generated NIfTI samples. **The first path in this list will be treated as the baseline for statistical comparisons.**
4.  Ensure the `--medical_resnet_path` points to the pre-trained weights file (`resnet_50_epoch_110_batch_0.pth`).
5.  Specify the `--output_csv` path where the final results table will be saved.
6.  (Optional) Adjust parameters like `--max_fid_samples`, `--max_ssim_samples`, and `--num_bootstraps` to balance computational cost and precision.
7.  Execute the script from the project's root directory:
    ```bash
    bash Evaluation_metrics/Evaluation_FID_MMD_MSSSIM.sh
    ```
    
---

## Reproducing the Wavelet Selection Ablation Study

This section details how to reproduce the wavelet selection analysis reported in the paper. The goal of this analysis is to empirically determine which wavelet basis is most suitable for our task by measuring its reconstruction fidelity.

### 1. The Analysis Script

This analysis is performed by the script located at `Evaluation_wavelet_ablations/calculate_wavelet_errors.py`.

### 2. How it Works

The script systematically evaluates the "round-trip" reconstruction error for a list of different wavelet families. For each real 3D MRI in the dataset and for each specified wavelet (e.g., 'haar', 'db4', 'sym4'):

1.  **Load Data:** A 3D NIfTI file is loaded into a NumPy array.
2.  **Forward DWT:** The script performs a 3D Discrete Wavelet Transform on the image data using the selected wavelet.
3.  **Inverse DWT:** It immediately performs an Inverse DWT on the resulting coefficients to reconstruct the image.
4.  **Calculate Error:** It computes the **Mean Absolute Error (MAE)** between the original image and the reconstructed image. This MAE value quantifies how much information was lost or altered during the DWT/IDWT round-trip process. A lower MAE indicates higher fidelity.
5.  **Aggregate Results:** After processing all files in the dataset, the script calculates the **mean** and **standard deviation** of the MAE scores for each wavelet family.
6.  **Report:** The final results are sorted by the mean MAE (lowest error first), printed to the console, and saved to a CSV file. The wavelet with the lowest mean MAE is considered the best choice for preserving anatomical information.

### 3. How to Run the Analysis

You can run the analysis using the `Evaluation_wavelet_ablations/calculate_wavelet_errors.sh` script, which provides a ready-to-use command.

**Steps:**

1.  Open the `Evaluation_wavelet_ablations/calculate_wavelet_errors.sh` script.
2.  Modify the `--input_dir` to point to the directory containing your real NIfTI dataset.
3.  (Optional) Modify the `--wavelets` list to test different wavelet families. The default list in the script matches the one used for the paper's ablation.
4.  Specify the desired `--output_csv` path for the results.
5.  (Optional) For a quick test, you can use the `--num_files` argument to limit the analysis to a small subset of your data.
6.  Execute the script from the project's root directory:
    ```bash
    bash Evaluation_wavelet_ablations/calculate_wavelet_errors.sh
    ```
---
## Reproducing the Sampling Time Benchmark

This section details how to reproduce the sampling time benchmarking for FlowLet as presented in the paper. The benchmark measures the wall-clock time required to generate a single 3D brain MRI sample for various numbers of ODE integration steps.

### 1. The Benchmarking Script

The entire process is managed by the script located at `Benchmarking_time/benchmarking_time.py`. This script is designed to provide accurate and stable timing measurements.

### 2. How it Works

The script systematically evaluates the model's sampling performance by following these steps for a predefined list of step counts (e.g., 1, 2, 5, 10, 100, 200):

1.  **Model Loading:** It loads a pre-trained model checkpoint and its associated configuration file. The model is set to evaluation mode (`model.eval()`) to disable operations like dropout.
2.  **GPU Warm-up:** Before any timing begins, the script generates a few "warm-up" samples. This is a crucial step to ensure that the GPU is fully initialized and any one-time memory allocations are completed. This prevents first-run overhead from skewing the timing results.
3.  **Timed Measurement Loop:** After the warm-up, the script generates a sequence of samples (e.g., 15) one by one. The generation time for each individual sample is measured.
4.  **Accurate Timing:** To ensure timing accuracy, especially on GPUs, the script uses the following precautions:
    *   **`torch.cuda.synchronize()`:** This command is called immediately *before* starting the timer and *after* the generation is complete. This forces the CPU to wait for all pending GPU operations to finish, ensuring that the measured time accurately reflects the full duration of the GPU workload, not just the time it took to launch the kernels.
    *   **`time.perf_counter()`:** This is used for measuring the time intervals, which is more suitable for short-duration benchmarks than other timing functions.
5.  **Aggregation and Reporting:** After timing all samples for a given step count, the script calculates the **mean** and **standard deviation** of the timings. This provides a stable average generation time and a measure of its variability.
6.  **Results Export:** The final results (Steps, Mean Time, Standard Deviation) are printed to the console and saved to a CSV file for easy analysis and plotting.

---

## All Brain Age Prediction experiments, regional plausibility evaluations, and their associated preprocessing steps are provided in the corresponding folders within the repository.