# SuSIE
Anonymous code for the paper Zero-Shot Robotic Manipulation With Pretrained Image-Editing Diffusion Models.

This repository contains the code for training the high-level image-editing diffusion model on video data.

- **Creating datasets**: We will add dataset instruction in our public version.
- **Installation**: `pip install -r requirements.txt` to install the versions of required packages confirmed to be working with this codebase. Then, `pip install -e .`. Only tested with Python 3.10. You'll also have to manually install Jax for your platform (see the [Jax installation instructions](https://jax.readthedocs.io/en/latest/installation.html)). Make sure you have the Jax version specified in `requirements.txt` (rather than using `--upgrade` as suggested in the Jax docs).
- **Training**: once the missing dataset paths have been filled in inside `base.py`, you can start training by running `python scripts/train.py --config configs/base.py:base`.

## Model Weights
We will release the UNet weights for our best-performing model, trained on BridgeData and Something-Something in our public version. They can be loaded using `FlaxUNet2DConditionModel.from_pretrained("MODEL_PATH", subfolder="unet")`. Use with the standard Stable Diffusion v1-5 VAE and text encoder.

Here's a quickstart for getting out-of-the-box subgoals using this repo:
```python
from susie.model import create_sample_fn
from susie.jax_utils import initialize_compilation_cache
import requests
import numpy as np
from PIL import Image

initialize_compilation_cache()

IMAGE_URL = "sample_image"

sample_fn = create_sample_fn("MODEL_PATH")
image = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256, 256)))
image_out = sample_fn(image, "open the drawer")

# to display the images if you're in a Jupyter notebook
display(Image.fromarray(image))
display(Image.fromarray(image_out))
```
