# SPRO: Self-Prompted Visual Reasoning and Optimization

This repository contains the implementation for the SPRO papers.

## Table of Contents
- [Requirements](#requirements)
- [Data Preparation](#data-preparation)
- [Workflow](#workflow)
  - [1. Data Generation](#1-data-generation)
  - [2. Direct Preference Optimization (DPO)](#2-direct-preference-optimization-dpo)
  - [3. Diffusion Model Tuning](#3-diffusion-model-tuning)
  - [4. Evaluation](#4-evaluation)
- [Iterative Training Process](#iterative-training-process)

## Requirements
```bash
pip install -r requirements.txt
```

## Data Preparation
Before starting, you need to organize the following data:
- Download base images for Flickr30k and PAP datasets and place them in these directories:
  - `data_gen/data/base_flickr_images_train/`
  - `data_gen/data/base_flickr_images_test/`
  - `data_gen/data/base_pap_images_train/`
  - `data_gen/data/base_pap_images_test/`
- Ensure corresponding CSV files are in the data directory:
  - `data_gen/data/flickr_test.csv`  - last 514 samples of flickr30k (provided)
  - `data_gen/data/flickr_train.csv` - first 30500 samples of flickr30k
  - `data_gen/data/pap_train.csv`    - Pick a Pic validation set of 17k
  - `data_gen/data/pap_test.csv`     - Pick a Pic unique test set of 500 (provided)

## Workflow

### 1. Data Generation
Navigate to the data generation directory:
```bash
cd data_gen
```

#### Generate improved prompts
```bash
# Make script executable
chmod +x prompt_generation.sh
# Run prompt generation (example for aesthetics objective, experiment name "it1")
./prompt_generation.sh
```

#### Merge generated prompt files
```bash
python3 merge_output_files.py --expname "it1"
# Output: flickr_train_it1_prompts.csv
```

#### Generate images using improved prompts
```bash
chmod +x image_generation.sh
./image_generation.sh
# Output: flickr_train-it1-sdxl-images/
```

#### Score generated images
To score for aesthetics and human preference
```bash
python3 aesthetics_scorer.py --image_dir data/flickr_train-it1-sdxl-images/ \
                            --base_dir data/base_flickr_images_train/ \
                            --scores_file "flickr_train_it1_prompts.csv" \
                            --expname "it1-sdxl"
# Output: flickr_train_it1_prompts_scored.csv
```
To score for engagement
```bash
python3 eoig_scorer.py --input-csv data/twitter_train_it1_prompts.csv\
                       --output-csv data/twitter_train_it1_prompts_Scored.csv\
                       --experiment-name it1\
                       --model-path /path/to/model\
                       --gpu-id 0
# Output: twitter_train_it1_prompts_Scored.csv   
```

### 2. Direct Preference Optimization (DPO)

Navigate to the DPO directory:
```bash
cd dpo
```

#### Create DPO pairs
```bash
python3 get_dpo_pairs.py --input_csv "../data_gen/flickr_train_it1_prompts_scored.csv" \
                         --exp_name "it1-sdxl"
# Output: data_formats/flickr_train_it1_dpopairs.jsonl
```

#### Format data for DPO training
```bash
python format_data.py data_formats/flickr_train_it1_dpopairs.jsonl \
                     data_formats/flickr_train_it1_dpopairs_llama11b
```

#### Run DPO training
```bash
accelerate launch dpo.py --dataset_name "data_formats/flickr_train_it1_dpopairs_llama11b" \
                         --model_name_or_path "meta-llama/Llama-3.2-11B-Vision-Instruct" \
                         --output_dir "checkpoints/dpotrained_it1_flickr_train_it1_dpopairs_llama11b" \
                         --bf16
# Output: trained model at "checkpoints/dpotrained_it1_flickr_train_it1_dpopairs_llama11b"
```

### 3. Diffusion Model Tuning

Navigate to the diffusion code directory:
```bash
cd diffusion_code
```

#### Setup diffusers
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
cd examples/text_to_image
pip install -r requirements_sdxl.txt
accelerate config default
```

#### Set model environment variables
```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
```

#### Prepare diffusion tuning data

```bash
python3 get_diff_tuning_data.py --input "../data_gen/flickr_tain_it1_prompts_scored.csv" \
                                --expname "it1-sdxl" \
                                --image-dir "../data_gen/flickr_train-it1-sdxl-images/"
```

#### Place the custom training script
Ensure `train_text_to_image_sdxl_usingcsv.py` is in `diffusers/examples/text_to_image/`

#### Train diffusion model
```bash
accelerate launch train_text_to_image_sdxl_usingcsv.py \
    --csv_path="it1_diff_data.csv" \
    --pretrained_model_name_or_path=$MODEL_NAME \
    --pretrained_vae_model_name_or_path=$VAE_NAME \
    --resolution=512 \
    --center_crop \
    --random_flip \
    --proportion_empty_prompts=0.2 \
    --train_batch_size=16 \
    --gradient_accumulation_steps=4 \
    --gradient_checkpointing \
    --num_train_epochs 50 \
    --use_8bit_adam \
    --learning_rate=1e-06 \
    --lr_scheduler="constant" \
    --lr_warmup_steps=50 \
    --mixed_precision="fp16" \
    --validation_prompt="a cute Sundar Pichai creature" \
    --validation_epochs 5 \
    --checkpointing_steps=500 \
    --output_dir="chekpoints/it1_diff_data"
```

## 4. Evaluation

After training completes for iteration 1:
1. Evaluate on test data CSV and images for the specified objective (SPRO text) using the trained model by repeating step 1 on it
2. Use the trained DPO model from step 2 in the next iteration
3. For SPRO image evaluation, update the diffusion model checkpoint in `image_generation.sh` and run it on testing data base captions by repeated step 1 image generation step. To use FLUX, change model name argument in command being executed through image_generation.sh file.
4. For SPROMM evaluation, after n iterations, update the diffusion model checkpoint in `image_generation.sh` and run it on testing data with improvised captions

## Iterative Training Process

For subsequent iterations (iteration 2+):
1. Start with models trained in the previous iteration
2. Follow the same workflow steps, using the previously trained models as starting points
3. Continue the iterative improvement process

