# CompoDiff: Versatile Composed Image Retrieval With Latent Diffusion

## Search and Image generation demo
We have set up a demo that can be tested in a local computing environment.

It can be executed with the following command:

```bash
$ cd compodiff
$ python demo_search.py
```

Demo will be hosted at https://0.0.0.0:8000

The unCLIP model used for image generation is from https://huggingface.co/kakaobrain/karlo-v1-alpha-image-variations.

### How to use demo
#### Usage 1. Project textual embeddings to visual embeddings
<img src=".github/example_t2i.gif" height="400">

#### Usage 2. Composed visual embeddings without mask for CIR
<img src=".github/example_cir.gif" height="400">

#### Usage 3. Composed visual embeddings with mask for CIR
<img src=".github/example_cir_with_mask.gif" height="400">

## 💡 Usage

### Build CompoDiff and CLIP models
```python
import compodiff
import torch
from PIL import Image
import requests

device = "cuda" if torch.cuda.is_available() else "cpu"

# build models
compodiff_model, clip_model, img_preprocess, tokenizer = compodiff.build_model()

compodiff_model, clip_model = compodiff_model.to(device), clip_model.to(device)

if device != 'cpu':
    clip_model = clip_model.half()
```

### Usage 1. Project textual embeddings to visual embeddings
```python
cfg_image_scale = 0.0
cfg_text_scale = 7.5

cfg_scale = (cfg_image_scale, cfg_text_scale)

input_text = "owl carved on the wooden wall"
negative_text = "low quality"

# tokenize the input_text first.
text_token_dict = tokenizer(text=input_text, return_tensors='pt', padding='max_length', truncation=True)
text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)

negative_text_token_dict = tokenizer(text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)

with torch.no_grad():
    # In the case of Usage 1, we do not use an image cond and a mask at all.
    image_cond = torch.zeros([1,1,768]).to(device)
    mask = torch.zeros([64, 64]).to(device).unsqueeze(0)
    
    text_cond = clip_model.encode_texts(text_tokens, text_attention_mask)
    negative_text_cond = clip_model.encode_texts(negative_text_tokens, negative_text_attention_mask)
    
    # do denoising steps here
    timesteps = 10
    sampled_image_features = compodiff_model.sample(image_cond, text_cond, negative_text_cond, mask, timesteps=10, cond_scale=cfg_scale, num_samples_per_batch=2)
    # NOTE: "sampled_image_features" is not L2-normalized
```

### Usage 2. Composed visual embeddings without mask for CIR
```python
cfg_image_scale = 1.5
cfg_text_scale = 7.5

cfg_scale = (cfg_image_scale, cfg_text_scale)

input_text = "as pencil sketch"
negative_text = "low quality"

# tokenize the input_text first.
text_token_dict = tokenizer(text=input_text, return_tensors='pt', padding='max_length', truncation=True)
text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)

negative_text_token_dict = tokenizer(text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)

# prepare a reference image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw).resize((512, 512))

with torch.no_grad():
    processed_image = img_preprocess(image, return_tensors='pt')['pixel_values'].to(device)
    
    # In the case of Usage 2, we do not use a mask at all.
    mask = torch.zeros([64, 64]).to(device).unsqueeze(0)
    
    image_cond = clip_model.encode_images(processed_image)
    
    text_cond = clip_model.encode_texts(text_tokens, text_attention_mask)
    negative_text_cond = clip_model.encode_texts(negative_text_tokens, negative_text_attention_mask)
    
    timesteps = 10
    sampled_image_features = compodiff_model.sample(image_cond, text_cond, negative_text_cond, mask, timesteps=timesteps, cond_scale=cfg_scale, num_samples_per_batch=2)
    
    # NOTE: If you want to apply more of the original image’s context, increase the source weight in the Advanced options from 0.1. This will convey the context of the original image as a strong signal.
    source_weight = 0.1
    sampled_image_features = (1 - source_weight) * sampled_image_features + source_weight * image_cond[0]
```

### Usage 3. Composed visual embeddings with mask for CIR
```python
cfg_image_scale = 1.5
cfg_text_scale = 7.5

cfg_scale = (cfg_image_scale, cfg_text_scale)

input_text = "as pencil sketch"
negative_text = "low quality"

# tokenize the input_text first.
text_token_dict = tokenizer(text=input_text, return_tensors='pt', padding='max_length', truncation=True)
text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)

negative_text_token_dict = tokenizer(text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)

# prepare a reference image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw).resize((512, 512))

# prepare a mask image
url = "mask_url"
mask = Image.open(requests.get(url, stream=True).raw).resize((512, 512))

with torch.no_grad():
    processed_image = img_preprocess(image, return_tensors='pt')['pixel_values'].to(device)
    processed_mask = img_preprocess(mask, do_normalize=False, return_tensors='pt')['pixel_values'].to(device)
    processed_mask = processed_mask[:,:1,:,:]
    
    masked_processed_image = processed_image * (1 - (processed_mask > 0.5).float())
    mask = transforms.Resize([64, 64])(mask)[:,0,:,:]
    mask = (mask > 0.5).float()
    
    image_cond = clip_model.encode_images(masked_processed_image)
    
    text_cond = clip_model.encode_texts(text_tokens, text_attention_mask)
    negative_text_cond = clip_model.encode_texts(negative_text_tokens, negative_text_attention_mask)
    
    timesteps = 10
    sampled_image_features = compodiff_model.sample(image_cond, text_cond, negative_text_cond, mask, timesteps=timesteps, cond_scale=cfg_scale, num_samples_per_batch=2)
    
    # NOTE: If you want to apply more of the original image’s context, increase the source weight in the Advanced options from 0.1. This will convey the context of the original image as a strong signal.
    source_weight = 0.05
    sampled_image_features = (1 - source_weight) * sampled_image_features + source_weight * image_cond[0]
```

### Shotout
K-NN index for the retrieval results are entirely trained using the entire Laion-5B imageset. For this retrieval you do not need to download any images, this is made possible thanks to the great work of [rom1504](https://github.com/rom1504/clip-retrieval).
