# Meta-Learning via Classifier(-free) Guidance

# Installation

The `hyperclip` conda environment can be created with the following commands:
```
conda env create -f environment.yml
conda activate hyperclip
pip install git+https://github.com/openai/CLIP.git
conda install pytorch cudatoolkit=11.3 -c pytorch
pip install -e .
```

To setup Weights and Biases run
```
wandb login
```
and paste your W&B API key.

# Meta-VQA Dataset

To re-compute the Meta-VQA dataset, first download the third-party [original VQA v2 dataset](https://visualqa.org/download.html) and place it in the `data/VQA/` folder, and then run (while in the `hyperclip` environment):
```
python scripts/precompute_image_features.py
python scripts/precompute_ques_features.py
python scripts/precompute_text_features.py
```
to re-generate the pre-computed CLIP embeddings for images, task questions and answers.

# Experiment scripts

To train multitask/MAML baselines or an unconditional Hypernetwork generative model (to later use as basis for conditional generation), use the script:
```
python scripts/train_few_shot.py [...]
```

To train a number of our models, we first need to prepare a precomputed "dataset" of fine-tuned networks/hnet latents/vae latents. We can do so with the script:
```
python scripts/precompute_adaptation.py (--few_shot_checkpoint <wandb id of train_few_shot.py hnet run> | --vae_checkpoint <wandb id of train_vae.py run>) [...]
```

In order to train the unconditional VAE hypernetwork (alternative to the previous HNET as basis for conditional generation methods), use the script:
```
python scripts/train_vae.py --precompute_checkpoint <wandb id of precompute_adaptation.py run> [...]
```

To train the HyperCLIP encoder (either from precomputed VAE/HNET fine-tunings, a VAE, or an HNET), use the script:
```
python scripts/train_hyperclip.py (--precompute_checkpoint <wandb id of precompute_adaptation.py run> | --vae_checkpoint <wandb id of train_vae.py run> | --few_shot_checkpoint <wandb id of train_few_shot.py run>) [...]
```

To train a hypernetwork latent diffusion model (HyperLDM), use the script:
```
python scripts/train_latent_diffusion.py (--precompute_checkpoint <wandb id of precompute_adaptation.py run> | --vae_checkpoint <wandb id of train_vae.py run> | --few_shot_checkpoint <wandb id of train_few_shot.py>) [...]
```
