### Step1: Check version of CUDA and CUDNN (if use GPU)
Due to the strong dependency of JAX on CUDA and cuDNN, it is essential to ensure that the versions are compatible to run the code successfully. Before installing JAX, it is recommended to carefully check the CUDA and cuDNN versions installed on your machine. Here are some methods we provide for checking the versions:

1. Checking CUDA version:
- Use the command `nvcc --version` in the terminal to check the installed CUDA version.

2. Checking cuDNN version:
- Check the version by examining the file names or metadata in the cuDNN installation directory 'cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2'.
- Or you can also use torch to check the CUDNN version 'python3 -c 'import torch;cudnn_version = torch.backends.cudnn.version();print(f"CUDNN Version: {cudnn_version}");print(torch.version.cuda)'

It is crucial to ensure that the installed CUDA and cuDNN versions are compatible with the specific version of JAX you intend to install.
### Step2: Install jax
Here is some subjections for install jax, the new manipulation should be found in [jax](https://github.com/google/jax) documentation. we tested our code in the 0.4.26 version of jax.

### Step3: Install Other Dependencies
```sh
pip install -r requirements.txt
```

### Step4: Install Safetygymnasium
```sh
git clone https://github.com/PKU-Alignment/safety-gymnasium.git
cd safety-gymnasium
pip install -e .
cd ..
```

### Step5: Data
Find the dataset in '.../FOSP/data/'
Download the dataset on google driven (which use be released after camera-ready)

```sh
# FOSP-online fine-tuning:
python FOSP/train.py --configs fosp --method fosp --run.script train_eval_online --run.from_checkpoint /xxx/checkpoint.ckpt  --task safetygym_SafetyPointGoal1-v0 --jax.logical_gpus 0 --run.steps 10000

# SafeDreamer-online fine-tuning:
python  FOSP/train.py --configs safedreamer --method safedreamer --run.script train_eval_online_direct --run.from_checkpoint /xxx/checkpoint.ckpt --task safetygym_SafetyPointGoal1-v0 --jax.logical_gpus 0 --run.steps 10000

```

```sh
# FOSP-offline:
python FOSP/train.py --configs fosp --method fosp --task safetygym_SafetyPointGoal1-v0 --jax.logical_gpus 0

# SafeDreamer-offline:
python FOSP/train.py --configs safedreamer --method safedreamer --task safetygym_SafetyPointGoal1-v0 --jax.logical_gpus 0

```

## Tips

- All configuration options are documented in `configs.yaml`, and you have the ability to override them through the command line.
- If you encounter CUDA errors, it is recommended to scroll up through the error messages, as the root cause is often an issue that occurred earlier, such as running out of memory or having incompatible versions of JAX and CUDA.
- To customize the GPU memory requirement, you can modify the `os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']` variable in the `jaxagent.py`. This allows you to adjust the memory allocation according to your specific needs.


## 📄 License
FOSP is released under Apache License 2.0.



## 👏 Acknowledgements
- [SafeDreamer](https://github.com/PKU-Alignment/SafeDreamer): Our codebase is built upon SafeDreamer.
