Metadata-Version: 2.1
Name: sde4mbrl
Version: 1.0.0
Summary: A framework for learning to control via uncertainty-aware neural stochastic differential equations representation
Home-page: UNKNOWN
License: GNU 3.0
Platform: UNKNOWN
Requires-Python: >=3.7
License-File: LICENSE

# sde4mbrl
A framework for learning to control via uncertainty-aware neural stochastic differential equation's representations of the dynamics.

## Installation

By installing the package as below, you also install few of the dependencies such as Jax, dm-haiku, optax, diffrax, mbrl, etc...
```
python -m pip install .
```

If JAX needs to run on the GPU, please follow the instructions on [Jax website](https://github.com/google/jax)
```
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

If additional packages are missing, please use PIP to install them.
This code has been tested with Python 3.8.5, Ubuntu 20.04.2 LTS, CUDA 11.2, cuDNN 8.1.1, and JAX 0.4.8.

## Reproducing the results of the paper

### Section 3.2: Using Neural SDEs to Represent Distance-Aware Estimates of Uncertainty

First, generate the dataset by running the following command:
```


