# Retrieval-augmentation for In-context RL
This repository contains the source code for **"Retrieval-augmentation for In-context RL"**.

## Overview
This codebase supports training [Decision Transformer (DT)](https://arxiv.org/abs/2106.01345) models online or from offline datasets.

This codebase relies on open-source frameworks, including: 
- [PyTorch](https://github.com/pytorch/pytorch)
- [Huggingface transformers](https://github.com/huggingface/transformers)
- [stable-baselines3](https://github.com/DLR-RM/stable-baselines3)
- [wandb](https://github.com/wandb/wandb)
- [Hydra](https://github.com/facebookresearch/hydra)

## Installation
Environment configuration and dependencies are available in `environment.yaml` and `requirements.txt`.

First, create the conda environment.
```
conda env create -f environment.yaml
conda activate mddt
```

Then install the remaining requirements (with MuJoCo already downloaded, if not see [here](#MuJoCo-installation)): 
```
pip install -r requirements.txt
```

Init the `continualworld` submodule and install: 
```
cd continual_world
pip install .
```
Install `meta-world`:
```
pip install git+https://github.com/rlworkgroup/metaworld.git@18118a28c06893da0f363786696cc792457b062b
```

Install custom version of [dmc2gym](https://github.com/denisyarats/dmc2gym). Our version makes `flatten_obs` optional, 
and, thus, allows us to construct the full observation space of all DMControl envs. 
```
cd dmc2gym_custom
pip install -e .
```

### MuJoCo installation
Download MuJoCo:
```
mkdir ~/.mujoco
cd ~/.mujoco
wget https://www.roboti.us/download/mujoco200_linux.zip
unzip mujoco200_linux.zip
mv mujoco200_linux mujoco200
wget https://www.roboti.us/file/mjkey.txt
```
Then add the following line to `.bashrc`:
```
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco200/bin
```

#### Troubleshooting on cluster (without root access)
The following issues were helpful: 
- https://github.com/openai/mujoco-py/issues/96#issuecomment-678429159
- https://github.com/openai/mujoco-py/issues/627#issuecomment-1383054926
- https://github.com/openai/mujypythoco-py/issues/323#issuecomment-618365770

First, install the following packages: 
```
conda install -c conda-forge glew mesalib
conda install -c menpo glfw3 osmesa
pip install patchelf
```
Create the symlink manually: 
- https://github.com/openai/mujoco-py/issues/763#issuecomment-1519090452 
```
cp /usr/lib64/libGL.so.1 $CONDA_PREFIX/lib
ln -s $CONDA_PREFIX/lib/libGL.so.1 $CONDA_PREFIX/lib/libGL.so
```
Then do: 
```
mkdir ~/rpm
cd ~/rpm
curl -o libgcrypt11.rpm ftp://ftp.pbone.net/mirror/ftp5.gwdg.de/pub/opensuse/repositories/home:/bosconovic:/branches:/home:/elimat:/lsi/openSUSE_Leap_15.1/x86_64/libgcrypt11-1.5.4-lp151.23.29.x86_64.rpm
rpm2cpio libgcrypt11.rpm | cpio -id
```
Finally, export the path to `rpm` dir (add to `~/.bashrc`):
```
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/rpm/usr/lib64
export LDFLAGS="-L/~/rpm/usr/lib64"
```

## Setup

### Experiment configuration
This codebase relies on [Hydra](https://github.com/facebookresearch/hydra), which configures experiments via `.yaml` files. 
Hydra automatically creates the log folder structure for a given run, as specified in the respective `config.yaml` file.

The `config.yaml` is the main configuration entry point and contains the default parameters. The file references the respective default parameter files under the block
`defaults`. In addition, `config.yaml` contains 4 important constants that configure the directory paths: 
```
LOG_DIR: ../logs
DATA_DIR: ../data
SSD_DATA_DIR: ../data
MODELS_DIR: ../models
```

### Datasets
The genereated datasets will be available upon publication. 

## Running experiments
In the following, we provide some illustrative examples of how to run the experiments in the paper. 

### Dark-Room
To run the experiments for Dark-Room 10x10 with 3 seeds, execute the following command: 
```
# RA-DT 
python main.py -m seed=42,43,44 experiment_name=darkroom10x10_radt env_params=dark_room agent_params=radt_disc_icl run_params=finetune eval_params=pretrain_icl agent_params.load_path.file_name=dt_medium_64 +agent_params.reinit_policy=True +agent_params.cache_kwargs.norm=True +agent_params.query_dropout=0.2 +agent_params.cache_kwargs.sim_cutoff=0.98 +agent_params.cache_kwargs.deduplicate=True +agent_params.cache_kwargs.top_k=50 env_params.target_return=[90,5] +agent_params.cache_kwargs.use_gpu=True 

# RA-DT - domain-agnostic
python main.py -m seed=42,43,44 experiment_name=darkroom10x10_radt_lm env_params=dark_room agent_params=radt_disc_icl run_params=finetune eval_params=pretrain_icl agent_params.load_path=null +agent_params.cache_kwargs.norm=True +agent_params.query_dropout=0.2 +agent_params.cache_kwargs.sim_cutoff=0.98 +agent_params.cache_kwargs.deduplicate=True +agent_params.cache_kwargs.top_k=50 env_params.target_return=[90,5] +agent_params.cache_kwargs.use_gpu=True +agent_params/retriever_kwargs=discrete_s_r_rtg +agent_params.retriever_kwargs.beta=10
```
Similarly, we can run experiments for other grid-sizes, e.g., for 20x20: 
```
python main.py -m seed=42,43,44 experiment_name=darkroom20x20_radt env_params=dark_room_20x20 agent_params=radt_disc_icl agent_params/data_paths=dark_room_20x20_train run_params=finetune eval_params=pretrain_icl_grids agent_params.load_path='${MODELS_DIR}/minihack/darkroom_20x20/dt_medium_64.zip' +agent_params.reinit_policy=True +agent_params.cache_kwargs.norm=True +agent_params.query_dropout=0.2 +agent_params.cache_kwargs.sim_cutoff=0.98 +agent_params.cache_kwargs.deduplicate=True +agent_params.cache_kwargs.top_k=50 env_params.target_return=[370,10] +agent_params.cache_kwargs.use_gpu=True 
```
And on mazerunner: 
```
python main.py -m seed=42,43,44 experiment_name=mazerunner_15x15_radt env_params=mazerunner agent_params=radt_disc_icl agent_params/data_paths=mazerunner15x15 run_params=finetune eval_params=pretrain_icl agent_params.load_path='${MODELS_DIR}/mazerunner_15x15/dt_medium_64.zip' +agent_params.reinit_policy=True +agent_params.cache_kwargs.norm=True +agent_params.query_dropout=0.2 +agent_params.cache_kwargs.sim_cutoff=0.98 +agent_params.cache_kwargs.deduplicate=True +agent_params.cache_kwargs.top_k=50 +agent_params.cache_kwargs.use_gpu=True +agent_params.eval_ret_steps=25
```
