{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2b576506",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from logging import Logger\n",
    "import os\n",
    "from typing import Dict, List\n",
    "\n",
    "import numpy as np\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=np.VisibleDeprecationWarning) \n",
    "import pandas as pd\n",
    "from tensorboardX import SummaryWriter\n",
    "import torch\n",
    "from tqdm import trange\n",
    "from torch.optim.lr_scheduler import ExponentialLR\n",
    "\n",
    "# from .evaluate import evaluate, evaluate_predictions\n",
    "# from .predict import predict\n",
    "# from .train import train\n",
    "# from .loss_functions import get_loss_func\n",
    "from chemprop.spectra_utils import normalize_spectra, load_phase_mask\n",
    "from chemprop.args import TrainArgs\n",
    "from chemprop.constants import MODEL_FILE_NAME\n",
    "from chemprop.data import get_class_sizes, get_data, MoleculeDataLoader, MoleculeDataset, set_cache_graph, split_data\n",
    "from chemprop.models import MoleculeModel\n",
    "from chemprop.nn_utils import param_count, param_count_all\n",
    "from chemprop.utils import build_optimizer, build_lr_scheduler, load_checkpoint, makedirs, \\\n",
    "    save_checkpoint, save_smiles_splits, load_frzn_model, multitask_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "6dc7ed73",
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.models import MoleculeModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "aedf8285",
   "metadata": {},
   "outputs": [],
   "source": [
    "bbbp_dir = '../../data/bbbp.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "19b94ba9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2039it [00:00, 244943.03it/s]\n",
      "100%|███████████████████████████████████| 2039/2039 [00:00<00:00, 179197.19it/s]\n",
      "  0%|                                                  | 0/2039 [00:00<?, ?it/s][15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      " 14%|█████▏                                | 281/2039 [00:00<00:00, 2807.75it/s][15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      " 28%|██████████▍                           | 562/2039 [00:00<00:00, 2800.31it/s][15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      " 41%|███████████████▋                      | 843/2039 [00:00<00:00, 2689.61it/s][15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      " 55%|████████████████████▏                | 1113/2039 [00:00<00:00, 2654.12it/s][15:18:44] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      " 69%|█████████████████████████▍           | 1401/2039 [00:00<00:00, 2732.15it/s][15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      " 83%|██████████████████████████████▊      | 1696/2039 [00:00<00:00, 2802.47it/s][15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "[15:18:45] WARNING: not removing hydrogen atom without neighbors\n",
      "100%|█████████████████████████████████████| 2039/2039 [00:00<00:00, 2787.57it/s]\n"
     ]
    }
   ],
   "source": [
    "data_process = get_data(bbbp_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fb6384ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<chemprop.data.data.MoleculeDataset at 0x7f3e1b3415d0>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a150c624",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "chemprop.data.data.MoleculeDataset"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(data_process)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "04c10722",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zhengyjo/anaconda3/lib/python3.7/site-packages/torch/nn/init.py:405: UserWarning: Initializing zero-element tensors is a no-op\n",
      "  warnings.warn(\"Initializing zero-element tensors is a no-op\")\n"
     ]
    }
   ],
   "source": [
    "import chemprop\n",
    "# Constructing MPNN\n",
    "from chemprop.args import TrainArgs\n",
    "arguments = [\n",
    "    '--data_path', 'data/regression.csv',\n",
    "    '--dataset_type', 'regression',\n",
    "    '--save_dir', 'test_checkpoints_reg',\n",
    "    '--epochs', '5',\n",
    "    '--save_smiles_splits'\n",
    "]\n",
    "\n",
    "args = chemprop.args.TrainArgs().parse_args(arguments)\n",
    "mpnn = chemprop.models.MoleculeModel(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "535f77d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading pretrained parameter \"encoder.encoder.0.cached_zero_vector\".\n",
      "Loading pretrained parameter \"encoder.encoder.0.W_i.weight\".\n",
      "Loading pretrained parameter \"encoder.encoder.0.W_h.weight\".\n",
      "Loading pretrained parameter \"encoder.encoder.0.W_o.weight\".\n",
      "Loading pretrained parameter \"encoder.encoder.0.W_o.bias\".\n",
      "Loading pretrained parameter \"readout.1.weight\".\n",
      "Loading pretrained parameter \"readout.1.bias\".\n",
      "Loading pretrained parameter \"readout.4.weight\".\n",
      "Loading pretrained parameter \"readout.4.bias\".\n",
      "Moving model to cuda\n"
     ]
    }
   ],
   "source": [
    "PATH = '../../test_checkpoints_reg/fold_0/model_0/model.pt'\n",
    "\n",
    "model = model = load_checkpoint(PATH, device=args.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "cc638d95",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = data_process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "58be7958",
   "metadata": {},
   "outputs": [],
   "source": [
    "mol_batch = batch.batch_graph()\n",
    "features_batch = batch.features()\n",
    "target_batch = batch.targets()\n",
    "mask_batch = batch.mask()\n",
    "atom_descriptors_batch = batch.atom_descriptors()\n",
    "atom_features_batch = batch.atom_features()\n",
    "bond_descriptors_batch = batch.bond_descriptors()\n",
    "bond_features_batch = batch.bond_features()\n",
    "constraints_batch = batch.constraints()\n",
    "data_weights_batch = batch.data_weights()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "62a57a0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "bond_types_batch = None\n",
    "bond_types_batch = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "0afd5d9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = model.encoder(\n",
    "            mol_batch\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "4f1634f3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2039, 300])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "78838af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "bond_features_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "871ba1af",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2039, 1])"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.encoder()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c8c66e9",
   "metadata": {},
   "source": [
    "# Set up a new model for contrastive learning\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "9c9070af",
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'chemprop.models.mpn_contrastive'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_18619/3812632089.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mchemprop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmpn_contrastive\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mMPNCEncoder\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'chemprop.models.mpn_contrastive'"
     ]
    }
   ],
   "source": [
    "from chemprop.models.mpn_contrastive import MPNCEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "a445c4c5",
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "module 'chemprop.models' has no attribute 'mpn_contrastive'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_18619/3263683179.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmpnc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mchemprop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmpn_contrastive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m: module 'chemprop.models' has no attribute 'mpn_contrastive'"
     ]
    }
   ],
   "source": [
    "mpnc = chemprop.models.mpn_contrastive(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0a827d6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
