# Official Implementation of NEUBAY

## Setup

We test our codebase on L40S and A100 GPUs. Our code is written in JAX & [Equinox](https://github.com/patrick-kidger/equinox). To install dependencies:
```bash
conda create -n neubay python=3.10
conda activate neubay
conda env update -f requirements.yml
```

## Pretrained World Ensembles

To support reproducibility and simplify agent tuning, we provide pretrained world-model ensembles in an anonymous Google Drive folder:

[Download pretrained ensembles](https://drive.google.com/drive/folders/1akQsbyeKqO4z5mn9vyvroSP16lagOXo8?usp=sharing)

After downloading, place the folder under `offline_world/ckpt`. Due to storage limitations, we currently provide one ensemble checkpoint per dataset. We will release all six checkpoints in the camera-ready version.

## Commands

All the training logs will be uploaded to wandb. We recommend setting `export XLA_PYTHON_CLIENT_PREALLOCATE=false` to prevent JAX from allocating all GPU memory.

### Offline-Bandit Datasets

Offline data for the bandit is provided in `offline_world/data`. To generate your own dataset, run: `python offline_world/bandit_data.py`. 
```bash
# train your own reward ensemble; skip this if you use pretrained one
python offline_world/bandit_ensemble.py --seed 0
# train agent on the reward ensemble with the same seed
python offline_bandit.py seed=0
```

### D4RL and NeoRL Continuous Control Datasets

First download all datasets by running `python get_all_datasets.py`. 

Train the world model ensemble (skip if using pretrained checkpoints)
```bash
# d4rl locomotion
python offline_world/cont_ensemble.py --config-path=../configs/d4rl_loco --config-name=base dataset_name=halfcheetah-medium-expert-v2 ensemble.total_epochs=600 seed=0
# neorl locomotion
python offline_world/cont_ensemble.py --config-path=../configs/neorl --config-name=base dataset_name=Hopper-v3-low ensemble.total_epochs=1200 seed=0
# adroit
python offline_world/cont_ensemble.py --config-path=../configs/adroit --config-name=base dataset_name=pen-cloned-v1 ensemble.total_epochs=2400 seed=0
# antmaze
python offline_world/cont_ensemble.py --config-path=../configs/antmaze --config-name=base dataset_name=antmaze-umaze-v2 ensemble.total_epochs=1200 seed=0
```
Then train the history-dependent agent on the pretrained ensemble:
```bash
# d4rl locomotion
python offline_cont.py --config-path=configs/d4rl_loco --config-name=base dataset_name=halfcheetah-medium-expert-v2 seed=0
# neorl locomotion
python offline_cont.py --config-path=configs/neorl --config-name=base dataset_name=Hopper-v3-low seed=0
# adroit
python offline_cont.py --config-path=configs/adroit --config-name=base dataset_name=pen-cloned-v1 seed=0
# antmaze
python offline_cont.py --config-path=configs/antmaze --config-name=base dataset_name=antmaze-umaze-v2 seed=0
```
For the ablation study on Markov agent:
```bash
# d4rl locomotion
python offline_markov.py --config-path=configs/d4rl_loco --config-name=base_markov dataset_name=halfcheetah-medium-expert-v2 seed=0
# neorl locomotion
python offline_markov.py --config-path=configs/neorl --config-name=base_markov dataset_name=Hopper-v3-low seed=0
# adroit
python offline_markov.py --config-path=configs/adroit --config-name=base_markov dataset_name=pen-cloned-v1 seed=0
# antmaze
python offline_markov.py --config-path=configs/antmaze --config-name=base_markov dataset_name=antmaze-umaze-v2 seed=0
```

### Evaluation on Compounding Errors

```bash
# collect rollouts: trained on hc-random and evaluated on hc-medium-replay
python eval_error.py --config-path=configs/d4rl_loco --config-name=base dataset_name=halfcheetah-random-v2 +eval_dataset=medium-replay-v2 collect.unc_quantile=-1.0
# plot rollout stats
python plot_ln.py --train_dataset halfcheetah-random-v2 --eval_dataset medium-replay-v2
```


## Acknowledgements

- https://github.com/patrick-kidger/equinox
- https://github.com/proroklab/memoroids
- https://github.com/yihaosun1124/OfflineRL-Kit
- https://github.com/Howuhh/sac-n-jax
- https://github.com/kwanyoungpark/LEQ
