# WorldGPT

## Overview

We propose **WorldGPT**, a versatile world model capable of freely predicting state transitions across modalities, from any given modality combination to any required modality combination. WorldGPT is trained to harness its inherent textual knowledge and integrate multimodal knowledge through watching millions of internet-sourced videos. We apply a novel **progressive state transition training** methodology, where the train target is evolved from single modality to multiple modalities, and unimodality to cross-modality.

![images](figs/3_train.svg)

## Cognitive Architecture

We introduce a novel cognitive architecture tailored for world models. This framework contains three parts: a knowledge retrieval system which provides external knowledge for special scenarios, an working memory mechanism which manages the history predictions, and a novel **ContextReflector** which efficiently extracts grounded infomation from the retrieved context (i.e., external knowledge and memory). To enable WorldGPT cooperate with the cognitive architecture, we construct high-quality sequential samples and retrieval-augmented samples to teach WorldGPT to utilize information from retrieved context through the **cognitive-augmented tuning** process.

![images](figs/2_model.svg)

## Getting Started

### Installation

**1. Prepare Environment**

First clone our repository, and create a python environment via the following commands:

```
git clone
cd WorldGPT
conda env create -f environment.yaml
conda activate worldgpt
```

**2. Prepare Pretrained Weights**

WorldGPT is based on following existing models. Please download the corresponding weights following the instructions:

* `Vicuna`:  WorldGPT employs `Vicuna V0 7B` as the language decoder. Prepare the full pretrained weights following the [official instructions](https://huggingface.co/lmsys/vicuna-7b-delta-v0). Then, set the variable *vicuna_path* in the [base config](config/base.yaml#L8) at Line 8.
* `LanguageBind` is the unified image/video/audio encoder. The model version for each modality is listed as follows:

  | image                                                                     | video                                                                                     | audio                                                                           |
  | ------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------- |
  | [LanguageBind_Image](https://huggingface.co/LanguageBind/LanguageBind_Image) | [LanguageBind_Video_V1.5_FT](https://huggingface.co/LanguageBind/LanguageBind_Video_V1.5_FT) | [LanguageBind_Audio_FT](https://huggingface.co/LanguageBind/LanguageBind_Audio_FT) |

  The weights are automatically downloaded by default. For manual downloaded weights or customed model versions, set the variable *languagebind_path* in the [base config](config/base.yaml#L13) at Line 13 - 16.
* `Diffusion Models` are used to generate image/video/audio outputs (if generation is enabled). The model for each modality is listed as follows:

  | image                                                                       | video                                                                | audio                                                        |
  | --------------------------------------------------------------------------- | -------------------------------------------------------------------- | ------------------------------------------------------------ |
  | [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) | [zeroscope_v2_576w](https://huggingface.co/cerspense/zeroscope_v2_576w) | [audioldm-l-full](https://huggingface.co/cvssp/audioldm-l-full) |

  The weights are automatically downloaded by default. For manual downloaded weights or customed model versions, set the variable *decoder_path* in the [base config](config/base.yaml#L9) at Line 9 - 12.

**3. Prepare WorldGPT Checkpoints**

Choose a pretrained checkpoint from the versions below:

| version                          | link                                                                                             |
| -------------------------------- | ------------------------------------------------------------------------------------------------ |
| worldgpt-languagebind-image      | [download](https://drive.google.com/drive/folders/1evhqNfhndxXRN0vx5xcbHkoesRptRLO-?usp=drive_link) |
| worldgpt-languagebind-multimodal | [download](https://drive.google.com/drive/folders/1T-Wp_BIwUjhqahVS1e0FCitsOnudLLTc?usp=drive_link) |
| worldgpt-decode-image            | [download](https://drive.google.com/drive/folders/13C8NZzW0FZmcguHZzAAw94k_3BRmzIdq?usp=drive_link) |
| worldgpt-decode-multimodal       | [download](https://drive.google.com/drive/folders/16Q5j4b7Ssj1t_H9B4T-XeURSYaECujgJ?usp=drive_link) |

Note that WorldGPT uses worldgpt-languagebind checkpoints by default, which output LanguageBind embeddings. Though we provide worldgpt-decode checkpoints for visualization, they do not accuately reflect the true model capabilities.

### Training

**1. Prepare Datasets**

1.1. WorldNet

We collect state transition datasets from various source and construct a comprehensive dataset named **WorldNet**. WorldNet consists of two subsets: WorldNet-Wild and WorldNet-Crafted. Each dataset is further split into subsets by the data source. The available modalities of each subset are listed as follows. The WorldNet-Crafted subset can be downloaded [here](https://drive.google.com/drive/folders/1hntZ8Q4GQg5Esq2q5EBStvAEs_tqVxW_?usp=drive_link).

<table>
	<tr>
		<th>Subset</th>
		<th>Source</th>
		<th>Modality</th>
	</tr>
	<tr>
		<td rowspan="2">WorldNet-Wild</td>
		<td>YT-Temporal-180M</td>
		<td>image, video,audio</td>
	</tr>
	<tr>
		<td>HowTo100M</td>
		<td>image, video, audio</td>
	</tr>
	<tr>
		<td rowspan="5">WorldNet-Crafted</td>
		<td>Charades</td>
		<td>image, video, audio</td>
	</tr>
	<tr>
		<td>AVQA</td>
		<td>image, video, audio</td>
	</tr>
	<tr>
		<td>Ego4D</td>
		<td>image, video, audio</td>
	</tr>
	<tr>
		<td>Something-Something V2</td>
		<td>image</td>
	</tr>
	<tr>
		<td>YouCook2</td>
		<td>image, video</td>
	</tr>
</table>

The downloaded dataset contains only data of original modalities. For faster training, please precompute the LanguageBind embeddings via the command:

```
python preprocess.py --data_root path/to/subset/modality/Train --modality image/video/audio --languagebind_path path/to/languagebind/weights
```

With the data prepared, organize each subset as the structure below:

```
└── ag/avqa/ego4d/s2s/youcook
    └── image/video/audio
        ├── Train
        └── Train_pt
```

In the directory of each modaltity, `Train` contains raw data, and `Train_pt` contains corresponding precomputed LanguageBind embeddings.

Finally, specify the dataset path in the [training config]() via *dataset_list* variable. Here is an example:

```
dataset_list:
  -
    root: worldnet/ag                                    # root of subset
    annotaion_path: worldnet/ag/state/action_train.json  # path to annotations
    modality: ['image', 'video', 'audio']                # available modalities of subset
    weight: 0.2                                          # possibility to be chosen in training
  ...
```

Additionally, set variable *precomputed_languagebind* to `True` if precomputed LanguageBind embeddings are available.

1.2. Custom Dataset

For training on custom dataset, first convert the annotas into the WorldNet format. Here is an example:

```
[
  {
    "state0": {
      "video": "00001_state0.mp4",
      "audio": "00001_state0.wav",
      "image": "00001_state0.jpg",
      "text": ""
    },
    "state1": {
      "video": "00001_state1.mp4",
      "audio": "00001_state1.wav",
      "image": "00001_state1.jpg",
      "text": ""
    },
    "action": {
      "text": "adjust the mold"
    },
  },
  ...
]
```

Only `.jpg` for image, `.mp4` for video, `.wav` for audio are valid data formats. Use an empty string to mark a missing modality. Then, you can follow the same steps as WorldNet to run on your custom dataset.

**2. Start Training**

Before starting training, first check the [base config](config/base.yaml), [training config](config/train.yaml) and [deepspeed config](config/ds_base.json) for detailed settings.

For quick starting, run the script:

```
bash scripts/train.sh
```

Specifying the command:

```
deepspeed --master_addr 127.0.0.1 --master_port 28459 train.py \
    --cfg_path config/train.yaml \
    --save_path ckpt/worldgpt \
    --load_path /path/to/worldgpt-languagebind-ckpt
```

### Evaluation

**1. LanguageBind Evaluation**

Before evaluation, check the [evaluation config](config/validate.yaml) for evaluation settings.

For quick starting, run the script:

```
bash scripts/validate.sh
```

Specifying the command:

```
deepspeed --master_addr 127.0.0.1 --master_port 28459 train.py \
    --eval_only --cfg_path config/validate.yaml \
    --log_path log/validate \
    --load_path /path/to/worldgpt-languagebind-checkpoint
```

**2. Visualized Inference**

We provide worldgpt-decode checkpoints for visualized inference. See [inference.ipynb](inference.ipynb) for details.
