{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/pj20/miniconda3/envs/kgc/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KG_processed.csv does not exist, processing...\n",
      "Iterating over relations...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 39/39 [00:14<00:00,  2.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Categorizing values...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 19/19 [00:00<00:00, 114.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iterating over node types...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:12<00:00,  1.82s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data preparation finished.\n"
     ]
    }
   ],
   "source": [
    "from k_level_pretrain import KData\n",
    "\n",
    "data_folder = './dataset_construction/'\n",
    "kg = KData(data_folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = kg.G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Graph(num_nodes={'disease': 4174, 'drug': 40544, 'effect/phenotype': 991, 'gene/protein': 13092, 'molecule': 88812, 'pathway': 35764, 'value': 101},\n",
       "      num_edges={('disease', 'rev_contraindication', 'molecule'): 16797, ('disease', 'rev_cooccurence_molecule_disease', 'molecule'): 110226, ('disease', 'rev_indication', 'molecule'): 4679, ('disease', 'rev_off-label use', 'molecule'): 1451, ('drug', 'rev_closematch', 'molecule'): 32619, ('drug', 'rev_to_drug', 'molecule'): 3083, ('drug', 'rev_type', 'molecule'): 4498, ('effect/phenotype', 'rev_drug_effect', 'molecule'): 52958, ('gene/protein', 'rev_cooccurence_molecule_gene/protein', 'molecule'): 107406, ('gene/protein', 'rev_drug_protein', 'molecule'): 20563, ('molecule', 'closematch', 'drug'): 32619, ('molecule', 'contraindication', 'disease'): 16797, ('molecule', 'cooccurence_molecule_disease', 'disease'): 110226, ('molecule', 'cooccurence_molecule_gene/protein', 'gene/protein'): 107406, ('molecule', 'cooccurence_molecule_molecule', 'molecule'): 290246, ('molecule', 'covalent_unit_count', 'value'): 6856, ('molecule', 'defined_atom_stereo_count', 'value'): 6770, ('molecule', 'defined_bond_stereo_count', 'value'): 6688, ('molecule', 'drug_drug', 'molecule'): 1918450, ('molecule', 'drug_effect', 'effect/phenotype'): 52958, ('molecule', 'drug_protein', 'gene/protein'): 20563, ('molecule', 'exact_mass', 'value'): 6875, ('molecule', 'has_component', 'molecule'): 45454, ('molecule', 'has_isotopologue', 'molecule'): 7374, ('molecule', 'has_parent', 'molecule'): 8684, ('molecule', 'has_same_connectivity', 'molecule'): 142206, ('molecule', 'has_stereoisomer', 'molecule'): 98704, ('molecule', 'hydrogen_bond_acceptor_count', 'value'): 6794, ('molecule', 'hydrogen_bond_donor_count', 'value'): 6775, ('molecule', 'in_pathway', 'pathway'): 291542, ('molecule', 'indication', 'disease'): 4679, ('molecule', 'isotope_atom_count', 'value'): 6817, ('molecule', 'molecular_weight', 'value'): 6815, ('molecule', 'mono_isotopic_weight', 'value'): 6842, ('molecule', 'neighbor_2d', 'molecule'): 88960, ('molecule', 'neighbor_3d', 'molecule'): 72114, ('molecule', 'non-hydrogen_atom_count', 'value'): 6673, ('molecule', 'off-label use', 'disease'): 1451, ('molecule', 'rotatable_bond_count', 'value'): 6833, ('molecule', 'structure_complexity', 'value'): 6736, ('molecule', 'tautomer_count', 'value'): 2936, ('molecule', 'to_drug', 'drug'): 3083, ('molecule', 'total_formal_charge', 'value'): 6646, ('molecule', 'tpsa', 'value'): 6833, ('molecule', 'type', 'drug'): 4498, ('molecule', 'undefined_atom_stereo_count', 'value'): 6748, ('molecule', 'undefined_bond_stereo_count', 'value'): 6654, ('molecule', 'xlogp3', 'value'): 1050, ('molecule', 'xlogp3-aa', 'value'): 3706, ('pathway', 'rev_in_pathway', 'molecule'): 291542, ('value', 'rev_covalent_unit_count', 'molecule'): 6856, ('value', 'rev_defined_atom_stereo_count', 'molecule'): 6770, ('value', 'rev_defined_bond_stereo_count', 'molecule'): 6688, ('value', 'rev_exact_mass', 'molecule'): 6875, ('value', 'rev_hydrogen_bond_acceptor_count', 'molecule'): 6794, ('value', 'rev_hydrogen_bond_donor_count', 'molecule'): 6775, ('value', 'rev_isotope_atom_count', 'molecule'): 6817, ('value', 'rev_molecular_weight', 'molecule'): 6815, ('value', 'rev_mono_isotopic_weight', 'molecule'): 6842, ('value', 'rev_non-hydrogen_atom_count', 'molecule'): 6673, ('value', 'rev_rotatable_bond_count', 'molecule'): 6833, ('value', 'rev_structure_complexity', 'molecule'): 6736, ('value', 'rev_tautomer_count', 'molecule'): 2936, ('value', 'rev_total_formal_charge', 'molecule'): 6646, ('value', 'rev_tpsa', 'molecule'): 6833, ('value', 'rev_undefined_atom_stereo_count', 'molecule'): 6748, ('value', 'rev_undefined_bond_stereo_count', 'molecule'): 6654, ('value', 'rev_xlogp3', 'molecule'): 1050, ('value', 'rev_xlogp3-aa', 'molecule'): 3706},\n",
       "      metagraph=[('disease', 'molecule', 'rev_contraindication'), ('disease', 'molecule', 'rev_cooccurence_molecule_disease'), ('disease', 'molecule', 'rev_indication'), ('disease', 'molecule', 'rev_off-label use'), ('molecule', 'drug', 'closematch'), ('molecule', 'drug', 'to_drug'), ('molecule', 'drug', 'type'), ('molecule', 'disease', 'contraindication'), ('molecule', 'disease', 'cooccurence_molecule_disease'), ('molecule', 'disease', 'indication'), ('molecule', 'disease', 'off-label use'), ('molecule', 'gene/protein', 'cooccurence_molecule_gene/protein'), ('molecule', 'gene/protein', 'drug_protein'), ('molecule', 'molecule', 'cooccurence_molecule_molecule'), ('molecule', 'molecule', 'drug_drug'), ('molecule', 'molecule', 'has_component'), ('molecule', 'molecule', 'has_isotopologue'), ('molecule', 'molecule', 'has_parent'), ('molecule', 'molecule', 'has_same_connectivity'), ('molecule', 'molecule', 'has_stereoisomer'), ('molecule', 'molecule', 'neighbor_2d'), ('molecule', 'molecule', 'neighbor_3d'), ('molecule', 'value', 'covalent_unit_count'), ('molecule', 'value', 'defined_atom_stereo_count'), ('molecule', 'value', 'defined_bond_stereo_count'), ('molecule', 'value', 'exact_mass'), ('molecule', 'value', 'hydrogen_bond_acceptor_count'), ('molecule', 'value', 'hydrogen_bond_donor_count'), ('molecule', 'value', 'isotope_atom_count'), ('molecule', 'value', 'molecular_weight'), ('molecule', 'value', 'mono_isotopic_weight'), ('molecule', 'value', 'non-hydrogen_atom_count'), ('molecule', 'value', 'rotatable_bond_count'), ('molecule', 'value', 'structure_complexity'), ('molecule', 'value', 'tautomer_count'), ('molecule', 'value', 'total_formal_charge'), ('molecule', 'value', 'tpsa'), ('molecule', 'value', 'undefined_atom_stereo_count'), ('molecule', 'value', 'undefined_bond_stereo_count'), ('molecule', 'value', 'xlogp3'), ('molecule', 'value', 'xlogp3-aa'), ('molecule', 'effect/phenotype', 'drug_effect'), ('molecule', 'pathway', 'in_pathway'), ('drug', 'molecule', 'rev_closematch'), ('drug', 'molecule', 'rev_to_drug'), ('drug', 'molecule', 'rev_type'), ('effect/phenotype', 'molecule', 'rev_drug_effect'), ('gene/protein', 'molecule', 'rev_cooccurence_molecule_gene/protein'), ('gene/protein', 'molecule', 'rev_drug_protein'), ('value', 'molecule', 'rev_covalent_unit_count'), ('value', 'molecule', 'rev_defined_atom_stereo_count'), ('value', 'molecule', 'rev_defined_bond_stereo_count'), ('value', 'molecule', 'rev_exact_mass'), ('value', 'molecule', 'rev_hydrogen_bond_acceptor_count'), ('value', 'molecule', 'rev_hydrogen_bond_donor_count'), ('value', 'molecule', 'rev_isotope_atom_count'), ('value', 'molecule', 'rev_molecular_weight'), ('value', 'molecule', 'rev_mono_isotopic_weight'), ('value', 'molecule', 'rev_non-hydrogen_atom_count'), ('value', 'molecule', 'rev_rotatable_bond_count'), ('value', 'molecule', 'rev_structure_complexity'), ('value', 'molecule', 'rev_tautomer_count'), ('value', 'molecule', 'rev_total_formal_charge'), ('value', 'molecule', 'rev_tpsa'), ('value', 'molecule', 'rev_undefined_atom_stereo_count'), ('value', 'molecule', 'rev_undefined_bond_stereo_count'), ('value', 'molecule', 'rev_xlogp3'), ('value', 'molecule', 'rev_xlogp3-aa'), ('pathway', 'molecule', 'rev_in_pathway')])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "65454"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(kg.df.x_id.unique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "kg.df.to_csv('./KGE/KG_processed.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json \n",
    "\n",
    "with open('./dataset_construction/idx_map.json', 'r') as f:\n",
    "    idx_map = json.load(f)\n",
    "\n",
    "with open('./dataset_construction/pretrain_molecules/downstream_cid_set.json', 'r') as f:\n",
    "    downstream_cid_set = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 88811/88811 [00:00<00:00, 278327.62it/s]\n"
     ]
    }
   ],
   "source": [
    "import json \n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "with open('./dataset_construction/idx_map.json') as f:\n",
    "    idx_map = json.load(f)\n",
    "with open('./dataset_construction/cid2emb.pkl', 'rb') as f:\n",
    "    cid2emb = pickle.load(f)\n",
    "\n",
    "molecule_idx2id = {value: key for key, value in idx_map['molecule'].items()}\n",
    "cnt = 0\n",
    "\n",
    "molecule_embedding = []\n",
    "for i in tqdm(range(len(molecule_idx2id))):\n",
    "    try:\n",
    "        cid = int(float(molecule_idx2id[i]))\n",
    "        molecule_embedding.append(cid2emb[cid]['emb'])\n",
    "    \n",
    "    # random embedding for molecules that are not in the pre-trained embedding\n",
    "    except:\n",
    "        molecule_embedding.append(np.random.rand(600))\n",
    "        cnt+=1\n",
    "        \n",
    "\n",
    "molecule_embedding = torch.tensor(np.array(molecule_embedding))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "kg.G.ndata['h']['molecule'] = molecule_embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.data import HeteroData\n",
    "\n",
    "df = kg.df\n",
    "hetero_data = HeteroData()\n",
    "\n",
    "print(\"processing edges...\")\n",
    "for relation_type in tqdm(df['relation'].unique()):\n",
    "    mask = df['relation'] == relation_type\n",
    "    xtype = df[mask]['x_type'].unique()[0]\n",
    "    ytype = df[mask]['y_type'].unique()[0]\n",
    "    \n",
    "    edge_index = torch.tensor(df[mask][['x_idx', 'y_idx']].to_numpy(), dtype=torch.long).t().contiguous()\n",
    "\n",
    "    hetero_data[xtype, relation_type, ytype].edge_index = edge_index\n",
    "\n",
    "print(\"processing nodes...\")\n",
    "unique_node_types = np.unique(np.append(np.unique(df.x_type.values), np.unique(df.y_type.values)))\n",
    "for node_type in tqdm(unique_node_types):\n",
    "    hetero_data[node_type].num_nodes = len(idx_map[node_type])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "hetero_data['molecule'].num_nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import RGCNConv\n",
    "from torch_geometric.data import DataLoader\n",
    "\n",
    "class RGNN(torch.nn.Module):\n",
    "    def __init__(self, in_channels, hidden_channels, out_channels):\n",
    "        super(RGNN, self).__init__()\n",
    "        self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations=len(hetero_data.edge_types))\n",
    "        self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations=len(hetero_data.edge_types))\n",
    "\n",
    "    def forward(self, x, edge_index, edge_type):\n",
    "        x = F.relu(self.conv1(x, edge_index, edge_type))\n",
    "        x = F.dropout(x, p=0.5, training=self.training)\n",
    "        x = self.conv2(x, edge_index, edge_type)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hetero_data['node_type']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from txgnn import GodeKGNN\n",
    "\n",
    "# gode_k = GodeKGNN(data = kg, \n",
    "#                     exp_name = 'Gode'\n",
    "#                     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(67, 67)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# df_in = gode_k.df_valid[['x_idx', 'relation', 'y_idx']]\n",
    "# # df_in[df_in.relation == 'cooccurence_disease_molecule']\n",
    "\n",
    "# len(gode_k.df_train.relation.unique()), len(df_in.relation.unique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from txgnn.utils import *\n",
    "\n",
    "# # gode_k.g_valid_pos, gode_k.g_valid_neg = evaluate_graph_construct(gode_k.df_valid, gode_k.G, \"fix_dst\", 1, gode_k.device)\n",
    "# g = gode_k.G\n",
    "# df_valid = gode_k.df_valid\n",
    "\n",
    "# out = {}\n",
    "# df_in = df_valid[['x_idx', 'relation', 'y_idx']]\n",
    "# for etype in g.canonical_etypes:\n",
    "#     try:\n",
    "#         df_temp = df_in[df_in.relation == etype[1]]\n",
    "#     except:\n",
    "#         print(etype[1])\n",
    "#     src = torch.Tensor(df_temp.x_idx.values).to(gode_k.device).to(dtype = torch.int64)\n",
    "#     dst = torch.Tensor(df_temp.y_idx.values).to(gode_k.device).to(dtype = torch.int64)\n",
    "#     out.update({etype: (src, dst)})\n",
    "\n",
    "# g = dgl.heterograph(out, num_nodes_dict={ntype: g.number_of_nodes(ntype) for ntype in g.ntypes})\n",
    "\n",
    "\n",
    "# # ng = Full_Graph_NegSampler(g_valid, 1, \"fix_dst\", gode_k.device)\n",
    "# weights = {\n",
    "#                 etype: (g.in_degrees(etype=etype) > 0).float()\n",
    "#                 for etype in g.canonical_etypes\n",
    "#         }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# gode_k.model_initialize(hidden_dim=100, \n",
    "#                         input_dim=100, \n",
    "#                         output_dim=100, \n",
    "#                       )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TxGNN.pretrain(n_epoch = 2, \n",
    "#                learning_rate = 1e-3,\n",
    "#                batch_size = 1024, \n",
    "#                train_print_per_n = 20)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('kgc')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3d0509d9aa81f2882b18eeb72d4d23c32cae9029e9b99f63cde94ba86c35ac78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
