{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db119898",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pydgn\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import os\n",
    "import json\n",
    "import os.path as osp\n",
    "sns.color_palette(\"colorblind\", as_cmap=True)\n",
    "sns.set_palette(\"colorblind\")\n",
    "from pydgn.data.dataset import OGBGDatasetInterface\n",
    "from dataset import OGBGmolpcbaFeatureMap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e41e7921",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_root = 'DATA/'\n",
    "dataset_name = 'ogbg-molpcba'\n",
    "exp_folder = f'GSPN_RESULTS/UNSUPERVISED/unsupervised_embedding_generation_categorical_{dataset_name}/MODEL_ASSESSMENT/'\n",
    "\n",
    "outer_fold = 1\n",
    "outer_folder = osp.join(exp_folder, f'OUTER_FOLD_{outer_fold}')\n",
    "ms_folder = osp.join(outer_folder, 'MODEL_SELECTION')\n",
    "\n",
    "config_id = 15  # best config for the unsupervised part according to regression task\n",
    "config_folder = osp.join(ms_folder, f'config_{config_id}')\n",
    "\n",
    "model_config_file = osp.join(config_folder, 'config_results.json')\n",
    "config = json.load(open(model_config_file, 'r'))['config']\n",
    "\n",
    "best_ckpt = torch.load(osp.join(config_folder, 'INNER_FOLD_1/best_checkpoint.pth'), map_location='cpu')['model_state']\n",
    "\n",
    "dataset = OGBGDatasetInterface(data_root, dataset_name)\n",
    "print(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e99b66a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pydgn.experiment.experiment import Experiment\n",
    "\n",
    "exp = Experiment(config, config_folder, exp_seed=0)\n",
    "model = exp.create_unsupervised_model(dataset.dim_node_features, dataset.dim_edge_features, dataset.dim_target)\n",
    "model.load_state_dict(best_ckpt)\n",
    "model.to('cpu')\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "789acf43",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_features = dataset.data.x.shape[1]\n",
    "unique_values = [torch.sort(torch.unique(dataset.data.x[:,f]), descending=False)[0] for f in range(num_features)] \n",
    "print(unique_values)\n",
    "\n",
    "def preprocess_node_features(g):    \n",
    "    for f in range(num_features):\n",
    "        id = 0\n",
    "        for v in unique_values[f].tolist():\n",
    "            assert id <= v\n",
    "            g.x[:, f][g.x[:, f] == v] = id\n",
    "            id += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82b3476b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.data.x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f56221ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rdkit import Chem\n",
    "from rdkit.Chem import Draw\n",
    "from rdkit.Chem.Draw import rdMolDraw2D\n",
    "import cairosvg\n",
    "import io"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f5e27d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from ogb.utils import smiles2graph\n",
    "mol_df = pd.read_csv('DATA/ogbg_molpcba/mapping/mol.csv')\n",
    "mol_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7cd6a8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "smile = mol_df['smiles'][1]\n",
    "mol = Chem.MolFromSmiles(smile)\n",
    "print(smile)\n",
    "graph_original = smiles2graph(smile)\n",
    "print(graph_original['node_feat'])\n",
    "mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2291d5b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.data import Data\n",
    "graph_original_data = Data(x=torch.tensor(graph_original['node_feat']), edge_index=torch.tensor(graph_original['edge_index']), edge_attr=torch.tensor(graph_original['edge_feat']))\n",
    "\n",
    "# print(graph_original_data.x)\n",
    "preprocess_node_features(graph_original_data)\n",
    "# print(graph_original_data.x)\n",
    "\n",
    "with torch.no_grad():\n",
    "    preds_g, node_posterior, [objective_v, objective_g, _, _, _, _, _, mixture_weights, avg_params_across_layers] = model(graph_original_data)\n",
    "    \n",
    "print(objective_v, objective_v.mean())\n",
    "sns.heatmap(node_posterior)\n",
    "plt.figure()\n",
    "sns.heatmap(node_posterior[10].unsqueeze(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fa80740",
   "metadata": {},
   "outputs": [],
   "source": [
    "smile = 'N#Cc1nnn(-c2ccc(O)cc2)c1O'\n",
    "mol = Chem.MolFromSmiles(smile)\n",
    "print(smile)\n",
    "graph_modified = smiles2graph(smile)\n",
    "print(graph_modified['node_feat'])\n",
    "mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96e5e53a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.data import Data\n",
    "graph_modified_data = Data(x=torch.tensor(graph_modified['node_feat']), edge_index=torch.tensor(graph_modified['edge_index']), edge_attr=torch.tensor(graph_modified['edge_feat']))\n",
    "preprocess_node_features(graph_modified_data)\n",
    "\n",
    "with torch.no_grad():\n",
    "    preds_g, node_posterior, [objective_v_1, objective_g, _, _, _, _, _, mixture_weights, avg_params_across_layers, ] = model(graph_modified_data)\n",
    "    \n",
    "\n",
    "print(objective_v_1, objective_v_1.mean())\n",
    "sns.heatmap(node_posterior)\n",
    "plt.figure()\n",
    "sns.heatmap(node_posterior[10].unsqueeze(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40467ed0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(smile)\n",
    "plt.figure(figsize=(10,5))  \n",
    "sns.heatmap((objective_v_1 - objective_v).unsqueeze(0), cmap='rocket_r') # substituting leads to a change in likelihood\n",
    "plt.rcParams.update({'font.size': 22})\n",
    "plt.xticks(ticks=np.arange(15)+0.5, labels=['N','C','C','N', 'N', 'N', 'C', 'C', 'C', 'C', 'Cl', 'C', 'C', 'C', 'N'])\n",
    "plt.ylabel(r'$\\Delta \\log \\mathcal{L}}$ after change')\n",
    "plt.yticks([])\n",
    "plt.xlabel(f'SMILES: {smile}')\n",
    "plt.tight_layout()\n",
    "plt.savefig('plots/delta_log_likelihood_1.pdf', bbox_inches='tight')\n",
    "print(objective_v_1 - objective_v)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3a8972f",
   "metadata": {},
   "source": [
    "### Two different situations. On the right of the image the carbon connected to the oxigen becomes more likely than when it was connected to the Cl,\n",
    "### whereas in the middle replacing the nitrogen with a carbon seems much less likely to happen."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26dbd151",
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pip install svglib django-renderpdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15cfa374",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rdkit import Chem\n",
    "from rdkit.Chem import Draw\n",
    "from rdkit.Chem.Draw import IPythonConsole\n",
    "from rdkit.Chem import rdFMCS\n",
    "from rdkit.Chem.Draw import rdDepictor\n",
    "from svglib.svglib import svg2rlg\n",
    "from reportlab.graphics import renderPDF\n",
    "\n",
    "\n",
    "smile = mol_df['smiles'][1]\n",
    "mol = Chem.MolFromSmiles(smile)\n",
    "graph_original = smiles2graph(smile)\n",
    "\n",
    "Draw.MolToFile(mol, 'plots/test.svg', size=(600, 600), \n",
    "                kekulize=True,\n",
    "                wedgeBonds=True,\n",
    "                fitImage=False,\n",
    "                options=None,\n",
    "                canvas=None)\n",
    "\n",
    "drawing = svg2rlg('plots/test.svg')\n",
    "renderPDF.drawToFile(drawing, 'plots/test.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b1d040f",
   "metadata": {},
   "outputs": [],
   "source": [
    "mol.GetAtomWithIdx(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bb745fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset_preprocessed = OGBGDatasetInterface(data_root, dataset_name)\n",
    "\n",
    "# model.to('cuda:0')\n",
    "# for sample in dataset_preprocessed:\n",
    "#     preprocess_node_features(sample)\n",
    "#     sample.to('cuda:0')\n",
    "#     with torch.no_grad():\n",
    "#         _, _, [loglik, _, _, _, _, _, _, _, _] = model(sample)\n",
    "    \n",
    "#     print(loglik.mean())\n",
    "#     sample.to('cpu')\n",
    "# model.to('cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a83922b",
   "metadata": {},
   "source": [
    "### Extra Molecules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0fe5fb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "id_mol = 1\n",
    "\n",
    "smile = mol_df['smiles'][id_mol]\n",
    "mol = Chem.MolFromSmiles(smile)\n",
    "print(smile)\n",
    "graph_original = smiles2graph(smile)\n",
    "graph_original_data = Data(x=torch.tensor(graph_original['node_feat']), edge_index=torch.tensor(graph_original['edge_index']), edge_attr=torch.tensor(graph_original['edge_feat']))\n",
    "\n",
    "preprocess_node_features(graph_original_data)\n",
    "\n",
    "with torch.no_grad():\n",
    "    preds_g, node_posterior, [objective_v, objective_g, _, _, _, _, _, mixture_weights, avg_params_across_layers] = model(graph_original_data)\n",
    "\n",
    "mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8cabcbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_smile = smile.replace('Cl', 'O') \n",
    "new_mol = Chem.MolFromSmiles(new_smile)\n",
    "print(new_smile)\n",
    "graph_modified = smiles2graph(new_smile)\n",
    "\n",
    "graph_modified_data = Data(x=torch.tensor(graph_modified['node_feat']), edge_index=torch.tensor(graph_modified['edge_index']), edge_attr=torch.tensor(graph_modified['edge_feat']))\n",
    "preprocess_node_features(graph_modified_data)\n",
    "\n",
    "with torch.no_grad():\n",
    "    preds_g, node_posterior, [objective_v_1, objective_g, _, _, _, _, _, mixture_weights, avg_params_across_layers, ] = model(graph_modified_data)\n",
    "\n",
    "plt.figure(figsize=(10,5))  \n",
    "sns.heatmap((objective_v_1 - objective_v).unsqueeze(0), cmap='rocket_r') # substituting leads to a change in likelihood\n",
    "plt.rcParams.update({'font.size': 22})\n",
    "labels = ['N','C','C','N', 'N', 'N', 'C', 'C', 'C', 'C', 'Cl', 'C', 'C', 'C', 'N']\n",
    "plt.xticks(ticks=np.arange(len(labels))+0.5, labels=labels)\n",
    "plt.ylabel(r'$\\Delta \\log \\mathcal{L}}$ after change')\n",
    "plt.yticks([])\n",
    "plt.xlabel(f'SMILES: {smile}')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'plots/delta_log_likelihood_{id_mol}.pdf', bbox_inches='tight')\n",
    "print(objective_v_1 - objective_v)    \n",
    "    \n",
    "Draw.MolToFile(mol, f'plots/mol_{id_mol}.svg', size=(600, 600), \n",
    "                kekulize=True,\n",
    "                wedgeBonds=True,\n",
    "                fitImage=False,\n",
    "                options=None,\n",
    "                canvas=None)\n",
    "\n",
    "drawing = svg2rlg(f'plots/mol_{id_mol}.svg')\n",
    "renderPDF.drawToFile(drawing, f'plots/mol_{id_mol}.pdf')\n",
    "\n",
    "Draw.MolToFile(new_mol, f'plots/mol_{id_mol}_modified.svg', size=(600, 600), \n",
    "                kekulize=True,\n",
    "                wedgeBonds=True,\n",
    "                fitImage=False,\n",
    "                options=None,\n",
    "                canvas=None)\n",
    "\n",
    "drawing = svg2rlg(f'plots/mol_{id_mol}_modified.svg')\n",
    "renderPDF.drawToFile(drawing, f'plots/mol_{id_mol}_modified.pdf')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee911468",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(42)\n",
    "id_mol = torch.randint(0, len(dataset), (1,)).item()\n",
    "print(id_mol)\n",
    "smile = mol_df['smiles'][id_mol]\n",
    "mol = Chem.MolFromSmiles(smile)\n",
    "print(smile)\n",
    "graph_original = smiles2graph(smile)\n",
    "graph_original_data = Data(x=torch.tensor(graph_original['node_feat']), edge_index=torch.tensor(graph_original['edge_index']), edge_attr=torch.tensor(graph_original['edge_feat']))\n",
    "\n",
    "preprocess_node_features(graph_original_data)\n",
    "\n",
    "with torch.no_grad():\n",
    "    preds_g, node_posterior, [objective_v, objective_g, _, _, _, _, _, mixture_weights, avg_params_across_layers] = model(graph_original_data)\n",
    "\n",
    "mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88916cda",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_smile = smile.replace('Cl', 'O') \n",
    "new_mol = Chem.MolFromSmiles(new_smile)\n",
    "print(new_smile)\n",
    "graph_modified = smiles2graph(new_smile)\n",
    "\n",
    "graph_modified_data = Data(x=torch.tensor(graph_modified['node_feat']), edge_index=torch.tensor(graph_modified['edge_index']), edge_attr=torch.tensor(graph_modified['edge_feat']))\n",
    "preprocess_node_features(graph_modified_data)\n",
    "\n",
    "with torch.no_grad():\n",
    "    preds_g, node_posterior, [objective_v_1, objective_g, _, _, _, _, _, mixture_weights, avg_params_across_layers, ] = model(graph_modified_data)\n",
    "\n",
    "plt.figure(figsize=(10,5))  \n",
    "sns.heatmap((objective_v_1 - objective_v).unsqueeze(0), cmap='rocket_r') # substituting leads to a change in likelihood\n",
    "plt.rcParams.update({'font.size': 22})\n",
    "labels = ['C','N','C','C','O','N','C','C','Cl','C','C','C','C','Cl','C','O','C','C','C','C','C','S']\n",
    "plt.xticks(ticks=np.arange(len(labels))+0.5, labels=labels)\n",
    "plt.ylabel(r'$\\Delta \\log \\mathcal{L}}$ after change')\n",
    "plt.yticks([])\n",
    "plt.xlabel(f'SMILES: {smile}')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'plots/delta_log_likelihood_{id_mol}.pdf', bbox_inches='tight')\n",
    "print(objective_v_1 - objective_v)    \n",
    "    \n",
    "Draw.MolToFile(mol, f'plots/mol_{id_mol}.svg', size=(600, 600), \n",
    "                kekulize=True,\n",
    "                wedgeBonds=True,\n",
    "                fitImage=False,\n",
    "                options=None,\n",
    "                canvas=None)\n",
    "\n",
    "drawing = svg2rlg(f'plots/mol_{id_mol}.svg')\n",
    "renderPDF.drawToFile(drawing, f'plots/mol_{id_mol}.pdf')\n",
    "\n",
    "Draw.MolToFile(new_mol, f'plots/mol_{id_mol}_modified.svg', size=(600, 600), \n",
    "                kekulize=True,\n",
    "                wedgeBonds=True,\n",
    "                fitImage=False,\n",
    "                options=None,\n",
    "                canvas=None)\n",
    "\n",
    "drawing = svg2rlg(f'plots/mol_{id_mol}_modified.svg')\n",
    "renderPDF.drawToFile(drawing, f'plots/mol_{id_mol}_modified.pdf')\n",
    "\n",
    "new_mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8644a8cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "id_mol = torch.randint(0, len(dataset), (1,)).item()\n",
    "print(id_mol)\n",
    "smile = mol_df['smiles'][id_mol]\n",
    "mol = Chem.MolFromSmiles(smile)\n",
    "print(smile)\n",
    "graph_original = smiles2graph(smile)\n",
    "graph_original_data = Data(x=torch.tensor(graph_original['node_feat']), edge_index=torch.tensor(graph_original['edge_index']), edge_attr=torch.tensor(graph_original['edge_feat']))\n",
    "\n",
    "preprocess_node_features(graph_original_data)\n",
    "\n",
    "with torch.no_grad():\n",
    "    preds_g, node_posterior, [objective_v, objective_g, _, _, _, _, _, mixture_weights, avg_params_across_layers] = model(graph_original_data)\n",
    "\n",
    "mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80a9477f",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_smile = smile.replace('O', 'N')\n",
    "new_smile = new_smile.replace('o', 'N')\n",
    "new_mol = Chem.MolFromSmiles(new_smile)\n",
    "print(new_smile)\n",
    "graph_modified = smiles2graph(new_smile)\n",
    "\n",
    "graph_modified_data = Data(x=torch.tensor(graph_modified['node_feat']), edge_index=torch.tensor(graph_modified['edge_index']), edge_attr=torch.tensor(graph_modified['edge_feat']))\n",
    "preprocess_node_features(graph_modified_data)\n",
    "\n",
    "with torch.no_grad():\n",
    "    preds_g, node_posterior, [objective_v_1, objective_g, _, _, _, _, _, mixture_weights, avg_params_across_layers, ] = model(graph_modified_data)\n",
    "\n",
    "plt.figure(figsize=(10,5))  \n",
    "sns.heatmap((objective_v_1 - objective_v).unsqueeze(0), cmap='rocket_r') # substituting leads to a change in likelihood\n",
    "plt.rcParams.update({'font.size': 22})\n",
    "labels=['O','C','C','C','C','N','C','C','C','C','C','C','C','C','C','C','C','C','C','C','O','O','C','C','C','C','C', 'Cl', 'C', 'C']\n",
    "plt.xticks(ticks=np.arange(len(labels))+0.5, labels=labels)\n",
    "plt.ylabel(r'$\\Delta \\log \\mathcal{L}}$ after change')\n",
    "plt.yticks([])\n",
    "plt.xlabel(f'SMILES: {smile}')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'plots/delta_log_likelihood_{id_mol}.pdf', bbox_inches='tight')\n",
    "print(objective_v_1 - objective_v)    \n",
    "\n",
    "Draw.MolToFile(mol, f'plots/mol_{id_mol}.svg', size=(600, 600), \n",
    "                kekulize=True,\n",
    "                wedgeBonds=True,\n",
    "                fitImage=False,\n",
    "                options=None,\n",
    "                canvas=None)\n",
    "\n",
    "drawing = svg2rlg(f'plots/mol_{id_mol}.svg')\n",
    "renderPDF.drawToFile(drawing, f'plots/mol_{id_mol}.pdf')\n",
    "\n",
    "Draw.MolToFile(new_mol, f'plots/mol_{id_mol}_modified.svg', size=(600, 600), \n",
    "                kekulize=True,\n",
    "                wedgeBonds=True,\n",
    "                fitImage=False,\n",
    "                options=None,\n",
    "                canvas=None)\n",
    "\n",
    "drawing = svg2rlg(f'plots/mol_{id_mol}_modified.svg')\n",
    "renderPDF.drawToFile(drawing, f'plots/mol_{id_mol}_modified.pdf')\n",
    "\n",
    "new_mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8705fba5",
   "metadata": {},
   "outputs": [],
   "source": [
    "id_mol = torch.randint(0, len(dataset), (1,)).item()\n",
    "print(id_mol)\n",
    "smile = mol_df['smiles'][id_mol]\n",
    "mol = Chem.MolFromSmiles(smile)\n",
    "print(smile)\n",
    "graph_original = smiles2graph(smile)\n",
    "graph_original_data = Data(x=torch.tensor(graph_original['node_feat']), edge_index=torch.tensor(graph_original['edge_index']), edge_attr=torch.tensor(graph_original['edge_feat']))\n",
    "\n",
    "preprocess_node_features(graph_original_data)\n",
    "\n",
    "with torch.no_grad():\n",
    "    preds_g, node_posterior, [objective_v, objective_g, _, _, _, _, _, mixture_weights, avg_params_across_layers] = model(graph_original_data)\n",
    "\n",
    "mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41b40108",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_smile = smile.replace('s', 'N') \n",
    "new_mol = Chem.MolFromSmiles(new_smile)\n",
    "print(new_smile)\n",
    "graph_modified = smiles2graph(new_smile)\n",
    "\n",
    "graph_modified_data = Data(x=torch.tensor(graph_modified['node_feat']), edge_index=torch.tensor(graph_modified['edge_index']), edge_attr=torch.tensor(graph_modified['edge_feat']))\n",
    "preprocess_node_features(graph_modified_data)\n",
    "\n",
    "with torch.no_grad():\n",
    "    preds_g, node_posterior, [objective_v_1, objective_g, _, _, _, _, _, mixture_weights, avg_params_across_layers, ] = model(graph_modified_data)\n",
    "\n",
    "plt.figure(figsize=(10,5))  \n",
    "sns.heatmap((objective_v_1 - objective_v).unsqueeze(0), cmap='rocket_r') # substituting leads to a change in likelihood\n",
    "plt.rcParams.update({'font.size': 22})\n",
    "labels = ['C','O','C','O','C','S','C','C','C','N','C','O','C','S','C','C','C','C']\n",
    "plt.xticks(ticks=np.arange(len(labels))+0.5, labels=labels, rotation = 0)\n",
    "plt.ylabel(r'$\\Delta \\log \\mathcal{L}}$ after change')\n",
    "plt.yticks([])\n",
    "plt.xlabel(f'SMILES: {smile}')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'plots/delta_log_likelihood_{id_mol}.pdf', bbox_inches='tight')\n",
    "print(objective_v_1 - objective_v)    \n",
    "    \n",
    "Draw.MolToFile(mol, f'plots/mol_{id_mol}.svg', size=(600, 600), \n",
    "                kekulize=True,\n",
    "                wedgeBonds=True,\n",
    "                fitImage=False,\n",
    "                options=None,\n",
    "                canvas=None)\n",
    "\n",
    "drawing = svg2rlg(f'plots/mol_{id_mol}.svg')\n",
    "renderPDF.drawToFile(drawing, f'plots/mol_{id_mol}.pdf')\n",
    "\n",
    "Draw.MolToFile(new_mol, f'plots/mol_{id_mol}_modified.svg', size=(600, 600), \n",
    "                kekulize=True,\n",
    "                wedgeBonds=True,\n",
    "                fitImage=False,\n",
    "                options=None,\n",
    "                canvas=None)\n",
    "\n",
    "drawing = svg2rlg(f'plots/mol_{id_mol}_modified.svg')\n",
    "renderPDF.drawToFile(drawing, f'plots/mol_{id_mol}_modified.pdf')\n",
    "\n",
    "new_mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8920341b",
   "metadata": {},
   "outputs": [],
   "source": [
    "id_mol = torch.randint(0, len(dataset), (1,)).item()\n",
    "print(id_mol)\n",
    "\n",
    "smile = mol_df['smiles'][id_mol]\n",
    "mol = Chem.MolFromSmiles(smile)\n",
    "print(smile)\n",
    "graph_original = smiles2graph(smile)\n",
    "graph_original_data = Data(x=torch.tensor(graph_original['node_feat']), edge_index=torch.tensor(graph_original['edge_index']), edge_attr=torch.tensor(graph_original['edge_feat']))\n",
    "\n",
    "preprocess_node_features(graph_original_data)\n",
    "\n",
    "with torch.no_grad():\n",
    "    preds_g, node_posterior, [objective_v, objective_g, _, _, _, _, _, mixture_weights, avg_params_across_layers] = model(graph_original_data)\n",
    "\n",
    "mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08f9ec68",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_smile = smile.replace('O', 'N') \n",
    "new_mol = Chem.MolFromSmiles(new_smile)\n",
    "print(new_smile)\n",
    "graph_modified = smiles2graph(new_smile)\n",
    "\n",
    "graph_modified_data = Data(x=torch.tensor(graph_modified['node_feat']), edge_index=torch.tensor(graph_modified['edge_index']), edge_attr=torch.tensor(graph_modified['edge_feat']))\n",
    "preprocess_node_features(graph_modified_data)\n",
    "\n",
    "with torch.no_grad():\n",
    "    preds_g, node_posterior, [objective_v_1, objective_g, _, _, _, _, _, mixture_weights, avg_params_across_layers, ] = model(graph_modified_data)\n",
    "\n",
    "plt.figure(figsize=(10,5))  \n",
    "sns.heatmap((objective_v_1 - objective_v).unsqueeze(0), cmap='rocket_r') # substituting leads to a change in likelihood\n",
    "plt.rcParams.update({'font.size': 22})\n",
    "labels = ['C','C','N','C','C','C','O','C','C','C','C','C','C','C','C','C','C','C', 'O', 'C','C','C','C','C','C','C','C','C','C']\n",
    "plt.xticks(ticks=np.arange(len(labels))+0.5, labels=labels)\n",
    "plt.ylabel(r'$\\Delta \\log \\mathcal{L}}$ after change')\n",
    "plt.yticks([])\n",
    "plt.xlabel(f'SMILES: {smile}')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'plots/delta_log_likelihood_{id_mol}.pdf', bbox_inches='tight')\n",
    "print(objective_v_1 - objective_v)    \n",
    "    \n",
    "Draw.MolToFile(mol, f'plots/mol_{id_mol}.svg', size=(600, 600), \n",
    "                kekulize=True,\n",
    "                wedgeBonds=True,\n",
    "                fitImage=False,\n",
    "                options=None,\n",
    "                canvas=None)\n",
    "\n",
    "drawing = svg2rlg(f'plots/mol_{id_mol}.svg')\n",
    "renderPDF.drawToFile(drawing, f'plots/mol_{id_mol}.pdf')\n",
    "\n",
    "Draw.MolToFile(new_mol, f'plots/mol_{id_mol}_modified.svg', size=(600, 600), \n",
    "                kekulize=True,\n",
    "                wedgeBonds=True,\n",
    "                fitImage=False,\n",
    "                options=None,\n",
    "                canvas=None)\n",
    "\n",
    "drawing = svg2rlg(f'plots/mol_{id_mol}_modified.svg')\n",
    "renderPDF.drawToFile(drawing, f'plots/mol_{id_mol}_modified.pdf')\n",
    "\n",
    "new_mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63f86b7e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1ce0af8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78615c19",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
