# A Large Recurrent Action Model
This repository contains the source code for **"A Large Recurrent Action Model: xLSTM Enables Fast Inference for Robotics Tasks"**.

## Overview
This codebase supports training [Decision Transformer (DT)](https://arxiv.org/abs/2106.01345) models online or from offline datasets.

This codebase relies on open-source frameworks, including: 
- [PyTorch](https://github.com/pytorch/pytorch)
- [Huggingface transformers](https://github.com/huggingface/transformers)
- [stable-baselines3](https://github.com/DLR-RM/stable-baselines3)
- [wandb](https://github.com/wandb/wandb)
- [Hydra](https://github.com/facebookresearch/hydra)

## Installation
Environment configuration and dependencies are available in `environment.yaml` and `requirements.txt`.

First, create the conda environment.
```
conda env create -f environment.yaml
conda activate lram
```

Then install the remaining requirements (with MuJoCo already downloaded, if not see [here](#MuJoCo-installation)): 
```
pip install -r requirements.txt
```

Init the `continualworld` submodule and install: 
```
cd continual_world
pip install .
```
Install `meta-world`:
```
pip install git+https://github.com/rlworkgroup/metaworld.git@18118a28c06893da0f363786696cc792457b062b
```

Install custom version of [dmc2gym](https://github.com/denisyarats/dmc2gym). Our version makes `flatten_obs` optional, 
and, thus, allows us to construct the full observation space of all DMControl envs. 
```
cd dmc2gym_custom
pip install -e .
```

### MuJoCo installation
Download MuJoCo:
```
mkdir ~/.mujoco
cd ~/.mujoco
wget https://www.roboti.us/download/mujoco200_linux.zip
unzip mujoco200_linux.zip
mv mujoco200_linux mujoco200
wget https://www.roboti.us/file/mjkey.txt
```
Then add the following line to `.bashrc`:
```
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco200/bin
```

#### Troubleshooting on cluster (without root access)
The following issues were helpful: 
- https://github.com/openai/mujoco-py/issues/96#issuecomment-678429159
- https://github.com/openai/mujoco-py/issues/627#issuecomment-1383054926
- https://github.com/openai/mujypythoco-py/issues/323#issuecomment-618365770

First, install the following packages: 
```
conda install -c conda-forge glew mesalib
conda install -c menpo glfw3 osmesa
pip install patchelf
```
Create the symlink manually: 
- https://github.com/openai/mujoco-py/issues/763#issuecomment-1519090452 
```
cp /usr/lib64/libGL.so.1 $CONDA_PREFIX/lib
ln -s $CONDA_PREFIX/lib/libGL.so.1 $CONDA_PREFIX/lib/libGL.so
```
Then do: 
```
mkdir ~/rpm
cd ~/rpm
curl -o libgcrypt11.rpm ftp://ftp.pbone.net/mirror/ftp5.gwdg.de/pub/opensuse/repositories/home:/bosconovic:/branches:/home:/elimat:/lsi/openSUSE_Leap_15.1/x86_64/libgcrypt11-1.5.4-lp151.23.29.x86_64.rpm
rpm2cpio libgcrypt11.rpm | cpio -id
```
Finally, export the path to `rpm` dir (add to `~/.bashrc`):
```
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/rpm/usr/lib64
export LDFLAGS="-L/~/rpm/usr/lib64"
```

## Setup

### Experiment configuration
This codebase relies on [Hydra](https://github.com/facebookresearch/hydra), which configures experiments via `.yaml` files. 
Hydra automatically creates the log folder structure for a given run, as specified in the respective `config.yaml` file.

The `config.yaml` is the main configuration entry point and contains the default parameters. The file references the respective default parameter files under the block
`defaults`. In addition, `config.yaml` contains 4 important constants that configure the directory paths: 
```
LOG_DIR: ../logs
DATA_DIR: ../data
SSD_DATA_DIR: ../data
MODELS_DIR: ../models
```

### Datasets
The genereated datasets will be available upon publication. 

## Running experiments
In the following, we provide illustrative examples of how to run the experiments in the paper. 

### Multi-Domain Training
To run the experiments for the 16M models on a server with 4 A100 GPUs
```
# xlstm - m 
python main.py -m +hydra/launcher=torchrun hydra.launcher.nproc_per_node=4 +ddp=True seed=0 experiment_name=16M_xlstm_m_v1 env_params=mt_dmc_procgen_atari_cs_mg agent_params=rga agent_params.kind=MDDXLSTM agent_params/model_kwargs=multi_domain_mtdmcpgcs agent_params/data_paths=mt45v2_dmc11_pg12_atari41_cs240_mg83 run_params=pretrain_icl eval_params=pretrain_rga +agent_params/replay_buffer_kwargs=multi_domain_mtdmccs +agent_params.accumulation_steps=6 agent_params/huggingface=xlstm_medium agent_params.huggingface.xlstm_config.context_length=150 agent_params.batch_size=32 +eval_params.use_valid_callback=True +agent_params.replay_buffer_kwargs.p_valid=0.025 +eval_params.n_jobs=4

# xlstm - ms 
python main.py -m +hydra/launcher=torchrun hydra.launcher.nproc_per_node=4 +ddp=True seed=0 experiment_name=16M_xlstm_ms_v1 env_params=mt_dmc_procgen_atari_cs_mg agent_params=rga agent_params.kind=MDDXLSTM agent_params/model_kwargs=multi_domain_mtdmcpgcs agent_params/data_paths=mt45v2_dmc11_pg12_atari41_cs240_mg83 run_params=pretrain_icl eval_params=pretrain_rga +agent_params/replay_buffer_kwargs=multi_domain_mtdmccs +agent_params.accumulation_steps=6 agent_params/huggingface=xlstm_medium agent_params.huggingface.xlstm_config.context_length=150 agent_params.batch_size=32 +agent_params.huggingface.xlstm_config.slstm_at='[1]' +eval_params.use_valid_callback=True +agent_params.replay_buffer_kwargs.p_valid=0.025 +eval_params.n_jobs=4 

# DT
python main.py -m +hydra/launcher=torchrun hydra.launcher.nproc_per_node=4 +ddp=True seed=0 experiment_name=16M_dt_v1 env_params=mt_dmc_procgen_atari_cs_mg agent_params=rga agent_params.kind=MDDT agent_params/model_kwargs=multi_domain_mtdmcpgcs agent_params/data_paths=mt45v2_dmc11_pg12_atari41_cs240_mg83 run_params=pretrain_icl eval_params=pretrain_rga +agent_params/replay_buffer_kwargs=multi_domain_mtdmccs +agent_params.accumulation_steps=6 agent_params/huggingface=dt_medium_64 +agent_params.model_kwargs.global_pos_embds=True agent_params.batch_size=32 +eval_params.use_valid_callback=True +agent_params.replay_buffer_kwargs.p_valid=0.025 +eval_params.n_jobs=4

# mamba
python main.py -m +hydra/launcher=torchrun hydra.launcher.nproc_per_node=4 +ddp=True seed=0 experiment_name=16M_mamba_v1 env_params=mt_dmc_procgen_atari_cs_mg agent_params=rga agent_params.kind=MDDMamba agent_params/model_kwargs=multi_domain_mtdmcpgcs agent_params/data_paths=mt45v2_dmc11_pg12_atari41_cs240_mg83 run_params=pretrain_icl eval_params=pretrain_rga +agent_params/replay_buffer_kwargs=multi_domain_mtdmccs +agent_params.accumulation_steps=6 agent_params/huggingface=mamba_medium agent_params.compile=False agent_params.batch_size=32 +eval_params.use_valid_callback=True +agent_params.replay_buffer_kwargs.p_valid=0.025 +eval_params.n_jobs=4

```