{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 252,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Data(edge_index=[2, 1951], relation=[1951], num_nodes=385),\n",
       " torch.Size([1951]),\n",
       " torch.Size([1951, 85]),\n",
       " torch.Size([1951, 16]),\n",
       " torch.Size([1951]))"
      ]
     },
     "execution_count": 252,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "molecule_mask = torch.tensor(ent_type[:,0][G_tg.edge_index[0]] == 1)\n",
    "ent_type = torch.tensor(np.load(f'{data_path}/ent_type_onehot.npy'))\n",
    "\n",
    "nodes, _, _, edge_mask = k_hop_subgraph(0, 1, G_tg.edge_index)\n",
    "double_mask = molecule_mask * edge_mask\n",
    "mask_idx = torch.where(double_mask)[0]\n",
    "edge_subgraph = G_tg.edge_subgraph(mask_idx)\n",
    "subgraph = edge_subgraph.subgraph(nodes)\n",
    "masked_node_ids = edge_subgraph.edge_index[0] # (num_masked_nodes,)\n",
    "motif_labels = motifs[masked_node_ids] # (num_masked_nodes, motif_len)\n",
    "node_labels = ent_type[masked_node_ids] # (num_masked_nodes, num_ent_type)\n",
    "rel_labels = subgraph.relation\n",
    "\n",
    "subgraph, masked_node_ids.shape, motif_labels.shape, node_labels.shape, rel_labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 251,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0,  0,  0,  ..., 31, 31, 31])"
      ]
     },
     "execution_count": 251,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rel_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 238,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([    0,     0,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "            1,    29,    29,    29,    29,   661,   661,   661,   661,   661,\n",
      "          661,   661,   661,   661,   661,  2277,  2343,  2343,  2343, 18686,\n",
      "        18686, 18686, 18686, 18686, 18686, 18686, 28000, 28000, 28000, 28749,\n",
      "        29943, 29943, 29943, 29943, 29943, 29943, 33784, 33784, 33784, 33784,\n",
      "        33784, 39941, 39941, 39941, 39941, 39941, 39941, 42393, 42393, 42393,\n",
      "        42393, 42393, 42393, 42393, 42393, 42393, 42624, 42624, 42624, 42624,\n",
      "        42624, 42624, 43274, 43462, 48532, 48532, 48532, 48532, 48532, 48532,\n",
      "        58132, 58132, 58132, 58132, 58132, 58132, 58132, 60341, 65188])\n"
     ]
    }
   ],
   "source": [
    "print(G_tg.edge_index[:, mask_idx][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 239,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0,  0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n",
       "          1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  4,\n",
       "          5,  5,  5,  6,  6,  6,  6,  6,  6,  6,  7,  7,  7,  8,  9,  9,  9,  9,\n",
       "          9,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12,\n",
       "         12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 14, 15, 16, 16, 16, 16, 16, 16,\n",
       "         17, 17, 17, 17, 17, 17, 17, 18, 19],\n",
       "        [ 1,  7,  0,  2,  4,  5,  3, 18,  8,  6, 14, 11, 16,  7, 15, 13,  9, 10,\n",
       "         17, 19, 12,  1,  2,  5,  3,  1,  2, 16, 11, 12, 10,  6, 13,  9, 17,  1,\n",
       "          1,  2,  7,  1,  3, 13,  9, 17, 12, 11,  0,  1,  5,  1,  1,  3,  6, 13,\n",
       "         17, 12,  1,  3, 11, 16, 12,  1,  3,  6, 10, 16, 12,  1,  3,  6,  9, 10,\n",
       "         11, 16, 13, 17,  1,  3,  6,  9, 12, 17,  1,  1,  1,  3, 10, 11, 12, 17,\n",
       "          1,  3,  6,  9, 12, 13, 16,  1,  1]])"
      ]
     },
     "execution_count": 239,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subgraph.edge_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3433, 3432)"
      ]
     },
     "execution_count": 107,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s1.number_of_nodes(), s1.number_of_edges()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16979"
      ]
     },
     "execution_count": 94,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for c_id in tqdm(center_molecule_ids):\n",
    "    subgraph = s1 = get_subgraph(G, c_id)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/pj20/miniconda3/envs/kgc/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Getting everything prepared...\n",
      "Loading entity type labels...\n",
      "Loading center molecule motifs...\n",
      "Loading entire knowledge graph...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/pj20/miniconda3/envs/kgc/lib/python3.8/site-packages/torch_geometric/utils/convert.py:250: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  data[key] = torch.tensor(value)\n",
      "/home/pj20/gode/k_level_pretrain/pretrain.py:55: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  molecule_mask = torch.tensor(ent_type[:,0][G_tg.edge_index[0]] == 1) # (num_edges,)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading molecule mask...\n",
      "Loading KGE embeddings...\n",
      "Loading KGE embeddings...\n",
      "Initializing model...\n"
     ]
    }
   ],
   "source": [
    "from pretrain import *\n",
    "\n",
    "# your credentials\n",
    "\n",
    "params = {\n",
    "    \"lr\": 1e-4,\n",
    "    \"hidden_dim\": 200,\n",
    "    \"epochs\": 100,\n",
    "    \"lambda\": \"08_15_10\"\n",
    "}\n",
    "logger = get_logger(lr=params['lr'], hidden_dim=params['hidden_dim'], epochs=params['epochs'], lambda_=params['lambda'])\n",
    "\n",
    "\n",
    "# Data path\n",
    "data_path = '../data_process/pretrain_data'\n",
    "print('Getting everything prepared...')\n",
    "ent_type, motifs, G_tg, center_molecule_ids, molecule_mask = get_everything(data_path)\n",
    "\n",
    "# Load KGE embeddings\n",
    "# return:\n",
    "# entity_embedding: (num_ent, emb_dim)\n",
    "# relation_embedding: (num_rel, emb_dim)\n",
    "emb_path = '/data/pj20/molkg_kge/transe'\n",
    "print('Loading KGE embeddings...')\n",
    "entity_embedding, relation_embedding = load_kge_embeddings(emb_path)\n",
    "\n",
    "# Initialize model\n",
    "print('Initializing model...')\n",
    "model = KGNN(\n",
    "    node_emb=entity_embedding,\n",
    "    rel_emb=relation_embedding,\n",
    "    num_nodes=ent_type.shape[0],\n",
    "    num_rels=39,\n",
    "    embedding_dim=512,\n",
    "    hidden_dim=200,\n",
    "    num_motifs=motifs.shape[1],\n",
    "    lambda_edge=0.8,\n",
    "    lambda_motif=1.5,\n",
    "    lambda_mol_class=1\n",
    ")\n",
    "\n",
    "# Train\n",
    "device = torch.device('cuda:1')\n",
    "model.to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "epochs = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loss: 0:   0%|          | 0/7605 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "CUDA error: device-side assert triggered\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m/home/pj20/gode/k_level_pretrain/test.ipynb Cell 8\u001b[0m in \u001b[0;36m<cell line: 3>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Bsunlab-serv-03.cs.illinois.edu/home/pj20/gode/k_level_pretrain/test.ipynb#X41sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39m# Get dataloader\u001b[39;00m\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Bsunlab-serv-03.cs.illinois.edu/home/pj20/gode/k_level_pretrain/test.ipynb#X41sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a>\u001b[0m train_loader, val_loader \u001b[39m=\u001b[39m get_dataloader(G_tg, center_molecule_ids, molecule_mask, motifs, ent_type, batch_size\u001b[39m=\u001b[39m\u001b[39m4\u001b[39m)\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2Bsunlab-serv-03.cs.illinois.edu/home/pj20/gode/k_level_pretrain/test.ipynb#X41sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a>\u001b[0m train_loop(model, train_loader, val_loader, optimizer, device, epochs)\n",
      "File \u001b[0;32m~/gode/k_level_pretrain/pretrain.py:220\u001b[0m, in \u001b[0;36mtrain_loop\u001b[0;34m(model, train_loader, val_loader, optimizer, device, epochs, logger, run, early_stop)\u001b[0m\n\u001b[1;32m    218\u001b[0m early_stop_indicator \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m\n\u001b[1;32m    219\u001b[0m \u001b[39mfor\u001b[39;00m epoch \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m1\u001b[39m, epochs\u001b[39m+\u001b[39m\u001b[39m1\u001b[39m):\n\u001b[0;32m--> 220\u001b[0m     train_loss \u001b[39m=\u001b[39m train(model, train_loader, device, optimizer)\n\u001b[1;32m    221\u001b[0m     valid_loss, y_prob_edge_all, y_prob_motif_all, y_prob_node_all, y_true_edge_all, y_true_motif_all, y_true_node_all \u001b[39m=\u001b[39m validate(model, val_loader, device)\n\u001b[1;32m    223\u001b[0m     y_prob_edge_all, y_prob_motif_all, y_prob_node_all, y_true_edge_all, y_true_motif_all, y_true_node_all \u001b[39m=\u001b[39m detach_numpy(y_prob_edge_all), detach_numpy(y_prob_motif_all), detach_numpy(y_prob_node_all), detach_numpy(y_true_edge_all), detach_numpy(y_true_motif_all), detach_numpy(y_true_node_all)\n",
      "File \u001b[0;32m~/gode/k_level_pretrain/pretrain.py:129\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model, train_loader, device, optimizer)\u001b[0m\n\u001b[1;32m    127\u001b[0m \u001b[39mfor\u001b[39;00m i, data \u001b[39min\u001b[39;00m pbar:\n\u001b[1;32m    128\u001b[0m     pbar\u001b[39m.\u001b[39mset_description(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mloss: \u001b[39m\u001b[39m{\u001b[39;00mtraining_loss\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m)\n\u001b[0;32m--> 129\u001b[0m     data \u001b[39m=\u001b[39m data\u001b[39m.\u001b[39;49mto(device)\n\u001b[1;32m    130\u001b[0m     optimizer\u001b[39m.\u001b[39mzero_grad()\n\u001b[1;32m    132\u001b[0m     \u001b[39m# Forward\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/kgc/lib/python3.8/site-packages/torch_geometric/data/data.py:262\u001b[0m, in \u001b[0;36mBaseData.to\u001b[0;34m(self, device, non_blocking, *args)\u001b[0m\n\u001b[1;32m    258\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mto\u001b[39m(\u001b[39mself\u001b[39m, device: Union[\u001b[39mint\u001b[39m, \u001b[39mstr\u001b[39m], \u001b[39m*\u001b[39margs: List[\u001b[39mstr\u001b[39m],\n\u001b[1;32m    259\u001b[0m        non_blocking: \u001b[39mbool\u001b[39m \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m):\n\u001b[1;32m    260\u001b[0m     \u001b[39mr\u001b[39m\u001b[39m\"\"\"Performs tensor device conversion, either for all attributes or\u001b[39;00m\n\u001b[1;32m    261\u001b[0m \u001b[39m    only the ones given in :obj:`*args`.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 262\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mapply(\n\u001b[1;32m    263\u001b[0m         \u001b[39mlambda\u001b[39;49;00m x: x\u001b[39m.\u001b[39;49mto(device\u001b[39m=\u001b[39;49mdevice, non_blocking\u001b[39m=\u001b[39;49mnon_blocking), \u001b[39m*\u001b[39;49margs)\n",
      "File \u001b[0;32m~/miniconda3/envs/kgc/lib/python3.8/site-packages/torch_geometric/data/data.py:245\u001b[0m, in \u001b[0;36mBaseData.apply\u001b[0;34m(self, func, *args)\u001b[0m\n\u001b[1;32m    242\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"Applies the function :obj:`func`, either to all attributes or only\u001b[39;00m\n\u001b[1;32m    243\u001b[0m \u001b[39mthe ones given in :obj:`*args`.\"\"\"\u001b[39;00m\n\u001b[1;32m    244\u001b[0m \u001b[39mfor\u001b[39;00m store \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstores:\n\u001b[0;32m--> 245\u001b[0m     store\u001b[39m.\u001b[39;49mapply(func, \u001b[39m*\u001b[39;49margs)\n\u001b[1;32m    246\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\n",
      "File \u001b[0;32m~/miniconda3/envs/kgc/lib/python3.8/site-packages/torch_geometric/data/storage.py:183\u001b[0m, in \u001b[0;36mBaseStorage.apply\u001b[0;34m(self, func, *args)\u001b[0m\n\u001b[1;32m    180\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"Applies the function :obj:`func`, either to all attributes or only\u001b[39;00m\n\u001b[1;32m    181\u001b[0m \u001b[39mthe ones given in :obj:`*args`.\"\"\"\u001b[39;00m\n\u001b[1;32m    182\u001b[0m \u001b[39mfor\u001b[39;00m key, value \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mitems(\u001b[39m*\u001b[39margs):\n\u001b[0;32m--> 183\u001b[0m     \u001b[39mself\u001b[39m[key] \u001b[39m=\u001b[39m recursive_apply(value, func)\n\u001b[1;32m    184\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\n",
      "File \u001b[0;32m~/miniconda3/envs/kgc/lib/python3.8/site-packages/torch_geometric/data/storage.py:679\u001b[0m, in \u001b[0;36mrecursive_apply\u001b[0;34m(data, func)\u001b[0m\n\u001b[1;32m    677\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mrecursive_apply\u001b[39m(data: Any, func: Callable) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Any:\n\u001b[1;32m    678\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(data, Tensor):\n\u001b[0;32m--> 679\u001b[0m         \u001b[39mreturn\u001b[39;00m func(data)\n\u001b[1;32m    680\u001b[0m     \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(data, torch\u001b[39m.\u001b[39mnn\u001b[39m.\u001b[39mutils\u001b[39m.\u001b[39mrnn\u001b[39m.\u001b[39mPackedSequence):\n\u001b[1;32m    681\u001b[0m         \u001b[39mreturn\u001b[39;00m func(data)\n",
      "File \u001b[0;32m~/miniconda3/envs/kgc/lib/python3.8/site-packages/torch_geometric/data/data.py:263\u001b[0m, in \u001b[0;36mBaseData.to.<locals>.<lambda>\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m    258\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mto\u001b[39m(\u001b[39mself\u001b[39m, device: Union[\u001b[39mint\u001b[39m, \u001b[39mstr\u001b[39m], \u001b[39m*\u001b[39margs: List[\u001b[39mstr\u001b[39m],\n\u001b[1;32m    259\u001b[0m        non_blocking: \u001b[39mbool\u001b[39m \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m):\n\u001b[1;32m    260\u001b[0m     \u001b[39mr\u001b[39m\u001b[39m\"\"\"Performs tensor device conversion, either for all attributes or\u001b[39;00m\n\u001b[1;32m    261\u001b[0m \u001b[39m    only the ones given in :obj:`*args`.\"\"\"\u001b[39;00m\n\u001b[1;32m    262\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mapply(\n\u001b[0;32m--> 263\u001b[0m         \u001b[39mlambda\u001b[39;00m x: x\u001b[39m.\u001b[39;49mto(device\u001b[39m=\u001b[39;49mdevice, non_blocking\u001b[39m=\u001b[39;49mnon_blocking), \u001b[39m*\u001b[39margs)\n",
      "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: device-side assert triggered\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1."
     ]
    }
   ],
   "source": [
    "# Get dataloader\n",
    "train_loader, val_loader = get_dataloader(G_tg, center_molecule_ids, molecule_mask, motifs, ent_type, batch_size=4)\n",
    "train_loop(model, train_loader, val_loader, optimizer, device, epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataBatch(edge_index=[2, 122], relation=[122], num_nodes=22, masked_node_ids=[122], rel_ids=[122], center_molecule_id=[1], motif_labels=[122, 85], node_labels=[122, 16], batch=[22], ptr=[2])\n"
     ]
    }
   ],
   "source": [
    "for batch in train_loader:\n",
    "    print(batch)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 39])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.zeros((1, 39)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 39])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.zeros(39).reshape(1, -1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(2, dtype=torch.int32)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.where(torch.tensor([0, 1, 4, 5]) == 4)[0][0].int()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from pretrain import *\n",
    "\n",
    "def get_everything(data_path):\n",
    "    # Training Labels\n",
    "    ## Load entity type labels\n",
    "    print('Loading entity type labels...')\n",
    "    ent_type = torch.tensor(np.load(f'{data_path}/ent_type_onehot.npy')) # (num_ent, num_ent_type)\n",
    "\n",
    "    ## Load center molecule motifs\n",
    "    print('Loading center molecule motifs...')\n",
    "    motifs = []\n",
    "    with open(f'{data_path}/id2motifs.json', 'r') as f:\n",
    "        id2motifs = json.load(f)\n",
    "    motif_len = len(id2motifs['0'])\n",
    "    for i in range(len(ent_type)):\n",
    "        if str(i) in id2motifs.keys():\n",
    "            motifs.append(np.array(id2motifs[str(i)]))\n",
    "        else:\n",
    "            motifs.append(np.array([0] * motif_len))\n",
    "\n",
    "    motifs = torch.tensor(np.array(motifs), dtype=torch.long) # (num_ent, motif_len)\n",
    "\n",
    "    ## Center molecule ids\n",
    "    center_molecule_ids = torch.tensor([int(key) for key in id2motifs.keys()])\n",
    "\n",
    "    # Entire Knowledge Graph (MolKG)\n",
    "    print('Loading entire knowledge graph...')\n",
    "    G = nx.read_gpickle(f'{data_path}/graph.gpickle')\n",
    "    G_tg = from_networkx(G)\n",
    "\n",
    "    # molecule_mask\n",
    "    print('Loading molecule mask...')\n",
    "    molecule_mask = torch.tensor(ent_type[:,0][G_tg.edge_index[0]] == 1) # (num_edges,)\n",
    "\n",
    "    return ent_type, motifs, G_tg, center_molecule_ids, molecule_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading entity type labels...\n",
      "Loading center molecule motifs...\n",
      "Loading entire knowledge graph...\n",
      "Loading molecule mask...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_107140/1126395188.py:34: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  molecule_mask = torch.tensor(ent_type[:,0][G_tg.edge_index[0]] == 1) # (num_edges,)\n"
     ]
    }
   ],
   "source": [
    "ent_type, motifs, G_tg, center_molecule_ids, molecule_mask = get_everything(\"../data_process/pretrain_data\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_107140/2937019473.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  molecule_mask = torch.tensor(ent_type[:,0] == 1) # (num_edges,)\n"
     ]
    }
   ],
   "source": [
    "molecule_mask = torch.tensor(ent_type[:,0] == 1) # (num_edges,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([88811])"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.where(molecule_mask)[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.utils import k_hop_subgraph\n",
    "\n",
    "nodes, _, _, edge_mask = k_hop_subgraph(int(center_molecule_ids[0]), 1, G_tg.edge_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "double_mask = molecule_mask * edge_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1192001"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(torch.where(edge_mask)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = '/data/pj20/molkg/pretrain_data'\n",
    "with open(f'{data_path}/graph.pt', 'rb') as f:\n",
    "    G_tg = torch.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(70403)"
      ]
     },
     "execution_count": 134,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "G_tg.edge_index[1][1332817]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "metadata": {},
   "outputs": [],
   "source": [
    "molecule_head_mask = ent_type[:,0][G_tg.edge_index[0]] == 1 # (num_edges,)\n",
    "non_molecule_head_mask = ent_type[:,0][G_tg.edge_index[0]] == 0 # (num_edges,)\n",
    "molecule_tail_mask = ent_type[:,0][G_tg.edge_index[1]] == 1 # (num_edges,)\n",
    "non_molecule_tail_mask = ent_type[:,0][G_tg.edge_index[1]] == 0 # (num_edges,)\n",
    "\n",
    "mol2mol_mask = molecule_head_mask * molecule_tail_mask\n",
    "mol2nonmol_mask = molecule_head_mask * non_molecule_tail_mask\n",
    "nonmol2mol_mask = non_molecule_head_mask * molecule_tail_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([False, False, False,  ..., False, False, False])"
      ]
     },
     "execution_count": 102,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "916534"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(torch.where(mol2nonmol_mask)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([184819, 16])"
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ent_type.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "molecule_mask[center_molecule_ids[10000]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(43924)"
      ]
     },
     "execution_count": 135,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "center_molecule_ids[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.utils import k_hop_subgraph\n",
    "\n",
    "nodes, _, _, edge_mask = k_hop_subgraph(int(center_molecule_ids[1]), 0, G_tg.edge_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 190,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.utils import to_undirected, k_hop_subgraph\n",
    "\n",
    "def get_k_hop_nodes(node_index, num_hops, edge_index):\n",
    "    # Convert the edge_index to undirected\n",
    "    edge_index = to_undirected(edge_index)\n",
    "\n",
    "    # Compute the k-hop subgraph\n",
    "    node_idx_k_hop, _, _, _ = k_hop_subgraph(node_index, num_hops, edge_index, relabel_nodes=False, num_nodes=None, flow='source_to_target')\n",
    "\n",
    "    return node_idx_k_hop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 218,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_ids = get_k_hop_nodes(int(center_molecule_ids[8]), 1, G_tg.edge_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 225,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 3, 4, 1, 4, 4, 0, 1, 4],\n",
       "        [2, 2, 2, 3, 3, 1, 1, 3, 2]])"
      ]
     },
     "execution_count": 225,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "G_tg.subgraph(torch.tensor([4, 5, 6, 7, 8])).edge_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 221,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 2,  4,  5,  4,  4,  3,  4,  5,  3,  4,  4,  4,  4,  4,  4,  4,  4,  4,\n",
       "          4,  0,  2,  5,  4,  4,  4,  4,  4,  4,  4],\n",
       "        [ 1, 12, 10, 11, 10, 10, 12,  0,  2,  2,  0,  1,  5,  3, 13, 17, 16, 14,\n",
       "         15,  7,  6,  7,  7,  6, 18,  9,  8, 19, 20]])"
      ]
     },
     "execution_count": 221,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sub_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/pj20/miniconda3/envs/kgc/lib/python3.8/site-packages/torch_geometric/data/storage.py:303: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'relation', 'edge_index', 'rel_label'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([], size=(2, 0), dtype=torch.int64)"
      ]
     },
     "execution_count": 129,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "G_tg.subgraph(torch.tensor([43924])).edge_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 130,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(torch.where(edge_mask)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [],
   "source": [
    "double_mask_mol2mol = mol2mol_mask * edge_mask\n",
    "double_mask_mol2nonmol = mol2nonmol_mask * edge_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 537424,  537430,  537431,  ..., 2100066, 2118004, 2119986])"
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.where(double_mask_mol2mol)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1251)"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "center_molecule_ids[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "       dtype=torch.float64)"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ent_type[85600]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4666855]), torch.Size([4666855]), torch.Size([4666855]))"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edge_mask.shape, molecule_mask.shape, double_mask.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask_idx = torch.where(double_mask == True)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([    10,     11,     12,  ..., 112293, 125440, 177470])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Data(edge_index=[2, 1192001], relation=[1192001], rel_label=[1192001, 39], num_nodes=1637)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "G_tg.subgraph(nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 226,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'True'"
      ]
     },
     "execution_count": 226,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f'{True}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('kgc')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3d0509d9aa81f2882b18eeb72d4d23c32cae9029e9b99f63cde94ba86c35ac78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
