{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "16fd3e8a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.2 0 Family.owl\n",
      "precision:  0.8123526455552336\n",
      "recall:  0.784583283641423\n",
      "0.2 0 Family2.owl\n",
      "precision:  0.3817576102207498\n",
      "recall:  0.6907062339966006\n",
      "0.2 0 glycordf.glycordf.14.owl.xml\n",
      "precision:  0.6560287294325595\n",
      "recall:  0.5723743386243386\n",
      "0.2 0 nihss.national-institutes-of-health-stroke-scale-ontology.11.owl.xml\n",
      "precision:  0.9804999999999999\n",
      "recall:  0.7525793650793651\n",
      "0.2 0 sso.syndromic-surveillance-ontology.1.owl.xml\n",
      "precision:  1.0\n",
      "recall:  0.815340909090909\n",
      "0.4 0 Family.owl\n"
     ]
    },
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: 'hGodel_S_output_alpha0.8/mask_0.4/Family.owl.cEmb.pkl'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 41\u001b[0m\n\u001b[1;32m     39\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m file_name \u001b[38;5;129;01min\u001b[39;00m base_names:\n\u001b[1;32m     40\u001b[0m     \u001b[38;5;28mprint\u001b[39m(rate, idx, file_name)\n\u001b[0;32m---> 41\u001b[0m     cEmb \u001b[38;5;241m=\u001b[39m pickle\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;28mopen\u001b[39m(out_path\u001b[38;5;241m+\u001b[39mfile_name\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.cEmb.pkl\u001b[39m\u001b[38;5;124m\"\u001b[39m,\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrb\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m     42\u001b[0m     rEmb \u001b[38;5;241m=\u001b[39m pickle\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;28mopen\u001b[39m(out_path\u001b[38;5;241m+\u001b[39mfile_name\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.rEmb.pkl\u001b[39m\u001b[38;5;124m\"\u001b[39m,\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrb\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m     43\u001b[0m     masked_cEmb \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;28mopen\u001b[39m(info_path\u001b[38;5;241m+\u001b[39mfile_name\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.masked_cEmb.npy\u001b[39m\u001b[38;5;124m\"\u001b[39m,\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrb\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n",
      "File \u001b[0;32m/data2/kristal/anaconda/lib/python3.11/site-packages/IPython/core/interactiveshell.py:310\u001b[0m, in \u001b[0;36m_modified_open\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m    303\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m}:\n\u001b[1;32m    304\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m    305\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIPython won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m by default \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    306\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    307\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myou can use builtins\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m open.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    308\u001b[0m     )\n\u001b[0;32m--> 310\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m io_open(file, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'hGodel_S_output_alpha0.8/mask_0.4/Family.owl.cEmb.pkl'"
     ]
    }
   ],
   "source": [
    "\n",
    "#%matplotlib inline\n",
    "# Evaluate with CQA task and show the results\n",
    "#import matplotlib.pyplot as plt\n",
    "\n",
    "import pickle\n",
    "import numpy as np\n",
    "from Evaluation import CQAnswering\n",
    "from model import DFALC\n",
    "import torch\n",
    "#import seaborn as sns \n",
    "import pandas as pd\n",
    "device = torch.device(\"cpu\")\n",
    "\n",
    "#sns.set_theme(style=\"whitegrid\", palette=\"pastel\")\n",
    "#sns.set(font_scale=5)\n",
    "depth=2\n",
    "\n",
    "mask_rates = [0.2,0.4,0.6,0.8]\n",
    "base_names = [\n",
    "    'Family.owl',\n",
    "    'Family2.owl',\n",
    "    'glycordf.glycordf.14.owl.xml',\n",
    "    # 'nifdys.neuroscience-information-framework-nif-dysfunction-ontlogy.14.owl.xml',\n",
    "    'nihss.national-institutes-of-health-stroke-scale-ontology.11.owl.xml',\n",
    "    # 'ontodm-core.ontology-of-core-data-mining-entities.6.owl.xml',\n",
    "    'sso.syndromic-surveillance-ontology.1.owl.xml',\n",
    "    \n",
    "]\n",
    "names = [\"Family\",\"Family2\",\"GlycoRDF\",\"Nihss\",\"Sso\"]#,\"Nifdys\",\"Ontodm\",,\n",
    "our_precision = []\n",
    "base_precision = []\n",
    "our_recall = []\n",
    "base_recall = []\n",
    "model_name = [\"Rule\"]\n",
    "for rate in mask_rates:\n",
    "    out_pathes = [\"hLukas_output_alpha0.8/mask_\"+str(rate)+\"/\"]\n",
    "    info_path = \"hLukas_output_alpha0.8/mask_\"+str(rate)+\"/\"\n",
    "    for idx, out_path in enumerate(out_pathes):\n",
    "        for file_name in base_names:\n",
    "            print(rate, idx, file_name)\n",
    "            cEmb = pickle.load(open(out_path+file_name+\".cEmb.pkl\",\"rb\"))\n",
    "            rEmb = pickle.load(open(out_path+file_name+\".rEmb.pkl\",\"rb\"))\n",
    "            masked_cEmb = np.load(open(info_path+file_name+\".masked_cEmb.npy\",\"rb\"))\n",
    "            masked_rEmb = np.load(open(info_path+file_name+\".masked_rEmb.npy\",\"rb\"))\n",
    "            true_cEmb = np.load(open(info_path+file_name+\".true_cEmb.npy\",\"rb\"))\n",
    "            true_rEmb = np.load(open(info_path+file_name+\".true_rEmb.npy\",\"rb\"))\n",
    "            c2id = pickle.load(open(info_path+file_name+\".c2id.pkl\",\"rb\"))\n",
    "            r2id = pickle.load(open(info_path+file_name+\".r2id.pkl\",\"rb\"))\n",
    "            i2id = pickle.load(open(info_path+file_name+\".i2id.pkl\",\"rb\"))\n",
    "            id2c = {i:c for c,i in c2id.items()}\n",
    "            id2r = {i:c for c,i in r2id.items()}\n",
    "            id2i = {i:c for c,i in i2id.items()}\n",
    "\n",
    "            model = DFALC({}, len(c2id), len(r2id), masked_cEmb, masked_rEmb, device,name=model_name[idx],loss_weight=0.5).to(device)\n",
    "            cqa = CQAnswering(\"input/\"+file_name+\".depth_\"+str(depth)+\".queries\", \"input/\"+file_name+\".depth_\"+str(depth)+\".answers\", c2id, r2id, i2id)\n",
    "            precision, recall = cqa.get_score(model, torch.tensor(cEmb), torch.tensor(rEmb), alpha=0.8)\n",
    "            if idx == 0:\n",
    "                our_precision.append(precision)\n",
    "                our_recall.append(recall)\n",
    "            else:\n",
    "                base_precision.append(precision)\n",
    "                base_recall.append(recall)\n",
    "            print(\"precision: \", precision)\n",
    "            print(\"recall: \", recall)\n",
    "\n",
    "\n",
    "our_data = pd.DataFrame(np.array([our_precision,our_recall,['mask {:.0f}%, DF-ALC'.format(0.2*(i//len(base_names)+1)*100) for i in range(len(base_names)*len(mask_rates))],names*len(mask_rates)]).T, columns = [\"Precision\",\"Recall\",\"mask_rate\",\"Ontology\"])\n",
    "our_data[\"Precision\"] = our_data[\"Precision\"].astype(\"float\")\n",
    "our_data[\"Recall\"] = our_data[\"Recall\"].astype(\"float\")\n",
    "base_data = pd.DataFrame(np.array([base_precision,base_recall,['mask {:.0f}%, Base'.format(0.2*(i//len(base_names)+1)*100) for i in range(len(base_names)*len(mask_rates))],names*len(mask_rates)]).T, columns = [\"Precision\",\"Recall\",\"mask_rate\",\"Ontology\"])\n",
    "base_data[\"Precision\"] = base_data[\"Precision\"].astype(\"float\")\n",
    "base_data[\"Recall\"] = base_data[\"Recall\"].astype(\"float\")\n",
    "'''\n",
    "fig, axes = plt.subplots(1,2,figsize=(65,13))\n",
    "axes1, axes2 = axes.flatten()\n",
    "ax1=sns.lineplot(data=our_data,x=\"Ontology\",y=\"Precision\",hue=\"mask_rate\",palette=[\"#00397E\"]*4,style=\"mask_rate\",ax=axes1,legend=False,linewidth=8)\n",
    "ax2=sns.lineplot(data=base_data,x=\"Ontology\",y=\"Precision\",hue=\"mask_rate\",palette=[\"#F66A2A\"]*4,style=\"mask_rate\",ax=axes1,legend=False,linewidth=8)\n",
    "ax1.set(ylim=(0,1))\n",
    "ax3=sns.lineplot(data=our_data,x=\"Ontology\",y=\"Recall\",hue=\"mask_rate\",palette=[\"#00397E\"]*4,style=\"mask_rate\",ax=axes2,linewidth=8)\n",
    "ax4=sns.lineplot(data=base_data,x=\"Ontology\",y=\"Recall\",hue=\"mask_rate\",palette=[\"#F66A2A\"]*4,style=\"mask_rate\",ax=axes2,linewidth=8)\n",
    "ax3.set(ylim=(0,1))\n",
    "leg = plt.legend(bbox_to_anchor=(1.02,1),loc=\"upper left\")\n",
    "for legobj in leg.legendHandles:\n",
    "    legobj.set_linewidth(8.0)\n",
    "plt.tight_layout(pad=0.05)\n",
    "# ax2.set(ylim=(0,1))\n",
    "# plt.show()\n",
    "fig.savefig(\"cqa.png\",dpi=400)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23c44645-6e6e-4e62-a079-c79f02646e63",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "b5d4ea6110d76bf407abdf3fc85b4f9a1bbb4f7f6454d667a509d28831b3322d"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
