# GLiBRL
Generalised Linear Models in Bayesian RL. We use the template provided by [MetaWorld+](https://github.com/rainx0r/metaworld-algorithms). We provide a JAX implementation of VariBAD as well.

## Installation

### From a clone of the repository

0. Install [uv](https://docs.astral.sh/uv/)
1. Create a virtual environment: `uv venv .venv --python 3.12`
2. Activate the virtual environment: `source .venv/bin/activate`
3. Install the dependencies: `uv sync` 
3. Install jax-cuda: `uv pip install -e ".[cuda12]"`

## Reproducing the results

Seeds can be set by an extra --seed argument. We enable float64 in GLiBRL for better numerical stability. 

- GLiBRL: `JAX_ENABLE_X64=1 python examples/meta_learning/glibrl_ml10.py` 
- VariBAD: `python examples/meta_learning/varibad_ml10.py`
- MAML: `python examples/meta_learning/maml_trpo_ml10.py`
- RL2: `python examples/meta_learning/rl2_ml10.py`

For ML45 experiments, just replace ml10 by ml45. GLiBRL without noise inference can be run by setting the property `full_bayesian` to `False`, in `glibrl_ml10.py` or `glibrl_ml45.py`
