{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "filepath = Path(\"Procedure-Protein-Mapping/Output-Data/test_with_predicted_proteins.csv\")\n",
    "data = pd.read_csv(filepath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from reasoner.deductive_reasoner import DeductiveReasoner\n",
    "from TruthValue import TruthValue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "reasoner = DeductiveReasoner()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_k = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1332/1332 [00:16<00:00, 81.11it/s]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "for index, row in tqdm(data.iterrows(), total=len(data)):\n",
    "    proteins = row['predicted_proteins']\n",
    "    proteins = json.loads(proteins.replace(\"'\", \"\\\"\"))\n",
    "    proteins = [(p, TruthValue(1.0, 0.9)) for p in proteins if reasoner.valid_protein(p)]\n",
    "    label = row['diagnoses']\n",
    "    label = json.loads(label.replace(\"'\", \"\\\"\"))\n",
    "    result = reasoner.deductive_reasoning(proteins)\n",
    "    result = sorted(result, key=lambda x: x[1].e, reverse=True)\n",
    "    result = result[:top_k]\n",
    "    results.append((index, result, label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "metadata": {},
   "outputs": [],
   "source": [
    "overlapping_rates = []\n",
    "for result in results:\n",
    "    pred = set(r[0] for r in result[1])\n",
    "    label = set(result[2])\n",
    "    overlapping_rate = len(pred&label)/len(pred|label)\n",
    "    overlapping_rates.append((overlapping_rate, result))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [],
   "source": [
    "overlapping_rate, (idx, pred, label) = max(overlapping_rates, key=lambda x: x[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 675"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = data.iloc[idx]\n",
    "procedures = sample[\"procedures\"]\n",
    "predicted_proteins = json.loads(sample[\"predicted_proteins\"].replace(\"'\", '\"'))\n",
    "proteins = [(p, TruthValue(1.0, 0.9)) for p in predicted_proteins if reasoner.valid_protein(p)]\n",
    "label = json.loads(sample[\"diagnoses\"].replace(\"'\", '\"'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = reasoner.deductive_reasoning(proteins)\n",
    "pred = result[:top_k]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('75839', <TruthValue: %1.00;0.90% (k=1)>),\n",
       " ('7560', <TruthValue: %1.00;0.90% (k=1)>),\n",
       " ('2530', <TruthValue: %1.00;0.90% (k=1)>),\n",
       " ('2727', <TruthValue: %1.00;0.90% (k=1)>),\n",
       " ('75989', <TruthValue: %1.00;0.90% (k=1)>)]"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['4271',\n",
       " '40391',\n",
       " '42731',\n",
       " '4254',\n",
       " '4240',\n",
       " '41401',\n",
       " '41042',\n",
       " 'V1251',\n",
       " '2859',\n",
       " 'V1271',\n",
       " '3051']"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"['3794', '3723', '8856', '9671', '3995', '9904']\""
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "procedures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['PROTEIN:21977',\n",
       " 'PROTEIN:4906',\n",
       " 'PROTEIN:3775',\n",
       " 'PROTEIN:4383',\n",
       " 'PROTEIN:6119',\n",
       " 'PROTEIN:77',\n",
       " 'PROTEIN:9683',\n",
       " 'PROTEIN:1346',\n",
       " 'PROTEIN:17994',\n",
       " 'PROTEIN:1530',\n",
       " 'PROTEIN:12505',\n",
       " 'PROTEIN:19793',\n",
       " 'PROTEIN:4417',\n",
       " 'PROTEIN:6489',\n",
       " 'PROTEIN:11974',\n",
       " 'PROTEIN:20115',\n",
       " 'PROTEIN:4358',\n",
       " 'PROTEIN:21279',\n",
       " 'PROTEIN:2706',\n",
       " 'PROTEIN:16610',\n",
       " 'PROTEIN:831',\n",
       " 'PROTEIN:1479',\n",
       " 'PROTEIN:1426',\n",
       " 'PROTEIN:9979',\n",
       " 'PROTEIN:10038',\n",
       " 'PROTEIN:256',\n",
       " 'PROTEIN:72',\n",
       " 'PROTEIN:11486',\n",
       " 'PROTEIN:1113',\n",
       " 'PROTEIN:19637',\n",
       " 'PROTEIN:9685',\n",
       " 'PROTEIN:19635',\n",
       " 'PROTEIN:13326',\n",
       " 'PROTEIN:4405',\n",
       " 'PROTEIN:11410',\n",
       " 'PROTEIN:8194',\n",
       " 'PROTEIN:14369',\n",
       " 'PROTEIN:15431',\n",
       " 'PROTEIN:6154',\n",
       " 'PROTEIN:17135',\n",
       " 'PROTEIN:7849',\n",
       " 'PROTEIN:987',\n",
       " 'PROTEIN:4820',\n",
       " 'PROTEIN:887',\n",
       " 'PROTEIN:830',\n",
       " 'PROTEIN:3585',\n",
       " 'PROTEIN:8782',\n",
       " 'PROTEIN:10889',\n",
       " 'PROTEIN:3834',\n",
       " 'PROTEIN:6885',\n",
       " 'PROTEIN:21294',\n",
       " 'PROTEIN:7559',\n",
       " 'PROTEIN:939',\n",
       " 'PROTEIN:8146',\n",
       " 'PROTEIN:6540',\n",
       " 'PROTEIN:10036',\n",
       " 'PROTEIN:3385',\n",
       " 'PROTEIN:2870',\n",
       " 'PROTEIN:15412',\n",
       " 'PROTEIN:19317',\n",
       " 'PROTEIN:959',\n",
       " 'PROTEIN:1811',\n",
       " 'PROTEIN:10028',\n",
       " 'PROTEIN:1108',\n",
       " 'PROTEIN:3291',\n",
       " 'PROTEIN:9765',\n",
       " 'PROTEIN:16803']"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted_proteins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "172"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample['patient_id']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, layer2_result = reasoner.deductive_reasoning([(p, TruthValue(1.0, 0.9)) for p in predicted_proteins if reasoner.valid_protein(p)], return_intermediate_results=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1512"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(layer2_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('GENE:32600', <TruthValue: %1.00;0.90% (k=1)>),\n",
       " ('GENE:33174', <TruthValue: %1.00;0.90% (k=1)>),\n",
       " ('GENE:35094', <TruthValue: %1.00;0.90% (k=1)>)]"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sorted(layer2_result, key=lambda x: x[1].e, reverse=True)[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from icd9_to_text import get_description"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Autosomal deletions NEC\n"
     ]
    }
   ],
   "source": [
    "print(get_description('75839'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "RL",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
