{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import rdflib\n",
    "from rdflib.namespace import SKOS, RDF, RDFS, Namespace\n",
    "\n",
    "# Load the Turtle file\n",
    "graph = rdflib.Graph()\n",
    "graph.parse(\"ICD9CM.ttl\", format=\"turtle\")\n",
    "\n",
    "# Define the UMLS namespace\n",
    "umls = Namespace(\"http://bioportal.bioontology.org/ontologies/umls/\")\n",
    "\n",
    "# Iterate through the graph to find ICD9 codes and UMLS CUIs\n",
    "icd9_to_cui_map = {}\n",
    "for subject in graph.subjects(RDF.type, rdflib.OWL.Class):\n",
    "    icd9_code = graph.value(subject, SKOS.notation)\n",
    "    umls_cui = graph.value(subject, umls.cui)\n",
    "\n",
    "    if icd9_code and umls_cui:\n",
    "        icd9_to_cui_map[str(icd9_code)] = str(umls_cui)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "# Print the mapping\n",
    "with open(\"icd9_to_umls_cui.json\", 'w') as f:\n",
    "    json.dump(icd9_to_cui_map, f, indent=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the Turtle file\n",
    "graph = rdflib.Graph()\n",
    "graph.parse(\"ATC.ttl\", format=\"turtle\")\n",
    "\n",
    "# Define the UMLS namespace\n",
    "umls = Namespace(\"http://bioportal.bioontology.org/ontologies/umls/\")\n",
    "\n",
    "# Iterate through the graph to find ATC codes and UMLS CUIs\n",
    "atc_to_cui_map = {}\n",
    "for subject in graph.subjects(RDF.type, rdflib.OWL.Class):\n",
    "    atc_code = graph.value(subject, SKOS.notation)\n",
    "    umls_cui = graph.value(subject, umls.cui)\n",
    "\n",
    "    if atc_code and umls_cui:\n",
    "        atc_to_cui_map[str(atc_code)] = str(umls_cui)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "# Print the mapping\n",
    "with open(\"atc_to_umls_cui.json\", 'w') as f:\n",
    "    json.dump(atc_to_cui_map, f, indent=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the Turtle file\n",
    "graph = rdflib.Graph()\n",
    "graph.parse(\"RXNORM.ttl\", format=\"turtle\")\n",
    "\n",
    "# Define the UMLS namespace\n",
    "umls = Namespace(\"http://bioportal.bioontology.org/ontologies/umls/\")\n",
    "\n",
    "# Iterate through the graph to find RxNorm codes and UMLS CUIs\n",
    "rxnorm_to_cui_map = {}\n",
    "for subject in graph.subjects(RDF.type, rdflib.OWL.Class):\n",
    "    rxnorm_code = graph.value(subject, SKOS.notation)\n",
    "    umls_cui = graph.value(subject, umls.cui)\n",
    "\n",
    "    if rxnorm_code and umls_cui:\n",
    "        rxnorm_to_cui_map[str(rxnorm_code)] = str(umls_cui)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "# Print the mapping\n",
    "with open(\"rxnorm_to_umls_cui.json\", 'w') as f:\n",
    "    json.dump(rxnorm_to_cui_map, f, indent=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import pandas as pd\n",
    "\n",
    "with open('./umls/umls.graph', 'rb') as f:\n",
    "    umls_g = pickle.load(f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.utils import to_networkx, from_networkx\n",
    "G_tg = from_networkx(umls_g) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2341070"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(umls_g.edges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./umls/umls.csv', 'r') as f:\n",
    "    lines_1 = f.readlines()\n",
    "\n",
    "# with open('./graph.txt', 'r') as f:\n",
    "#     lines_2 = f.readlines()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "triple_set = set()\n",
    "tuple_set = set()\n",
    "node_set = set()\n",
    "\n",
    "for line in lines_1:\n",
    "    items = line.split('\\t')\n",
    "    e1 = items[1]\n",
    "    r = items[0]\n",
    "    e2 = items[2]\n",
    "    triple_set.add((e1, r, e2))\n",
    "    tuple_set.add((e1, e2))\n",
    "    node_set.add(e1)\n",
    "    node_set.add(e2)\n",
    "    \n",
    "\n",
    "# for line in lines_2:\n",
    "#     items = line.split('\\t')\n",
    "#     e1 = items[0]\n",
    "#     r = items[1]\n",
    "#     e2 = items[2][:-1]\n",
    "#     if (e1, e2) not in tuple_set and (e2, e1) not in tuple_set:\n",
    "#         tuple_set.add((e1, e2))\n",
    "#         triple_set.add((e1, r, e2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1212586, 297927)"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(triple_set), len(node_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_str = \"\"\n",
    "\n",
    "for triple in triple_set:\n",
    "    out_str += triple[0] + '\\t' + triple[1] + '\\t' + triple[2] + '\\n'\n",
    "\n",
    "with open('./umls_graph.txt', 'w') as f:\n",
    "    f.write(out_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "ent_emb = np.load('./umls/ent_emb.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1024"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(ent_emb[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import rdflib\n",
    "from rdflib.namespace import SKOS, RDF, RDFS, Namespace\n",
    "import csv\n",
    "from collections import defaultdict\n",
    "\n",
    "# Load the Turtle file\n",
    "graph = rdflib.Graph()\n",
    "graph.parse(\"ICD9CM.ttl\", format=\"turtle\")\n",
    "\n",
    "# Define the UMLS namespace\n",
    "umls = Namespace(\"http://bioportal.bioontology.org/ontologies/umls/\")\n",
    "\n",
    "# Iterate through the graph to find ICD9 codes and UMLS CUIs\n",
    "icd9_to_cui_map = defaultdict(list)\n",
    "for subject in graph.subjects(RDF.type, rdflib.OWL.Class):\n",
    "    icd9_code = graph.value(subject, SKOS.notation)\n",
    "    umls_cui = graph.value(subject, umls.cui)\n",
    "\n",
    "    if icd9_code and umls_cui:\n",
    "        icd9_to_cui_map['ICD9CM'].append(str(icd9_code))\n",
    "        icd9_to_cui_map['UMLS'].append(str(umls_cui))\n",
    "\n",
    "with open('ICD9CM_to_UMLS.csv', 'w', newline='') as file:\n",
    "    writer = csv.writer(file)\n",
    "    writer.writerow(icd9_to_cui_map.keys())\n",
    "    writer.writerows(zip(*icd9_to_cui_map.values()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the Turtle file\n",
    "graph = rdflib.Graph()\n",
    "graph.parse(\"RXNORM.ttl\", format=\"turtle\")\n",
    "\n",
    "# Define the UMLS namespace\n",
    "umls = Namespace(\"http://bioportal.bioontology.org/ontologies/umls/\")\n",
    "\n",
    "# Iterate through the graph to find RxNorm codes and UMLS CUIs\n",
    "rxnorm_to_cui_map = defaultdict(list)\n",
    "for subject in graph.subjects(RDF.type, rdflib.OWL.Class):\n",
    "    rxnorm_code = graph.value(subject, SKOS.notation)\n",
    "    umls_cui = graph.value(subject, umls.cui)\n",
    "\n",
    "    if rxnorm_code and umls_cui:\n",
    "        rxnorm_to_cui_map['RxNorm'].append(str(rxnorm_code))\n",
    "        rxnorm_to_cui_map['UMLS'].append(str(umls_cui))\n",
    "\n",
    "with open('RxNorm_to_UMLS.csv', 'w', newline='') as file:\n",
    "    writer = csv.writer(file)\n",
    "    writer.writerow(rxnorm_to_cui_map.keys())\n",
    "    writer.writerows(zip(*rxnorm_to_cui_map.values()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the Turtle file\n",
    "graph = rdflib.Graph()\n",
    "graph.parse(\"ATC.ttl\", format=\"turtle\")\n",
    "\n",
    "# Define the UMLS namespace\n",
    "umls = Namespace(\"http://bioportal.bioontology.org/ontologies/umls/\")\n",
    "\n",
    "# Iterate through the graph to find ATC codes and UMLS CUIs\n",
    "atc_to_cui_map = defaultdict(list)\n",
    "for subject in graph.subjects(RDF.type, rdflib.OWL.Class):\n",
    "    atc_code = graph.value(subject, SKOS.notation)\n",
    "    umls_cui = graph.value(subject, umls.cui)\n",
    "\n",
    "    if atc_code and umls_cui:\n",
    "        atc_to_cui_map['ATC'].append(str(atc_code))\n",
    "        atc_to_cui_map['UMLS'].append(str(umls_cui))\n",
    "\n",
    "with open('ATC_to_UMLS.csv', 'w', newline='') as file:\n",
    "    writer = csv.writer(file)\n",
    "    writer.writerow(atc_to_cui_map.keys())\n",
    "    writer.writerows(zip(*atc_to_cui_map.values()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./RxNorm_to_ATC.csv', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "    lines = lines[1:]\n",
    "\n",
    "rxnorm_to_atc_map = {}\n",
    "for line in lines:\n",
    "    rxnorm, atc = line.split(',')\n",
    "    atc = atc[:-1]\n",
    "    rxnorm_to_atc_map[rxnorm] = atc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyhealth.medcode import CrossMap\n",
    "from collections import Counter\n",
    "# Load the Turtle file\n",
    "graph = rdflib.Graph()\n",
    "graph.parse(\"RXNORM.ttl\", format=\"turtle\")\n",
    "\n",
    "# Define the UMLS namespace\n",
    "umls = Namespace(\"http://bioportal.bioontology.org/ontologies/umls/\")\n",
    "\n",
    "# Iterate through the graph to find RxNorm codes and UMLS CUIs\n",
    "rxnorm_to_cui_map = defaultdict(list)\n",
    "mapping = CrossMap(\"RxNorm\", \"ATC\")\n",
    "for subject in graph.subjects(RDF.type, rdflib.OWL.Class):\n",
    "    rxnorm_code = graph.value(subject, SKOS.notation)\n",
    "    umls_cui = graph.value(subject, umls.cui)\n",
    "\n",
    "    if rxnorm_code and umls_cui:\n",
    "        rxnorm_to_cui_map['RxNorm'].append(str(rxnorm_code))\n",
    "        rxnorm_to_cui_map['UMLS'].append(str(umls_cui))\n",
    "\n",
    "with open('RxNorm_to_ATCUMLS.csv', 'w', newline='') as file:\n",
    "    writer = csv.writer(file)\n",
    "    writer.writerow(rxnorm_to_cui_map.keys())\n",
    "    writer.writerows(zip(*rxnorm_to_cui_map.values()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 104819/104819 [00:00<00:00, 261568.96it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "rxnorm_to_atc3_to_cui_map = defaultdict(list)\n",
    "\n",
    "mapping = CrossMap(\"RxNorm\", \"ATC\")\n",
    "for rxnorm_code in tqdm(rxnorm_to_cui_map['RxNorm']):\n",
    "    try:\n",
    "        atc3_code = Counter(mapping.map(str(rxnorm_code), target_kwargs={\"level\": 3})).most_common(1)[0][0]\n",
    "    except:\n",
    "        continue\n",
    "    umls_cui = atc_to_cui_map[atc3_code]\n",
    "    rxnorm_to_atc3_to_cui_map['RxNorm'].append(rxnorm_code)\n",
    "    rxnorm_to_atc3_to_cui_map['UMLS'].append(umls_cui)\n",
    "\n",
    "\n",
    "with open('RxNorm_to_ATC3_UMLS.csv', 'w', newline='') as file:\n",
    "    writer = csv.writer(file)\n",
    "    writer.writerow(rxnorm_to_atc3_to_cui_map.keys())\n",
    "    writer.writerows(zip(*rxnorm_to_atc3_to_cui_map.values()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./umls_graph.txt\", 'r') as f:\n",
    "    lines = f.readlines()\n",
    "\n",
    "triple_set = set()\n",
    "for line in lines:\n",
    "    h, r, t = line.split('\\t')\n",
    "    t = t[:-1]\n",
    "    triple_set.add((h,r,t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 22406/22406 [3:37:08<00:00,  1.72it/s]  \n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "entity_set = set()\n",
    "\n",
    "for idx in tqdm(range(len(icd9_to_cui_map['ICD9CM']))):\n",
    "    icd9_code_name = icd9_to_cui_map['ICD9CM'][idx].replace('.', '')\n",
    "    out_file = f'../graphs/condition/ICD9CM_large/{icd9_code_name}.txt'\n",
    "    out_str = \"\"\n",
    "    umls_cui = icd9_to_cui_map['UMLS'][idx]\n",
    "    for triple in triple_set:\n",
    "        if umls_cui in triple:\n",
    "            out_str += triple[0] + '\\t' + triple[1] + '\\t' + triple[2] + '\\n'\n",
    "    if out_str != \"\":\n",
    "        with open(out_file, 'w') as f:\n",
    "            f.write(out_str)\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 22406/22406 [00:00<00:00, 37044.05it/s]\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "node2idx = {}\n",
    "edge2idx = {}\n",
    "\n",
    "node_idx = 0\n",
    "edge_idx = 0\n",
    "for idx in tqdm(range(len(icd9_to_cui_map['ICD9CM']))):\n",
    "    icd9_code_name = icd9_to_cui_map['ICD9CM'][idx].replace('.', '')\n",
    "    feat_file = f'../graphs/condition/ICD9CM_large/{icd9_code_name}.txt'\n",
    "    try:\n",
    "        with open(feat_file, 'r') as f:\n",
    "            lines = f.readlines()\n",
    "    except:\n",
    "        continue\n",
    "\n",
    "    for line in lines:\n",
    "        h, r, t = line[:-1].split('\\t')\n",
    "        if h not in node2idx.keys():\n",
    "            node2idx[h] = node_idx\n",
    "            node_idx += 1\n",
    "        if t not in node2idx.keys():\n",
    "            node2idx[t] = node_idx\n",
    "            node_idx += 1\n",
    "        if r not in edge2idx.keys():\n",
    "            edge2idx[r] = edge_idx\n",
    "            edge_idx += 1\n",
    "\n",
    "\n",
    "out_file = f'../graphs/condition/ICD9CM_large/ent2id.json'\n",
    "with open(out_file, 'w') as f:\n",
    "    json.dump(node2idx, f, indent=6)\n",
    "\n",
    "out_file = f'../graphs/condition/ICD9CM_large/rel2id.json'\n",
    "with open(out_file, 'w') as f:\n",
    "    json.dump(edge2idx, f, indent=6)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "with open('./umls/umls.csv', 'r') as f:\n",
    "    lines_1 = f.readlines()\n",
    "\n",
    "triple_set = set()\n",
    "tuple_set = set()\n",
    "\n",
    "for line in lines_1:\n",
    "    items = line.split('\\t')\n",
    "    e1 = items[1]\n",
    "    r = items[0]\n",
    "    e2 = items[2]\n",
    "    triple_set.add((e1, r, e2))\n",
    "    tuple_set.add((e1, e2))\n",
    "\n",
    "# ent_emb = np.load('./umls/ent_emb.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 22406/22406 [1:59:43<00:00,  3.12it/s]  \n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "entity_set = set()\n",
    "\n",
    "for idx in tqdm(range(len(icd9_to_cui_map['ICD9CM']))):\n",
    "    icd9_code_name = icd9_to_cui_map['ICD9CM'][idx].replace('.', '')\n",
    "    out_file = f'../graphs/condition/ICD9CM_base/{icd9_code_name}.txt'\n",
    "    out_str = \"\"\n",
    "    umls_cui = icd9_to_cui_map['UMLS'][idx]\n",
    "    for triple in triple_set:\n",
    "        if umls_cui in triple:\n",
    "            out_str += triple[0] + '\\t' + triple[1] + '\\t' + triple[2] + '\\n'\n",
    "    out_str += umls_cui + '\\t' + 'self' + '\\t' + umls_cui + '\\n'\n",
    "    with open(out_file, 'w') as f:\n",
    "        f.write(out_str)\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('./umls/concepts.txt', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "\n",
    "ent2id = {}\n",
    "id_ = 0\n",
    "for line in lines:\n",
    "    ent2id[line[:-1]] = id_\n",
    "    id_ += 1\n",
    "\n",
    "with open('./umls/relations.txt', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "\n",
    "rel2id = {}\n",
    "id_ = 0\n",
    "for line in lines:\n",
    "    rel2id[line[:-1]] = id_\n",
    "    id_ += 1\n",
    "\n",
    "\n",
    "out_file = f'../graphs/condition/ICD9CM_base/ent2id.json'\n",
    "with open(out_file, 'w') as f:\n",
    "    json.dump(ent2id, f, indent=6)\n",
    "\n",
    "out_file = f'../graphs/condition/ICD9CM_base/rel2id.json'\n",
    "with open(out_file, 'w') as f:\n",
    "    json.dump(rel2id, f, indent=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "ent2word = {}\n",
    "\n",
    "with open('./umls/concept_names.txt', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "\n",
    "for line in lines:\n",
    "    ent, word = line[:-1].split('\\t')\n",
    "    ent2word[ent] = word\n",
    "\n",
    "with open('../graphs/condition/ICD9CM_base/ent2word.json', 'w') as f:\n",
    "    json.dump(ent2word, f, indent=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "\n",
    "with open('../graphs/condition/ICD9CM_base/ent2id.json', 'r') as f:\n",
    "    ent2id = json.load(f)\n",
    "\n",
    "with open('../graphs/condition/ICD9CM_base/rel2id.json', 'r') as f:\n",
    "    rel2id = json.load(f)\n",
    "\n",
    "ent_emb = np.load('../KG_mapping/umls/ent_emb.npy')\n",
    "\n",
    "with open('../KG_mapping/ICD9CM_to_UMLS.csv', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "    lines = lines[1:]\n",
    "\n",
    "with open('../KG_mapping/umls/umls.csv', 'r') as f:\n",
    "    lines_1 = f.readlines()\n",
    "\n",
    "triple_set = set()\n",
    "tuple_set = set()\n",
    "\n",
    "for line in lines_1:\n",
    "    items = line.split('\\t')\n",
    "    e1 = items[1]\n",
    "    r = items[0]\n",
    "    e2 = items[2]\n",
    "    triple_set.add((e1, r, e2))\n",
    "\n",
    "icd9_to_umls = {}\n",
    "for line in lines:\n",
    "    icd9cm, umls = line.split(',')\n",
    "    umls = umls[:-1]\n",
    "    icd9_to_umls[icd9cm.replace('.', '')] = umls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 19340/19340 [00:00<00:00, 36287.55it/s]\n"
     ]
    }
   ],
   "source": [
    "import glob\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "\n",
    "triple_files = glob.glob('../graphs/condition/ICD9CM_base/*.txt')\n",
    "node_set_all = set()\n",
    "edge_set_all = set()\n",
    "for triple_file in tqdm(triple_files):\n",
    "    # file_name = '../../../../data/pj20/graphs/condition/ICD9CM_base_ext/' + triple_file.split('/')[-1]\n",
    "    # if os.path.exists(file_name) == False:\n",
    "        # triple_str = \"\"\n",
    "    node_set = set()\n",
    "    with open(triple_file, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "    for line in lines:\n",
    "        h, r, t = line[:-1].split('\\t')\n",
    "        # node_set.add(h)\n",
    "        # node_set.add(t)\n",
    "        node_set_all.add(h)\n",
    "        edge_set_all.add(r)\n",
    "        node_set_all.add(t)\n",
    "        \n",
    "        # for node in node_set:\n",
    "        #     for triple in triple_set:\n",
    "        #         if node in triple:\n",
    "        #             h, r, t = triple\n",
    "        #             triple_str += h + '\\t' + r + '\\t' + t + '\\n'\n",
    "        #             node_set_all.add(h)\n",
    "        #             node_set_all.add(t)\n",
    "        #             edge_set_all.add(r)\n",
    "        \n",
    "        # with open(file_name, 'w') as f:\n",
    "        #     f.write(triple_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "73629"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(node_set_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "ent_emb_new = []\n",
    "ent2id_new = {}\n",
    "rel2id_new = {}\n",
    "\n",
    "idx = 0\n",
    "for node in node_set_all:\n",
    "    try:\n",
    "        ent_emb_new.append(ent_emb[ent2id[node]])\n",
    "        ent2id_new[node] = idx\n",
    "        idx += 1\n",
    "    except:\n",
    "        continue\n",
    "\n",
    "ent_emb_new = np.array(ent_emb_new)\n",
    "\n",
    "idx = 0\n",
    "for edge in edge_set_all:\n",
    "    rel2id_new[edge] = idx\n",
    "    idx += 1\n",
    "\n",
    "\n",
    "with open('../graphs/condition/ICD9CM_base/ent2id_new.json', 'w') as f:\n",
    "    json.dump(ent2id_new, f, indent=6)\n",
    "\n",
    "with open('../graphs/condition/ICD9CM_base/rel2id_new.json', 'w') as f:\n",
    "    json.dump(rel2id_new, f, indent=6)\n",
    "\n",
    "np.save(arr=ent_emb_new, file='../graphs/condition/ICD9CM_base/ent_emb_new.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "id2ent_new = {value: key for key, value in ent2id_new.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(62798, 62798, 65)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(ent2id_new), len(ent_emb_new), len(rel2id_new)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../graphs/condition/ICD9CM_base/id2ent_new.json', 'w') as f:\n",
    "    json.dump(id2ent_new, f, indent=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import glob\n",
    "# import os\n",
    "# from tqdm import tqdm\n",
    "# from multiprocessing import Pool, cpu_count\n",
    "\n",
    "\n",
    "# def process_triple_file(triple_file):\n",
    "#     file_name = '../../../../data/pj20/graphs/condition/ICD9CM_base_ext/' + os.path.basename(triple_file)\n",
    "#     if not os.path.exists(file_name):\n",
    "#         node_set = {h for h, r, t in triple_set} | {t for h, r, t in triple_set}\n",
    "#         triple_str = '\\n'.join([f'{h}\\t{r}\\t{t}' for h, r, t in triple_set if h in node_set and t in node_set])\n",
    "#         with open(file_name, 'w') as f:\n",
    "#             f.write(triple_str)\n",
    "#         return (node_set, {r for h, r, t in triple_set if h in node_set and t in node_set})\n",
    "#     else:\n",
    "#         return (set(), set())\n",
    "\n",
    "# triple_files = glob.glob('../../../../data/pj20/graphs/condition/ICD9CM_base/*.txt')\n",
    "# with Pool(cpu_count() - 1) as pool:\n",
    "#     results = list(tqdm(pool.imap(process_triple_file, triple_files), total=len(triple_files)))\n",
    "\n",
    "# node_set_all = set().union(*[r[0] for r in results])\n",
    "# edge_set_all = set().union(*[r[1] for r in results])\n",
    "\n",
    "# print(f'Number of unique nodes: {len(node_set_all)}')\n",
    "# print(f'Number of unique edges: {len(edge_set_all)}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 15/19340 [00:28<10:22:12,  1.93s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m/home/pj20/experiment/KG_mapping/process.ipynb Cell 37\u001b[0m in \u001b[0;36m<cell line: 10>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Bsunlab-serv-03.cs.illinois.edu/home/pj20/experiment/KG_mapping/process.ipynb#X51sdnNjb2RlLXJlbW90ZQ%3D%3D?line=28'>29</a>\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mlen\u001b[39m(node_triple_dict[key]) \u001b[39m<\u001b[39m\u001b[39m=\u001b[39m \u001b[39m5\u001b[39m \u001b[39mand\u001b[39;00m i \u001b[39m<\u001b[39m\u001b[39m=\u001b[39m \u001b[39mlen\u001b[39m(triple_set):\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Bsunlab-serv-03.cs.illinois.edu/home/pj20/experiment/KG_mapping/process.ipynb#X51sdnNjb2RlLXJlbW90ZQ%3D%3D?line=29'>30</a>\u001b[0m     \u001b[39mfor\u001b[39;00m triple \u001b[39min\u001b[39;00m triple_set:\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2Bsunlab-serv-03.cs.illinois.edu/home/pj20/experiment/KG_mapping/process.ipynb#X51sdnNjb2RlLXJlbW90ZQ%3D%3D?line=30'>31</a>\u001b[0m         \u001b[39mif\u001b[39;00m (key \u001b[39min\u001b[39;49;00m triple) \u001b[39mand\u001b[39;00m (triple \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m node_triple_dict[key]):\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Bsunlab-serv-03.cs.illinois.edu/home/pj20/experiment/KG_mapping/process.ipynb#X51sdnNjb2RlLXJlbW90ZQ%3D%3D?line=31'>32</a>\u001b[0m             node_triple_dict[key]\u001b[39m.\u001b[39mappend(triple)\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Bsunlab-serv-03.cs.illinois.edu/home/pj20/experiment/KG_mapping/process.ipynb#X51sdnNjb2RlLXJlbW90ZQ%3D%3D?line=32'>33</a>\u001b[0m         i \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import json\n",
    "from collections import defaultdict\n",
    "\n",
    "with open('../graphs/condition/ICD9CM_base/id2ent_new.json', 'r') as f:\n",
    "    id2ent_new = json.load(f)\n",
    "\n",
    "triple_files = glob.glob('../graphs/condition/ICD9CM_base/*.txt')\n",
    "store_dir = \"/data/pj20/graphs/umls_icd9_2hop/\"\n",
    "\n",
    "for triple_file in tqdm(triple_files):\n",
    "    triple_set_ = set()\n",
    "    node_triple_dict = defaultdict(list)\n",
    "    out_file = store_dir + triple_file.replace('../graphs/condition/ICD9CM_base/', '')\n",
    "    out_str = \"\"\n",
    "\n",
    "    with open(triple_file, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "    for line in lines:\n",
    "        h, r, t = line[:-1].split('\\t')\n",
    "        triple = (h, r, t)\n",
    "        triple_set_.add(triple)\n",
    "        node_triple_dict[h].append(triple)\n",
    "        node_triple_dict[t].append(triple)\n",
    "    \n",
    "    for key in node_triple_dict.keys():\n",
    "        i = 0\n",
    "\n",
    "        ## limit the extended 2-hop triples to 5\n",
    "        while len(node_triple_dict[key]) <= 5 and i <= len(triple_set):\n",
    "            for triple in triple_set:\n",
    "                if (key in triple) and (triple not in node_triple_dict[key]):\n",
    "                    node_triple_dict[key].append(triple)\n",
    "                i += 1\n",
    "\n",
    "    for key, triple_list in node_triple_dict.items():\n",
    "        for triple in triple_list:\n",
    "            h, r, t = triple\n",
    "            out_str += h + '\\t' + r + '\\t' + t + '\\n'\n",
    "        \n",
    "    \n",
    "    with open(out_file, 'w') as f:\n",
    "        f.write(out_str)\n",
    "\n",
    "\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/19340 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "36004.txt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import glob\n",
    "from tqdm import tqdm\n",
    "\n",
    "triple_files = glob.glob('../graphs/condition/ICD9CM_base/*.txt')\n",
    "for triple_file in tqdm(triple_files):\n",
    "    print(triple_file.replace('../graphs/condition/ICD9CM_base/', ''))\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('kgc')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3d0509d9aa81f2882b18eeb72d4d23c32cae9029e9b99f63cde94ba86c35ac78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
