{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "filepath = Path(\"Procedure-Protein-Mapping/Output-Data/test_with_predicted_proteins.csv\")\n",
    "data = pd.read_csv(filepath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>patient_id</th>\n",
       "      <th>visit</th>\n",
       "      <th>diagnoses</th>\n",
       "      <th>procedures</th>\n",
       "      <th>medications</th>\n",
       "      <th>proteins</th>\n",
       "      <th>SYMPTOMS</th>\n",
       "      <th>predicted_proteins</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>6179</td>\n",
       "      <td>1</td>\n",
       "      <td>['4373', 'V103', '71690', '4019', '2720', '7810']</td>\n",
       "      <td>['3972', '8841']</td>\n",
       "      <td>['N02B', 'A12C', 'A01A', 'C10A', 'A06A', 'C02D...</td>\n",
       "      <td>['PROTEIN:4876', 'PROTEIN:11410', 'PROTEIN:132...</td>\n",
       "      <td>['Unknown', 'Multiple and unspecified open wou...</td>\n",
       "      <td>['PROTEIN:21977']</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>213</td>\n",
       "      <td>7</td>\n",
       "      <td>['99812', '51881', '5856', '99681', '40391', '...</td>\n",
       "      <td>['3995', '5491']</td>\n",
       "      <td>['A07A', 'N02B', 'B01A', 'A01A', 'C08C', 'C10A...</td>\n",
       "      <td>['PROTEIN:3897', 'PROTEIN:6839']</td>\n",
       "      <td>['Unknown', 'Unknown']</td>\n",
       "      <td>['PROTEIN:21977', 'PROTEIN:3775', 'PROTEIN:77'...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3856</td>\n",
       "      <td>1</td>\n",
       "      <td>['5551', '486', '56981', '5119', '5180', '9974...</td>\n",
       "      <td>['4582', '415', '4592', '3491']</td>\n",
       "      <td>['B05C', 'A12A', 'A12C', 'M01A', 'N01A', 'N02B...</td>\n",
       "      <td>['PROTEIN:3775', 'PROTEIN:6839', 'PROTEIN:1661...</td>\n",
       "      <td>['Unknown', 'Unknown', 'Compression of vein', ...</td>\n",
       "      <td>['PROTEIN:14889', 'PROTEIN:4358', 'PROTEIN:482...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2907</td>\n",
       "      <td>6</td>\n",
       "      <td>['51881', '42821', '4280', '49122', '29590', '...</td>\n",
       "      <td>['9671', '9604']</td>\n",
       "      <td>['A07A', 'A12B', 'C03C', 'B01A', 'C02D', 'A02B...</td>\n",
       "      <td>['PROTEIN:3897', 'PROTEIN:6839']</td>\n",
       "      <td>['Poisoning by chloral hydrate group', 'Poison...</td>\n",
       "      <td>['PROTEIN:21977', 'PROTEIN:20611', 'PROTEIN:49...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1576</td>\n",
       "      <td>3</td>\n",
       "      <td>['1125', '78552', '99592', '56722', '5720', '5...</td>\n",
       "      <td>['5491', '3897', '3893']</td>\n",
       "      <td>['A01A', 'A12A', 'B05C', 'A12C', 'C07A', 'N02B...</td>\n",
       "      <td>['PROTEIN:1321', 'PROTEIN:19534', 'PROTEIN:154...</td>\n",
       "      <td>['Unknown', 'Deaf, nonspeaking, not elsewhere ...</td>\n",
       "      <td>['PROTEIN:3775', 'PROTEIN:8012', 'PROTEIN:1346...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   patient_id  visit                                          diagnoses  \\\n",
       "0        6179      1  ['4373', 'V103', '71690', '4019', '2720', '7810']   \n",
       "1         213      7  ['99812', '51881', '5856', '99681', '40391', '...   \n",
       "2        3856      1  ['5551', '486', '56981', '5119', '5180', '9974...   \n",
       "3        2907      6  ['51881', '42821', '4280', '49122', '29590', '...   \n",
       "4        1576      3  ['1125', '78552', '99592', '56722', '5720', '5...   \n",
       "\n",
       "                        procedures  \\\n",
       "0                 ['3972', '8841']   \n",
       "1                 ['3995', '5491']   \n",
       "2  ['4582', '415', '4592', '3491']   \n",
       "3                 ['9671', '9604']   \n",
       "4         ['5491', '3897', '3893']   \n",
       "\n",
       "                                         medications  \\\n",
       "0  ['N02B', 'A12C', 'A01A', 'C10A', 'A06A', 'C02D...   \n",
       "1  ['A07A', 'N02B', 'B01A', 'A01A', 'C08C', 'C10A...   \n",
       "2  ['B05C', 'A12A', 'A12C', 'M01A', 'N01A', 'N02B...   \n",
       "3  ['A07A', 'A12B', 'C03C', 'B01A', 'C02D', 'A02B...   \n",
       "4  ['A01A', 'A12A', 'B05C', 'A12C', 'C07A', 'N02B...   \n",
       "\n",
       "                                            proteins  \\\n",
       "0  ['PROTEIN:4876', 'PROTEIN:11410', 'PROTEIN:132...   \n",
       "1                   ['PROTEIN:3897', 'PROTEIN:6839']   \n",
       "2  ['PROTEIN:3775', 'PROTEIN:6839', 'PROTEIN:1661...   \n",
       "3                   ['PROTEIN:3897', 'PROTEIN:6839']   \n",
       "4  ['PROTEIN:1321', 'PROTEIN:19534', 'PROTEIN:154...   \n",
       "\n",
       "                                            SYMPTOMS  \\\n",
       "0  ['Unknown', 'Multiple and unspecified open wou...   \n",
       "1                             ['Unknown', 'Unknown']   \n",
       "2  ['Unknown', 'Unknown', 'Compression of vein', ...   \n",
       "3  ['Poisoning by chloral hydrate group', 'Poison...   \n",
       "4  ['Unknown', 'Deaf, nonspeaking, not elsewhere ...   \n",
       "\n",
       "                                  predicted_proteins  \n",
       "0                                  ['PROTEIN:21977']  \n",
       "1  ['PROTEIN:21977', 'PROTEIN:3775', 'PROTEIN:77'...  \n",
       "2  ['PROTEIN:14889', 'PROTEIN:4358', 'PROTEIN:482...  \n",
       "3  ['PROTEIN:21977', 'PROTEIN:20611', 'PROTEIN:49...  \n",
       "4  ['PROTEIN:3775', 'PROTEIN:8012', 'PROTEIN:1346...  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "proteins = data['proteins']\n",
    "proteins = proteins.apply(lambda x: json.loads(str(x).replace(\"'\", '\"')))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted_proteins = data['predicted_proteins']\n",
    "predicted_proteins = predicted_proteins.apply(lambda x: json.loads(str(x).replace(\"'\", '\"')))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "rates = []\n",
    "correct = []\n",
    "for y, label in zip(predicted_proteins, proteins):\n",
    "    y = set(y)\n",
    "    label = set(label)\n",
    "    correct.append(len(y - label) > 0)\n",
    "    rate = len(y & label)/len(y | label)\n",
    "    rates.append(rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overlapping rate: 0.03212547319078474\n"
     ]
    }
   ],
   "source": [
    "print(f\"Overlapping rate: {sum(rates)/len(rates)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Correctness rate: 0.8731231231231231\n"
     ]
    }
   ],
   "source": [
    "print(f\"Correctness rate: {sum(correct)/len(correct)}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "RL",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
