# Look Locally, Learn Precisely: Interpretable and Unbiased Text-to-Image Generation with Background Fidelity


## 📌 Overview

We propose a novel training and inference framework for learning interpretable concept vectors in the **h-space** (mid-block features) of Stable Diffusion. Our approach uses heatmaps generated by the DAAM method to guide the training of an auxiliary MLP module injected into the UNet's bottleneck, enabling semantically aligned and interpretable modifications during generation.


## 📁 Repository Structure

```
.
├── train.py                 # Heatmap-guided training using DAAM
├── inference.py            # Inference using learned h-space modifications and DWT-based filtering
├── data-generator.py       # DAAM-based dataset generation
├── config.py               # Argument parser for all scripts
├── requirements.txt        # Required Python packages
└── ...
```

## 🚀 Getting Started

### 1. Environment Setup

Create a virtual environment and install dependencies:

```bash
python3 -m venv env
source env/bin/activate
pip install -r requirements.txt
```

Make sure your environment includes GPU-enabled PyTorch and the appropriate CUDA version.

### 2. Data Generation

Generate synthetic data with attention heatmaps using DAAM:

```bash
python data-generator.py \
  --prompt "a photo of a person" \
  --concepts female male \
  --output_dir ./datasets \
  --num_samples 2000 \
  --model_path CompVis/stable-diffusion-v1-4
```

This will generate images and corresponding concept-specific attention heatmaps saved in `datasets/`.

### 3. Training

Train the MLP controller to learn interpretable concept vectors:

```bash
python train.py \
  --train_data_dir datasets \
  --output_dir exps/exp_female \
  --prompt "a photo of a person" \
  --concept female \
  --train_batch_size 8 \
  --num_train_epochs 20 \
  --resolution 512 \
  --model_type MLP \
  --center_crop --random_flip \
  --learning_rate 1e-1 \
  --gradient_accumulation_steps 1
```

The model saves checkpoints and logs in `exps/exp_female/`.

### 4. Inference & Evaluation

Generate images using the trained model with various modifications:

```bash
python inference.py \
  --output_dir exps/exp_female \
  --image_dir test_output \
  --prompt "a photo of a doctor in the hospital" \
  --evaluation_type eval \
  --num_test_samples 2
```

Results will be saved under `exps/exp_female/test_output/`.


## 📄 License

This project leverages the [Stable Diffusion model](https://github.com/CompVis/stable-diffusion) under the [CreativeML Open RAIL-M License](https://huggingface.co/spaces/CompVis/stable-diffusion-license).


## 🙏 Acknowledgments

- [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
- [DAAM](https://github.com/tangkaiyang/DAAM)
