{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pubchempy as pcp\n",
    "import json\n",
    "\n",
    "with open('./primekg_id_mapping.json', 'r') as f:\n",
    "    id2name_drug = json.load(f)['id2name_drug']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from tqdm import tqdm\n",
    "\n",
    "drugbank2pubchem = {}\n",
    "\n",
    "for key in tqdm(id2name_drug.keys()):\n",
    "    DBid = key\n",
    "    drug_name = id2name_drug[key]\n",
    "    while True:\n",
    "        try:\n",
    "            PCid = pcp.get_compounds(drug_name, 'name')\n",
    "            drugbank2pubchem[DBid] = PCid\n",
    "            break\n",
    "        except pcp.PubChemHTTPError as e:\n",
    "            if 'ServerBusy' in str(e):\n",
    "                print('Server is busy. Retrying in 10 seconds...')\n",
    "                time.sleep(10)\n",
    "            else:\n",
    "                raise e\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Search for the compound name\n",
    "p = pcp.get_properties('IsomericSMILES', 1, 'cid')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'CID': 1, 'IsomericSMILES': 'CC(=O)OC(CC(=O)[O-])C[N+](C)(C)C'}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7957/7957 [00:00<00:00, 397490.26it/s]\n"
     ]
    }
   ],
   "source": [
    "cnt = 0\n",
    "prime_kg_cid_set = set()\n",
    "\n",
    "for key in tqdm(id2name_drug.keys()):\n",
    "    try:\n",
    "        prime_kg_cid_set.add(drugbank2pubchem[key][0].cid)\n",
    "    except:\n",
    "        cnt+=1\n",
    "        continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "drugbank2pubchem_cid = {key: drugbank2pubchem[key][0].cid for key in drugbank2pubchem.keys() if len(drugbank2pubchem[key]) > 0}\n",
    "\n",
    "with open(\"prime_kg_db_cid.json\", 'w') as f:\n",
    "    json.dump(drugbank2pubchem_cid, f, indent=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7083"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(prime_kg_cid_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import pandas as pd\n",
    "\n",
    "downstream_smiles_set = set()\n",
    "\n",
    "# Path to the directory containing the CSV files\n",
    "csv_dir = './finetune'\n",
    "\n",
    "# Use glob to get a list of all CSV files in the directory\n",
    "csv_files = glob.glob(csv_dir + '/*.csv')\n",
    "\n",
    "# Loop over each file and extract the SMILES strings from the first column\n",
    "for csv_file in csv_files:\n",
    "    df = pd.read_csv(csv_file)\n",
    "    smiles_col = df.iloc[:, 0]  # assuming the first column is the SMILES column\n",
    "    downstream_smiles_set.update(smiles_col)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "53263"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(downstream_smiles_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46598/46598 [7:31:09<00:00,  1.72it/s]   \n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "from tqdm import tqdm\n",
    "\n",
    "smile_cid = {}\n",
    "\n",
    "for smile in tqdm(downstream_smiles_set):\n",
    "    while True:\n",
    "        try:\n",
    "            PCid = pcp.get_compounds(smile, 'smiles')[0].cid\n",
    "            smile_cid[smile] = PCid\n",
    "            break\n",
    "        except pcp.PubChemHTTPError as e:\n",
    "            if 'ServerBusy' in str(e):\n",
    "                print('Server is busy. Retrying in 1 seconds...')\n",
    "                time.sleep(1)\n",
    "            else:\n",
    "                smile_cid[smile] = 'null'\n",
    "                break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('./smile2cid_1.json', 'w') as f:\n",
    "    json.dump(smile_cid, f, indent=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('./smile2cid_1.json', 'r') as f:\n",
    "    cid_1 = json.load(f).values()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46598/46598 [4:50:04<00:00,  2.68it/s]   \n"
     ]
    }
   ],
   "source": [
    "import requests\n",
    "from collections import defaultdict\n",
    "from operator import itemgetter\n",
    "from tqdm import tqdm\n",
    "import re\n",
    "\n",
    "# List of CIDs to query\n",
    "cid_list = list(cid_1)\n",
    "\n",
    "# Define the API endpoints for each entity type\n",
    "endpoints = ['chemical_to_disease', 'chemical_to_gene', 'chemical_to_chemical']\n",
    "\n",
    "# Initialize dictionaries to store the top 5 entities for each CID\n",
    "top_compounds = defaultdict(list)\n",
    "top_diseases = defaultdict(list)\n",
    "top_genes = defaultdict(list)\n",
    "\n",
    "# Loop over each CID in the list and query the top 5 compounds, diseases, and genes\n",
    "for cid in tqdm(cid_list):\n",
    "    if cid == 'null' or cid == None:\n",
    "        continue\n",
    "    for endpoint in endpoints:\n",
    "        # Construct the API endpoint URL\n",
    "        url = f\"https://pubchem.ncbi.nlm.nih.gov/rest/rdf/query?graph=cooccurrence&entity1=CID{cid}&type={endpoint}\"\n",
    "        \n",
    "        # Query the API and parse the response\n",
    "        response = requests.get(url)\n",
    "        lines = response.text\n",
    "        pattern = r'<td><a href=\"(.*?)\">(.*?)</a></td>\\n\\s*<td>(\\d+)</td>'\n",
    "        edges = re.findall(pattern, lines)\n",
    "        \n",
    "        # Sort the edges by score and keep the top 5\n",
    "        try:\n",
    "            sorted_edges = edges[:5]\n",
    "        except:\n",
    "            continue\n",
    "        \n",
    "        # Normalize the edge weights by score and store the results in the corresponding dictionary\n",
    "        for edge in sorted_edges:\n",
    "            entity_id = edge[0].split('/')[-1]\n",
    "            score = int(edge[2])\n",
    "            normalized_score = score / int(sorted_edges[0][2])  # normalize by the maximum score\n",
    "            if endpoint == 'chemical_to_chemical':\n",
    "                top_compounds[cid].append((entity_id, normalized_score))\n",
    "            elif endpoint == 'chemical_to_disease':\n",
    "                top_diseases[cid].append((entity_id, normalized_score))\n",
    "            elif endpoint == 'chemical_to_gene':\n",
    "                top_genes[cid].append((entity_id, normalized_score))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./chemical_chemical_2.json', 'w') as f:\n",
    "    json.dump(top_compounds, f, indent=6)\n",
    "\n",
    "with open('./chemical_disease_2.json', 'w') as f:\n",
    "    json.dump(top_diseases, f, indent=6)\n",
    "\n",
    "with open('./chemical_gene_2.json', 'w') as f:\n",
    "    json.dump(top_genes, f, indent=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.16 ('txgnn_env')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "79cb95e61c4f960f4e102f21c45668d32cb5c494b237694c15d64b50342e6e99"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
