{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bb32b9b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from NG import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7c770dc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "bace_dir = '../bace/1.pickle'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c5efdb2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data(x=[47, 3], edge_index=[2, 100], edge_attr=[100, 2])\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "\n",
    "# Open the pickle file for reading\n",
    "with open(bace_dir, 'rb') as file:\n",
    "    loaded_data = pickle.load(file)\n",
    "\n",
    "# Now you can work with the loaded data\n",
    "print(loaded_data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "671a9caf",
   "metadata": {},
   "outputs": [],
   "source": [
    "nodeEncoder = GNNNodeEncoder(5, 128, JK=\"last\", gnn_type='gin', aggr='add').to(\"cuda:0\")\n",
    "graph_model = GNNGraphEncoder(nodeEncoder, 128, graph_pooling=\"add\").to(\"cuda:0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "778f58fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_output = graph_model(loaded_data.to(\"cuda:0\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "559387d0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([36, 128])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_output[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "6bb86ec3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 128])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_output[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c99080b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.models import MoleculeModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "1650c4cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "bbbp_dir = '../../data/bbbp.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "5d98d239",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'get_data' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_26697/1456936766.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdata_process\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbbbp_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'get_data' is not defined"
     ]
    }
   ],
   "source": [
    "data_process = get_data(bbbp_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3678c953",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
