{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e29fb882-f41b-4274-ae46-a4015dbc2c56",
   "metadata": {},
   "outputs": [],
   "source": [
    "import open3d\n",
    "from generate_point_cloud import (\n",
    "    get_atom_coords,\n",
    "    get_atomic_vdw_radii,\n",
    "    get_molecular_surface,\n",
    "    get_electrostatics,\n",
    "    get_atomic_partial_charges,\n",
    "    get_electrostatics_given_point_charges,\n",
    ")\n",
    "from pharm_utils.pharmacophore import get_pharmacophores\n",
    "\n",
    "from conformer_generation import update_mol_coordinates\n",
    "\n",
    "print('importing rdkit')\n",
    "import rdkit\n",
    "from rdkit.Chem import rdDetermineBonds\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "print('importing torch')\n",
    "import torch\n",
    "import torch_geometric\n",
    "from torch_geometric.nn import radius_graph\n",
    "import torch_scatter\n",
    "\n",
    "import pickle\n",
    "from copy import deepcopy\n",
    "import os\n",
    "import multiprocessing\n",
    "from tqdm import tqdm\n",
    "\n",
    "import sys\n",
    "sys.path.insert(-1, \"model/\")\n",
    "sys.path.insert(-1, \"model/equiformer_v2\")\n",
    "\n",
    "print('importing lightning')\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from pytorch_lightning.loggers import CSVLogger\n",
    "\n",
    "from lightning_module import LightningModule\n",
    "from datasets import HeteroDataset\n",
    "\n",
    "import importlib\n",
    "\n",
    "from inference import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95c73210-d818-479d-a085-9ecc518451fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "chkpt = 'shepherd_chkpts/x1x3x4_diffusion_mosesaq_20240824_last.ckpt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40cb7561-8908-42e3-8a6b-7f141ba01608",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "\n",
    "model_pl = LightningModule.load_from_checkpoint(chkpt)\n",
    "params = model_pl.params\n",
    "model_pl.to(device)\n",
    "model_pl.model.device = device"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6f7e232-ea01-4974-ad3b-d5eaf1564f6b",
   "metadata": {},
   "source": [
    "# Natural Products Interaction Profiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98d9c59c-3778-4933-9594-e82343771a05",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('conformers/np/molblock_charges_NPs.pkl', 'rb') as f:\n",
    "    molblocks_and_charges = pickle.load(f)\n",
    "\n",
    "index = 0\n",
    "mol = rdkit.Chem.MolFromMolBlock(molblocks_and_charges[index][0], removeHs = False) # target natural product\n",
    "charges = np.array(molblocks_and_charges[index][1]) # xTB partial charges in implicit water\n",
    "display(mol)\n",
    "\n",
    "# extracting target interaction profiles (ESP and pharmacophores)\n",
    "mol_coordinates = np.array(mol.GetConformer().GetPositions())\n",
    "mol_coordinates = mol_coordinates - np.mean(mol_coordinates, axis = 0)\n",
    "mol = update_mol_coordinates(mol, mol_coordinates, copy = True)\n",
    "\n",
    "centers = mol.GetConformer().GetPositions()\n",
    "radii = get_atomic_vdw_radii(mol)\n",
    "surface = get_molecular_surface(\n",
    "    centers, \n",
    "    radii, \n",
    "    params['dataset']['x3']['num_points'], \n",
    "    probe_radius = params['dataset']['probe_radius'],\n",
    "    num_samples_per_atom = 20,\n",
    ")\n",
    "\n",
    "pharm_types, pharm_pos, pharm_direction = get_pharmacophores(\n",
    "    mol,\n",
    "    multi_vector = params['dataset']['x4']['multivectors'],\n",
    "    check_access = params['dataset']['x4']['check_accessibility'],\n",
    ")\n",
    "\n",
    "electrostatics = get_electrostatics_given_point_charges(\n",
    "    charges, centers, surface,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d41552b-30b8-468f-a701-7225b82b4990",
   "metadata": {},
   "source": [
    "# PDB Ligand Interaction Profiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d7a7fc9-7ead-4c91-99f0-08a43dfac2e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('conformers/pdb/molblock_charges_pdb_lowestenergy.pkl', 'rb') as f:\n",
    "    molblocks_and_charges = pickle.load(f)\n",
    "\n",
    "index = 6\n",
    "mol = rdkit.Chem.MolFromMolBlock(molblocks_and_charges[index][0], removeHs = False) # target natural product\n",
    "charges = np.array(molblocks_and_charges[index][1]) # xTB partial charges in implicit water\n",
    "display(mol)\n",
    "\n",
    "# extracting target interaction profiles (ESP and pharmacophores)\n",
    "mol_coordinates = np.array(mol.GetConformer().GetPositions())\n",
    "mol_coordinates = mol_coordinates - np.mean(mol_coordinates, axis = 0)\n",
    "mol = update_mol_coordinates(mol, mol_coordinates, copy = True)\n",
    "\n",
    "centers = mol.GetConformer().GetPositions()\n",
    "radii = get_atomic_vdw_radii(mol)\n",
    "surface = get_molecular_surface(\n",
    "    centers, \n",
    "    radii, \n",
    "    params['dataset']['x3']['num_points'], \n",
    "    probe_radius = params['dataset']['probe_radius'],\n",
    "    num_samples_per_atom = 20,\n",
    ")\n",
    "\n",
    "pharm_types, pharm_pos, pharm_direction = get_pharmacophores(\n",
    "    mol,\n",
    "    multi_vector = params['dataset']['x4']['multivectors'],\n",
    "    check_access = params['dataset']['x4']['check_accessibility'],\n",
    ")\n",
    "\n",
    "electrostatics = get_electrostatics_given_point_charges(\n",
    "    charges, centers, surface,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c501bdd1-c729-422c-bff6-bd3cc7fb726e",
   "metadata": {},
   "source": [
    "# Fragment merging interaction profiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6c54f28-cb50-42b5-90fd-43d8b91774d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('conformers/fragment_merging/fragment_merge_condition.pickle', 'rb') as f:\n",
    "    fragment_merge_features = pickle.load(f)\n",
    "COM = fragment_merge_features['x3']['positions'].mean(0)\n",
    "fragment_merge_features['x2']['positions'] = fragment_merge_features['x2']['positions'] - COM\n",
    "fragment_merge_features['x3']['positions'] = fragment_merge_features['x3']['positions'] - COM\n",
    "fragment_merge_features['x4']['positions'] = fragment_merge_features['x4']['positions'] - COM\n",
    "    \n",
    "surface = deepcopy(fragment_merge_features['x3']['positions'])\n",
    "electrostatics = deepcopy(fragment_merge_features['x3']['charges'])\n",
    "pharm_types = deepcopy(fragment_merge_features['x4']['types'])\n",
    "pharm_pos = deepcopy(fragment_merge_features['x4']['positions'])\n",
    "pharm_direction = deepcopy(fragment_merge_features['x4']['directions'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "747d173c-186c-4065-97d4-6671eb270442",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "0ed35d0a-a5a5-46be-87a5-9fe1b7812c28",
   "metadata": {},
   "source": [
    "# Running conditional generation with inpainting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e44682ce-9784-4a90-b065-0b41a07c250d",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_atoms = 60\n",
    "batch_size = 1\n",
    "num_pharmacophores = len(pharm_types) # must equal pharm_pos.shape[0] if inpainting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bf1f997-99ba-4a4b-9e21-fb8ee91ccff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_samples = inference_sample(\n",
    "    model_pl,\n",
    "    batch_size = batch_size,\n",
    "    \n",
    "    N_x1 = n_atoms,\n",
    "    N_x4 = num_pharmacophores,\n",
    "    \n",
    "    unconditional = False,\n",
    "    \n",
    "    prior_noise_scale = 1.0,\n",
    "    denoising_noise_scale = 1.0,\n",
    "    \n",
    "    inject_noise_at_ts = [],\n",
    "    inject_noise_scales = [],    \n",
    "    \n",
    "    harmonize = False,\n",
    "    harmonize_ts = [],\n",
    "    harmonize_jumps = [],\n",
    "    \n",
    "    \n",
    "    # all the below options are only relevant if unconditional is False\n",
    "    \n",
    "    inpaint_x2_pos = False, # x2 is implicitly modeled via x3\n",
    "    \n",
    "    inpaint_x3_pos = True,\n",
    "    inpaint_x3_x = True,\n",
    "    \n",
    "    inpaint_x4_pos = True,\n",
    "    inpaint_x4_direction = True,\n",
    "    inpaint_x4_type = True,\n",
    "    \n",
    "    stop_inpainting_at_time_x2 = 0.0,\n",
    "    add_noise_to_inpainted_x2_pos = 0.0,\n",
    "    \n",
    "    stop_inpainting_at_time_x3 = 0.0,\n",
    "    add_noise_to_inpainted_x3_pos = 0.0,\n",
    "    add_noise_to_inpainted_x3_x = 0.0,\n",
    "    \n",
    "    stop_inpainting_at_time_x4 = 0.0,\n",
    "    add_noise_to_inpainted_x4_pos = 0.0,\n",
    "    add_noise_to_inpainted_x4_direction = 0.0,\n",
    "    add_noise_to_inpainted_x4_type = 0.0,\n",
    "    \n",
    "    # these are the inpainting targets\n",
    "    center_of_mass = np.zeros(3), # center of mass of x1; already centered to zero above\n",
    "    surface = surface,\n",
    "    electrostatics = electrostatics,\n",
    "    pharm_types = pharm_types,\n",
    "    pharm_pos = pharm_pos,\n",
    "    pharm_direction = pharm_direction,\n",
    "        \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebefb39e-0ede-43c7-9e40-f8d39dd62f5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(generated_samples) # == batch_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c38796b7-ea42-4a4b-b999-b7e2cb3c2c61",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_samples[0]['x1']['atoms']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2215859-9973-4516-a391-aab0f708d4dc",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "generated_samples[0]['x1']['positions']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f030f234-c79d-4c8b-9ff0-fc73d3cff996",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-shepherd]",
   "language": "python",
   "name": "conda-env-.conda-shepherd-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
