{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "condition_mapping_file = \"../../resources/CCSCM.csv\"\n",
    "procedure_mapping_file = \"../../resources/CCSPROC.csv\"\n",
    "drug_file = \"../../resources/ATC.csv\"\n",
    "\n",
    "condition_dict = {}\n",
    "with open(condition_mapping_file, newline='') as csvfile:\n",
    "    reader = csv.DictReader(csvfile)\n",
    "    for row in reader:\n",
    "        condition_dict[row['code']] = row['name'].lower()\n",
    "\n",
    "procedure_dict = {}\n",
    "with open(procedure_mapping_file, newline='') as csvfile:\n",
    "    reader = csv.DictReader(csvfile)\n",
    "    for row in reader:\n",
    "        procedure_dict[row['code']] = row['name'].lower()\n",
    "\n",
    "drug_dict = {}\n",
    "with open(drug_file, newline='') as csvfile:\n",
    "    reader = csv.DictReader(csvfile)\n",
    "    for row in reader:\n",
    "        if row['level'] == '3.0':\n",
    "            drug_dict[row['code']] = row['name'].lower()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re \n",
    "from ChatGPT import ChatGPT\n",
    "from ChatGPT import ChatGPT\n",
    "import json\n",
    "\n",
    "def extract_data_in_brackets(input_string):\n",
    "    pattern = r\"\\[(.*?)\\]\"\n",
    "    matches = re.findall(pattern, input_string)\n",
    "    return matches\n",
    "\n",
    "def divide_text(long_text, max_len=800):\n",
    "    sub_texts = []\n",
    "    start_idx = 0\n",
    "    while start_idx < len(long_text):\n",
    "        end_idx = start_idx + max_len\n",
    "        sub_text = long_text[start_idx:end_idx]\n",
    "        sub_texts.append(sub_text)\n",
    "        start_idx = end_idx\n",
    "    return sub_texts\n",
    "\n",
    "def filter_triples(triples):\n",
    "    chatgpt = ChatGPT()\n",
    "    response = chatgpt.chat(\n",
    "        f\"\"\"\n",
    "            I have a list of triples. I want to select 50 most important triples from the list.\n",
    "            The importance of a triple is based on how you think it will help imrpove healthcare prediction tasks (e.g., drug recommendation, mortality prediction, readmission prediction …).\n",
    "            If you think a triple is important, please keep it. Otherwise, please remove it.\n",
    "            You can also add triples from your background knowledge.\n",
    "            The total size of the updated list should be below 50.\n",
    "\n",
    "            triples: {triples}\n",
    "            updates:\n",
    "        \"\"\"\n",
    "        )\n",
    "    json_string = str(response)\n",
    "    json_data = json.loads(json_string)\n",
    "\n",
    "    filtered_triples = extract_data_in_brackets(json_data['content'])\n",
    "    return filtered_triples\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ChatGPT import ChatGPT\n",
    "import json\n",
    "\n",
    "def graph_gen(term: str, mode: str):\n",
    "    if mode == \"condition\":\n",
    "        example = \\\n",
    "        \"\"\"\n",
    "        Example:\n",
    "        prompt: systemic lupus erythematosus\n",
    "        updates: [[systemic lupus erythematosus, is an, autoimmune condition], [systemic lupus erythematosus, may cause, nephritis], [anti-nuclear antigen, is a test for, systemic lupus erythematosus], [systemic lupus erythematosus, is treated with, steroids], [methylprednisolone, is a, steroid]]\n",
    "        \"\"\"\n",
    "    elif mode == \"procedure\":\n",
    "        example = \\\n",
    "        \"\"\"\n",
    "        Example:\n",
    "        prompt: endoscopy\n",
    "        updates: [[endoscopy, is a, medical procedure], [endoscopy, used for, diagnosis], [endoscopic biopsy, is a type of, endoscopy], [endoscopic biopsy, can detect, ulcers]]\n",
    "        \"\"\"\n",
    "    elif mode == \"drug\":\n",
    "        example = \\\n",
    "        \"\"\"\n",
    "        Example:\n",
    "        prompt: iobenzamic acid\n",
    "        updates: [[iobenzamic acid, is a, drug], [iobenzamic acid, may have, side effects], [side effects, can include, nausea], [iobenzamic acid, used as, X-ray contrast agent], [iobenzamic acid, formula, C16H13I3N2O3]]\n",
    "        \"\"\"\n",
    "    chatgpt = ChatGPT()\n",
    "    response = chatgpt.chat(\n",
    "        f\"\"\"\n",
    "            Given a prompt (a medical condition/procedure/drug), extrapolate as many relationships as possible of it and provide a list of updates.\n",
    "            The relationships should be helpful for healthcare prediction (e.g., drug recommendation, mortality prediction, readmission prediction …)\n",
    "            Each update should be exactly in format of [ENTITY 1, RELATIONSHIP, ENTITY 2]. The relationship is directed, so the order matters.\n",
    "            Both ENTITY 1 and ENTITY 2 should be noun.\n",
    "            Any element in [ENTITY 1, RELATIONSHIP, ENTITY 2] should be conclusive, make it as short as possible.\n",
    "            Do this in both breadth and depth. Expand [ENTITY 1, RELATIONSHIP, ENTITY 2] until the size reaches 100.\n",
    "\n",
    "            {example}\n",
    "\n",
    "            prompt: {term}\n",
    "            updates:\n",
    "        \"\"\"\n",
    "        )\n",
    "    json_string = str(response)\n",
    "    json_data = json.loads(json_string)\n",
    "\n",
    "    triples = extract_data_in_brackets(json_data['content'])\n",
    "    outstr = \"\"\n",
    "    for triple in triples:\n",
    "        outstr += triple.replace('[', '').replace(']', '').replace(', ', '\\t') + '\\n'\n",
    "\n",
    "    return outstr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Future work - Including Clinical Notes\n",
    "# import json\n",
    "\n",
    "# with open('../../clinical_notes/subject_text_dict.json', 'r') as f:\n",
    "#     subject_text_dict = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "for key in tqdm(condition_dict.keys()):\n",
    "    file = f'../../graphs/condition/CCSCM/{key}.txt'\n",
    "    if os.path.exists(file):\n",
    "        with open(file=file, mode=\"r\", encoding='utf-8') as f:\n",
    "            prev_triples = f.read()\n",
    "        if len(prev_triples.split('\\n')) < 100:\n",
    "            outstr = graph_gen(term=condition_dict[key], mode=\"condition\")\n",
    "            outfile = open(file=file, mode='w', encoding='utf-8')\n",
    "            outstr = prev_triples + outstr\n",
    "            # print(outstr)\n",
    "            outfile.write(outstr)\n",
    "    else:\n",
    "        outstr = graph_gen(term=condition_dict[key], mode=\"condition\")\n",
    "        outfile = open(file=file, mode='w', encoding='utf-8')\n",
    "        outstr = outstr\n",
    "        # print(outstr)\n",
    "        outfile.write(outstr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 231/231 [47:01<00:00, 12.21s/it] \n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "for key in tqdm(procedure_dict.keys()):\n",
    "    file = f'../../graphs/procedure/CCSPROC/{key}.txt'\n",
    "    if os.path.exists(file):\n",
    "        with open(file=file, mode=\"r\", encoding='utf-8') as f:\n",
    "            prev_triples = f.read()\n",
    "        if len(prev_triples.split('\\n')) < 150:\n",
    "            outstr = graph_gen(term=procedure_dict[key], mode=\"procedure\")\n",
    "            outfile = open(file=file, mode='w', encoding='utf-8')\n",
    "            outstr = prev_triples + outstr\n",
    "            # print(outstr)\n",
    "            outfile.write(outstr)\n",
    "    else:\n",
    "        outstr = graph_gen(term=procedure_dict[key], mode=\"procedure\")\n",
    "        outfile = open(file=file, mode='w', encoding='utf-8')\n",
    "        outstr = outstr\n",
    "        # print(outstr)\n",
    "        outfile.write(outstr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "for key in tqdm(drug_dict.keys()):\n",
    "    file = f'../../graphs/drug/ATC5/{key}.txt'\n",
    "    if os.path.exists(file):\n",
    "        with open(file=file, mode=\"r\", encoding='utf-8') as f:\n",
    "            prev_triples = f.read()\n",
    "        if len(prev_triples.split('\\n')) < 150:\n",
    "            outstr = graph_gen(term=drug_dict[key], mode=\"drug\")\n",
    "            outfile = open(file=file, mode='w', encoding='utf-8')\n",
    "            outstr = prev_triples + outstr\n",
    "            # print(outstr)\n",
    "            outfile.write(outstr)\n",
    "        # continue\n",
    "    else:\n",
    "        outstr = graph_gen(term=drug_dict[key], mode=\"drug\")\n",
    "        outfile = open(file=file, mode='w', encoding='utf-8')\n",
    "        outstr = outstr\n",
    "        # print(outstr)\n",
    "        outfile.write(outstr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 269/269 [44:31<00:00,  9.93s/it] \n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "for key in tqdm(drug_dict.keys()):\n",
    "    file = f'../../graphs/drug/ATC3/{key}.txt'\n",
    "    if os.path.exists(file):\n",
    "        with open(file=file, mode=\"r\", encoding='utf-8') as f:\n",
    "            prev_triples = f.read()\n",
    "        if len(prev_triples.split('\\n')) < 150:\n",
    "            outstr = graph_gen(term=drug_dict[key], mode=\"drug\")\n",
    "            outfile = open(file=file, mode='w', encoding='utf-8')\n",
    "            outstr = prev_triples + outstr\n",
    "            # print(outstr)\n",
    "            outfile.write(outstr)\n",
    "        # continue\n",
    "    else:\n",
    "        outstr = graph_gen(term=drug_dict[key], mode=\"drug\")\n",
    "        outfile = open(file=file, mode='w', encoding='utf-8')\n",
    "        outstr = outstr\n",
    "        # print(outstr)\n",
    "        outfile.write(outstr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('kgc')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3d0509d9aa81f2882b18eeb72d4d23c32cae9029e9b99f63cde94ba86c35ac78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
