{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "54b141db",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Union, Tuple\n",
    "from logging import Logger\n",
    "\n",
    "import numpy as np\n",
    "from rdkit import Chem\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from mpn import MPN\n",
    "from ffn import build_ffn, MultiReadout\n",
    "from chemprop.args import TrainArgs\n",
    "from chemprop.features import BatchMolGraph\n",
    "from chemprop.nn_utils import initialize_weights\n",
    "\n",
    "from chemprop.models import MoleculeModel\n",
    "\n",
    "from chemprop.utils import build_optimizer, build_lr_scheduler, load_checkpoint, makedirs, \\\n",
    "    save_checkpoint, save_smiles_splits, load_frzn_model, multitask_mean\n",
    "\n",
    "class MoleculeModel_Multiple(nn.Module):\n",
    "    \"\"\"A :class:`MoleculeModel` is a model which contains a message passing network following by feed-forward layers.\"\"\"\n",
    "\n",
    "    def __init__(self, args: TrainArgs,num_models,logger: Logger = None):\n",
    "        \"\"\"\n",
    "        :param args: A :class:`~chemprop.args.TrainArgs` object containing model arguments.\n",
    "        \"\"\"\n",
    "        super(MoleculeModel_Multiple, self).__init__()\n",
    "        \n",
    "        self.logger = logger\n",
    "        self.model_lst = nn.ModuleList([])\n",
    "        self.coefficients = nn.ParameterList([])\n",
    "        self.num_models = num_models\n",
    "        self.encoder_path = args.encoder_path.split(\",\")\n",
    "        for model_idx in range(num_models):\n",
    "            temp =  MoleculeModel(args)\n",
    "            if args.encoder_path is not None:\n",
    "                temp = load_encoder_model(model=temp,path=self.encoder_path[model_idx],current_args=args, logger=self.logger)\n",
    "            self.model_lst.append(temp.to(args.device))\n",
    "            self.coefficients.append(nn.Parameter(torch.tensor(1.0)))\n",
    "        \n",
    "        self.classification = args.dataset_type == \"classification\"\n",
    "        self.multiclass = args.dataset_type == \"multiclass\"\n",
    "        self.loss_function = args.loss_function\n",
    "\n",
    "        if hasattr(args, \"train_class_sizes\"):\n",
    "            self.train_class_sizes = args.train_class_sizes\n",
    "        else:\n",
    "            self.train_class_sizes = None\n",
    "\n",
    "        # when using cross entropy losses, no sigmoid or softmax during training. But they are needed for mcc loss.\n",
    "        if self.classification or self.multiclass:\n",
    "            self.no_training_normalization = args.loss_function in [\n",
    "                \"cross_entropy\",\n",
    "                \"binary_cross_entropy\",\n",
    "            ]\n",
    "\n",
    "        self.is_atom_bond_targets = args.is_atom_bond_targets\n",
    "\n",
    "        if self.is_atom_bond_targets:\n",
    "            self.atom_targets, self.bond_targets = args.atom_targets, args.bond_targets\n",
    "            self.atom_constraints, self.bond_constraints = (\n",
    "                args.atom_constraints,\n",
    "                args.bond_constraints,\n",
    "            )\n",
    "            self.adding_bond_types = args.adding_bond_types\n",
    "\n",
    "        self.relative_output_size = 1\n",
    "        if self.multiclass:\n",
    "            self.relative_output_size *= args.multiclass_num_classes\n",
    "        if self.loss_function == \"mve\":\n",
    "            self.relative_output_size *= 2  # return means and variances\n",
    "        if self.loss_function == \"dirichlet\" and self.classification:\n",
    "            self.relative_output_size *= (\n",
    "                2  # return dirichlet parameters for positive and negative class\n",
    "            )\n",
    "        if self.loss_function == \"evidential\":\n",
    "            self.relative_output_size *= (\n",
    "                4  # return four evidential parameters: gamma, lambda, alpha, beta\n",
    "            )\n",
    "\n",
    "        if self.classification:\n",
    "            self.sigmoid = nn.Sigmoid()\n",
    "\n",
    "        if self.multiclass:\n",
    "            self.multiclass_softmax = nn.Softmax(dim=2)\n",
    "\n",
    "        if self.loss_function in [\"mve\", \"evidential\", \"dirichlet\"]:\n",
    "            self.softplus = nn.Softplus()\n",
    "            \n",
    "        \n",
    "        if self.is_atom_bond_targets:\n",
    "            self.output_size = self.relative_output_size\n",
    "        else:   \n",
    "            self.output_size = self.relative_output_size * args.num_tasks,\n",
    "            \n",
    "    def forward(\n",
    "        self,\n",
    "        batch: Union[\n",
    "            List[List[str]],\n",
    "            List[List[Chem.Mol]],\n",
    "            List[List[Tuple[Chem.Mol, Chem.Mol]]],\n",
    "            List[BatchMolGraph],\n",
    "        ],\n",
    "        features_batch: List[np.ndarray] = None,\n",
    "        atom_descriptors_batch: List[np.ndarray] = None,\n",
    "        atom_features_batch: List[np.ndarray] = None,\n",
    "        bond_descriptors_batch: List[np.ndarray] = None,\n",
    "        bond_features_batch: List[np.ndarray] = None,\n",
    "        constraints_batch: List[torch.Tensor] = None,\n",
    "        bond_types_batch: List[torch.Tensor] = None,\n",
    "    ) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Runs the :class:`MoleculeModel` on input.\n",
    "\n",
    "        :param batch: A list of list of SMILES, a list of list of RDKit molecules, or a\n",
    "                      list of :class:`~chemprop.features.featurization.BatchMolGraph`.\n",
    "                      The outer list or BatchMolGraph is of length :code:`num_molecules` (number of datapoints in batch),\n",
    "                      the inner list is of length :code:`number_of_molecules` (number of molecules per datapoint).\n",
    "        :param features_batch: A list of numpy arrays containing additional features.\n",
    "        :param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors.\n",
    "        :param atom_features_batch: A list of numpy arrays containing additional atom features.\n",
    "        :param bond_descriptors_batch: A list of numpy arrays containing additional bond descriptors.\n",
    "        :param bond_features_batch: A list of numpy arrays containing additional bond features.\n",
    "        :param constraints_batch: A list of PyTorch tensors which applies constraint on atomic/bond properties.\n",
    "        :param bond_types_batch: A list of PyTorch tensors storing bond types of each bond determined by RDKit molecules.\n",
    "        :return: The output of the :class:`MoleculeModel`, containing a list of property predictions.\n",
    "        \"\"\"\n",
    "        outputs_lst = []\n",
    "        for idx in range(self.num_models):\n",
    "            temp = self.model_lst[idx](\n",
    "                batch,\n",
    "                features_batch,\n",
    "                atom_descriptors_batch,\n",
    "                atom_features_batch,\n",
    "                bond_descriptors_batch,\n",
    "                bond_features_batch,\n",
    "                constraints_batch,\n",
    "                bond_types_batch,\n",
    "            )\n",
    "            outputs_lst.append(temp * self.coefficients[idx])\n",
    "            \n",
    "        output = outputs_lst[0]\n",
    "        \n",
    "        for idx in (1,self.num_models-1):\n",
    "            output = output + outputs_lst[idx]\n",
    "            \n",
    "        # Don't apply sigmoid during training when using BCEWithLogitsLoss\n",
    "        if (\n",
    "            self.classification\n",
    "            and not (self.training and self.no_training_normalization)\n",
    "            and self.loss_function != \"dirichlet\"\n",
    "        ):\n",
    "            if self.is_atom_bond_targets:\n",
    "                output = [self.sigmoid(x) for x in output]\n",
    "            else:\n",
    "                output = self.sigmoid(output)\n",
    "        if self.multiclass:\n",
    "            output = output.reshape(\n",
    "                (output.shape[0], -1, self.num_classes)\n",
    "            )  # batch size x num targets x num classes per target\n",
    "            if (\n",
    "                not (self.training and self.no_training_normalization)\n",
    "                and self.loss_function != \"dirichlet\"\n",
    "            ):\n",
    "                output = self.multiclass_softmax(\n",
    "                    output\n",
    "                )  # to get probabilities during evaluation, but not during training when using CrossEntropyLoss\n",
    "\n",
    "        # Modify multi-input loss functions\n",
    "        if self.loss_function == \"mve\":\n",
    "            if self.is_atom_bond_targets:\n",
    "                outputs = []\n",
    "                for x in output:\n",
    "                    means, variances = torch.split(x, x.shape[1] // 2, dim=1)\n",
    "                    variances = self.softplus(variances)\n",
    "                    outputs.append(torch.cat([means, variances], axis=1))\n",
    "                return outputs\n",
    "            else:\n",
    "                means, variances = torch.split(output, output.shape[1] // 2, dim=1)\n",
    "                variances = self.softplus(variances)\n",
    "                output = torch.cat([means, variances], axis=1)\n",
    "        if self.loss_function == \"evidential\":\n",
    "            if self.is_atom_bond_targets:\n",
    "                outputs = []\n",
    "                for x in output:\n",
    "                    means, lambdas, alphas, betas = torch.split(\n",
    "                        x, x.shape[1] // 4, dim=1\n",
    "                    )\n",
    "                    lambdas = self.softplus(lambdas)  # + min_val\n",
    "                    alphas = (\n",
    "                        self.softplus(alphas) + 1\n",
    "                    )  # + min_val # add 1 for numerical contraints of Gamma function\n",
    "                    betas = self.softplus(betas)  # + min_val\n",
    "                    outputs.append(torch.cat([means, lambdas, alphas, betas], dim=1))\n",
    "                return outputs\n",
    "            else:\n",
    "                means, lambdas, alphas, betas = torch.split(\n",
    "                    output, output.shape[1] // 4, dim=1\n",
    "                )\n",
    "                lambdas = self.softplus(lambdas)  # + min_val\n",
    "                alphas = (\n",
    "                    self.softplus(alphas) + 1\n",
    "                )  # + min_val # add 1 for numerical contraints of Gamma function\n",
    "                betas = self.softplus(betas)  # + min_val\n",
    "                output = torch.cat([means, lambdas, alphas, betas], dim=1)\n",
    "        if self.loss_function == \"dirichlet\":\n",
    "            if self.is_atom_bond_targets:\n",
    "                outputs = []\n",
    "                for x in output:\n",
    "                    outputs.append(nn.functional.softplus(x) + 1)\n",
    "                return outputs\n",
    "            else:\n",
    "                output = nn.functional.softplus(output) + 1\n",
    "\n",
    "        return output\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5ce34aeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from logging import Logger\n",
    "import os\n",
    "from typing import Dict, List\n",
    "\n",
    "import numpy as np\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=np.VisibleDeprecationWarning) \n",
    "import pandas as pd\n",
    "from tensorboardX import SummaryWriter\n",
    "import torch\n",
    "from tqdm import trange\n",
    "from torch.optim.lr_scheduler import ExponentialLR\n",
    "\n",
    "# from .evaluate import evaluate, evaluate_predictions\n",
    "# from .predict import predict\n",
    "# from .train import train\n",
    "# from .loss_functions import get_loss_func\n",
    "from chemprop.spectra_utils import normalize_spectra, load_phase_mask\n",
    "from chemprop.args import TrainArgs\n",
    "from chemprop.constants import MODEL_FILE_NAME\n",
    "from chemprop.data import get_class_sizes, get_data, MoleculeDataLoader, MoleculeDataset, set_cache_graph, split_data\n",
    "from chemprop.models import MoleculeModel\n",
    "from chemprop.nn_utils import param_count, param_count_all\n",
    "from chemprop.utils import build_optimizer, build_lr_scheduler, load_checkpoint, makedirs, \\\n",
    "    save_checkpoint, save_smiles_splits, load_frzn_model, multitask_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "666c11f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "606b743b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.models import MoleculeModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f31b2322",
   "metadata": {},
   "outputs": [],
   "source": [
    "bbbp_dir = '../../data/bbbp.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6ad666fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Union, Tuple\n",
    "\n",
    "import numpy as np\n",
    "from rdkit import Chem\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from mpn import MPN\n",
    "from ffn import build_ffn, MultiReadout\n",
    "from chemprop.args import TrainArgs\n",
    "from chemprop.features import BatchMolGraph\n",
    "from chemprop.nn_utils import initialize_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "61f9c1d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MoleculeModel_Multiple(nn.Module):\n",
    "    \"\"\"A :class:`MoleculeModel` is a model which contains a message passing network following by feed-forward layers.\"\"\"\n",
    "\n",
    "    def __init__(self, args: TrainArgs,num_models):\n",
    "        \"\"\"\n",
    "        :param args: A :class:`~chemprop.args.TrainArgs` object containing model arguments.\n",
    "        \"\"\"\n",
    "        super(MoleculeModel_Multiple, self).__init__()\n",
    "        \n",
    "        self.model_lst = []\n",
    "        self.coefficients = []\n",
    "        self.num_models = num_models\n",
    "        self.encoder_path = args.encoder_path.split(\",\")\n",
    "        for model_idx in range(num_models):\n",
    "            temp =  MoleculeModel(args)\n",
    "            if args.encoder_path is not None:\n",
    "                temp = load_encoder_model(model=temp,path=self.encoder_path[model_idx],current_args=args, logger=logger)\n",
    "            self.model_lst.append(temp)\n",
    "            self.coefficients.append(nn.Parameter(torch.tensor(1.0)))\n",
    "        \n",
    "        self.classification = args.dataset_type == \"classification\"\n",
    "        self.multiclass = args.dataset_type == \"multiclass\"\n",
    "        self.loss_function = args.loss_function\n",
    "\n",
    "        if hasattr(args, \"train_class_sizes\"):\n",
    "            self.train_class_sizes = args.train_class_sizes\n",
    "        else:\n",
    "            self.train_class_sizes = None\n",
    "\n",
    "        # when using cross entropy losses, no sigmoid or softmax during training. But they are needed for mcc loss.\n",
    "        if self.classification or self.multiclass:\n",
    "            self.no_training_normalization = args.loss_function in [\n",
    "                \"cross_entropy\",\n",
    "                \"binary_cross_entropy\",\n",
    "            ]\n",
    "\n",
    "        self.is_atom_bond_targets = args.is_atom_bond_targets\n",
    "\n",
    "        if self.is_atom_bond_targets:\n",
    "            self.atom_targets, self.bond_targets = args.atom_targets, args.bond_targets\n",
    "            self.atom_constraints, self.bond_constraints = (\n",
    "                args.atom_constraints,\n",
    "                args.bond_constraints,\n",
    "            )\n",
    "            self.adding_bond_types = args.adding_bond_types\n",
    "\n",
    "        self.relative_output_size = 1\n",
    "        if self.multiclass:\n",
    "            self.relative_output_size *= args.multiclass_num_classes\n",
    "        if self.loss_function == \"mve\":\n",
    "            self.relative_output_size *= 2  # return means and variances\n",
    "        if self.loss_function == \"dirichlet\" and self.classification:\n",
    "            self.relative_output_size *= (\n",
    "                2  # return dirichlet parameters for positive and negative class\n",
    "            )\n",
    "        if self.loss_function == \"evidential\":\n",
    "            self.relative_output_size *= (\n",
    "                4  # return four evidential parameters: gamma, lambda, alpha, beta\n",
    "            )\n",
    "\n",
    "        if self.classification:\n",
    "            self.sigmoid = nn.Sigmoid()\n",
    "\n",
    "        if self.multiclass:\n",
    "            self.multiclass_softmax = nn.Softmax(dim=2)\n",
    "\n",
    "        if self.loss_function in [\"mve\", \"evidential\", \"dirichlet\"]:\n",
    "            self.softplus = nn.Softplus()\n",
    "            \n",
    "        \n",
    "        if self.is_atom_bond_targets:\n",
    "            self.output_size = self.relative_output_size\n",
    "        else:   \n",
    "            self.output_size = self.relative_output_size * args.num_tasks,\n",
    "            \n",
    "    def forward(\n",
    "        self,\n",
    "        batch: Union[\n",
    "            List[List[str]],\n",
    "            List[List[Chem.Mol]],\n",
    "            List[List[Tuple[Chem.Mol, Chem.Mol]]],\n",
    "            List[BatchMolGraph],\n",
    "        ],\n",
    "        features_batch: List[np.ndarray] = None,\n",
    "        atom_descriptors_batch: List[np.ndarray] = None,\n",
    "        atom_features_batch: List[np.ndarray] = None,\n",
    "        bond_descriptors_batch: List[np.ndarray] = None,\n",
    "        bond_features_batch: List[np.ndarray] = None,\n",
    "        constraints_batch: List[torch.Tensor] = None,\n",
    "        bond_types_batch: List[torch.Tensor] = None,\n",
    "    ) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Runs the :class:`MoleculeModel` on input.\n",
    "\n",
    "        :param batch: A list of list of SMILES, a list of list of RDKit molecules, or a\n",
    "                      list of :class:`~chemprop.features.featurization.BatchMolGraph`.\n",
    "                      The outer list or BatchMolGraph is of length :code:`num_molecules` (number of datapoints in batch),\n",
    "                      the inner list is of length :code:`number_of_molecules` (number of molecules per datapoint).\n",
    "        :param features_batch: A list of numpy arrays containing additional features.\n",
    "        :param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors.\n",
    "        :param atom_features_batch: A list of numpy arrays containing additional atom features.\n",
    "        :param bond_descriptors_batch: A list of numpy arrays containing additional bond descriptors.\n",
    "        :param bond_features_batch: A list of numpy arrays containing additional bond features.\n",
    "        :param constraints_batch: A list of PyTorch tensors which applies constraint on atomic/bond properties.\n",
    "        :param bond_types_batch: A list of PyTorch tensors storing bond types of each bond determined by RDKit molecules.\n",
    "        :return: The output of the :class:`MoleculeModel`, containing a list of property predictions.\n",
    "        \"\"\"\n",
    "        outputs_lst = []\n",
    "        for idx in range(self.num_models):\n",
    "            temp = self.model_lst[idx](\n",
    "                batch,\n",
    "                features_batch,\n",
    "                atom_descriptors_batch,\n",
    "                atom_features_batch,\n",
    "                bond_descriptors_batch,\n",
    "                bond_features_batch,\n",
    "                constraints_batch,\n",
    "                bond_types_batch,\n",
    "            )\n",
    "            outputs_lst.append(temp * self.coefficients[idx])\n",
    "            \n",
    "        output = outputs_lst[0]\n",
    "        \n",
    "        for idx in (1,self.num_models):\n",
    "            output = output + outputs_lst[idx]\n",
    "            \n",
    "        # Don't apply sigmoid during training when using BCEWithLogitsLoss\n",
    "        if (\n",
    "            self.classification\n",
    "            and not (self.training and self.no_training_normalization)\n",
    "            and self.loss_function != \"dirichlet\"\n",
    "        ):\n",
    "            if self.is_atom_bond_targets:\n",
    "                output = [self.sigmoid(x) for x in output]\n",
    "            else:\n",
    "                output = self.sigmoid(output)\n",
    "        if self.multiclass:\n",
    "            output = output.reshape(\n",
    "                (output.shape[0], -1, self.num_classes)\n",
    "            )  # batch size x num targets x num classes per target\n",
    "            if (\n",
    "                not (self.training and self.no_training_normalization)\n",
    "                and self.loss_function != \"dirichlet\"\n",
    "            ):\n",
    "                output = self.multiclass_softmax(\n",
    "                    output\n",
    "                )  # to get probabilities during evaluation, but not during training when using CrossEntropyLoss\n",
    "\n",
    "        # Modify multi-input loss functions\n",
    "        if self.loss_function == \"mve\":\n",
    "            if self.is_atom_bond_targets:\n",
    "                outputs = []\n",
    "                for x in output:\n",
    "                    means, variances = torch.split(x, x.shape[1] // 2, dim=1)\n",
    "                    variances = self.softplus(variances)\n",
    "                    outputs.append(torch.cat([means, variances], axis=1))\n",
    "                return outputs\n",
    "            else:\n",
    "                means, variances = torch.split(output, output.shape[1] // 2, dim=1)\n",
    "                variances = self.softplus(variances)\n",
    "                output = torch.cat([means, variances], axis=1)\n",
    "        if self.loss_function == \"evidential\":\n",
    "            if self.is_atom_bond_targets:\n",
    "                outputs = []\n",
    "                for x in output:\n",
    "                    means, lambdas, alphas, betas = torch.split(\n",
    "                        x, x.shape[1] // 4, dim=1\n",
    "                    )\n",
    "                    lambdas = self.softplus(lambdas)  # + min_val\n",
    "                    alphas = (\n",
    "                        self.softplus(alphas) + 1\n",
    "                    )  # + min_val # add 1 for numerical contraints of Gamma function\n",
    "                    betas = self.softplus(betas)  # + min_val\n",
    "                    outputs.append(torch.cat([means, lambdas, alphas, betas], dim=1))\n",
    "                return outputs\n",
    "            else:\n",
    "                means, lambdas, alphas, betas = torch.split(\n",
    "                    output, output.shape[1] // 4, dim=1\n",
    "                )\n",
    "                lambdas = self.softplus(lambdas)  # + min_val\n",
    "                alphas = (\n",
    "                    self.softplus(alphas) + 1\n",
    "                )  # + min_val # add 1 for numerical contraints of Gamma function\n",
    "                betas = self.softplus(betas)  # + min_val\n",
    "                output = torch.cat([means, lambdas, alphas, betas], dim=1)\n",
    "        if self.loss_function == \"dirichlet\":\n",
    "            if self.is_atom_bond_targets:\n",
    "                outputs = []\n",
    "                for x in output:\n",
    "                    outputs.append(nn.functional.softplus(x) + 1)\n",
    "                return outputs\n",
    "            else:\n",
    "                output = nn.functional.softplus(output) + 1\n",
    "\n",
    "        return output\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "36e4ee29",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_paths_arr = ['../../M3_KMGCL_encoder_smiles_alpha_0.0_01102024.pt','../../M3_KMGCL_encoder_image_alpha_0.0_01102024.pt','../../M3_KMGCL_encoder_nmr_alpha_0.0_01102024.pt','../../M3_KMGCL_encoder_fusion_fingerprint_alpha_0.0_01102024.pt','../../M3_KMGCL_encoder_fusion_nmr_alpha_1_01102024.pt'] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "20136e9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_paths = ','.join(str(v) for v in encoder_paths_arr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "daae4446",
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.args import TrainArgs\n",
    "# arguments = [\n",
    "#     '--data_path', bbbp_dir,\n",
    "#     '--dataset_type', 'classification',\n",
    "#     '--multi_modality_ensemble','True',\n",
    "#     '--save_dir', 'bbbp_test_checkpoints_multi',\n",
    "#     '--epochs', '1',\n",
    "#     '--encoder_path',encoder_paths,\n",
    "#     '--save_smiles_splits'\n",
    "# ]\n",
    "arguments = [\n",
    "    '--data_path', bbbp_dir,\n",
    "    '--dataset_type', 'classification'\n",
    "]\n",
    "\n",
    "args = TrainArgs().parse_args(arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1cd47a74",
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'TrainArgs' object has no attribute 'multi_modality_ensemble'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_23906/3287233286.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmulti_modality_ensemble\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m: 'TrainArgs' object has no attribute 'multi_modality_ensemble'"
     ]
    }
   ],
   "source": [
    "args.multi_modality_ensemble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1a1086c1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__annotations__',\n",
       " '__class__',\n",
       " '__deepcopy__',\n",
       " '__delattr__',\n",
       " '__dict__',\n",
       " '__dir__',\n",
       " '__doc__',\n",
       " '__eq__',\n",
       " '__format__',\n",
       " '__ge__',\n",
       " '__getattribute__',\n",
       " '__getstate__',\n",
       " '__gt__',\n",
       " '__hash__',\n",
       " '__init__',\n",
       " '__init_subclass__',\n",
       " '__le__',\n",
       " '__lt__',\n",
       " '__module__',\n",
       " '__ne__',\n",
       " '__new__',\n",
       " '__reduce__',\n",
       " '__reduce_ex__',\n",
       " '__repr__',\n",
       " '__setattr__',\n",
       " '__setstate__',\n",
       " '__sizeof__',\n",
       " '__str__',\n",
       " '__subclasshook__',\n",
       " '__weakref__',\n",
       " '_action_groups',\n",
       " '_actions',\n",
       " '_add_action',\n",
       " '_add_argument',\n",
       " '_add_arguments',\n",
       " '_add_container_actions',\n",
       " '_add_subparsers',\n",
       " '_annotations',\n",
       " '_atom_constraints',\n",
       " '_atom_descriptors_size',\n",
       " '_atom_features_size',\n",
       " '_bond_constraints',\n",
       " '_bond_descriptors_size',\n",
       " '_bond_features_size',\n",
       " '_check_conflict',\n",
       " '_check_value',\n",
       " '_configure',\n",
       " '_crossval_index_sets',\n",
       " '_defaults',\n",
       " '_explicit_bool',\n",
       " '_features_size',\n",
       " '_get_annotations',\n",
       " '_get_args',\n",
       " '_get_argument_names',\n",
       " '_get_class_dict',\n",
       " '_get_class_variables',\n",
       " '_get_formatter',\n",
       " '_get_from_self_and_super',\n",
       " '_get_handler',\n",
       " '_get_kwargs',\n",
       " '_get_nargs_pattern',\n",
       " '_get_option_tuples',\n",
       " '_get_optional_actions',\n",
       " '_get_optional_kwargs',\n",
       " '_get_positional_actions',\n",
       " '_get_positional_kwargs',\n",
       " '_get_value',\n",
       " '_get_values',\n",
       " '_handle_conflict_error',\n",
       " '_handle_conflict_resolve',\n",
       " '_has_negative_number_optionals',\n",
       " '_initialized',\n",
       " '_load_from_config_files',\n",
       " '_log_all',\n",
       " '_match_argument',\n",
       " '_match_arguments_partial',\n",
       " '_mutually_exclusive_groups',\n",
       " '_negative_number_matcher',\n",
       " '_num_tasks',\n",
       " '_option_string_actions',\n",
       " '_optionals',\n",
       " '_parse_known_args',\n",
       " '_parse_optional',\n",
       " '_parsed',\n",
       " '_pop_action_class',\n",
       " '_positionals',\n",
       " '_print_message',\n",
       " '_read_args_from_files',\n",
       " '_registries',\n",
       " '_registry_get',\n",
       " '_remove_action',\n",
       " '_subparser_buffer',\n",
       " '_subparsers',\n",
       " '_task_names',\n",
       " '_train_data_size',\n",
       " '_underscores_to_dashes',\n",
       " 'activation',\n",
       " 'add_argument',\n",
       " 'add_argument_group',\n",
       " 'add_help',\n",
       " 'add_mutually_exclusive_group',\n",
       " 'add_subparser',\n",
       " 'add_subparsers',\n",
       " 'adding_bond_types',\n",
       " 'adding_h',\n",
       " 'aggregation',\n",
       " 'aggregation_norm',\n",
       " 'allow_abbrev',\n",
       " 'args_from_configs',\n",
       " 'argument_buffer',\n",
       " 'argument_default',\n",
       " 'as_dict',\n",
       " 'atom_constraints',\n",
       " 'atom_descriptor_scaling',\n",
       " 'atom_descriptors',\n",
       " 'atom_descriptors_path',\n",
       " 'atom_descriptors_size',\n",
       " 'atom_features_size',\n",
       " 'atom_messages',\n",
       " 'atom_targets',\n",
       " 'batch_size',\n",
       " 'bias',\n",
       " 'bias_solvent',\n",
       " 'bond_constraints',\n",
       " 'bond_descriptor_scaling',\n",
       " 'bond_descriptors',\n",
       " 'bond_descriptors_path',\n",
       " 'bond_descriptors_size',\n",
       " 'bond_features_size',\n",
       " 'bond_targets',\n",
       " 'cache_cutoff',\n",
       " 'checkpoint_dir',\n",
       " 'checkpoint_frzn',\n",
       " 'checkpoint_path',\n",
       " 'checkpoint_paths',\n",
       " 'class_balance',\n",
       " 'class_variables',\n",
       " 'config_path',\n",
       " 'configure',\n",
       " 'conflict_handler',\n",
       " 'constraints_path',\n",
       " 'convert_arg_line_to_args',\n",
       " 'crossval_index_dir',\n",
       " 'crossval_index_file',\n",
       " 'crossval_index_sets',\n",
       " 'cuda',\n",
       " 'data_path',\n",
       " 'data_weights_path',\n",
       " 'dataset_type',\n",
       " 'depth',\n",
       " 'depth_solvent',\n",
       " 'description',\n",
       " 'device',\n",
       " 'dropout',\n",
       " 'empty_cache',\n",
       " 'ensemble_size',\n",
       " 'epilog',\n",
       " 'epochs',\n",
       " 'error',\n",
       " 'evidential_regularization',\n",
       " 'exit',\n",
       " 'explicit_h',\n",
       " 'extra_args',\n",
       " 'extra_metrics',\n",
       " 'features_generator',\n",
       " 'features_only',\n",
       " 'features_path',\n",
       " 'features_scaling',\n",
       " 'features_size',\n",
       " 'ffn_hidden_size',\n",
       " 'ffn_num_layers',\n",
       " 'final_lr',\n",
       " 'folds_file',\n",
       " 'format_help',\n",
       " 'format_usage',\n",
       " 'formatter_class',\n",
       " 'freeze_first_only',\n",
       " 'from_dict',\n",
       " 'fromfile_prefix_chars',\n",
       " 'frzn_ffn_layers',\n",
       " 'get_default',\n",
       " 'get_reproducibility_info',\n",
       " 'gpu',\n",
       " 'grad_clip',\n",
       " 'hidden_size',\n",
       " 'hidden_size_solvent',\n",
       " 'ignore_columns',\n",
       " 'init_lr',\n",
       " 'is_atom_bond_targets',\n",
       " 'keeping_atom_map',\n",
       " 'load',\n",
       " 'log_frequency',\n",
       " 'loss_function',\n",
       " 'max_data_size',\n",
       " 'max_lr',\n",
       " 'metric',\n",
       " 'metrics',\n",
       " 'minimize_score',\n",
       " 'mpn_shared',\n",
       " 'multiclass_num_classes',\n",
       " 'no_adding_bond_types',\n",
       " 'no_atom_descriptor_scaling',\n",
       " 'no_bond_descriptor_scaling',\n",
       " 'no_cache_mol',\n",
       " 'no_cuda',\n",
       " 'no_features_scaling',\n",
       " 'no_shared_atom_bond_ffn',\n",
       " 'num_folds',\n",
       " 'num_lrs',\n",
       " 'num_tasks',\n",
       " 'num_workers',\n",
       " 'number_of_molecules',\n",
       " 'overwrite_default_atom_features',\n",
       " 'overwrite_default_bond_features',\n",
       " 'parse_args',\n",
       " 'parse_intermixed_args',\n",
       " 'parse_known_args',\n",
       " 'parse_known_intermixed_args',\n",
       " 'phase_features_path',\n",
       " 'prefix_chars',\n",
       " 'print_help',\n",
       " 'print_usage',\n",
       " 'process_args',\n",
       " 'prog',\n",
       " 'pytorch_seed',\n",
       " 'quiet',\n",
       " 'reaction',\n",
       " 'reaction_mode',\n",
       " 'reaction_solvent',\n",
       " 'register',\n",
       " 'resume_experiment',\n",
       " 'save',\n",
       " 'save_dir',\n",
       " 'save_preds',\n",
       " 'save_smiles_splits',\n",
       " 'seed',\n",
       " 'separate_test_atom_descriptors_path',\n",
       " 'separate_test_bond_descriptors_path',\n",
       " 'separate_test_constraints_path',\n",
       " 'separate_test_features_path',\n",
       " 'separate_test_path',\n",
       " 'separate_test_phase_features_path',\n",
       " 'separate_val_atom_descriptors_path',\n",
       " 'separate_val_bond_descriptors_path',\n",
       " 'separate_val_constraints_path',\n",
       " 'separate_val_features_path',\n",
       " 'separate_val_path',\n",
       " 'separate_val_phase_features_path',\n",
       " 'set_defaults',\n",
       " 'shared_atom_bond_ffn',\n",
       " 'show_individual_scores',\n",
       " 'smiles_columns',\n",
       " 'spectra_activation',\n",
       " 'spectra_phase_mask_path',\n",
       " 'spectra_target_floor',\n",
       " 'split_key_molecule',\n",
       " 'split_sizes',\n",
       " 'split_type',\n",
       " 'target_columns',\n",
       " 'target_weights',\n",
       " 'task_names',\n",
       " 'test',\n",
       " 'test_fold_index',\n",
       " 'train_data_size',\n",
       " 'undirected',\n",
       " 'usage',\n",
       " 'use_input_features',\n",
       " 'val_fold_index',\n",
       " 'warmup_epochs',\n",
       " 'weights_ffn_num_layers']"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8f78b6b7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'get_data' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_30723/1456936766.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdata_process\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbbbp_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'get_data' is not defined"
     ]
    }
   ],
   "source": [
    "data_process = get_data(bbbp_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c3adad4d",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'data_process' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_30723/3416432980.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdata_process\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'data_process' is not defined"
     ]
    }
   ],
   "source": [
    "data_process.batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "91354f48",
   "metadata": {},
   "outputs": [],
   "source": [
    "import chemprop\n",
    "# Constructing MPNN\n",
    "from chemprop.args import TrainArgs\n",
    "arguments = [\n",
    "    '--data_path', 'data/regression.csv',\n",
    "    '--dataset_type', 'regression',\n",
    "    '--save_dir', 'test_checkpoints_reg',\n",
    "    '--epochs', '5',\n",
    "    '--save_smiles_splits'\n",
    "]\n",
    "\n",
    "args = chemprop.args.TrainArgs().parse_args(arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23d88a77",
   "metadata": {},
   "outputs": [],
   "source": [
    "mpnc = MPNCEncoder(args,300,300).to(\"cuda:0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0d1666e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zhengyjo/anaconda3/lib/python3.7/site-packages/torch/nn/init.py:405: UserWarning: Initializing zero-element tensors is a no-op\n",
      "  warnings.warn(\"Initializing zero-element tensors is a no-op\")\n"
     ]
    }
   ],
   "source": [
    "from chemprop.models import MoleculeModel\n",
    "model = MoleculeModel(args).to(\"cuda:0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2a0bb9a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_input = data_process.batch_graph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c17ed978",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = model(test_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "aec5c024",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([], device='cuda:0', size=(2039, 0), grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a92c3515",
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.utils import build_optimizer, build_lr_scheduler, load_checkpoint, makedirs, \\\n",
    "    save_checkpoint, save_smiles_splits, load_frzn_model, multitask_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79da594e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
