# TWISTED: Enhancing Transformer World Models with Spatio-Temporal Encoding and Graph-Based Optimal Decoding

## Setup

Use Python 3.10. Install the necessary packages from `requirements.txt`:
```
pip install -r requirements.txt
```

## Configuration

Configuration files are stored in the `config/` directory. You can create your own configuration file or pass additional configs as command line arguments directly to the scripts below. Available options are listed in `configs.py`.

By default, runs are logged to Weights & Biases (WANDB), which requires a WANDB login. You can disable this with `--wandb_config.enable=False`.

## Policy training

A training run can be launched with

```
python train.py --config_path={config_path} [--{additional_configs} ...]
```

Example:
```
python train.py --config_path=config/minatar_asterix_twisted.yaml --seed=1
```

### Configs
* TWISTED: `config/minatar_{game}_twisted.yaml`
* Baseline (reproduced): `config/minatar_{game}_baseline.yaml`

## Evaluation

You can evaluate the policy of a checkpoint with `evaluate_minatar.py`. Set `restore_ckpt_path` (required) and `restore_ckpt_step` (optional) to evaluate. Configuration matching the checkpoint should be passed to `evaluate_minatar.py`. You should modify the `filename` variable (L63) inside the code to save the evaluation result.

```
python evaluate_minatar.py --config_path={config_path} --restore_ckpt_path={checkpoint_path} --restore_ckpt_step={checkpoint_step}
```

## Graphs

You can generate a graph of returns with `compare_minatar_return.py`. You should set paths for the results inside the code.

Then run
```
python compare_minatar_return.py
```
