# sde4mbrl
A framework for learning to control via uncertainty-aware neural stochastic differential equation's representations of the dynamics.

## Installation

By installing the package as below, you also install few of the dependencies such as Jax, dm-haiku, optax, diffrax, mbrl, etc...
```
python -m pip install .
```

If JAX needs to run on the GPU, please follow the instructions on [Jax website](https://github.com/google/jax)
```
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

If additional packages are missing, please use PIP to install them.
This code has been tested with Python 3.8.5, Ubuntu 20.04.2 LTS, CUDA 11.2, cuDNN 8.1.1, and JAX 0.4.8.

## Reproducing the results of the paper

### Section 3.2: Using Neural SDEs to Represent Distance-Aware Estimates of Uncertainty

First, generate the dataset by running the following command:
```
cd sde4mbrlExamples/mass_spring_damper/
XLA_PYTHON_CLIENT_MEM_FRACTION=0.5 python density_script.py --fun gen_traj
```
The dataset are going to be stored in my_data and you will be seeing 3 datasets of trajectories.

Then, we train the density term shown in Figure 3.
```
# Change line 58 of config_density_dataset.py to specify the datset for which the diffusion term should be trained
# For examples, to train with 3 trajectories, use trajId: 'DensityExp_3' then run the python code and stop when convergence
XLA_PYTHON_CLIENT_MEM_FRACTION=0.5 python density_script.py --fun train
```

When trained the diffusion term for both 'DensityExp_6' and 'DensityExp_25', we can display the plot in the paper by running
```
JAX_PLATFORM_NAME=cpu python density_script.py --fun plot --learned_dir my_models/ --data_dir my_data/
```
You can modify what to plot or not using lines 329 to 332 of config_density_dataset.py


### Section 4.1: Spring Mass Damper.
Use mass_spring_damper.yaml to configure the model you want to train. 
1. Disable the diffusion term (lines 33 to 53) and set the prior noise (line 21) to zero for training the neural ODE.
2. Do not change anything in mass_spring_damper.yaml for the neural SDE.

First generate the dataset used for our experiments
```
python mass_spring_damper.py --fun gen_traj
```

Then, modify the mass_spring_damper.yaml to train the neural ODE or neural SDE.
```
XLA_PYTHON_CLIENT_MEM_FRACTION=0.5 python mass_spring_model.py --fun train --model_dir mass_spring_damper.yaml --model_type nesde_bboxes --data MSD_MeasLow_TopRight_500
```
where the last parameter is the dataset you want to train on. The dataset are going to be stored in my_data.

For the Gaussian ensemble,  check train_gaussian_mlp_en.py

For plotting, the configuration file is given in config_pred_analysis.py to replicate the paper result

### Section 4.2: Cartpole
Use cartpole.yaml to configure the model you want to train.

## Section 4.3: Hexacopter
see rotor_uav/hexa_ahg for the real world experiments
