# Inference-time Diffusion Model Alignment via Random Ordinary Equations

# 💻 Installation

🥰 Feel free to adopt this project for your diffusion model research! 



😀 The installation is tested with NVIDIA Driver `535.230.02` , CUDA `12.2` and `setuptools==75.1.0` in Ubuntu `22.04.6 LTS`. 

[1] Clone our repository from Github:  

```text
# Comming soon ...
```

[2] Create a conda virtual environment with Python 3.10 and activate it. 

```text
conda create -n Diffusion-OT-MCTS python=3.10
conda activate Diffusion-OT-MCTS
```

[3] Install versions of `torch` and `torchvision` compatible with your CUDA version. Here we install `torch==2.3.1` and `torchvision==0.18.1` for example. 



```shell
pip install torch==2.3.1 torchvision==0.18.1
```

[4] Install the dependencies. 

```bash
pip install -r requirement.txt
```

# 🏞️ Models \& Datasets

[1] Use `./script/download_model_dataset.sh` to download models and datasets. 

[2] Update the paths in `config/dataet/*.yaml` to the directory where you have stored these datasets. 

​	For example, for DrawBench, 

```yaml
# config/dataset/draw_bench.yaml

prompt_list_json_path: /mnt/d/hytidel/dataset/zhwang/HPDv2/benchmark/drawbench.json
```

[3] Update the paths in `config/model/*.yaml` to the directory where you have stored these models. 

​	For example, for GPT-2, 

```yaml
# config/model/gpt_2.yaml

gpt_2:
  ckpt_root_path: /mnt/d/hytidel/model/openai-community/gpt2
```

[4] Update the paths in `config/pipeline/*.yaml` to the directory where you have stored these models. 

​	For example, for SD-Turbo, 

```yaml
# config/pipeline/sd-turbo.yaml

pipeline_path: /mnt/d/hytidel/model/stabilityai/sd-turbo
```

[5] Update the paths in `config/reward_model/*.yaml` to the directory where you have stored these models. 

​	For example, for CLIP score, 

```yaml
# config/reward_model/clip_score.yaml

clip_score:
  open_clip_model_ckpt_path: /mnt/d/hytidel/model/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin
```

​	For HPS v2, 

```yaml
# config/reward_model/hps_v2.yaml

hps_v2:
  hps_model_ckpt_path: /mnt/d/hytidel/model/HPSv2/HPS_v2_compressed.pt
  vit_model_ckpt_path: /mnt/d/hytidel/model/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin
```

# 🛠️ Usage

😄 We provided some scripts in `script/` for you to run the specified tasks conveniently. 



[**Sec. 5.2**]  RODE Sampling

[1] Empirical Distribution

​	(1) SD-Turbo: 

​		`script/sample/t2i/run_sample_scheduled/hpd_v2/sd-turbo/comparison_between_style/run.sh`

​	(2) SD v1.4: 

​		`script/sample/t2i/run_sample_scheduled/hpd_v2/sd_v1_4/comparison_between_style/run.sh`

[2] Sample Diversity

​	(1) Sample: 

​		`script/sample/t2i/run_sample_scheduled/mscoco_2014_5k_test/sd_v1_4/run.sh`

​	(2) Calculate MPD: 

​		`script/cal_metric/cal_mean_pairwise_distance/sd-turbo/2.sh`

​		`script/cal_metric/cal_mean_pairwise_distance/sd_v1_4/*.sh`



[**Sec. 5.3**]  Aligning with Aesthetics

[1] Inference-step Scaling: 

​	(1) Sample: 

​		`script/sample/t2i/run_sample_scheduled/hpd_v2/sd_v1_4/baseline/run_baseline.sh`

​	(2) Calculate HPS v2, PS, IR: 

​		`script/cal_metric/cal_final_reward_baseline/hps_v2/sd_v1_4/HumanPreferenceDataset_v2/run_baseline.sh`

​		`script/cal_metric/cal_final_reward_baseline/pick_score/sd_v1_4/HumanPreferenceDataset_v2/run_baseline.sh`

​		`script/cal_metric/cal_final_reward_baseline/image_reward/sd_v1_4/HumanPreferenceDataset_v2/run_baseline.sh`

​	(3) Display Results: 

​		(i) HPS v2: 

​			`script/display_result/display_result_baseline/hpd_v2/sd_v1_4/hps_v2/*.sh`

​		(ii) PS: 

​			`script/display_result/display_result_baseline/hpd_v2/sd_v1_4/pick_score/*.sh`

​		(iii) IR: 

​			`script/display_result/display_result_baseline/hpd_v2/sd_v1_4/image_reward/*.sh`

[2] Trajectory Search Methods: 

​	(1) BS-eps (including GS-eps): 

​		(i) Search: 

​			`script/search/run_optimal_control_bs_eps/sd_v1_4/hps_v2/bs_*.sh`

​		(ii) Display Results: 

​			`script/display_result/display_result_bs_eps/sd_v1_4/bs_*.sh`

​	(2) BS-eta (including GS-eta): 

​		(i) Search: 

​			`script/search/run_optimal_control_bs_eta/sd_v1_4/hps_v2/bs_*.sh`

​		(ii) Display results: 

​			`script/display_result/display_result_bs_eta/sd_v1_4/bs_*.sh`

​	(3) MCTS-eps: 

​		(i) Search: 

​			`script/search/run_optimal_control_mcts_eps/sd_v1_4/hps_v2/mcts_eps_999.sh`

​		(ii) Cal PS, IR: 

​			`script/cal_metric/cal_final_reward_mcts_eps/sd_v1_4/image_reward.sh`

​			`script/cal_metric/cal_final_reward_mcts_eps/sd_v1_4/pick_score.sh`

​		(iii) Display results: 

​			`script/display_result/display_result_mcts_eps/sd_v1_4/hps_v2.sh`

​	(4) Ours (MCTS-eta): 

​		(i) Search: 

​			`script/search/run_optimal_control_mcts/sd_v1_4/hps_v2/main/latent_max_max_999.sh`

​		(ii) Cal PS, IR: 

​			`script/cal_metric/cal_final_reward_ours/sd_v1_4/image_reward.sh`

​			`script/cal_metric/cal_final_reward_ours/sd_v1_4/pick_score.sh`

​		(iii) Display results: 

​			`script/display_result/display_result_ours/sd_v1_4/hps_v2_50_prompt.sh`

[3] Plot: 

​	(1) Prepare `scaling_list`: 

​		(i) HPS v2: `script/display_result/get_scaling_list/sd_v1_4/hps_v2/*.sh`

​		(ii) PS: `script/display_result/get_scaling_list/sd_v1_4/pick_score/*.sh`

​		(iii) IR: `script/display_result/get_scaling_list/sd_v1_4/image_reward/*.sh`

​	(2) Plot line charts: 

​		(i) HPS v2: `script/plot/line_chart_scaling/sd_v1_4/hps_v2.sh`

​		(ii) PS \& IR: `script/plot/line_chart_scaling_others/sd_v1_4/pick_score_image_reward.sh`



[**Sec. 5.4**]  Aligning with Semantics

[1] DDPM \& DDIM: 

​	(1) Sample: 

​		`script/sample/t2i/run_sample_scheduled/draw_bench_30/sdxl/run_draw_bench_30.sh`

​	(2) Calculate HPS v2 \& CLIP score: 

​		`script/cal_metric/cal_final_reward_baseline/hps_v2/sdxl/DrawBench/baseline.sh`

​		`script/cal_metric/cal_final_reward_baseline/clip_score/sdxl/DrawBench/baseline.sh`

​	(3) Display results: 

`script/display_result/display_result_baseline/draw_bench_30/sdxl/baseline/hps_v2/*.sh`



​		`script/display_result/display_result_baseline/draw_bench_30/sdxl/baseline/clip_score/*.sh`

[2] Z-Sampling: 

​	(1) Sample: 

​		`script/baseline/run_z_sampling/draw_bench_30/sdxl/ddim.sh`

​	(2) Calculate HPS v2 \& CLIP score: 

​		`script/cal_metric/cal_final_reward_baseline/hps_v2/sdxl/DrawBench/z_sampling.sh`

​		`script/cal_metric/cal_final_reward_baseline/clip_score/sdxl/DrawBench/z_sampling.sh`

​	(3) Display results: 

​		`script/display_result/display_result_baseline/draw_bench_30/sdxl/z_sampling/single_k/*.sh`

​		`script/display_result/display_result_baseline/draw_bench_30/sdxl/z_sampling/across_k/*.sh`



[**Sec. 5.5**]  Aligning with Composite Rewards

[1] Baselines: 

​	(1) Sample: 

​		`script/sample/t2i/run_sample_scheduled/hpd_v2/pixart_alpha_xl/baseline/run_baseline.sh`

​	(2) Calculate HPS v2: 

​		`script/cal_metric/cal_final_reward_baseline/hps_v2/pixart_alpha/baseline.sh`

​	(3) Calculate compressibility reward (CR): 

​		`script/cal_metric/cal_final_reward_baseline/compressibility_reward/pixart_alpha/baseline.sh`

​	(4) Calculate CLIP score: 

​		`script/cal_metric/cal_final_reward_baseline/clip_score/pixart_alpha/baseline.sh`

[2] Ours: 

​	(1) Calculate HPS v2: 

​		`script/search/run_optimal_control_mcts/pixart_alpha_xl/hps_v2/ours_999.sh`

​	(2) Calculate CR: 

​		`script/search/run_optimal_control_mcts/pixart_alpha_xl/compressibility_reward/ours_999.sh`

​	(3) Calculate CLIP score: 

​		`script/search/run_optimal_control_mcts/pixart_alpha_xl/clip_score/ours_999.sh`



[**Sec. 5.6**]  Ablations and Applications

[1] (Appx. J) Comparison between Reward Shaping, MDP Modeling and Value Policies 

​	(1) Cumulative reward: 	

​		`script/search/run_optimal_control_mcts/sd_v1_4/hps_v2/test_reward_mdp_value/latent_reward/cumulative_average.sh`

​		`script/search/run_optimal_control_mcts/sd_v1_4/hps_v2/test_reward_mdp_value/latent_reward/cumulative_max.sh`

​		`script/search/run_optimal_control_mcts/sd_v1_4/hps_v2/test_reward_mdp_value/latent_reward/max_average.sh`

​		`script/search/run_optimal_control_mcts/sd_v1_4/hps_v2/test_reward_mdp_value/latent_reward/max_max.sh`

​	(2) Sparse reward: 

​		`script/search/run_optimal_control_mcts/sd_v1_4/hps_v2/test_reward_mdp_value/sparse_reward/sparse_average.sh`

​		`script/search/run_optimal_control_mcts/sd_v1_4/hps_v2/test_reward_mdp_value/sparse_reward/sparse_max.sh`

[2] (Appx. N) Latent Reward Policies: 

​	`script/search/run_optimal_control_mcts/sd_v1_4/hps_v2/latent_reward_policy/*.sh`

[3] (Appx. O.1) Reward Hacking and Effects of $\tau$​ : 

​	(1) Color channel reward: 

​	`script/search/run_optimal_control_mcts/sd_v1_4/color_channel_reward/ablation/ours_500.sh`

​	`script/search/run_optimal_control_mcts/sd_v1_4/color_channel_reward/ablation/ours_wo_tau_500.sh`

​	(2) Laplacian variance (LAPV): 

​		`script/search/run_optimal_control_mcts/sd_v1_4/laplacian_var_reward/ablation/ours_500.sh`

​		`script/search/run_optimal_control_mcts/sd_v1_4/laplacian_var_reward/ablation/ours_wo_tau_500.sh`

[4] (Appx. O.2) Ablations on $m$ and $\zeta$ : 

​	(1) Ablations on $m$ : 

​		`script/display_result/sd_v1_4/ablation/effect_of_m.sh`

​	(2) Ablations on $\zeta$ : 

​		`script/display_result/sd_v1_4/hps_v2/test_zeta.sh`

[5] (Appx. O.3) Main Ablations: 

​	(1) Search: 

​		`script/search/run_optimal_control_mcts/sd_v1_4/hps_v2/ablation/latent_max_max_no_online_update.sh`

​		`script/search/run_optimal_control_mcts/sdxl/hps_v2/application/ours_promptist_gn.sh`

​		`script/search/run_optimal_control_mcts/sd_v1_4/hps_v2/ablation/latent_max_max_no_depth_limit.sh`

​	(2) Display results: 

​		`script/display_result/sd_v1_4/ablation/wo_pseudo_final.sh`

​		`script/display_result/sd_v1_4/ablation/wo_depth_limit.sh`

​		`script/display_result/sd_v1_4/ablation/wo_online_update.sh`

[6] (Appx. P.1) Synergy with Community Modules: 

​	(1) Search: 

​		(i) Ours: `script/search/run_optimal_control_mcts/sdxl/hps_v2/application/ours_gn.sh`

​		(ii) Ours + Promptist: 

​			`script/search/run_optimal_control_mcts/sdxl/hps_v2/application/ours.sh`

​		(iii) Ours + Promptist + GN: 

​			`script/search/run_optimal_control_mcts/sdxl/hps_v2/application/ours_promptist.sh`

​	(2) Display results: 

​		`script/search/run_optimal_control_mcts/sdxl/hps_v2/application/*.sh`

[7] (Appx. P.2) Robustness of Image Reward Functions

​	(1) HPS v2: 

​		`script/cal_metric/cal_reward_robustness/sd_v1_4/hps_v2.sh`

​	(2) PS: 

​		`script/cal_metric/cal_reward_robustness/sd_v1_4/pick_score.sh`

​	(3) IR: 

​		`script/cal_metric/cal_reward_robustness/sd_v1_4/image_reward.sh`

---

---

