# RiskZero

# Docker
If on windows, make sure you have WSL2 installed and create a folder inside WSL (Ubuntu) where this repo is cloned. Then, follow these instructions to setup the container:
1. Install docker using the [instructions here](https://docs.docker.com/engine/install/ubuntu/). Make sure you follow the post-installation steps for Linux linked at the bottom, if on Linux.
2. Then download the nvidia container toolkit using these [instructions](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). Verify you can run the sample workload linked at the bottom.
3. Enter WSL if using Windows.
4. In the docker directory, run `chmod +x build.sh && ./build.sh` to build the image.
5. Download the dev containers extension. Press `ctrl+P` and select `> Rebuild and reopen in dev-container`.
6. This will open the container, which has the JAX environment setup. 

# Mini Grid
To run the mini grid experiments, call `python3 -m src.experiment.grid`

# Space Invaders
To run the space invaders experiments, call `python3 -m src.experiment.space_invader`

# Stochastic Bipartite Matching
To run the stochastic bipartite matching experiments, first generate a dataset: `python3 -m datasets.stochastic_bm.create_stochastic_er_dataset`
Then run experiments against it: `python3 -m src.experiment.sbm`

# Stochastic Maximum Independent Set
To run the stochastic maximum independent set experiments, first generate a dataset: `python3 -m datasets.stochastic_mis.create_stochastic_er_dataset`
Then run experiments against it: `python3 -m src.experiment.smis`

Note: The graph experiments require patching the `pgx` library to handle batches of problem instances. 
Follow any `from pgx import core` to the `core.py` file and add the following method to the `Env` class:
```python
def init_v2(
        self,
        key: PRNGKey,
        iteration: jnp.ndarray,
        offset: jnp.ndarray,
        num_envs: int,
        split: int,
    ) -> State:
        state = self._init(key, iteration, offset, num_envs, split)
        observation = self.observe(state)
        return state.replace(observation=observation)
```