# Selective Amnesia (SA) for Stable Diffusion
This is the official repository for Selective Amnesia for Stable Diffusion. The code of this project
is modifed from the official [Stable Diffusion](https://github.com/CompVis/stable-diffusion) repository.

# Requirements 
Install requirements using a `conda` environment:
```
conda env create -f environment.yaml
conda activate sa-sd
```

## Download SD v1.4 Checkpoint
We will use the SD v1.4 checkpoint (with EMA). You can either download it from the official HuggingFace [link](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) and move it to the root directory of this project, or alternatively
```
wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original
```

# Forgetting Training with SA
We consider forgetting the celebrity Brad Pitt in the following steps.

1. Generate Samples For FIM and GR.

    The prompts for generating the dataset for FIM/GR is given in `fim_prompts.txt`. This was generated automatically from GPT3.5. To generate the samples from the prompts using SD v1.4, run
    ```
    python scripts/txt2img_fim_prompts.py --ckpt sd-v1-4-full-ema.ckpt --from-file fim_prompts.txt --outdir fim_dataset
    ```
    The images will be stored in the folder `fim_dataset`.

2. Calculate the FIM.
    ```
    python main_fim.py -t --base configs/stable-diffusion/fim.yaml --gpus "0,1,2,3" --num_nodes 1 --finetune_from sd-v1-4-full-ema.ckpt --n_chunks 20
    ```
    where the `--gpus` flag should be set to all the GPU IDs that you intend to use. `--n_chunks` can be increased if you are running out of VRAM. This should produce multiple `fisher_dict_rank_[GPU RANK].pkl` files, one for each GPU. Run
    ```
    python combine_fisher_dict.py
    ```
    to combine them into a single dictionary file `full_fisher_dict.pkl`.

3. Generate the surrogate dataset $q(x|c_f)$ represented by images of "middle aged man":
    ```
    python scripts/txt2img_make_n_samples.py --outdir q_dist/middle_aged_man_dataset --prompt "a middle aged man" --n_samples 1000
    ```

4. Forgetting Training with SA
    ```
    python main_forget.py -t --base configs/stable-diffusion/forget_brad_pitt.yaml --gpus "0,1,2,3" --num_nodes 1 --finetune_from sd-v1-4-full-ema.ckpt
    ```
    The results and checkpoint are saved in `logs`.

    ## Edit Config File
    You can edit the config files in `configs/stable-diffusion`. Parameters that you should pay attention to have accompanying comments in the config file. Here are some notable ones:
    $\lambda$ and the layers to train can be modified in lines 18 and 19
    ```
    ...
    lmbda: 50 # change FIM weighting term here
    train_method: 'full' # choices: ['full', 'xattn', 'noxattn']
    ...
    ```
    $c_f$ and the path to the surrogate distribution can be modified in lines 85 and 86
    ```
    ...
    forget_prompt: brad pitt
    forget_dataset_path: ./q_dist/middle_aged_man_dataset
    ...
    ```
    The number of training epochs can be adjusted in line 133
    ```
    ...
    max_epochs: 200
    ...
    ```

# COMING SOON: NudeNet and GIPHY Celebrity Detector