# Diffusion World Models (DWM)

## Setup


- Create a [conda](https://docs.conda.io/projects/miniconda/en/latest/miniconda-install.html) environment (or any environment)

```sh
conda create -n dwm python=3.10
conda activate dwm
```

- Install the requirements

```sh
pip install -r requirements.txt
AutoROM -y
```

- Warning: Atari ROMs will be downloaded with the dependencies, which means that you acknowledge that you have the license to use them.


## Play Breakout in a Diffusion World Model


The best way to compare world models (DWM and IRIS) is to **try and play them**! 

```sh
conda activate dwm
cd playable_wm
python src/play.py
```

Or directly watch the videos of DWM and IRIS:

- Breakout: 
    - `./videos/atari_breakout_dwm.mp4`
    - `./videos/atari_breakout_iris.mp4`
    
- Asterix
    - `./videos/atari_asterix_dwm.mp4`
    - `./videos/atari_asterix_iris.mp4`


**Commands**

```
↑/↓ : switch environment (dwm/iris/real)
←/→ : switch control (policy/human)
 ⏎  : reset environment
```

## Videos

The folder `./videos` contains video samples (slowed down to 4fps for viewability) for methods in the paper.

Main methods to compare:

**DWM**

- `csgo_dwm_framestack_T=20_gen_real.avi`
- `drive_dwm_crossattention_T=20_gen_real.avi`

**IRIS**

- `csgo_IRIS_16token_gen_real.avi`
- `drive_IRIS_16token_gen_real.avi`

**DreamerV3**

- `csgo_dreamerv3_gen_real.avi`
- `drive_dreamerv3_gen_real.avi`

To identify other methods, see the extended results in Table 3.


## Launch a training run


```bash
python src/main.py env.train.id=BreakoutNoFrameskip-v4 common.device=cuda:0
```

By default, the logs are synced to [weights & biases](https://wandb.ai), set `wandb.mode=disabled` to turn it off.


## Configuration


- All configuration files are located in `config/`, the main configuration file is `config/trainer.yaml`.
- The simplest way to customize the configuration is to edit these files directly.
- Please refer to [Hydra](https://github.com/facebookresearch/hydra) for more details regarding configuration management.


## Run folder


Each new run is located at `outputs/YYYY-MM-DD/hh-mm-ss/`. This folder is structured as:

```txt
outputs/YYYY-MM-DD/hh-mm-ss/
│
└─── checkpoints
│   │
│   └─── all
│   |   │   agent_epoch_00000.pt
│   |   │   ...
│   │   last.pt
|
└─── config
│   |   trainer.yaml
|
└─── dataset
│   │
│   └─── train
│   |   │   ...
│   │
│   └─── test
│   |   │   ...
│
└─── scripts
│   │   resume.sh
|
└─── src
|   |   ...
|
└─── wandb
    |   ...
```


## Visualize


In the run folder:

```
python src/play.py
```
