{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_46857/52417269.py:3: DtypeWarning: Columns (4,5) have mixed types. Specify dtype option on import or set low_memory=False.\n",
      "  kg_ori = pd.read_csv('./KG_processed.csv')\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "kg_ori = pd.read_csv('./KG_processed.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x_type</th>\n",
       "      <th>x_id</th>\n",
       "      <th>relation</th>\n",
       "      <th>y_type</th>\n",
       "      <th>y_id</th>\n",
       "      <th>value_level</th>\n",
       "      <th>x_idx</th>\n",
       "      <th>y_idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>molecule</td>\n",
       "      <td>23978.0</td>\n",
       "      <td>drug_protein</td>\n",
       "      <td>gene/protein</td>\n",
       "      <td>F8</td>\n",
       "      <td>F8</td>\n",
       "      <td>33625</td>\n",
       "      <td>5390.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>molecule</td>\n",
       "      <td>23978.0</td>\n",
       "      <td>drug_protein</td>\n",
       "      <td>gene/protein</td>\n",
       "      <td>F5</td>\n",
       "      <td>F5</td>\n",
       "      <td>33625</td>\n",
       "      <td>5388.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>molecule</td>\n",
       "      <td>977.0</td>\n",
       "      <td>drug_protein</td>\n",
       "      <td>gene/protein</td>\n",
       "      <td>HBA2</td>\n",
       "      <td>HBA2</td>\n",
       "      <td>86664</td>\n",
       "      <td>6360.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>molecule</td>\n",
       "      <td>82153.0</td>\n",
       "      <td>drug_protein</td>\n",
       "      <td>gene/protein</td>\n",
       "      <td>SERPINA6</td>\n",
       "      <td>SERPINA6</td>\n",
       "      <td>78526</td>\n",
       "      <td>11017.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>molecule</td>\n",
       "      <td>5311000.0</td>\n",
       "      <td>drug_protein</td>\n",
       "      <td>gene/protein</td>\n",
       "      <td>SERPINA6</td>\n",
       "      <td>SERPINA6</td>\n",
       "      <td>54050</td>\n",
       "      <td>11017.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2523862</th>\n",
       "      <td>molecule</td>\n",
       "      <td>439153.0</td>\n",
       "      <td>in_pathway</td>\n",
       "      <td>pathway</td>\n",
       "      <td>PWID1228018</td>\n",
       "      <td>PWID1228018</td>\n",
       "      <td>45820</td>\n",
       "      <td>9216.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2523863</th>\n",
       "      <td>molecule</td>\n",
       "      <td>644102.0</td>\n",
       "      <td>in_pathway</td>\n",
       "      <td>pathway</td>\n",
       "      <td>PWID1239777</td>\n",
       "      <td>PWID1239777</td>\n",
       "      <td>65691</td>\n",
       "      <td>20975.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2523864</th>\n",
       "      <td>molecule</td>\n",
       "      <td>668.0</td>\n",
       "      <td>in_pathway</td>\n",
       "      <td>pathway</td>\n",
       "      <td>PWID1324467</td>\n",
       "      <td>PWID1324467</td>\n",
       "      <td>67814</td>\n",
       "      <td>31394.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2523865</th>\n",
       "      <td>molecule</td>\n",
       "      <td>644102.0</td>\n",
       "      <td>in_pathway</td>\n",
       "      <td>pathway</td>\n",
       "      <td>PWID1234075</td>\n",
       "      <td>PWID1234075</td>\n",
       "      <td>65691</td>\n",
       "      <td>15274.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2523866</th>\n",
       "      <td>molecule</td>\n",
       "      <td>6176.0</td>\n",
       "      <td>in_pathway</td>\n",
       "      <td>pathway</td>\n",
       "      <td>PWID1234444</td>\n",
       "      <td>PWID1234444</td>\n",
       "      <td>63352</td>\n",
       "      <td>15643.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2523867 rows × 8 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "           x_type       x_id      relation        y_type         y_id  \\\n",
       "0        molecule    23978.0  drug_protein  gene/protein           F8   \n",
       "1        molecule    23978.0  drug_protein  gene/protein           F5   \n",
       "2        molecule      977.0  drug_protein  gene/protein         HBA2   \n",
       "3        molecule    82153.0  drug_protein  gene/protein     SERPINA6   \n",
       "4        molecule  5311000.0  drug_protein  gene/protein     SERPINA6   \n",
       "...           ...        ...           ...           ...          ...   \n",
       "2523862  molecule   439153.0    in_pathway       pathway  PWID1228018   \n",
       "2523863  molecule   644102.0    in_pathway       pathway  PWID1239777   \n",
       "2523864  molecule      668.0    in_pathway       pathway  PWID1324467   \n",
       "2523865  molecule   644102.0    in_pathway       pathway  PWID1234075   \n",
       "2523866  molecule     6176.0    in_pathway       pathway  PWID1234444   \n",
       "\n",
       "         value_level  x_idx    y_idx  \n",
       "0                 F8  33625   5390.0  \n",
       "1                 F5  33625   5388.0  \n",
       "2               HBA2  86664   6360.0  \n",
       "3           SERPINA6  78526  11017.0  \n",
       "4           SERPINA6  54050  11017.0  \n",
       "...              ...    ...      ...  \n",
       "2523862  PWID1228018  45820   9216.0  \n",
       "2523863  PWID1239777  65691  20975.0  \n",
       "2523864  PWID1324467  67814  31394.0  \n",
       "2523865  PWID1234075  65691  15274.0  \n",
       "2523866  PWID1234444  63352  15643.0  \n",
       "\n",
       "[2523867 rows x 8 columns]"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kg_ori"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## KG triples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2523867/2523867 [01:18<00:00, 31986.05it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "kg_triples = \"\"\n",
    "\n",
    "for i in tqdm(range(len(kg_ori))):\n",
    "    head = kg_ori.x_type[i] + '_' + str(kg_ori.x_id[i])\n",
    "    relation = kg_ori.relation[i]\n",
    "    if kg_ori.y_type[i] == 'value':\n",
    "        tail = kg_ori.y_type[i] + '_' + str(kg_ori.y_id[i]).split('.')[0]\n",
    "    else:\n",
    "        tail = kg_ori.y_type[i] + '_' + str(kg_ori.y_id[i])\n",
    "\n",
    "    kg_triples += head + '\\t' + relation + '\\t' + tail + '\\n'\n",
    "\n",
    "with open('./graph.txt', 'w') as f:\n",
    "    f.write(kg_triples)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "184819"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "entity_set = set()\n",
    "\n",
    "triples = kg_triples.split('\\n')\n",
    "\n",
    "for triple in triples:\n",
    "    if triple == '':\n",
    "        continue\n",
    "    head, relation, tail = triple.split('\\t')\n",
    "    entity_set.add(head)\n",
    "    entity_set.add(tail)\n",
    "\n",
    "len(entity_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_type = {\n",
    "    'molecule': 0,\n",
    "    'gene/protein': 1,\n",
    "    'disease': 2,\n",
    "    'effect/phenotype': 3,\n",
    "    'drug': 4,\n",
    "    'pathway': 5,\n",
    "    'value_1': 6,\n",
    "    'value_2': 7,\n",
    "    'value_3': 8,\n",
    "    'value_4': 9,\n",
    "    'value_5': 10,\n",
    "    'value_6': 11,\n",
    "    'value_7': 12,\n",
    "    'value_8': 13,\n",
    "    'value_9': 14,\n",
    "    'value_10': 15,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('./id2entity.json', 'r') as f:\n",
    "    id2entity = json.load(f)\n",
    "with open('./id2relation.json', 'r') as f:\n",
    "    id2relation = json.load(f)\n",
    "\n",
    "entity2id = {value: key for key, value in id2entity.items()}\n",
    "relation2id = {value: key for key, value in id2relation.items()}\n",
    "\n",
    "with open('./entity2id.json', 'w') as f:\n",
    "    json.dump(entity2id, f, indent=6)\n",
    "with open('./relation2id.json', 'w') as f:\n",
    "    json.dump(relation2id, f, indent=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2523867/2523867 [01:21<00:00, 31086.17it/s]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "entity_type_onehot = {}\n",
    "for i in tqdm(range(len(kg_ori))):\n",
    "    head = kg_ori.x_type[i] + '_' + str(kg_ori.x_id[i])\n",
    "    relation = kg_ori.relation[i]\n",
    "    if kg_ori.y_type[i] == 'value':\n",
    "        tail = kg_ori.y_type[i] + '_' + str(kg_ori.y_id[i]).split('.')[0]\n",
    "    else:\n",
    "        tail = kg_ori.y_type[i] + '_' + str(kg_ori.y_id[i])\n",
    "\n",
    "    if head not in entity_type_onehot:\n",
    "        entity_type_onehot[head] = np.zeros(16)\n",
    "        entity_type_onehot[head][entity_type[kg_ori.x_type[i]]] = 1\n",
    "\n",
    "    if tail not in entity_type_onehot:\n",
    "        entity_type_onehot[tail] = np.zeros(16)\n",
    "        if kg_ori.y_type[i] == 'value':\n",
    "            entity_type_onehot[tail][entity_type['value_' + str(kg_ori.value_level[i]).split('.')[0]]] = 1\n",
    "        else:\n",
    "            entity_type_onehot[tail][entity_type[kg_ori.y_type[i]]] = 1\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## entity type one-hot label (multi-class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 184819/184819 [00:00<00:00, 906678.06it/s]\n"
     ]
    }
   ],
   "source": [
    "etype_onehot = np.zeros((len(entity_type_onehot), 16))\n",
    "\n",
    "for entity, onehot in tqdm(entity_type_onehot.items()):\n",
    "    etype_onehot[int(entity2id[entity])] = onehot\n",
    "\n",
    "np.save('./pretrain_data/ent_type_onehot.npy', etype_onehot)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 53263/53263 [00:00<00:00, 828771.70it/s]\n"
     ]
    }
   ],
   "source": [
    "valid_smile = set()\n",
    "valid_cid_ = set()\n",
    "\n",
    "with open('./downstream_smile2cid.json', 'r') as f:\n",
    "    smile2cid = json.load(f)\n",
    "\n",
    "for smile, cid in tqdm(smile2cid.items()):\n",
    "    molecule = f\"molecule_{cid}.0\"\n",
    "    if cid != None and molecule in entity2id:\n",
    "        valid_smile.add(smile)\n",
    "        valid_cid_.add(cid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "vs = 'smiles\\n'\n",
    "for smile in valid_smile:\n",
    "    vs += smile + '\\n'\n",
    "with open('./valid_smiles.csv', 'w') as f:\n",
    "    f.write(vs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## motif labels for valid smiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "# used grover's motif labler to label the motifs\n",
    "motifs = np.load('./motifs.npz')['features']\n",
    "valid_smile = vs.split('\\n')[1:-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "id2motifs = {}\n",
    "for i in range(len(motifs)):\n",
    "    idx = entity2id['molecule_' + str(smile2cid[valid_smile[i]]) + '.0']\n",
    "    id2motifs[idx] = motifs[i].tolist()\n",
    "\n",
    "with open('./pretrain_data/id2motifs.json', 'w') as f:\n",
    "    json.dump(id2motifs, f, indent=6)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## networkx knowledge graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2523867/2523867 [00:08<00:00, 283214.79it/s]\n"
     ]
    }
   ],
   "source": [
    "import networkx as nx\n",
    "\n",
    "G = nx.Graph()\n",
    "triple_set = set()\n",
    "with open('./graph.txt', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "\n",
    "for line in tqdm(lines):\n",
    "    items = line.split('\\t')\n",
    "    if len(items) == 3:\n",
    "        h, r, t = items\n",
    "        t = t[:-1]\n",
    "        h = int(entity2id[h])\n",
    "        r = int(relation2id[r])\n",
    "        t = int(entity2id[t])\n",
    "        triple = (h, r, t)\n",
    "        if triple not in triple_set:\n",
    "            edge = (h, t)\n",
    "            G.add_edge(*edge, relation=r)\n",
    "            triple_set.add(triple)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "nx.write_gpickle(G, './pretrain_data/graph.gpickle')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('kgc')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3d0509d9aa81f2882b18eeb72d4d23c32cae9029e9b99f63cde94ba86c35ac78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
