# Diffusion Meta-Prompting Models

<p align="center">
<img src=assets/teaser.png />
</p>



[**Learning to Sample Foundation Model Prompts with Diffusion**](https://arxiv.org/abs/)

## Requirements
A suitable [conda](https://conda.io/) environment named `ldm` can be created
and activated with:

```
conda env create -f environment.yml
conda activate ldm
```

# Pretrained Models
Will be made publicly available after acceptance


## Model Training

Logs and checkpoints for trained models are saved to `logs/<START_DATE_AND_TIME>_<config_spec>`.

### Training autoencoder models

Configs for training a KL-regularized autoencoder on the 10 fine-grained datasets are provided at `configs/autoencoder`.
Training can be started by running
```shell script
python main.py --base configs/autoencoder/<config_spec>.yaml -t --gpus 0,    
```
where `config_spec` is one of the given config files.

### Training LDMs 

In ``configs/latent-diffusion/`` we provide configs for training LDMs on the prompts. 
Training can be started by running

```shell script
python main.py --base configs/latent-diffusion/<config_spec>.yaml -t --gpus 0,
``` 

where ``<config_spec>`` is one of the configs provided.

## Model Inference

```shell script
cd /path/to/repo/
export PYTHONPATH=$PWD
```

### Inference DMPCoOp

For FGVC dataset

```shell script
python scripts/txt2prompts_latent_coop.py --config config.yaml --prompt "fgvc_aircraft" --ckpt /path/to/ckpt --scale 4.5 --H 16 --n_samples 1 --epoch_num 14 --ctx 4 --adapter /path/to/coop/ckpt --avg_samples --unconditional
```

### Inference DMPMulti

For sampling "identity-18" prompt

```shell script
python scripts/txt2ti_identity.py --config config.yaml --prompt "identity-18" --ckpt /path/to/ckpt --scale 4.5 --H 768 --n_sample 1 --concept_prompt 18 --output_folder_name mdp_ids
```

For sampling "smiling" slider prompt

```shell script
python scripts/txt2prompts_all.py --config config.yaml --prompt "smiling" --ckpt /path/to/ckpt --scale 4.5 --H 768 --n_samples 4 --concept_prompt smiling --prefix diff_
```


