# gcontrol

Guidance control for diffusion-based adversarial image generation.

## Description

Lightweight extension of the [`diffuseres`](https://github.com/huggingface/diffusers/tree/main) library that provides 
options for controlling the diffusion guidance method. This package was built to assist with adversarial diffusion 
guidance, i.e., using the diffusion process to generate adversarial samples. 

## Installation
Installation requires the HuggingFace 
[`accelerate`](https://github.com/huggingface/accelerate),
[`diffuseres`](https://github.com/huggingface/diffusers),
[`transformers`](https://github.com/huggingface/transformers),
[`timm`](https://github.com/huggingface/pytorch-image-models), and
[`datasets`](https://github.com/huggingface/datasets) packages. 

### Standard Install (Recommended)
Installs `gcontrol` and all necessary dependencies

```
pip install .
```

### Frozen HuggingFace Install
Installs `gcontrol` and the last known working builds of 
[`accelerate`](https://github.com/huggingface/accelerate),
[`diffuseres`](https://github.com/huggingface/diffusers),
[`transformers`](https://github.com/huggingface/transformers),
[`timm`](https://github.com/huggingface/pytorch-image-models), and
[`datasets`](https://github.com/huggingface/datasets). Note that this 
option is only recommended if future updates to the HuggingFace packages break 
`gcontrol`.

```
pip install .[frozen_huggingface]
```
## Testing

Unit tests can be run via:
```
pytest tests
```

To run longer tests specify cache directories in `tests/conftest.py`, and then run:
```
pytest tests --runslow
```

## Getting Started

### Quickstart

Running classifier-free guidance:
```
import torch
from diffusers.schedulers import DDIMScheduler
from gcontrol.diffusion_pipelines import GCStableDiffusionPipeline
from gcontrol.guidance_controllers.common import ClassifierFreeGuidance
from gcontrol.guidance_controllers.stable_diffusion import AdversarialClassifierGuidance
from gcontrol.utils import get_timm_config

# Initialising guidance controller
guidance_controller = ClassifierFreeGuidance()

# Initialising diffusion pipeline
model_id = "sd-legacy/stable-diffusion-v1-5"
pipe = GCStableDiffusionPipeline.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16,
    cache_dir = "",
    use_safetensors = True,
    guidance_controller = guidance_controller
)

# Specifying DDIM scheduler
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

pipe = pipe.to("cuda")

# Running classifier-free guidance
prompt = "tiger, Panthera tigris"
image = pipe(
    prompt,
    guidance_scale = 7.5,
    num_images_per_prompt=1,
    num_inference_steps = 100, 
    height = 512,
    width = 512,
    output_type = "pil",

)
image.images[0]
```

Running adversarial guidance against ResNet50 for target adversarial class `346` 
(water buffalo, water ox, Asiatic buffalo, Bubalus bubalis):
```
import torch
import timm
from diffusers.schedulers import DDIMScheduler
from gcontrol.diffusion_pipelines import GCStableDiffusionPipeline
from gcontrol.guidance_controllers.common import ClassifierFreeGuidance
from gcontrol.guidance_controllers.stable_diffusion import AdversarialClassifierGuidance
from gcontrol.utils import get_timm_config

# Loading timm model
resnet_model = timm.create_model(
    model_name = "resnet50", 
    pretrained = True
).eval().to(dtype = torch.bfloat16, device = "cuda")
resnet_config = get_timm_config(resnet_model)

# Initialising guidance controller
guidance_controller = AdversarialClassifierGuidance(resnet_model, **resnet_config)

# Initialising diffusion pipeline
model_id = "sd-legacy/stable-diffusion-v1-5"
pipe = GCStableDiffusionPipeline.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16,
    use_safetensors = True,
    guidance_controller = guidance_controller
)

# Specifying DDIM scheduler
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

pipe = pipe.to("cuda")

# Running adversarial diffusion for target class 346 (water buffalo, water ox, Asiatic buffalo, Bubalus bubalis)
prompt = "tiger, Panthera tigris"
image = pipe(
    prompt,
    num_images_per_prompt=1,
    gprompt = "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
    num_inference_steps = 100, 
    height = 512,
    width = 512,
    output_type = "pil",
    target_idx = 346,
    g_w = 7.5,
    g_p = 2,
    g_m = 0.5,
    g_s = 4,
    time_travel_sample=5,
)
image.images[0]
```

## Authors

REDACTED


