{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "import pickle\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "filepath = Path(\"Procedure-Protein-Mapping/Output-Data/test_with_predicted_proteins.csv\")\n",
    "data = pd.read_csv(filepath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>patient_id</th>\n",
       "      <th>visit</th>\n",
       "      <th>diagnoses</th>\n",
       "      <th>procedures</th>\n",
       "      <th>medications</th>\n",
       "      <th>proteins</th>\n",
       "      <th>SYMPTOMS</th>\n",
       "      <th>predicted_proteins</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>6179</td>\n",
       "      <td>1</td>\n",
       "      <td>['4373', 'V103', '71690', '4019', '2720', '7810']</td>\n",
       "      <td>['3972', '8841']</td>\n",
       "      <td>['N02B', 'A12C', 'A01A', 'C10A', 'A06A', 'C02D...</td>\n",
       "      <td>['PROTEIN:4876', 'PROTEIN:11410', 'PROTEIN:132...</td>\n",
       "      <td>['Unknown', 'Multiple and unspecified open wou...</td>\n",
       "      <td>['PROTEIN:21977']</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>213</td>\n",
       "      <td>7</td>\n",
       "      <td>['99812', '51881', '5856', '99681', '40391', '...</td>\n",
       "      <td>['3995', '5491']</td>\n",
       "      <td>['A07A', 'N02B', 'B01A', 'A01A', 'C08C', 'C10A...</td>\n",
       "      <td>['PROTEIN:3897', 'PROTEIN:6839']</td>\n",
       "      <td>['Unknown', 'Unknown']</td>\n",
       "      <td>['PROTEIN:21977', 'PROTEIN:3775', 'PROTEIN:77'...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3856</td>\n",
       "      <td>1</td>\n",
       "      <td>['5551', '486', '56981', '5119', '5180', '9974...</td>\n",
       "      <td>['4582', '415', '4592', '3491']</td>\n",
       "      <td>['B05C', 'A12A', 'A12C', 'M01A', 'N01A', 'N02B...</td>\n",
       "      <td>['PROTEIN:3775', 'PROTEIN:6839', 'PROTEIN:1661...</td>\n",
       "      <td>['Unknown', 'Unknown', 'Compression of vein', ...</td>\n",
       "      <td>['PROTEIN:14889', 'PROTEIN:4358', 'PROTEIN:482...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2907</td>\n",
       "      <td>6</td>\n",
       "      <td>['51881', '42821', '4280', '49122', '29590', '...</td>\n",
       "      <td>['9671', '9604']</td>\n",
       "      <td>['A07A', 'A12B', 'C03C', 'B01A', 'C02D', 'A02B...</td>\n",
       "      <td>['PROTEIN:3897', 'PROTEIN:6839']</td>\n",
       "      <td>['Poisoning by chloral hydrate group', 'Poison...</td>\n",
       "      <td>['PROTEIN:21977', 'PROTEIN:20611', 'PROTEIN:49...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1576</td>\n",
       "      <td>3</td>\n",
       "      <td>['1125', '78552', '99592', '56722', '5720', '5...</td>\n",
       "      <td>['5491', '3897', '3893']</td>\n",
       "      <td>['A01A', 'A12A', 'B05C', 'A12C', 'C07A', 'N02B...</td>\n",
       "      <td>['PROTEIN:1321', 'PROTEIN:19534', 'PROTEIN:154...</td>\n",
       "      <td>['Unknown', 'Deaf, nonspeaking, not elsewhere ...</td>\n",
       "      <td>['PROTEIN:3775', 'PROTEIN:8012', 'PROTEIN:1346...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   patient_id  visit                                          diagnoses  \\\n",
       "0        6179      1  ['4373', 'V103', '71690', '4019', '2720', '7810']   \n",
       "1         213      7  ['99812', '51881', '5856', '99681', '40391', '...   \n",
       "2        3856      1  ['5551', '486', '56981', '5119', '5180', '9974...   \n",
       "3        2907      6  ['51881', '42821', '4280', '49122', '29590', '...   \n",
       "4        1576      3  ['1125', '78552', '99592', '56722', '5720', '5...   \n",
       "\n",
       "                        procedures  \\\n",
       "0                 ['3972', '8841']   \n",
       "1                 ['3995', '5491']   \n",
       "2  ['4582', '415', '4592', '3491']   \n",
       "3                 ['9671', '9604']   \n",
       "4         ['5491', '3897', '3893']   \n",
       "\n",
       "                                         medications  \\\n",
       "0  ['N02B', 'A12C', 'A01A', 'C10A', 'A06A', 'C02D...   \n",
       "1  ['A07A', 'N02B', 'B01A', 'A01A', 'C08C', 'C10A...   \n",
       "2  ['B05C', 'A12A', 'A12C', 'M01A', 'N01A', 'N02B...   \n",
       "3  ['A07A', 'A12B', 'C03C', 'B01A', 'C02D', 'A02B...   \n",
       "4  ['A01A', 'A12A', 'B05C', 'A12C', 'C07A', 'N02B...   \n",
       "\n",
       "                                            proteins  \\\n",
       "0  ['PROTEIN:4876', 'PROTEIN:11410', 'PROTEIN:132...   \n",
       "1                   ['PROTEIN:3897', 'PROTEIN:6839']   \n",
       "2  ['PROTEIN:3775', 'PROTEIN:6839', 'PROTEIN:1661...   \n",
       "3                   ['PROTEIN:3897', 'PROTEIN:6839']   \n",
       "4  ['PROTEIN:1321', 'PROTEIN:19534', 'PROTEIN:154...   \n",
       "\n",
       "                                            SYMPTOMS  \\\n",
       "0  ['Unknown', 'Multiple and unspecified open wou...   \n",
       "1                             ['Unknown', 'Unknown']   \n",
       "2  ['Unknown', 'Unknown', 'Compression of vein', ...   \n",
       "3  ['Poisoning by chloral hydrate group', 'Poison...   \n",
       "4  ['Unknown', 'Deaf, nonspeaking, not elsewhere ...   \n",
       "\n",
       "                                  predicted_proteins  \n",
       "0                                  ['PROTEIN:21977']  \n",
       "1  ['PROTEIN:21977', 'PROTEIN:3775', 'PROTEIN:77'...  \n",
       "2  ['PROTEIN:14889', 'PROTEIN:4358', 'PROTEIN:482...  \n",
       "3  ['PROTEIN:21977', 'PROTEIN:20611', 'PROTEIN:49...  \n",
       "4  ['PROTEIN:3775', 'PROTEIN:8012', 'PROTEIN:1346...  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from reasoner.deductive_reasoner import DeductiveReasoner\n",
    "from TruthValue import TruthValue\n",
    "\n",
    "reasoner = DeductiveReasoner()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_k = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1332/1332 [00:15<00:00, 83.26it/s]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "for index, row in tqdm(data.iterrows(), total=len(data)):\n",
    "    proteins = row['predicted_proteins']\n",
    "    proteins = json.loads(proteins.replace(\"'\", \"\\\"\"))\n",
    "    proteins = [(p, TruthValue(1.0, 0.9)) for p in proteins if reasoner.valid_protein(p)]\n",
    "    label = row['diagnoses']\n",
    "    label = json.loads(label.replace(\"'\", \"\\\"\"))\n",
    "    result = reasoner.deductive_reasoning(proteins)\n",
    "    result = sorted(result, key=lambda x: x[1].e, reverse=True)\n",
    "    result = result[:top_k]\n",
    "    results.append((index, result, label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(results, open(\"oregano/reasoning_results.pkl\", \"wb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# results_df = pd.DataFrame(results, columns=['index', 'results'])\n",
    "# results_df.set_index('index', inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = pickle.load(open(\"oregano/reasoning_results.pkl\", \"rb\"))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Top 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "overlapped_results = []\n",
    "for index, result, label in results:\n",
    "    overlapped = set(r[0] for r in result) - set(label) \n",
    "    is_overlapped = len(overlapped) > 0\n",
    "    overlapped_results.append((index, is_overlapped))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 87.39%\n"
     ]
    }
   ],
   "source": [
    "acc_top5 = sum([r[1] for r in overlapped_results]) / len(overlapped_results) * 100\n",
    "print(f\"Accuracy: {acc_top5:.2f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Top 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "overlapped_results = []\n",
    "for index, result, label in results:\n",
    "    overlapped = set(r[0] for r in result[:1]) - set(label) \n",
    "    is_overlapped = len(overlapped) > 0\n",
    "    overlapped_results.append((index, is_overlapped))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 80.93%\n"
     ]
    }
   ],
   "source": [
    "acc_top1 = sum([r[1] for r in overlapped_results]) / len(overlapped_results) * 100\n",
    "print(f\"Accuracy: {acc_top1:.2f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Final Report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top 5 Accuracy: 87.39%\n",
      "Top 1 Accuracy: 80.93%\n"
     ]
    }
   ],
   "source": [
    "print(f\"Top 5 Accuracy: {acc_top5:.2f}%\")\n",
    "print(f\"Top 1 Accuracy: {acc_top1:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "RL",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
