# RetroBridge

## Setup

```shell
# Create an environment and load some packages with conda
conda create --name retrobridge python=3.9 libgcc-ng=9.3.0 libstdcxx-ng=9.3.0
conda activate retrobridge
conda install -c conda-forge rdkit
conda install -c conda-forge graph-tool

# Install all the necessary packages
bash setup.sh
```

## Training

* RetroBridge: `python train_retrobridge.py --config configs/retrobridge.yml`
* DiGress: `python train_retrodiff.py --config configs/digress.yml`
* ForwardBridge: `python train_forwardbridge_MIT.py --config configs/forwardbridge.yml`

## Sampling

* Samples used for Tables 1 and 2 are provided in directory `samples`.
* Trained models can be downloaded from [Zenodo](https://zenodo.org/record/8370261):
```shell
mkdir -p models
wget https://zenodo.org/record/8370261/files/retrobridge.ckpt?download=1 -O models/retrobridge.ckpt
wget https://zenodo.org/record/8370261/files/digress.ckpt?download=1 -O models/digress.ckpt
wget https://zenodo.org/record/8370261/files/forwardbridge.ckpt?download=1 -O models/forwardbridge.ckpt
```

Sampling with RetroBridge model:
```shell
python sample.py \
       --config configs/retrobridge.yml \
       --checkpoint models/retrobridge.ckpt \
       --samples samples \
       --model RetroBridge \
       --mode test \
       --n_samples 10 \
       --n_steps 500 \
       --sampling_seed 1
```

Sampling with DiGress:
```shell
python sample.py \
       --config configs/digress.yml \
       --checkpoint models/digress.ckpt \
       --samples samples \
       --model RetroDiff \
       --mode test \
       --n_samples 10 \
       --n_steps 500 \
       --sampling_seed 1
```

Sampling with ForwardBridge:
```shell
python sample_MIT.py \
       --config configs/forwardbridge.yml \
       --checkpoint models/forwardbridge.ckpt \
       --samples samples \
       --model RetroBridge \
       --mode test \
       --n_samples 10 \
       --n_steps 500 \
       --sampling_seed 1
```

## Evaluation

### Run Molecular Transformer for round-trip evaluation

Download Molecular Transformer and follow the instructions on their [GitHub page](https://github.com/pschwllr/MolecularTransformer)

To make forward predictions for all generated reactants, run:
```bash
python /src/metrics/round_trip.py --csv_file <path/to/retrobridge_csv> --csv_out <path/to/output_csv> --mol_trans_dir <path/to/MolecularTransformer_dir>
```

### Metrics

To compute the metrics reported in the paper, run the following commands in python:
```python
import numpy as np
import pandas as pd

from pathlib import Path
from src.analmetricsysis.eval_csv_helpers import canonicalize, compute_confidence, assign_groups, compute_accuracy

csv_file = Path('<path/to/output_csv>')
df = pd.read_csv(csv_file)
df = assign_groups(df, samples_per_product_per_file=10)
df.loc[(df['product'] == 'C') & (df['true'] == 'C'), 'true'] = 'Placeholder'

df_processed = compute_confidence(df)

for key in ['product', 'pred_product']:
    df_processed[key] = df_processed[key].apply(canonicalize)

compute_accuracy(df_processed, top=[1, 3, 5, 10], scoring=lambda df: np.log(df['confidence']))
```
