# ReLax
Reinforcement Learning for research that does not compromise on performance, with JAX.

## Installation

```bash
# Create environemnt
mamba create -n relax python=3.11 'numpy<2' tqdm tensorboardX matplotlib scikit-learn black snakeviz ipykernel ipdb setproctitle numba
conda activate relax
# One of: Install jax WITH CUDA
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Or: Install jax WITHOUT CUDA
pip install --upgrade "jax[cpu]"
# Install package
pip install -r requirements.txt
pip install -e .
```

To install `mujoco_py`, add
```python
# cython: language_level=3, legacy_implicit_noexcept=True
```
to the top of `mujoco_py/cymj.pyx`

Optional: Install safety-gym
```bash
pip install glfw
pip install mujoco-py==2.0.2.13 --no-cache-dir --no-binary :all: --no-build-isolation
pip install -e .  # in safety-gym
```

## Installation for ONNX export

```bash
# For ONNX
mamba create -n relax-onnx python=3.11 numpy tqdm tensorboardX matplotlib scikit-learn black snakeviz ipykernel ipdb tensorflow-cpu tf2onnx onnx
conda activate relax-onnx
pip install --upgrade "jax[cpu]"
# Install package
pip install -r requirements.txt
pip install -e .
```

## Environemnt variables
- JAX_PLATFORMS=cpu
- taskset -c 0,2,4,6 / numactl -C +0,1,2,3
- XLA_PYTHON_CLIENT_MEM_FRACTION=.1
- XLA_FLAGS='--xla_gpu_deterministic_ops=true' # for reproducibility (#13672)

