{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/pj20/miniconda3/envs/kgc/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "\n",
    "with open('/data/pj20/exp_data/icd9cm_icd9proc/drugrec_dataset_umls.pkl', 'rb') as f:\n",
    "    sample_dataset = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "condition_mapping_file = \"../../resources/ICD9CM.csv\"\n",
    "procedure_mapping_file = \"../../resources/ICD9PROC.csv\"\n",
    "drug_file = \"../../resources/ATC.csv\"\n",
    "\n",
    "condition_dict = {}\n",
    "with open(condition_mapping_file, newline='') as csvfile:\n",
    "    reader = csv.DictReader(csvfile)\n",
    "    for row in reader:\n",
    "        condition_dict[row['code'].replace('.', '')] = row['name'].lower()\n",
    "\n",
    "procedure_dict = {}\n",
    "with open(procedure_mapping_file, newline='') as csvfile:\n",
    "    reader = csv.DictReader(csvfile)\n",
    "    for row in reader:\n",
    "        procedure_dict[row['code'].replace('.', '')] = row['name'].lower()\n",
    "\n",
    "drug_dict = {}\n",
    "with open(drug_file, newline='') as csvfile:\n",
    "    reader = csv.DictReader(csvfile)\n",
    "    for row in reader:\n",
    "        if row['level'] == '3.0':\n",
    "            drug_dict[row['code'].replace('.', '')] = row['name'].lower()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def flatten(lst):\n",
    "    result = []\n",
    "    for item in lst:\n",
    "        if isinstance(item, list):\n",
    "            result.extend(flatten(item))\n",
    "        else:\n",
    "            result.append(item)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "condition_dict_sample = {}\n",
    "procedure_dict_sample = {}\n",
    "\n",
    "for i in range(101):\n",
    "    sample = sample_dataset[i]\n",
    "    for condition in flatten(sample['conditions']):\n",
    "        if condition not in condition_dict_sample:\n",
    "            condition_dict_sample[condition] = condition_dict[condition]\n",
    "    for procedure in flatten(sample['procedures']):\n",
    "        if procedure not in procedure_dict_sample:\n",
    "            try:\n",
    "                procedure_dict_sample[procedure] = procedure_dict[procedure]\n",
    "            except:\n",
    "                procedure_dict_sample[procedure[:-1]] = procedure_dict[procedure[:-1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import os\n",
    "\n",
    "# for i in range(1, 101):\n",
    "#     folder_name = f\"../../graphs/patient_samples/{i}\"\n",
    "#     os.mkdir(folder_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'visit_id': '184167',\n",
       " 'patient_id': '10',\n",
       " 'conditions': [['V3000', '7742', '76525', '76515', 'V290']],\n",
       " 'procedures': [['9983', '9915', '966']],\n",
       " 'drugs': ['J01C', 'J01G', 'V06D', 'B05X', 'B03A'],\n",
       " 'drugs_all': [['J01C', 'J01G', 'V06D', 'B05X', 'B03A']],\n",
       " 'drugs_ind': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
       "         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],\n",
       "        dtype=torch.float64)}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('../../graphs/condition/ICD9CM_base_umls/ent2word.json', 'r') as f:\n",
    "    umls_ent2word = json.load(f)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "297927"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(umls_ent2word)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 101/101 [00:00<00:00, 139.40it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "patient_sample_path = \"../../graphs/patient_samples\"\n",
    "\n",
    "for j in tqdm(range(101)):\n",
    "    sample = sample_dataset[i]\n",
    "    patient_path = f\"{patient_sample_path}/{j}\"\n",
    "    patient_desp = f\"{patient_path}/description.txt\"\n",
    "    patient_gpt_graph = f\"{patient_path}/gpt_graph.csv\"\n",
    "    patient_umls_graph = f\"{patient_path}/umls_graph.csv\"\n",
    "\n",
    "    triple_set_gpt = set()\n",
    "    triple_set_umls = set()\n",
    "\n",
    "    conditions  = flatten(sample['conditions'])\n",
    "    procedures = flatten(sample['procedures'])\n",
    "    drugs = flatten(sample['drugs'])\n",
    "\n",
    "    ### BEGIN Write description\n",
    "    desp_condition = \"\"\n",
    "    for i in range(len(conditions)):\n",
    "        desp_condition += f\"{i}: \" + condition_dict_sample[conditions[i]] + \",\\n\"\n",
    "    desp_condition = desp_condition[:-2]\n",
    "\n",
    "    desp_procedure = \"\"\n",
    "    for i in range(len(procedures)):\n",
    "        try:\n",
    "            desp_procedure += f\"{i}: \" + procedure_dict_sample[procedures[i]] + \",\\n\"\n",
    "        except:\n",
    "            desp_procedure += f\"{i}: \" + procedure_dict_sample[procedures[i][:-1]] + \",\\n\"\n",
    "    desp_procedure = desp_procedure[:-2]\n",
    "\n",
    "    desp_drug = \"\"\n",
    "    for i in range(len(drugs)):\n",
    "        desp_drug += f\"{i}: \" + drug_dict[drugs[i]] + \",\\n\"\n",
    "    desp_drug = desp_drug[:-2]\n",
    "    \n",
    "    desp_all = f\"Patient ID: {j}\\nConditions:\\n[\\n{desp_condition}\\n]\\nProcedures:\\n[\\n{desp_procedure}\\n]\\nDrugs:\\n[\\n{desp_drug}\\n]\"\n",
    "    with open(patient_desp, 'w') as f:\n",
    "        f.write(desp_all)\n",
    "    ### END Write description\n",
    "\n",
    "\n",
    "    ### BEGIN Write graph\n",
    "    for condition in conditions:\n",
    "        cond_file_gpt = f'../../graphs/condition/ICD9CM_base_gpt/{condition}.txt'\n",
    "        cond_file_umls = f'../../graphs/condition/ICD9CM_base_umls/{condition}.txt'\n",
    "        with open(cond_file_gpt, 'r') as f:\n",
    "            lines = f.readlines()\n",
    "        for line in lines:\n",
    "            items = line.split('\\t')\n",
    "            if len(items) == 3:\n",
    "                h, r, t = items\n",
    "                t = t[:-1]\n",
    "                triple = (h, r, t)\n",
    "                triple_set_gpt.add(triple)\n",
    "        with open(cond_file_umls, 'r') as f:\n",
    "            lines = f.readlines()\n",
    "        for line in lines:\n",
    "            items = line.split('\\t')\n",
    "            if len(items) == 3:\n",
    "                h, r, t = items\n",
    "                t = t[:-1]\n",
    "                if r == \"self\":\n",
    "                    continue\n",
    "                triple = (umls_ent2word[h], r, umls_ent2word[t])\n",
    "                triple_set_umls.add(triple)\n",
    "\n",
    "    for procedure in procedures:\n",
    "        proc_file_gpt = f'../../graphs/procedure/ICD9PROC_base_gpt/{procedure}.txt'\n",
    "        proc_file_umls = f'../../graphs/condition/ICD9CM_base_umls/{procedure}.txt'\n",
    "        try:\n",
    "            with open(proc_file_gpt, 'r') as f:\n",
    "                lines = f.readlines()\n",
    "        except:\n",
    "            proc_file_gpt = f'../../graphs/procedure/ICD9PROC_base_gpt/{procedure[:-1]}.txt'\n",
    "            with open(proc_file_gpt, 'r') as f:\n",
    "                lines = f.readlines()\n",
    "            \n",
    "        for line in lines:\n",
    "            items = line.split('\\t')\n",
    "            if len(items) == 3:\n",
    "                h, r, t = items\n",
    "                t = t[:-1]\n",
    "                triple = (h, r, t)\n",
    "                triple_set_gpt.add(triple)\n",
    "        with open(proc_file_umls, 'r') as f:\n",
    "            lines = f.readlines()\n",
    "        for line in lines:\n",
    "            items = line.split('\\t')\n",
    "            if len(items) == 3:\n",
    "                h, r, t = items\n",
    "                t = t[:-1]\n",
    "                if r == \"self\":\n",
    "                    continue\n",
    "                triple = (umls_ent2word[h], r, umls_ent2word[t])\n",
    "                triple_set_umls.add(triple)\n",
    "    \n",
    "    triple_list_gpt = [*triple_set_gpt]\n",
    "    triple_list_umls = [*triple_set_umls]\n",
    "\n",
    "    with open(patient_gpt_graph, 'w', newline='') as file:\n",
    "        writer = csv.writer(file)\n",
    "        writer.writerow([\"head\", \"relation\", \"tail\"])\n",
    "        for triple in triple_list_gpt:\n",
    "            writer.writerow([triple[0], triple[1], triple[2]])\n",
    "    \n",
    "    with open(patient_umls_graph, 'w', newline='') as file:\n",
    "        writer = csv.writer(file)\n",
    "        writer.writerow([\"head\", \"relation\", \"tail\"])\n",
    "        for triple in triple_list_umls:\n",
    "            writer.writerow([triple[0], triple[1], triple[2]])\n",
    "\n",
    "    ### END Write graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('kgc')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3d0509d9aa81f2882b18eeb72d4d23c32cae9029e9b99f63cde94ba86c35ac78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
