# SDXL Bottleneck Concept Module

A lightweight concept learning pipeline for SDXL that injects learnable concept vectors into the UNet mid-block (bottleneck) without fine-tuning the full model.

## Pipeline Overview

### 1. Data Generation (`data_generation.py`)

Generates training images using SDXL with optional refiner support.

```bash
python data_generation.py \
    --prompt "a photo of a woman" \
    --output_dir datasets_SDXL_female \
    --num_samples 2000 \
    --seed 42 \
    --steps 50
```

**Outputs:** `image_XXXX.jpg` files, `labels.json`, `concept_dict.json`, and `metadata.json`.

---

### 2. Training (`train.py`)

Trains a `BottleneckConceptModule` that learns concept embeddings added to the UNet mid-block activations. Uses a two-prompt alignment loss: the neutral prompt + concept vectors should match the explicit concept prompt output.

```bash
accelerate launch train.py \
    --train_data_dir datasets_SDXL_female \
    --output_dir exps_female_sdxl \
    --train_batch_size 4 \
    --num_train_epochs 20 \
    --learning_rate 1e-2 \
    --seed 42
```

**Outputs:** `concept_final.pt` (concept module weights), `unet_clean_final.pth` (optional clean UNet), `loss_history.png`.

---

### 3. Inference (`inference.py`)

Generates images with four variants per seed: original, concept-only, wavelet-only, and concept+wavelet.

```bash
python inference.py \
    --output_dir exps_female_sdxl \
    --prompt "a photo of a doctor" \
    --fp16
```

**Outputs:** Images saved under `images_out/{original, concept_only, wavelet_only, concept_wavelet}/`.

---

## Supporting Files

| File | Description |
|------|-------------|
| `config.py` | Argument parser with training/inference hyperparameters |
| `utils_data.py` | `TrainingDataset` and dataloader utilities |

## Requirements

```
torch, diffusers, transformers, accelerate, pywt, ruamel.yaml, tqdm, pillow
```
