{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "54b141db",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mpn_contrastive import MPNCEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5ce34aeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from logging import Logger\n",
    "import os\n",
    "from typing import Dict, List\n",
    "\n",
    "import numpy as np\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=np.VisibleDeprecationWarning) \n",
    "import pandas as pd\n",
    "from tensorboardX import SummaryWriter\n",
    "import torch\n",
    "from tqdm import trange\n",
    "from torch.optim.lr_scheduler import ExponentialLR\n",
    "\n",
    "# from .evaluate import evaluate, evaluate_predictions\n",
    "# from .predict import predict\n",
    "# from .train import train\n",
    "# from .loss_functions import get_loss_func\n",
    "from chemprop.spectra_utils import normalize_spectra, load_phase_mask\n",
    "from chemprop.args import TrainArgs\n",
    "from chemprop.constants import MODEL_FILE_NAME\n",
    "from chemprop.data import get_class_sizes, get_data, MoleculeDataLoader, MoleculeDataset, set_cache_graph, split_data\n",
    "from chemprop.models import MoleculeModel\n",
    "from chemprop.nn_utils import param_count, param_count_all\n",
    "from chemprop.utils import build_optimizer, build_lr_scheduler, load_checkpoint, makedirs, \\\n",
    "    save_checkpoint, save_smiles_splits, load_frzn_model, multitask_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "606b743b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.models import MoleculeModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f31b2322",
   "metadata": {},
   "outputs": [],
   "source": [
    "bbbp_dir = '../../data/bbbp.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8f78b6b7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2039it [00:00, 245442.14it/s]\n",
      "100%|███████████████████████████████████| 2039/2039 [00:00<00:00, 176913.71it/s]\n",
      "  0%|                                                  | 0/2039 [00:00<?, ?it/s][17:50:02] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:02] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:02] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:02] WARNING: not removing hydrogen atom without neighbors\n",
      " 15%|█████▋                                | 302/2039 [00:00<00:00, 3009.02it/s][17:50:02] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:02] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:02] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:02] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:02] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:02] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:02] WARNING: not removing hydrogen atom without neighbors\n",
      " 30%|███████████▌                          | 619/2039 [00:00<00:00, 3099.96it/s][17:50:02] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:02] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:02] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      " 46%|█████████████████▎                    | 930/2039 [00:00<00:00, 2811.21it/s][17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      " 61%|██████████████████████▌              | 1240/2039 [00:00<00:00, 2917.84it/s][17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      " 76%|████████████████████████████▏        | 1550/2039 [00:00<00:00, 2980.88it/s][17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      " 91%|█████████████████████████████████▊   | 1863/2039 [00:00<00:00, 3028.67it/s][17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "[17:50:03] WARNING: not removing hydrogen atom without neighbors\n",
      "100%|█████████████████████████████████████| 2039/2039 [00:00<00:00, 3005.93it/s]\n"
     ]
    }
   ],
   "source": [
    "data_process = get_data(bbbp_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "91354f48",
   "metadata": {},
   "outputs": [],
   "source": [
    "import chemprop\n",
    "# Constructing MPNN\n",
    "from chemprop.args import TrainArgs\n",
    "arguments = [\n",
    "    '--data_path', 'data/regression.csv',\n",
    "    '--dataset_type', 'regression',\n",
    "    '--save_dir', 'test_checkpoints_reg',\n",
    "    '--epochs', '5',\n",
    "    '--save_smiles_splits'\n",
    "]\n",
    "\n",
    "args = chemprop.args.TrainArgs().parse_args(arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23d88a77",
   "metadata": {},
   "outputs": [],
   "source": [
    "mpnc = MPNCEncoder(args,300,300).to(\"cuda:0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0d1666e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zhengyjo/anaconda3/lib/python3.7/site-packages/torch/nn/init.py:405: UserWarning: Initializing zero-element tensors is a no-op\n",
      "  warnings.warn(\"Initializing zero-element tensors is a no-op\")\n"
     ]
    }
   ],
   "source": [
    "from chemprop.models import MoleculeModel\n",
    "model = MoleculeModel(args).to(\"cuda:0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2a0bb9a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_input = data_process.batch_graph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c17ed978",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = model(test_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "aec5c024",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([], device='cuda:0', size=(2039, 0), grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a92c3515",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
