# SD3.5 External Attention Heads Pipeline

A lightweight approach to concept learning in Stable Diffusion 3.5 by training external attention heads without fine-tuning the base model.

---

## 1. Data Generation (`data_generation.py`)

Generates training images using SD3.5 (large or large-turbo).

**Usage:**
```python
from data_generation import SD35DataCreator, CfgTurboSingle

creator = SD35DataCreator(CfgTurboSingle)
creator.run(create_zip=True, height=1024, width=1024)
```

**Outputs:**
- `{root_dir}/0.jpg, 1.jpg, ...` — Generated images
- `{root_dir}/labels.json` — Prompt-concept pairs
- `{root_dir}/concept_dict.json` — Concept name → index mapping

**Available Configs:** `CfgLargeSingle`, `CfgLargeBatch`, `CfgLargeRace`, `CfgLargeFull`, `CfgTurboSingle`

---

## 2. Training (`train.py`)

Trains external attention heads to learn concept transformations (e.g., person → woman).

**Key Parameters (in `TrainConfig`):**
| Parameter | Default | Description |
|-----------|---------|-------------|
| `resolution` | 512 | Training image size |
| `batch_size` | 4 | Batch size |
| `num_epochs` | 30 | Training epochs |
| `learning_rate` | 1e-4 | Learning rate |
| `prompt_person` | "a photo of a person" | Source prompt |
| `prompt_woman` | "a photo of a woman" | Target prompt |

**Run:**
```bash
python train.py
```

**Outputs:**
- `external_heads_epoch_{N}.pt` — Trained head weights `[layers, heads, seq_len, head_dim]`
- `head_importance_epoch_{N}.json` — Per-head gradient importance scores

---

## 3. Inference (`inference.py`)

Generates images with trained external heads applied at configurable strengths.

**Key Parameters:**
| Parameter | Default | Description |
|-----------|---------|-------------|
| `EXTERNAL_HEADS_PATH` | `./external_heads.../epoch_30.pt` | Trained checkpoint |
| `COEFFICIENT_LIST` | `[0, 10]` | Head scaling factors |
| `TARGET_HEADS` | `[9, 19, 12, 28]` | Specific heads to use (or `None` for all) |
| `IMAGE_RESOLUTION` | 512 | Must match training |

**Run:**
```bash
python inference.py
```

**Outputs:**
```
images_out/
└── seed_42/
    ├── baseline_coef_0.0.png
    ├── image_coef_0.00.png
    └── image_coef_10.00.png
```

---

## Requirements

```
torch
diffusers
torchvision
PIL
tqdm
```

Set `HF_TOKEN` environment variable for gated model access.
