{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6a46ac38",
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'DatasetModels'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_28087/1885569700.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mDatasetModels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mKMGCLDataset\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mKMGCLDataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mDatasetModels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGraphDataset\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mGraphDataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mDatasetModels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCNMRDataset\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCNMRDataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mDatasetModels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPeakDataset\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPeakDataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mDatasetModels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mImageDataset\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mImageDataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'DatasetModels'"
     ]
    }
   ],
   "source": [
    "from DatasetModels.KMGCLDataset import KMGCLDataset\n",
    "from DatasetModels.GraphDataset import GraphDataset\n",
    "from DatasetModels.CNMRDataset import CNMRDataset\n",
    "from DatasetModels.PeakDataset import PeakDataset\n",
    "from DatasetModels.ImageDataset import ImageDataset\n",
    "from DatasetModels.SmilesDataset import SmilesDataset\n",
    "from DatasetModels.FingerPrintDataset import FingerPrintDataset\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from chemprop.args import TrainArgs\n",
    "from chemprop.constants import TEST_SCORES_FILE_NAME, TRAIN_LOGGER_NAME\n",
    "from chemprop.data.utils import get_data_cl\n",
    "from chemprop.data import get_data,get_task_names, MoleculeDataset, validate_dataset_type,MoleculeDataLoader\n",
    "from chemprop.utils import create_logger, makedirs, timeit, multitask_mean\n",
    "from chemprop.features import set_extra_atom_fdim, set_extra_bond_fdim, set_explicit_h, set_adding_hs, set_keeping_atom_map, set_reaction, reset_featurization_parameters\n",
    "\n",
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "def build_dataset_loader(config,pass_args):\n",
    "\n",
    "    dataframe = pd.read_csv(config.dataset_file)\n",
    "\n",
    "    #graph_dataset = GraphDataset(dataframe['graph'], config)\n",
    "    \n",
    "    # Zhengyang's modification to use chemprop Molecule dataset\n",
    "    \n",
    "    graph_dataset = get_data_cl(path=pass_args.data_path,\n",
    "        args=pass_args,smiles_columns= ['smiles']) \n",
    "    # Zhengyang's modification end\n",
    "    \n",
    "    peak_dataset = PeakDataset(dataframe['cnmr'], config)\n",
    "    nmr_dataset = CNMRDataset(dataframe['cnmr'], config)\n",
    "    image_dataset = ImageDataset(dataframe['image'], config)\n",
    "    fingerprint_dataset = FingerPrintDataset(dataframe['fingerprint'], config)\n",
    "    smiles_dataset = SmilesDataset(dataframe['smiles'], config)\n",
    "\n",
    "\n",
    "    # Create CGIPDataset instance\n",
    "    dataset = KMGCLDataset(graph_dataset, peak_dataset, nmr_dataset, image_dataset, smiles_dataset, fingerprint_dataset, config)\n",
    "    dataset_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=config.shuffle, drop_last=config.drop_last, collate_fn=dataset.collate_fn)\n",
    "\n",
    "    return dataset_loader\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dedcd639",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
