<p align="center">
  <img src="syncogen.gif" width="400">
</p>

## SynCoGen
**Syn**thesizabile **CoGen**eration (**SynCoGen**) is a generative small molecule design framework comprised of simultaneous absorbing state discrete diffusion model and continuous flow matching. The discrete diffusion model is based on [Simple and Effective Masked Diffusion Language Models](https://arxiv.org/pdf/2406.07524) for synthesizable molecular generation, based on [this repo](https://github.com/kuleshov-group/mdlm) and using the SUBS parametrization. The continuous flow matching model is a straightforward implementation of conditional flow matching over molecular coordinates. We provide a simple graph transformer as a backbone.

<p align="center">
  <img src="molmdgm.png" width="400">
</p>

Rather than operate on sequences, SynCoGen's discrete component operates on graphs. Masking is performed on both node (B x N x D) and edge (B x N x N x D) matrices. ```graph_utils.py``` contains util scripts for handling RDKit ```Molecules``` and ```Reactions```.

### RGFN Chemistry
The purpose of SynCoGen is to generate molecules with valid synthesis paths. To this end, we select a subset of reactions and building blocks from a vocabulary initially defined in [RGFN](https://arxiv.org/pdf/2406.08506) and generate a dataset by sampling from the space of molecules generated combinatorially from the vocabulary. Conformers are then generated for these molecules using GFN-xTB. Upon validation end, we evaluate all sampled molecules by their validity according to RGFN's reaction rules.

### Graph Parametrization
Node and edge identities are encoded as onehot vectors. The dimensionality of the node onehot vector is simply the number of fragment types. The dimensionality of the edge onehot vector is (R x C^2) where R is the number of reaction types and C is the number of centers in the reaction. Both reaction type and center indices are encoded to define generated molecules are strictly as possible, such that a graph sampled by SynCoGen corresponds to a valid molecule with no structuralambiguity.

### Coordinate Parametrization
We use continuous flow matching to predict the coordinates of the atoms in the molecule. Ground-truth and predicted coordinates are given by a (B x N x MAX_ATOMS x 3) tensor, where MAX_ATOMS is the maximum number of atoms in any of the fragments in the vocabulary. During loss calculation, we only consider the atoms in the molecule, and ignore the dummy atoms used to pad the fragments as well as the atoms that are dropped during reactions. During sampling, we first reassemble the fragments into an RDKit molecule using the predicted graph, and then use the number of atoms attributed to each fragment to determine a final molecule mask. The remaining coordinates are assigned to the molecule and an SDF file is generated.

### Data Generation
As RDKit is not capable of all-atom chemical reaction product attribution, we define a simple graph sampler to generate fragment graphs that accurately track bonding between fragments based on reaction rules. To sample, use ```python ./sample_rgfn_graphs.py -n [num_molecules] -l [num_fragments_per_mol] --seed 0 --save_conformers --xtb```. This implementation assumes you have a working installation of GFN-xTB. 

The full dataset is available [here](https://tyers.s3.us-west-1.amazonaws.com/all_steps_clean.tgz). For SynSpace, download all_steps_clean.tgz. For pharmacophores, download pharmacophores.lmdb. An LMDB version of SynSpace conformers will be uploaded soon.

### Directory Structure
```
SynCoGen/
    `configs/` Hydra config files
    `data/` Datasets
    `src/` Source code
    `vocabulary/` Reaction and building block vocabularies
    `main.py` Run training or evaluation
    `sample_rgfn_graphs.py` Sample random molecules and/or generate a dataset for training
    ...
```

### Getting Started

Create the environment with conda:
```bash
module load cuda/12.4.0   # on cluster
conda env create -f requirements.yaml
conda activate syncogen
```

Create the following directories to store saved models and slurm logs:
```bash
mkdir outputs
mkdir watch_folder
```

Example training script:
If you want to run without validation (MUCH FASTER FOR DEBUGGING + OVERFIT), additionally set `trainer.limit_val_batches=0`.

### Training - Unconditional
```bash
python main.py \
    trainer.num_sanity_val_steps=0 \
    model=semla \
    backbone=semla \
    self_conditioning=True \
    rgfn.reassembly_logging=True \
    rgfn.compatibility_mask=True \
    sampling.steps=100 \
    model.length=5 \
    trainer.precision="16-mixed" \
    loader.num_workers=0 \
    loader.global_batch_size=512 \
    loader.batch_size=64 \
    loader.eval_batch_size=100 \
    sampling.num_sample_batches=1 \
    sampling.num_sample_log=100 \
    sampling.constrain_edge_sampling=True \
    trainer.check_val_every_n_epoch=1 \
    data.cache_dir="<PATH/TO/DATA>" \
    spatial.sample_conformer=True \
    spatial.align=False \
    spatial.center=True \
    spatial.rotate=True \
    spatial.translate=False \
    spatial.prior="gaussian" \
    spatial.normalize=True \
    spatial.equivariant_ot=True \
    spatial.pairwise_threshold=5.0 \
    spatial.bond_loss="[pairwise_distance, bond_length, smooth_lddt]" \
    spatial.bond_length_coef=0.2 \
    spatial.pairwise_distance_coef=0.4 \
    spatial.smooth_lddt_coef=0.4 \
    spatial.square_bond_loss=True \
    spatial.bond_loss_time_threshold=0.25 \
    spatial.scale_noise=True \
    spatial.scale_noise_factor=0.2 \
    true_edge_weight=1.0 \
    paths.root_dir="<PATH/TO/SynCoGen>" 
```

### Training - Pharmacophore Conditioning
```bash
python main.py \
    trainer.num_sanity_val_steps=0 \
    model=semla \
    backbone=semla \
    self_conditioning=True \
    rgfn.reassembly_logging=True \
    rgfn.compatibility_mask=True \
    sampling.steps=100 \
    model.length=5 \
    trainer.precision="16-mixed" \
    loader.num_workers=0 \
    loader.global_batch_size=512 \
    loader.batch_size=64 \
    loader.eval_batch_size=100 \
    sampling.num_sample_batches=1 \
    sampling.num_sample_log=100 \
    sampling.constrain_edge_sampling=True \
    trainer.check_val_every_n_epoch=1 \
    data.cache_dir="<PATH/TO/DATA>" \
    spatial.sample_conformer=True \
    spatial.align=False \
    spatial.center=True \
    spatial.rotate=True \
    spatial.translate=False \
    spatial.prior="gaussian" \
    spatial.normalize=True \
    spatial.equivariant_ot=True \
    spatial.pairwise_threshold=5.0 \
    spatial.bond_loss="[pairwise_distance, bond_length, smooth_lddt]" \
    spatial.bond_length_coef=0.2 \
    spatial.pairwise_distance_coef=0.4 \
    spatial.smooth_lddt_coef=0.4 \
    spatial.square_bond_loss=True \
    spatial.bond_loss_time_threshold=0.25 \
    spatial.scale_noise=True \
    spatial.scale_noise_factor=0.2 \
    true_edge_weight=1.0 \
    paths.root_dir="<PATH/TO/SynCoGen>" \
    spatial.pharmacophore_conditioning=True \
    spatial.pharmacophore_subset=7
```

### Sampling - Unconditional
Example sampling script:
```bash
python main.py \
    mode=sample_eval \
    sampling.constrain_edge_sampling=True \
    rgfn.compatibility_mask=True \
    sampling.num_sample_batches=1 \
    self_conditioning=True \
    model=semla \
    backbone=semla \
    eval.checkpoint_path="<path_to_checkpoint>/<checkpoint>.ckpt" \
    sampling.num_sample_batches=5 \
    loader.eval_batch_size=200 \
    sampling.steps=100 \
    sampling.refine_coordinates_steps=0 \
    spatial.normalize=True \
    spatial.scale_noise=True \
    model.length=5 
```

### Sampling - Pharmacophore Conditioning
```bash
python main.py \
    mode=sample_eval \
    sampling.constrain_edge_sampling=True \
    rgfn.compatibility_mask=True \
    self_conditioning=True \
    model=semla \
    backbone=semla \
    eval.checkpoint_path="<path_to_checkpoint>/<checkpoint>.ckpt" \
    sampling.num_sample_batches=1 \
    loader.eval_batch_size=200 \
    sampling.steps=100 \
    spatial.normalize=True \
    spatial.scale_noise=True \
    spatial.pharmacophore_conditioning=True \
    spatial.pharmacophore_subset=7 \
    spatial.pharm_cond_mol="<path_to_mol>/<mol>.sdf" \
    model.length=5
```
