# Alignment is key to applying diffusion models to retrosynthesis

### Examples of the generative process

The files generation_example_1.mp4 and generation_example_2.mp4 contain examples of the generative process. In them, the 'absorbing state' of the diffusion model is visualized with the gold atom (Au). The generation starts from the prior, which is just the absorbing state for all of the nodes, and no edges. The model then iteratively generates the reactants by transforming the absorbing-state nodes to atoms and some empty edges in the adjacency matrix into bonds. The product (on the right) is given as conditioning information. The bottom of the figure shows the ground-truth reaction. 

# Installation instructions

The codebase is originally based on DiGress (https://github.com/cvignac/DiGress), and some parts of the code still derive from there. 

We used python 3.9.16. Once you have an empty environment set up, run the following commands (assuming a conda installation, without a GPU):

```bash
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 cpuonly -c pytorch
pip install torch_geometric
pip install pyg_lib torch_scatter torch_sparse torch_spline_conv -f https://data.pyg.org/whl/torch-1.13.1+cpu.html
pip install wandb=0.15.3
pip install hydra-core
pip install pytorch-lightning==1.9
pip install overrides
pip install rdkit==2023.03.1
pip install multiset
pip install pandas
pip install matplotlib
pip install seaborn
pip install imageio==2.29.0
```

After that: ```pip install -e .```. 

# Relevant places in the code that describe the model

For ease of navigation and understanding the exact functioning of the model through the code, the following files are the most relevant. Many of the files reference a 'Placeholder' class, which is a simple class that is used to store the data in the dense format, i.e., the node features and adjacency matrices, along with additional information about the graphs. 
- **src/diffusion/diffusion_abstract.py**: 
    - the function **training_step** contains the training loop, where noise is added to the graphs, and consequently the model is trained to denoise the graphs with cross-entropy
    - the function **apply_noise** applies the categorical noise to the graphs. The transition matrices are defined in src/diffusion/noise_schedule.py, and imported in the diffusion module in diffusion_abstract.py
    - the function **sample_one_batch** samples a batch of graphs. It takes as input a graph from the data, noises out the reactants, and then samples the reactants from the model with the iterative diffusion process. 
- **src/diffusion/noise_schedule.py**: 
    - contains a class that generates transition matrices for the absorbing-state diffusion process.

# Training run
Run the following command to train a model (assumes wandb integration). Replace the placeholders with the appropriate values. For instance, the following runs our best model without wandb integration:

```bash
python src/train.py +experiment=pe_skip general.wandb.mode=offline
```

Processing the data when running the command for the first time usually takes 10-20 minutes depending on the hardware. If you want to try out the model quickly, you can run the a tiny version of the dataset with the following command:

```bash
python src/train.py +experiment=pe_skip general.wandb.mode=offline dataset.dataset_nb=block-15-tiny
```

# Evaluation pipeline

Run the following commands to evaluate a model. They assume wandb integration, and in particular that the checkpoints were uploaded to a wandb project during training. The results are uploaded to wandb, but also will be outputted in log files under the hydra directory. Some of the communication is also handled through wandb. Replace the placeholders with the appropriate values. Some of the parameters are only relevant when using the scripts to parallelize across multiple GPUs, and have been set to zero. Explanations for the parameters:

```bash
run_id: the wandb run id of the model to evaluate
edge_conditional_set: The set to evaluate. 'train', 'val', or 'test'
epochs: A list of epochs to evaluate (assumes that the checkpoints were saved as wandb artifacts)
n_conditions: The number of conditions to evaluate
n_samples_per_condition: The number of samples to generate per condition
total_cond_eval: The total number of conditions to evaluate (should be equal to n_conditions, unless we parallelize to multiple processes)
sampling_step: The number of diffusion steps to sample from the model
sampling_step_counts: A list of sampling step counts. Should equal to [sampling_step]. This is only relevant when using the scripts to parallelize across multiple GPUs and we want to evaluate the model with different diffusion steps.
random_seed: The random seed at which to generate samples
condition_first: The index of the first condition to evaluate in the dataset. Usually should equal to zero. 
return_smiles_with_atom_mapping: Whether to return the SMILES with atom mapping. This is relevant when we want to add the stereochemistry later on in the pipeline. 
```

Running the pipeline:

```bash
python src/wandb_download_model_weights.py general.wandb.run_id={run_id} general.wandb.checkpoint_epochs={epochs}

python src/sample_array_job.py general.wandb.mode=offline general.wandb.run_id={run_id} diffusion.edge_conditional_set={edge_conditional_set} general.wandb.checkpoint_epochs=[{epoch}] test.condition_first=0 test.condition_index=0 test.n_conditions={n_conditions} test.n_samples_per_condition={n_samples_per_condition} dataset.shuffle=False general.wandb.load_run_config=True hydra.run.dir=experiments/{experiment_name}/ test.total_cond_eval={total_cond_eval} train.seed={seed} diffusion.diffusion_steps_eval={sampling_step} test.inpaint_on_one_reactant={inpaint_on_one_reactant}

python src/sample_wandblog.py general.wandb.run_id={run_id} diffusion.edge_conditional_set={edge_conditional_set} general.wandb.checkpoint_epochs={epochs} test.n_conditions={n_conditions} test.n_samples_per_condition={n_samples_per_condition} hydra.run.dir=experiments/{experiment_name}/ test.total_cond_eval={total_cond_eval} general.wandb.eval_sampling_steps={sampling_step_counts}

python src/evaluate_array_job.py general.wandb.mode=offline general.wandb.load_run_config=True general.wandb.run_id={run_id} diffusion.edge_conditional_set={edge_conditional_set} general.wandb.checkpoint_epochs=[{epoch}] test.condition_first=0 test.condition_index=0 test.n_conditions={n_conditions} test.n_samples_per_condition={n_samples_per_condition} dataset.shuffle=False test.total_cond_eval={total_cond_eval} hydra.run.dir=experiments/{experiment_folder}/ train.seed={random_seed} diffusion.diffusion_steps_eval={sampling_step} test.return_smiles_with_atom_mapping={return_smiles_with_atom_mapping}

python src/evaluate_wandblog.py general.wandb.checkpoint_epochs={epochs} hydra.run.dir=experiments/{experiment_folder}/ general.wandb.run_id={run_id} test.total_cond_eval={total_cond_eval} test.n_samples_per_condition={n_samples_per_condition} diffusion.edge_conditional_set={edge_conditional_set} general.wandb.eval_sampling_steps={sampling_step_counts} test.condition_first=0
```

After that, to include stereochemistry and compare directly to the original data, get the output file of evaluate_wandblog.py (the .txt file), and replace OLD_SAMPLE_PATH in line 195 of src/transfer_stereo_and_compare-directly_to_data.py and run the following command:

```bash
python src/transfer_stereo_and_compare-directly_to_data.py
```