# Read Me

This is the JAX implementation of the VIPO framework. Compared to the PyTorch version, it is approximately 10 times faster and includes various bug fixes.

Value Function Alignment (Vd vs Vm):

![Vd vs Vm](vs.png)

The dynamics ensemble model has been fully optimized with flax.vmap for parallelization, which is the primary factor behind the speedup.

All key functions interacting with the dynamics model have been accelerated using jit, including:

- offline replaybuffers
- common uncertainty penalizations
- dynamics.step()
- learner.rollout()
- ...

> Please follow the instructions to integrate your planner logic. A random policy is used as a placeholder by default.

## Dependencies

Mujoco:
```sh
MUJOCO_DIR=~/.mujoco
mkdir $MUJOCO_DIR
wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz -O $MUJOCO_DIR/mujoco210.tar.gz
tar -xzf $MUJOCO_DIR/mujoco210.tar.gz -C $MUJOCO_DIR
rm $MUJOCO_DIR/mujoco210.tar.gz

export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/username/.mujoco/mujoco210
```

D4RL:
```sh
pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl
```

jaxrl_m:
```sh
cd ./jaxrl_m
pip install -e .
cd ..
```

Others:
```sh
pip install jax flax distrax orbax wandb
```

## Run

```sh
python run_vipo.py
```