{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b576506",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from logging import Logger\n",
    "import os\n",
    "from typing import Dict, List\n",
    "\n",
    "import numpy as np\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=np.VisibleDeprecationWarning) \n",
    "import pandas as pd\n",
    "from tensorboardX import SummaryWriter\n",
    "import torch\n",
    "from tqdm import trange\n",
    "from torch.optim.lr_scheduler import ExponentialLR\n",
    "\n",
    "# from .evaluate import evaluate, evaluate_predictions\n",
    "# from .predict import predict\n",
    "# from .train import train\n",
    "# from .loss_functions import get_loss_func\n",
    "from chemprop.spectra_utils import normalize_spectra, load_phase_mask\n",
    "from chemprop.args import TrainArgs\n",
    "from chemprop.constants import MODEL_FILE_NAME\n",
    "from chemprop.data import get_class_sizes, get_data, MoleculeDataLoader, MoleculeDataset, set_cache_graph, split_data\n",
    "from chemprop.models import MoleculeModel\n",
    "from chemprop.nn_utils import param_count, param_count_all\n",
    "from chemprop.utils import build_optimizer, build_lr_scheduler, load_checkpoint, makedirs, \\\n",
    "    save_checkpoint, save_smiles_splits, load_frzn_model, multitask_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b64118c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "check = chemprop.args.TrainArgs()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea528cb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "dir(check)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "064e89e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import chemprop\n",
    "# Constructing MPNN\n",
    "from chemprop.args import TrainArgs\n",
    "arguments = [\n",
    "    '--data_path', 'data/regression.csv',\n",
    "    '--dataset_type', 'regression',\n",
    "    '--save_dir', 'test_checkpoints_reg',\n",
    "    '--epochs', '5',\n",
    "    '--save_smiles_splits'\n",
    "    #'--encoder_path','M3_KMGCL_encoder_smiles.pt'\n",
    "]\n",
    "\n",
    "args = chemprop.args.TrainArgs().parse_args(arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e9afad97",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'regression'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args.dataset_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08278f9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "args.data_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dc7ed73",
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.models import MoleculeModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aedf8285",
   "metadata": {},
   "outputs": [],
   "source": [
    "bbbp_dir = '../../data/bbbp.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19b94ba9",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_process = get_data(bbbp_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb6384ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a150c624",
   "metadata": {},
   "outputs": [],
   "source": [
    "type(data_process)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04c10722",
   "metadata": {},
   "outputs": [],
   "source": [
    "import chemprop\n",
    "# Constructing MPNN\n",
    "from chemprop.args import TrainArgs\n",
    "arguments = [\n",
    "    '--data_path', 'data/regression.csv',\n",
    "    '--dataset_type', 'regression',\n",
    "    '--save_dir', 'test_checkpoints_reg',\n",
    "    '--epochs', '5',\n",
    "    '--save_smiles_splits'\n",
    "]\n",
    "\n",
    "args = chemprop.args.TrainArgs().parse_args(arguments)\n",
    "mpnn = chemprop.models.MoleculeModel(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "535f77d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH = '../../test_checkpoints_reg/fold_0/model_0/model.pt'\n",
    "\n",
    "model = model = load_checkpoint(PATH, device=args.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc638d95",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = data_process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58be7958",
   "metadata": {},
   "outputs": [],
   "source": [
    "mol_batch = batch.batch_graph()\n",
    "features_batch = batch.features()\n",
    "target_batch = batch.targets()\n",
    "mask_batch = batch.mask()\n",
    "atom_descriptors_batch = batch.atom_descriptors()\n",
    "atom_features_batch = batch.atom_features()\n",
    "bond_descriptors_batch = batch.bond_descriptors()\n",
    "bond_features_batch = batch.bond_features()\n",
    "constraints_batch = batch.constraints()\n",
    "data_weights_batch = batch.data_weights()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62a57a0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "bond_types_batch = None\n",
    "bond_types_batch = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0afd5d9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = model.encoder(\n",
    "            mol_batch\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f1634f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78838af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "bond_features_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "871ba1af",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.encoder()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c8c66e9",
   "metadata": {},
   "source": [
    "# Set up a new model for contrastive learning\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c9070af",
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.models.mpn_contrastive import MPNCEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a445c4c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "mpnc = chemprop.models.mpn_contrastive(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0a827d6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
