# Reinforced sequential Monte Carlo for amortised sampling

This repository contains the code for the paper "Reinforced sequential Monte Carlo for amortised sampling".

## Installation
- python 3.10
- jax 0.6.2

We recommend using the conda (or mamba) environment to install the dependencies.
```bash
conda create -n rsmc python=3.10
conda activate rsmc
```

Install tensorflow first since it sometimes causes conflicts with other packages.
```bash
pip install tensorflow==2.16.1
```

Install the jax and jaxlib with the appropriate CUDA version or TPU support, e.g., cuda12
```bash
pip install -U "jax[cuda12]==0.6.2"
```

Install the other dependencies.
```bash
pip install -r requirements.txt
```


## Usage

Here we mainly focus on the GFlowNet-based algorithms. 

Basic usage:
```bash
python run.py algorithm=<algorithm_name> target=<target_name>
```

`<algorithm_name>` can be one of the following:
- `gfn_tb` (for TB or LV loss with importance-weighted buffer (IW-Buf; section 3.3))
- `gfn_subtb_smc` (for TB/SubTB combined loss with IW-Buf and sequential Monte Carlo (SMC; section 3.2))
- `dds` (for DDS baseline)
- `pis` (for PIS baseline)
- `smc_mh` (for SMC-RWM baseline)
- `smc` (for SMC-HMC baseline)

For CMCD and SCLD baselines, please refer to the [repository of SCLD](https://github.com/anonymous3141/SCLD).

`target_name` can be one of the following:
- Gradient-free setting
  - `gaussian_mixture40`
  - `gaussian_mixture40_5d`
  - `funnel`
  - `many_well`
- Gradient-based setting
  - `funnel_lp`
  - `planar_robot_4goals`
  - `gaussian_mixture40_50d`
  - `student_t_mixture_50d`
  - `many_well_64d`

Please refer to our paper for more details on the algorithms and targets.

Full run scripts will be uploaded upon the acceptance of the paper.
