# Minimal Jax RL Implementation

Based on [dibyaghosh/jaxrl_minimal](https://github.com/dibyaghosh/jaxrl_minimal).

## Environment
```
conda create -n jaxrl python=3.10
conda activate jaxrl
pip install -e .
pip install -r requirements.txt
```
For GPU:
```
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

For TPU
```
pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
See the [Jax Github page](https://github.com/google/jax) for more details on installing Jax.

## Contributing
Experimental things and training/eval scripts should go in `experiments/`. To make any changes to files outside of `experiments/`, please open a pull request.

To enable code checks and auto-formatting, please install pre-commit hooks (run this in the root directory):
```
pre-commit install
```
The hooks should now run before every commit. If files are modified during the checks, you'll need to re-stage them and commit again.
