{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eebefaa7-cd7c-4d1f-8e8d-da2ddffbef55",
   "metadata": {},
   "outputs": [],
   "source": [
    "import open3d \n",
    "from generate_point_cloud import (\n",
    "    get_atom_coords, \n",
    "    get_atomic_vdw_radii, \n",
    "    get_molecular_surface,\n",
    "    get_electrostatics,\n",
    "    get_atomic_partial_charges,\n",
    "    get_electrostatics_given_point_charges,\n",
    ")\n",
    "from pharm_utils.pharmacophore import get_pharmacophores\n",
    "\n",
    "from conformer_generation import update_mol_coordinates\n",
    "\n",
    "print('importing rdkit')\n",
    "import rdkit\n",
    "from rdkit.Chem import rdDetermineBonds\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "print('importing torch')\n",
    "import torch\n",
    "import torch_geometric\n",
    "from torch_geometric.nn import radius_graph\n",
    "import torch_scatter\n",
    "\n",
    "import pickle\n",
    "from copy import deepcopy\n",
    "import os\n",
    "import multiprocessing\n",
    "from tqdm import tqdm\n",
    "\n",
    "import sys\n",
    "sys.path.insert(-1, \"model/\")\n",
    "sys.path.insert(-1, \"model/equiformer_v2\")\n",
    "\n",
    "print('importing lightning')\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from pytorch_lightning.loggers import CSVLogger\n",
    "\n",
    "from lightning_module import LightningModule\n",
    "from datasets import HeteroDataset\n",
    "\n",
    "import importlib\n",
    "\n",
    "from inference import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "120a2d86-2a6c-4bcb-b5b6-3bde2d51d15d",
   "metadata": {},
   "outputs": [],
   "source": [
    "chkpt = 'shepherd_chkpts/x1x3x4_diffusion_mosesaq_20240824_last.ckpt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "209e1bfb-3f5c-46cf-9be4-b6256c83cba3",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "model_pl = LightningModule.load_from_checkpoint(chkpt)\n",
    "params = model_pl.params\n",
    "model_pl.to(device)\n",
    "model_pl.model.device = device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6b65e6d-cae7-4de7-bc5d-d0ccb00c0324",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1\n",
    "n_atoms = 30\n",
    "num_pharmacophores = 8 # set to 5 (dummy value) if using a model that does not model x4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50631996-8cbf-4df7-97bd-a614f56649b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# only use to break symmetry during unconditional generation\n",
    "T = params['noise_schedules']['x1']['T']\n",
    "inject_noise_at_ts = list(np.arange(130, 80, -1)) # [150]\n",
    "inject_noise_scales = [1.0] * len(inject_noise_at_ts)\n",
    "harmonize = True\n",
    "harmonize_ts = [80]\n",
    "harmonize_jumps = [20]\n",
    "\n",
    "\n",
    "generated_samples = inference_sample(\n",
    "    model_pl,\n",
    "    batch_size = batch_size,\n",
    "    \n",
    "    N_x1 = n_atoms,\n",
    "    N_x4 = num_pharmacophores, \n",
    "    \n",
    "    unconditional = True,\n",
    "    \n",
    "    prior_noise_scale = 1.0,\n",
    "    denoising_noise_scale = 1.0,\n",
    "    \n",
    "    # only use to break symmetry during unconditional generation\n",
    "    inject_noise_at_ts = inject_noise_at_ts, #[],\n",
    "    inject_noise_scales = inject_noise_scales, #[],    \n",
    "    harmonize = harmonize, # False\n",
    "    harmonize_ts = harmonize_ts, #[],\n",
    "    harmonize_jumps = harmonize_jumps, #[],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81d1e269-bbab-47ec-a1e0-cfd301ae2677",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(generated_samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db8b465e-2d03-4894-af27-03942832410b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(generated_samples[0]['x1']['atoms']) # atomic numbers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f7ab362-2b92-4638-a9dd-3a986fc50723",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(generated_samples[0]['x1']['positions']) # atomic coordinates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "230086e4-ef0b-4fc8-950a-3a06d7fde112",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(generated_samples[0]['x3']['positions']) # ESP surface point coordinates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c329952-a56d-4cd6-811b-80e485b7c9ea",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(generated_samples[0]['x3']['charges']) # ESP values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e8c66d3-2178-410b-89d6-1c54692527ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(generated_samples[0]['x4']['types']) # pharmacophore types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5672f111-498f-4d72-80bf-e5adbe38912b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(generated_samples[0]['x4']['positions']) # pharmacophore positions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52aadc47-6d8b-4096-ae9d-4e18cad095ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(generated_samples[0]['x4']['directions']) # pharmacophore directions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0fe3005-0e75-4d72-81e0-688134383885",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-shepherd]",
   "language": "python",
   "name": "conda-env-.conda-shepherd-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
