# ClimateGAN: Raising Climate Change Awareness by Generating Images of Floods

ICLR Paper ID: **Paper4226**

This document provides context and guidance to run a dummy training loop on a small sample of our data.

## Setup

Best tested with Python 3.8 on Linux and an NVIDIA GPU with RAM >=32 GB

1. `$ pip install -r requirements-3.8.txt`

2. Best used setting a [comet.ml](https://comet.ml) account and follow subsequent comet instructions (used to log losses, metrics, inferences)
    * OR add `args.no_comet=True` to the training command-line (no log whatsoever except a few prints)

## Data

Sample data is provided in `data/` and code is adapted to use it. Only 4 samples are provided as a proof of concept for the code's training and validation loops.

Samples were randomly selected (for transparency, selection code is shared in `make_iclr_dataset.py`)

* modes: `train` or `val`
* domains
  * Masker `r` for real, `s` for simulated
  * Painter `rf` for real and flooded
* tasks
  * `x` -> input
  * Masker
    * `s` -> segmentation map as decoded pytorch tensors
    * `d` -> depth map as image
    * `m` -> flood/ground mask as binary mask
  * Painter
    * `m` flood/ground mask as binary mask

(*N.B. some masks may appear blank because they are binary but encoded as uint8 arrays therefore areas with pixel values of 1 appear almost as dark as areas with pixel value 0*)

Example `json` data file:

```json
[
  {
    "x": "data/train/r/x/JR_503.png", 
    "m": "data/train/r/m/JR_503.png", 
    "d": "data/train/r/d/JR_503.png", 
    "s": "data/train/r/s/JR_503.pt"
  }, 
  {
    "x": "data/train/r/x/gsv_001896.jpg",
    "m": "data/train/r/m/gsv_001896.png",
    "s": "data/train/r/s/gsv_001896.pt",
    "d": "data/train/r/d/gsv_001896.png"
  }
  ...
]
```

## Run code

```bash
# train the Masker
$ python train.py tasks="['m', 's', 'd']"

# train the Painter
$ python train.py tasks="['p']"
```

or

```bash
# don't use comet logging
$ python train.py tasks=<PICK ONE> args.no_comet=True
```

## Coding conventions

* Tasks
  * `x` is an input image, in [-1, 1]
  * `s` is a segmentation target with `long` classes
  * `d` is a depth map target in R encoded as `1/depth`
  * `m` is a binary mask with `m > 0` indicates where water is/should be
* Domains
  * `r` is the *real* domain for the masker. Input images are real pictures of urban/suburban/rural areas
  * `s` is the *simulated* domain for the masker. Input images are taken from our Unity world
  * `rf` is the *real flooded* domain for the painter. Training images are pairs `(x, m)` of flooded scenes for which the water should be reconstructed, in the validation data input images are not flooded and we provide a manually labeled mask `m`
  
* Flow
  * This describes the call stack for the `trainer`'s standard training procedure
  * `train()`
    * `run_epoch()`
      * `update_G()`
        * `zero_grad(G)`
        * `get_G_loss()`
          * `get_masker_loss()` -> total masker loss `L_{Masker}`
            * `masker_m_loss()`  -> masking loss `L_{mask}`
            * `masker_s_loss()`  -> segmentation loss `L_{seg}`
            * `masker_d_loss()`  -> depth estimation loss `L_{depth}`
          * `get_painter_loss()` -> painter's loss `L_{Painter}`
        * `g_loss.backward()`
        * `g_opt_step()`
      * `update_D()`
        * `zero_grad(D)`
        * `get_D_loss()`
          * painter's disc losses
          * `masker_m_loss()` -> flood-mask (M) AdvEnt disc loss
          * `masker_s_loss()` -> segmentation (S) AdvEnt disc loss
        * `d_loss.backward()`
        * `d_opt_step()`
      * `update_learning_rates()` -> update learning rates according to schedules defined in `opts.gen.opt` and `opts.dis.opt`
    * `run_validation()`
      * compute val losses
      * `eval_images()` -> compute metrics
      * `log_comet_images()` -> compute and upload inferences
    * `save()`

### Generator

* **Encoder**:

  `trainer.G.encoder` Deeplabv3-based encoder
  * Code adapted from https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes

* **Decoders**:
  * `trainer.G.decoders["s"]` -> *Segmentation* -> DLV3+ architecture (ASPP + Decoder)
  * `trainer.G.decoders["d"]` -> *Depth* -> Adapted from [DADA](https://github.com/valeoai/DADA)
  * `trainer.G.decoders["m"]` -> *Mask* -> SPADE Blocks (denormalize `z` conditionally on `d` and `s`) -> Binary mask: 1 = would be under water
    * `trainer.G.mask()` predicts a mask and optionally applies `sigmoid` from an `x` input or a `z` input

* **Painter**: `trainer.G.painter` -> [GauGAN SPADE-based](https://github.com/NVlabs/SPADE)
  * input = masked image

* If `opts.gen.p.paste_original_content` the painter should only create water and not reconstruct outside the mask: the output of `paint()` is `painted * m + x * (1 - m)`

High level methods of interest:

* `trainer.G.encode()` to compute the shared latent vector `z`
* `trainer.G.mask(x=x)` or `trainer.G.mask(z=z)` to infer the mask
* `trainer.G.paint(m, x)` higher level function which takes care of masking
* `trainer.compute_flood(x)` to create a flood image from `x`, running the Masker and then the Painter under the hood


## Logging on comet

Comet.ml will look for api keys in the following order: argument to the `Experiment(api_key=...)` call, `COMET_API_KEY` environment variable, `.comet.config` file in the current working directory, `.comet.config` in the current user's home directory.

If you're not managing several comet accounts at the same time, we recommend putting `.comet.config` in your `$HOME` as:

```
[comet]
api_key=<api_key>
workspace=<workspace>
```

## Figures

While this code sample is not sufficient to reproduce the exact figures of our paper (it would require a pretrained model and the test data set), we share the exact code we used to produce Figures 5 6 and 7.

## Tests

We created tests to check for the sanity of a basic epoch loop. Unfortunately we could not make these easily runnable in the context of ICLR's review process but we still provide the script if the reader wants to have a look.

## Example Script

Inference flow after a model has been trained

```python
from pathlib import Path
from skimage.io import imsave
from tqdm import tqdm

from climategan.trainer import Trainer
from climategan.utils import find_images
from climategan.tutils import tensor_ims_to_np_uint8s
from climategan.transforms import PrepareInference


model_path = "some/path/to/output/folder" # not .ckpt
input_folder = "path/to/a/folder/with/images"
output_path = "path/where/images/will/be/written"

# resume trainer
trainer = Trainer.resume_from_path(model_path, new_exp=None, inference=True)

# find paths for all images in the input folder. There is a recursive option. 
im_paths = sorted(find_images(input_folder), key=lambda x: x.name)

# Load images into tensors 
#   * smaller side resized to 640 - keeping aspect ratio
#   * then longer side is cropped in the center
#   * result is a 1x3x640x640 float tensor in [-1; 1]
xs = PrepareInference()(im_paths)

# send to device
xs = [x.to(trainer.device) for x in xs]

# compute flood
#   * compute mask
#   * binarize mask if bin_value > 0
#   * paint x using this mask
ys = [trainer.compute_flood(x, bin_value=0.5) for x in tqdm(xs)]

# convert 1x3x640x640 float tensors in [-1; 1] into 640x640x3 numpy arrays in [0, 255]
np_ys = [tensor_ims_to_np_uint8s(y) for y in tqdm(ys)]

# write images
for i, n in tqdm(zip(im_paths, np_ys), total=len(im_paths)):
    imsave(Path(output_path) / i.name, n)
```
