{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7a4ae08f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "57c9dbb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d5265d38",
   "metadata": {},
   "outputs": [],
   "source": [
    "bbbp_dir = '../../data/bbbp.csv'\n",
    "df_bbbp = pd.read_csv(bbbp_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8fefeaea",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_smile = df_bbbp['smiles'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "afc9fb92",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[Cl].CC(C)NCC(O)COc1cccc2ccccc12'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_smile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "43a7d2d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_mol_datapoint = MoleculeDatapoint(sample_smile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "41648b78",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<data.MoleculeDatapoint at 0x7f235f660790>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_mol_datapoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "727f66c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_mol_datapoint.bond_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "613868c1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__class__',\n",
       " '__delattr__',\n",
       " '__dict__',\n",
       " '__dir__',\n",
       " '__doc__',\n",
       " '__eq__',\n",
       " '__format__',\n",
       " '__ge__',\n",
       " '__getattribute__',\n",
       " '__gt__',\n",
       " '__hash__',\n",
       " '__init__',\n",
       " '__init_subclass__',\n",
       " '__le__',\n",
       " '__lt__',\n",
       " '__module__',\n",
       " '__ne__',\n",
       " '__new__',\n",
       " '__reduce__',\n",
       " '__reduce_ex__',\n",
       " '__repr__',\n",
       " '__setattr__',\n",
       " '__sizeof__',\n",
       " '__str__',\n",
       " '__subclasshook__',\n",
       " '__weakref__',\n",
       " 'atom_descriptors',\n",
       " 'atom_features',\n",
       " 'atom_targets',\n",
       " 'bond_descriptors',\n",
       " 'bond_features',\n",
       " 'bond_targets',\n",
       " 'bond_types',\n",
       " 'constraints',\n",
       " 'extend_features',\n",
       " 'features',\n",
       " 'features_generator',\n",
       " 'is_adding_hs_list',\n",
       " 'is_explicit_h_list',\n",
       " 'is_keeping_atom_map_list',\n",
       " 'is_mol_list',\n",
       " 'is_reaction_list',\n",
       " 'max_molwt',\n",
       " 'mol',\n",
       " 'num_tasks',\n",
       " 'number_of_atoms',\n",
       " 'number_of_bonds',\n",
       " 'number_of_molecules',\n",
       " 'overwrite_default_atom_features',\n",
       " 'overwrite_default_bond_features',\n",
       " 'phase_features',\n",
       " 'raw_atom_descriptors',\n",
       " 'raw_atom_features',\n",
       " 'raw_atom_targets',\n",
       " 'raw_bond_descriptors',\n",
       " 'raw_bond_features',\n",
       " 'raw_bond_targets',\n",
       " 'raw_constraints',\n",
       " 'raw_features',\n",
       " 'raw_targets',\n",
       " 'reset_features_and_targets',\n",
       " 'row',\n",
       " 'set_atom_descriptors',\n",
       " 'set_atom_features',\n",
       " 'set_bond_descriptors',\n",
       " 'set_bond_features',\n",
       " 'set_features',\n",
       " 'set_targets',\n",
       " 'smiles',\n",
       " 'targets']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(sample_mol_datapoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8f309d2",
   "metadata": {},
   "source": [
    "# Create sample MPN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "5293f11c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zhengyjo/anaconda3/lib/python3.7/site-packages/torch/nn/init.py:405: UserWarning: Initializing zero-element tensors is a no-op\n",
      "  warnings.warn(\"Initializing zero-element tensors is a no-op\")\n"
     ]
    }
   ],
   "source": [
    "import chemprop\n",
    "# Constructing MPNN\n",
    "from chemprop.args import TrainArgs\n",
    "arguments = [\n",
    "    '--data_path', 'data/regression.csv',\n",
    "    '--dataset_type', 'regression',\n",
    "    '--save_dir', 'test_checkpoints_reg',\n",
    "    '--epochs', '5',\n",
    "    '--save_smiles_splits'\n",
    "]\n",
    "\n",
    "args = chemprop.args.TrainArgs().parse_args(arguments)\n",
    "mpnn = chemprop.models.MoleculeModel(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "7aedbcd0",
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "'MoleculeDatapoint' object is not subscriptable",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_23004/2972346342.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmpnn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample_mol_datapoint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1128\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1129\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1131\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1132\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/chemprop/models/model.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, batch, features_batch, atom_descriptors_batch, atom_features_batch, bond_descriptors_batch, bond_features_batch, constraints_batch, bond_types_batch)\u001b[0m\n\u001b[1;32m    282\u001b[0m                 \u001b[0matom_features_batch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    283\u001b[0m                 \u001b[0mbond_descriptors_batch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 284\u001b[0;31m                 \u001b[0mbond_features_batch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    285\u001b[0m             )\n\u001b[1;32m    286\u001b[0m             \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreadout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencodings\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1128\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1129\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1131\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1132\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/chemprop/models/mpn.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, batch, features_batch, atom_descriptors_batch, atom_features_batch, bond_descriptors_batch, bond_features_batch)\u001b[0m\n\u001b[1;32m    266\u001b[0m         \u001b[0;34m:\u001b[0m\u001b[0;32mreturn\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mA\u001b[0m \u001b[0mPyTorch\u001b[0m \u001b[0mtensor\u001b[0m \u001b[0mof\u001b[0m \u001b[0mshape\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_molecules\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0mcontaining\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mencoding\u001b[0m \u001b[0mof\u001b[0m \u001b[0meach\u001b[0m \u001b[0mmolecule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    267\u001b[0m         \"\"\"\n\u001b[0;32m--> 268\u001b[0;31m         \u001b[0;32mif\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mBatchMolGraph\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    269\u001b[0m             \u001b[0;31m# Group first molecules, second molecules, etc for mol2graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    270\u001b[0m             \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmols\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmols\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: 'MoleculeDatapoint' object is not subscriptable"
     ]
    }
   ],
   "source": [
    "mpnn(sample_mol_datapoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0434ba75",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
