{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# local install shepherd_score as directed in the README.md\n",
    "import numpy as np\n",
    "import pickle\n",
    "import pandas\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem import AllChem\n",
    "import open3d as o3d\n",
    "from shepherd_score.evaluations.evaluate import ConsistencyEvalPipeline, ConditionalEvalPipeline\n",
    "from shepherd_score.container import Molecule"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Consistency (unconditional) evaluation minimal example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To take random molecules from the training set, align, and score them for the lower bound:\n",
    "# Here we provide 45 molecules randomly sampled from ShEPhERD-GDB17\n",
    "with open('test_data/gdb17_sampled_molblock_charges.pkl', 'rb') as f:\n",
    "    training_molblock_charges = pickle.load(f)\n",
    "\n",
    "# Load the generated samples -> samples is a list of dictionaries\n",
    "path_to_unconditional_samples_pickle = './test_data/gdb17_x3_uncond_samples.pickle'\n",
    "with open(path_to_unconditional_samples_pickle, 'rb') as f:\n",
    "    unconditional_samples = pickle.load(f)\n",
    "\n",
    "# for this example we use x3\n",
    "ls_atoms_pos = []\n",
    "ls_surf_points = []\n",
    "ls_surf_esp = []\n",
    "\n",
    "for i in range(len(unconditional_samples)):\n",
    "    ls_atoms_pos.append(\n",
    "        (unconditional_samples[i]['x1']['atoms'], unconditional_samples[i]['x1']['positions'])\n",
    "    )\n",
    "    ls_surf_points.append(\n",
    "        unconditional_samples[i]['x3']['positions']\n",
    "    )\n",
    "    ls_surf_esp.append(\n",
    "        unconditional_samples[i]['x3']['charges']\n",
    "    )\n",
    "\n",
    "# Initialize evaluation\n",
    "consis_pipeline = ConsistencyEvalPipeline(\n",
    "    generated_mols=ls_atoms_pos,\n",
    "    generated_surf_points=ls_surf_points,\n",
    "    generated_surf_esp=ls_surf_esp,\n",
    "    generated_pharm_feats=None,\n",
    "    probe_radius=0.6,\n",
    "    pharm_multi_vector=False,\n",
    "    solvent=None, # we use gas phase for GDB17\n",
    "    random_molblock_charges=training_molblock_charges,\n",
    ")\n",
    "\n",
    "# Run evaluation\n",
    "print('Running evaluation...')\n",
    "consis_pipeline.evaluate(num_processes=1, verbose=True)\n",
    "print('Finished evaluation.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can save the outputs or just view them in the consis_pipeline object.\n",
    "save_file_path = 'test_data/consis_example.npz'\n",
    "np.savez(\n",
    "    file=str(save_file_path),\n",
    "    molblocks=consis_pipeline.molblocks,\n",
    "    molblocks_post_opt = consis_pipeline.molblocks_post_opt,\n",
    "    num_valid = consis_pipeline.num_valid,\n",
    "    num_valid_post_opt = consis_pipeline.num_valid_post_opt,\n",
    "    num_consistent_graph = consis_pipeline.num_consistent_graph,\n",
    "    strain_energies = consis_pipeline.strain_energies,\n",
    "    rmsds = consis_pipeline.rmsds,\n",
    "    SA_scores = consis_pipeline.SA_scores,\n",
    "    logPs = consis_pipeline.logPs,\n",
    "    QEDs = consis_pipeline.QEDs,\n",
    "    fsp3s = consis_pipeline.fsp3s,\n",
    "    frac_valid = consis_pipeline.frac_valid,\n",
    "    frac_valid_post_opt = consis_pipeline.frac_valid_post_opt,\n",
    "    frac_consistent = consis_pipeline.frac_consistent,\n",
    "    frac_unique = consis_pipeline.frac_unique,\n",
    "    frac_unique_post_opt = consis_pipeline.frac_unique_post_opt,\n",
    "    avg_graph_diversity = consis_pipeline.avg_graph_diversity,\n",
    "    sims_surf_consistent = consis_pipeline.sims_surf_consistent,\n",
    "    sims_esp_consistent = consis_pipeline.sims_esp_consistent,\n",
    "    sims_pharm_consistent = consis_pipeline.sims_pharm_consistent,\n",
    "    sims_surf_upper_bound_75 = consis_pipeline.sims_surf_upper_bound_75,\n",
    "    sims_esp_upper_bound_75 = consis_pipeline.sims_esp_upper_bound_75,\n",
    "    sims_surf_upper_bound_400 = consis_pipeline.sims_surf_upper_bound_400,\n",
    "    sims_esp_upper_bound_400 = consis_pipeline.sims_esp_upper_bound_400,\n",
    "    sims_surf_lower_bound = consis_pipeline.sims_surf_lower_bound,\n",
    "    sims_esp_lower_bound = consis_pipeline.sims_esp_lower_bound,\n",
    "    sims_pharm_lower_bound = consis_pipeline.sims_pharm_lower_bound,\n",
    "    sims_surf_consistent_relax = consis_pipeline.sims_surf_consistent_relax,\n",
    "    sims_esp_consistent_relax = consis_pipeline.sims_esp_consistent_relax,\n",
    "    sims_pharm_consistent_relax = consis_pipeline.sims_pharm_consistent_relax,\n",
    "    sims_surf_consistent_relax_align = consis_pipeline.sims_surf_consistent_relax_aligned,\n",
    "    sims_esp_consistent_relax_align = consis_pipeline.sims_esp_consistent_relax_aligned,\n",
    "    sims_pharm_consistent_relax_align = consis_pipeline.sims_pharm_consistent_relax_aligned,\n",
    "    graph_similarity_matrix = consis_pipeline.graph_similarity_matrix\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conditional evaluation minimal example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the generated samples -> samples is a list of dictionaries\n",
    "path_to_conditional_samples_pickle = './test_data/gdb17_x3_samples.pickle'\n",
    "with open(path_to_conditional_samples_pickle, 'rb') as f:\n",
    "    conditional_samples = pickle.load(f)\n",
    "\n",
    "condition = 'esp' # we set this condition since samples are from P(x1 | x3)\n",
    "\n",
    "ref_mol = Chem.MolFromMolBlock(conditional_samples[0], removeHs=False) # first element of conditional_samples is the reference molecule molblock\n",
    "ref_partial_charges = conditional_samples[1] # second element of conditional_samples is the reference molecule's xtb partial charges\n",
    "surface_points = conditional_samples[2] # third element of conditional_samples is the reference molecule's surface points\n",
    "electrostatics = conditional_samples[3] # fourth element of conditional_samples is the reference molecule's ESP values\n",
    "pharm_types = conditional_samples[4] # fifth element of conditional_samples is the reference molecule's pharm types\n",
    "pharm_ancs = conditional_samples[5] # sixth element of conditional_samples is the reference molecule's pharm ancs\n",
    "pharm_vecs = conditional_samples[6] # seventh element of conditional_samples is the reference molecule's pharm vectors\n",
    "ref_molec = Molecule(ref_mol,\n",
    "                     probe_radius=0.6,\n",
    "                     partial_charges=np.array(ref_partial_charges),\n",
    "                     num_surf_points=400,\n",
    "                     pharm_multi_vector=False,\n",
    "                     pharm_types=pharm_types,\n",
    "                     pharm_ancs=pharm_ancs,\n",
    "                     pharm_vecs=pharm_vecs)\n",
    "\n",
    "# The last element of conditional_samples contains a list of dictionaries of the generated samples \n",
    "# Get the generated molecules' atoms and positions\n",
    "generated_mols = [(conditional_samples[-1][i]['x1']['atoms'], conditional_samples[-1][i]['x1']['positions']) for i in range(len(conditional_samples[-1]))]\n",
    "\n",
    "print(f'Starting Conditional Eval Pipeline.')\n",
    "cond_pipe = ConditionalEvalPipeline(\n",
    "    ref_molec=ref_molec,\n",
    "    generated_mols=generated_mols,\n",
    "    condition=condition,\n",
    "    num_surf_points=400,\n",
    "    pharm_multi_vector=False,\n",
    "    solvent=None # GDB17 uses gas phase\n",
    ")\n",
    "\n",
    "# Run evaluation\n",
    "cond_pipe.evaluate(num_processes=1, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can save the outputs or just view them in the cond_pipe object.\n",
    "save_file = './test_data/cond_example.npz'\n",
    "np.savez(\n",
    "    save_file,\n",
    "    ref_molblock = cond_pipe.ref_molblock,\n",
    "    ref_mol_SA_score = cond_pipe.ref_mol_SA_score,\n",
    "    ref_mol_QED = cond_pipe.ref_mol_QED,\n",
    "    ref_mol_logP = cond_pipe.ref_mol_logP,\n",
    "    ref_mol_fsp3 = cond_pipe.ref_mol_fsp3,\n",
    "    ref_mol_morgan_fp = cond_pipe.ref_mol_morgan_fp,\n",
    "    ref_surf_resampling_scores = cond_pipe.ref_surf_resampling_scores,\n",
    "    ref_surf_esp_resampling_scores = cond_pipe.ref_surf_esp_resampling_scores,\n",
    "    sims_surf_upper_bound = cond_pipe.sims_surf_upper_bound,\n",
    "    sims_esp_upper_bound = cond_pipe.sims_esp_upper_bound,\n",
    "    molblocks = cond_pipe.molblocks,\n",
    "    molblocks_post_opt = cond_pipe.molblocks_post_opt,\n",
    "    num_valid = cond_pipe.num_valid,\n",
    "    num_valid_post_opt = cond_pipe.num_valid_post_opt,\n",
    "    num_consistent_graph = cond_pipe.num_consistent_graph,\n",
    "    strain_energies = cond_pipe.strain_energies,\n",
    "    rmsds = cond_pipe.rmsds,\n",
    "    SA_scores = cond_pipe.SA_scores,\n",
    "    logPs = cond_pipe.logPs,\n",
    "    QEDs = cond_pipe.QEDs,\n",
    "    fsp3s = cond_pipe.fsp3s,\n",
    "    frac_valid = cond_pipe.frac_valid,\n",
    "    frac_valid_post_opt = cond_pipe.frac_valid_post_opt,\n",
    "    frac_consistent = cond_pipe.frac_consistent,\n",
    "    frac_unique = cond_pipe.frac_unique,\n",
    "    frac_unique_post_opt = cond_pipe.frac_unique_post_opt,\n",
    "    avg_graph_diversity = cond_pipe.avg_graph_diversity,\n",
    "    sims_surf_target = cond_pipe.sims_surf_target,\n",
    "    sims_esp_target = cond_pipe.sims_esp_target,\n",
    "    sims_pharm_target = cond_pipe.sims_pharm_target,\n",
    "    sims_surf_target_relax = cond_pipe.sims_surf_target_relax,\n",
    "    sims_esp_target_relax = cond_pipe.sims_esp_target_relax,\n",
    "    sims_pharm_target_relax = cond_pipe.sims_pharm_target_relax,\n",
    "    graph_similarities = cond_pipe.graph_similarities,\n",
    "    sims_surf_target_relax_esp_aligned = cond_pipe.sims_surf_target_relax_esp_aligned,\n",
    "    sims_esp_target_relax_esp_aligned = cond_pipe.sims_esp_target_relax_esp_aligned,\n",
    "    sims_pharm_target_relax_esp_aligned = cond_pipe.sims_pharm_target_relax_esp_aligned,\n",
    "    molblocks_post_opt_esp_aligned = cond_pipe.molblocks_post_opt_esp_aligned\n",
    "\n",
    ")\n",
    "print(f'Finished sample evaluation!\\nSaved to {save_file}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
