# Code Appendix of _Consistency Model is an Effective Posterior Sample Approximation for Diffusion Inverse Solvers_

# Prerequisites
* Environments:
    * python 3.10, pytorch 2.1.0+cu121, transformers 4.26.1
    * other versions might work as well, but this is the version we have tested on.
* Pre-trained models:
  * all the Pre-trained model should be downloaded and put in ./bins
  * Pre-trained EDM and CM:
    * EDM for LSUN Bedroom: https://openaipublic.blob.core.windows.net/consistency/edm_bedroom256_ema.pt
    * EDM for LSUN Cat: https://openaipublic.blob.core.windows.net/consistency/edm_cat256_ema.pt
    * CM (cd, lpips) for LSUN Bedroom: https://openaipublic.blob.core.windows.net/consistency/cd_bedroom256_lpips.pt
    * CM (cd, lpips) for LSUN Cat: https://openaipublic.blob.core.windows.net/consistency/cd_cat256_lpips.pt
  * Pre-trained Model for operators
    * Segmentation Model:
      * Model A for DIS:
        * http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-mobilenetv2dilated-c1_deepsup/encoder_epoch_20.pth
        * http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-mobilenetv2dilated-c1_deepsup/decoder_epoch_20.pth
        * rename them into 'encoder_epoch_20_A.pth' and 'decoder_epoch_20_A.pth'.
      * Model B for evaluation:
        * 'ckpt/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth' and 'ckpt/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth' in https://colab.research.google.com/github/CSAILVision/semantic-segmentation-pytorch/blob/master/notebooks/DemoSegmenter.ipynb#scrollTo=FQf7ybR2dgau
        * rename them into 'encoder_epoch_20_B.pth' and 'decoder_epoch_20_B.pth'.
    * Layout Estimation Model: https://drive.google.com/file/d/1aUJoXM9SQMe0LC38pA8v8r43pPOAaQ-a/view
    * Captioning Model:
      * Model A for DIS: (automatic download) https://github.com/salesforce/BLIP
        * download model_base_retrieval_coco.pth to ./bins/
        * download model_base_caption_capfilt_large.pth to ./bins/
        * download bert-base-uncased to ./bins/
      * Model B for evaluation: (automatic download) https://lightning.ai/docs/torchmetrics/stable/multimodal/clip_score.html
        * download openai/clip-vit-large-patch14 to ./bins/openai/clip-vit-large-patch14
    * Classification Model:
      * Model A for DIS (automatic download): https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html
      * Model B for evaluation (automatic download): https://pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html

# Dataset and Image Samples
* We need first 1000 samples of LSUN Bedroom and LSUN Cat datasets. For LSUN cat, please only keep images with short side >= 256. All the images should be cropped into a square by its short edge and resized into 256x256, saved losslessly in PNG format.
* The dataset can be found in: https://github.com/fyu/lsun
* The LSUN cat dataset can be put in ./datasets/lsun_cat, the LSUN bedroom dataset can be put in ./dataset/lsun_bedroom.

# Fast First Run
* After downloading the pre-trained model A of segmentation operator, the prepared dataset sample can be played with
  ```bash
  python -u image_inverse.py --training_mode edm --generator determ-indiv --batch_size 1 --sigma_max 80 --sigma_min 0.002 --s_churn 0 --steps 1000 --sampler sample_euler_ancestral_cm --model_path ./bins/edm_bedroom256_ema.pt --distiller_path ./bins/cd_bedroom256_lpips.pt --attention_resolutions 32,16,8  --class_cond False --dropout 0.1 --image_size 256 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --num_samples 1 --resblock_updown True --use_fp16 True --use_scale_shift_norm False --weight_schedule karras --savedir=results/ --cfg segmentation_config.yaml
  ```
  * The result is put in ./results/, with several sub-dir
      * label: the source image
      * input ($y$): the segmentation results of source image
      * recon ($\hat{x}$): the reconstruction 
  * The correct behaviour is: 
      * label: the source image
      * input ($y$): looks like a segmentation of source image
      * recon ($\hat{x}$): has similar layout as input
  * we provide an example of run results in results_example
      * label: ![alt text](./results_example/roomsegmentation/label/00000.png)
      * input ($y$): ![alt text](./results_example/roomsegmentation/input_color/00000.png)
      * recon: ![alt text](./results_example/roomsegmentation/recon/00000.png)
      * This example run should be able to be reproduced almost exactly by the above command

# Run Experiments
* To run test with proposed I, do
  ```bash
  # segmentation
  python -u image_inverse.py --training_mode edm --generator determ-indiv --batch_size 1 --sigma_max 80 --sigma_min 0.002 --s_churn 0 --steps 1000 --sampler sample_euler_ancestral_cm --model_path ./bins/edm_bedroom256_ema.pt --distiller_path ./bins/cd_bedroom256_lpips.pt --attention_resolutions 32,16,8  --class_cond False --dropout 0.1 --image_size 256 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --num_samples 1 --resblock_updown True --use_fp16 True --use_scale_shift_norm False --weight_schedule karras --savedir=results/ --cfg segmentation_config.yaml

  # room layout estimation
  python -u image_inverse.py --training_mode edm --generator determ-indiv --batch_size 1 --sigma_max 80 --sigma_min 0.002 --s_churn 0 --steps 1000 --sampler sample_euler_ancestral_cm --model_path ./bins/edm_bedroom256_ema.pt --distiller_path ./bins/cd_bedroom256_lpips.pt --attention_resolutions 32,16,8  --class_cond False --dropout 0.1 --image_size 256 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --num_samples 1 --resblock_updown True --use_fp16 True --use_scale_shift_norm False --weight_schedule karras --savedir=results/ --cfg roomlayout.yaml

  # captioning
  python -u image_inverse.py --training_mode edm --generator determ-indiv --batch_size 1 --sigma_max 80 --sigma_min 0.002 --s_churn 0 --steps 1000 --sampler sample_euler_ancestral_cm --model_path ./bins/edm_bedroom256_ema.pt --distiller_path ./bins/cd_bedroom256_lpips.pt --attention_resolutions 32,16,8  --class_cond False --dropout 0.1 --image_size 256 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --num_samples 1 --resblock_updown True --use_fp16 True --use_scale_shift_norm False --weight_schedule karras --savedir=results/ --cfg blip_text_config.yaml

  # classification
  python -u image_inverse.py --training_mode edm --generator determ-indiv --batch_size 1 --sigma_max 80 --sigma_min 0.002 --s_churn 0 --steps 1000 --sampler sample_euler_ancestral_cm --model_path ./bins/edm_cat256_ema.pt --distiller_path ./bins/cd_cat256_lpips.pt --attention_resolutions 32,16,8  --class_cond False --dropout 0.1 --image_size 256 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --num_samples 1 --resblock_updown True --use_fp16 True --use_scale_shift_norm False --weight_schedule karras --savedir=results/ --cfg cat_classification.yaml

  # super-resolution
  python -u image_inverse.py --training_mode edm --generator determ-indiv --batch_size 1 --sigma_max 80 --sigma_min 0.002 --s_churn 0 --steps 1000 --sampler sample_euler_ancestral_cm --model_path ./bins/edm_bedroom256_ema.pt --distiller_path ./bins/cd_bedroom256_lpips.pt --attention_resolutions 32,16,8  --class_cond False --dropout 0.1 --image_size 256 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --num_samples 1 --resblock_updown True --use_fp16 True --use_scale_shift_norm False --weight_schedule karras --savedir=results/ --cfg super_resolution_config.yaml
  ```
* To run test with proposed II, do
  ```bash
  # TODO
  ```
* Evaluation:
  ```bash
  # evaluate FID, KID, LPIPS, MSE, PSNR, etc.
  python eval_scripts/eval_fid.py

  # evaluate miou for segmentation
  python eval_miou.py

  # evaluate miou for layout
  python eval_miou_room2.py

  # evaluate clip score for captioning
  python eval_scripts/eval_clip_score.py

  # evaluate accuracy for classification
  python eval_catcls.py
  
  # evaluate MSE for down-sampling
  python eval_scripts/eval_mse.py
  ```

# Notes
* All the links in this Readme.md are pointing to public, thirdparty sources, i.e., they do not contain any information about the identity of authors.
* This code base is heavily insipred by the the source code of Consistency Model: https://github.com/openai/consistency_models, and the source code of DPS: https://github.com/DPS2022/diffusion-posterior-sampling
* The model and code are from those sources:
  * Segmentation: https://github.com/CSAILVision/semantic-segmentation-pytorch
  * Layout: https://github.com/leVirve/lsun-room
  * Captioning: https://github.com/salesforce/BLIP
  * Classification: torchvision pretrained