{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open('/data/pj20/exp_data/icd9cm_icd9proc/drugrec_dataset_umls.pkl', 'rb') as f:\n",
    "    sample_dataset = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "condition_mapping_file = \"../../resources/ICD9CM.csv\"\n",
    "procedure_mapping_file = \"../../resources/ICD9PROC.csv\"\n",
    "drug_file = \"../../resources/ATC.csv\"\n",
    "\n",
    "condition_dict = {}\n",
    "with open(condition_mapping_file, newline='') as csvfile:\n",
    "    reader = csv.DictReader(csvfile)\n",
    "    for row in reader:\n",
    "        condition_dict[row['code'].replace('.', '')] = row['name'].lower()\n",
    "\n",
    "procedure_dict = {}\n",
    "with open(procedure_mapping_file, newline='') as csvfile:\n",
    "    reader = csv.DictReader(csvfile)\n",
    "    for row in reader:\n",
    "        procedure_dict[row['code'].replace('.', '')] = row['name'].lower()\n",
    "\n",
    "drug_dict = {}\n",
    "with open(drug_file, newline='') as csvfile:\n",
    "    reader = csv.DictReader(csvfile)\n",
    "    for row in reader:\n",
    "        if row['level'] == '3.0':\n",
    "            drug_dict[row['code'].replace('.', '')] = row['name'].lower()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def flatten(lst):\n",
    "    result = []\n",
    "    for item in lst:\n",
    "        if isinstance(item, list):\n",
    "            result.extend(flatten(item))\n",
    "        else:\n",
    "            result.append(item)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "condition_dict_sample = {}\n",
    "procedure_dict_sample = {}\n",
    "\n",
    "for i in range(101):\n",
    "    sample = sample_dataset[i]\n",
    "    for condition in flatten(sample['conditions']):\n",
    "        if condition not in condition_dict_sample:\n",
    "            condition_dict_sample[condition] = condition_dict[condition]\n",
    "    for procedure in flatten(sample['procedures']):\n",
    "        if procedure not in procedure_dict_sample:\n",
    "            try:\n",
    "                procedure_dict_sample[procedure] = procedure_dict[procedure]\n",
    "            except:\n",
    "                procedure_dict_sample[procedure[:-1]] = procedure_dict[procedure[:-1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re \n",
    "from ChatGPT import ChatGPT\n",
    "import json\n",
    "\n",
    "def extract_data_in_brackets(input_string):\n",
    "    pattern = r\"\\[(.*?)\\]\"\n",
    "    matches = re.findall(pattern, input_string)\n",
    "    return matches\n",
    "\n",
    "def divide_text(long_text, max_len=800):\n",
    "    sub_texts = []\n",
    "    start_idx = 0\n",
    "    while start_idx < len(long_text):\n",
    "        end_idx = start_idx + max_len\n",
    "        sub_text = long_text[start_idx:end_idx]\n",
    "        sub_texts.append(sub_text)\n",
    "        start_idx = end_idx\n",
    "    return sub_texts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ChatGPT import ChatGPT\n",
    "import json\n",
    "\n",
    "def graph_gen(term: str, mode: str):\n",
    "    if mode == \"condition\":\n",
    "        example = \\\n",
    "        \"\"\"\n",
    "        Example:\n",
    "        prompt: systemic lupus erythematosus\n",
    "        updates: [[systemic lupus erythematosus, is an, autoimmune condition], [systemic lupus erythematosus, may cause, nephritis], [anti-nuclear antigen, is a test for, systemic lupus erythematosus], [systemic lupus erythematosus, is treated with, steroids], [methylprednisolone, is a, steroid]]\n",
    "        \"\"\"\n",
    "    elif mode == \"procedure\":\n",
    "        example = \\\n",
    "        \"\"\"\n",
    "        Example:\n",
    "        prompt: endoscopy\n",
    "        updates: [[endoscopy, is a, medical procedure], [endoscopy, used for, diagnosis], [endoscopic biopsy, is a type of, endoscopy], [endoscopic biopsy, can detect, ulcers]]\n",
    "        \"\"\"\n",
    "    elif mode == \"drug\":\n",
    "        example = \\\n",
    "        \"\"\"\n",
    "        Example:\n",
    "        prompt: iobenzamic acid\n",
    "        updates: [[iobenzamic acid, is a, drug], [iobenzamic acid, may have, side effects], [side effects, can include, nausea], [iobenzamic acid, used as, X-ray contrast agent], [iobenzamic acid, formula, C16H13I3N2O3]]\n",
    "        \"\"\"\n",
    "    chatgpt = ChatGPT()\n",
    "    response = chatgpt.chat(\n",
    "        f\"\"\"\n",
    "            Given a prompt (a medical condition/procedure/drug), extrapolate as many relationships as possible of it and provide a list of updates.\n",
    "            The relationships should be helpful for healthcare prediction (e.g., drug recommendation, mortality prediction, readmission prediction …)\n",
    "            Each update should be exactly in format of [ENTITY 1, RELATIONSHIP, ENTITY 2]. The relationship is directed, so the order matters.\n",
    "            Both ENTITY 1 and ENTITY 2 should be noun.\n",
    "            Any element in [ENTITY 1, RELATIONSHIP, ENTITY 2] should be conclusive, make it as short as possible.\n",
    "            Do this in both breadth and depth. Expand [ENTITY 1, RELATIONSHIP, ENTITY 2] until the size reaches 100.\n",
    "\n",
    "            {example}\n",
    "\n",
    "            prompt: {term}\n",
    "            updates:\n",
    "        \"\"\"\n",
    "        )\n",
    "    json_string = str(response)\n",
    "    json_data = json.loads(json_string)\n",
    "\n",
    "    triples = extract_data_in_brackets(json_data['content'])\n",
    "    outstr = \"\"\n",
    "    for triple in triples:\n",
    "        outstr += triple.replace('[', '').replace(']', '').replace(', ', '\\t') + '\\n'\n",
    "\n",
    "    return outstr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 449/449 [1:25:51<00:00, 11.47s/it]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "for key in tqdm(condition_dict_sample.keys()):\n",
    "    file = f'../../graphs/condition/ICD9CM_base_gpt/{key}.txt'\n",
    "    if os.path.exists(file):\n",
    "        with open(file=file, mode=\"r\", encoding='utf-8') as f:\n",
    "            prev_triples = f.read()\n",
    "        if len(prev_triples.split('\\n')) < 50:\n",
    "            outstr = graph_gen(term=condition_dict_sample[key], mode=\"condition\")\n",
    "            outfile = open(file=file, mode='w', encoding='utf-8')\n",
    "            outstr = prev_triples + outstr\n",
    "            outfile.write(outstr)\n",
    "    else:\n",
    "        outstr = graph_gen(term=condition_dict_sample[key], mode=\"condition\")\n",
    "        outfile = open(file=file, mode='w', encoding='utf-8')\n",
    "        outstr = outstr\n",
    "        outfile.write(outstr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('kgc')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3d0509d9aa81f2882b18eeb72d4d23c32cae9029e9b99f63cde94ba86c35ac78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
