{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "16fd3e8a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.2 0 Family2.owl\n",
      "precision:  0.1\n",
      "recall:  0.06936111111111112\n",
      "0.2 0 glycan.owl\n",
      "precision:  0.7\n",
      "recall:  0.5214285714285715\n",
      "0.2 0 glycordf.glycordf.14.owl.xml\n",
      "precision:  0.35\n",
      "recall:  0.1642857142857143\n",
      "0.4 0 Family2.owl\n",
      "precision:  0.1\n",
      "recall:  0.03469444444444444\n",
      "0.4 0 glycan.owl\n",
      "precision:  0.7\n",
      "recall:  0.3547619047619047\n",
      "0.4 0 glycordf.glycordf.14.owl.xml\n",
      "precision:  0.35\n",
      "recall:  0.15912698412698412\n",
      "0.6 0 Family2.owl\n",
      "precision:  0.1\n",
      "recall:  0.017583333333333333\n",
      "0.6 0 glycan.owl\n",
      "precision:  0.6\n",
      "recall:  0.21190476190476187\n",
      "0.6 0 glycordf.glycordf.14.owl.xml\n",
      "precision:  0.35\n",
      "recall:  0.12896825396825395\n",
      "0.8 0 Family2.owl\n",
      "precision:  0.05\n",
      "recall:  0.009027777777777777\n",
      "0.8 0 glycan.owl\n",
      "precision:  0.5\n",
      "recall:  0.18333333333333335\n",
      "0.8 0 glycordf.glycordf.14.owl.xml\n",
      "precision:  0.35\n",
      "recall:  0.06388888888888888\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (4,) + inhomogeneous part.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 68\u001b[0m\n\u001b[1;32m     64\u001b[0m             \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprecision: \u001b[39m\u001b[38;5;124m\"\u001b[39m, precision)\n\u001b[1;32m     65\u001b[0m             \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrecall: \u001b[39m\u001b[38;5;124m\"\u001b[39m, recall)\n\u001b[0;32m---> 68\u001b[0m our_data \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mDataFrame(np\u001b[38;5;241m.\u001b[39marray([our_precision,our_recall,[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmask \u001b[39m\u001b[38;5;132;01m{:.0f}\u001b[39;00m\u001b[38;5;124m%\u001b[39m\u001b[38;5;124m, DF-ALC\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;241m0.2\u001b[39m\u001b[38;5;241m*\u001b[39m(i\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;28mlen\u001b[39m(base_names)\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m100\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(base_names)\u001b[38;5;241m*\u001b[39m\u001b[38;5;28mlen\u001b[39m(mask_rates))],names\u001b[38;5;241m*\u001b[39m\u001b[38;5;28mlen\u001b[39m(mask_rates)])\u001b[38;5;241m.\u001b[39mT, columns \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPrecision\u001b[39m\u001b[38;5;124m\"\u001b[39m,\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRecall\u001b[39m\u001b[38;5;124m\"\u001b[39m,\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmask_rate\u001b[39m\u001b[38;5;124m\"\u001b[39m,\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOntology\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m     69\u001b[0m our_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPrecision\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m our_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPrecision\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mastype(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m     70\u001b[0m our_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRecall\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m our_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRecall\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mastype(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "\u001b[0;31mValueError\u001b[0m: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (4,) + inhomogeneous part."
     ]
    }
   ],
   "source": [
    "\n",
    "#%matplotlib inline\n",
    "# Evaluate with CQA task and show the results\n",
    "#import matplotlib.pyplot as plt\n",
    "\n",
    "import pickle\n",
    "import numpy as np\n",
    "from Evaluation import CQAnswering\n",
    "from model import DFALC\n",
    "import torch\n",
    "#import seaborn as sns \n",
    "import pandas as pd\n",
    "device = torch.device(\"cpu\")\n",
    "\n",
    "#sns.set_theme(style=\"whitegrid\", palette=\"pastel\")\n",
    "#sns.set(font_scale=5)\n",
    "depth=2\n",
    "\n",
    "mask_rates = [0.2,0.4,0.6,0.8]\n",
    "base_names = [\n",
    "    #'Family.owl',\n",
    "    'Family2.owl',\n",
    "    'glycan.owl',\n",
    "    'glycordf.glycordf.14.owl.xml',\n",
    "    # 'nifdys.neuroscience-information-framework-nif-dysfunction-ontlogy.14.owl.xml',\n",
    "    #'nihss.national-institutes-of-health-stroke-scale-ontology.11.owl.xml',\n",
    "    # 'ontodm-core.ontology-of-core-data-mining-entities.6.owl.xml',\n",
    "    #'sso.syndromic-surveillance-ontology.1.owl.xml',\n",
    "    \n",
    "]\n",
    "names = [\"Family\",\"Family2\",\"Glycan\",\"GlycoRDF\",\"Nihss\",\"Sso\"]#,\"Nifdys\",\"Ontodm\",,\n",
    "our_precision = []\n",
    "base_precision = []\n",
    "our_recall = []\n",
    "base_recall = []\n",
    "model_name = [\"Rule\"]\n",
    "for rate in mask_rates:\n",
    "    out_pathes = [\"Box2EL_output_alpha0.8/mask_\"+str(rate)+\"/\"]\n",
    "    info_path = \"Box2EL_output_alpha0.8/mask_\"+str(rate)+\"/\"\n",
    "    for idx, out_path in enumerate(out_pathes):\n",
    "        for file_name in base_names:\n",
    "            print(rate, idx, file_name)\n",
    "            cEmb = pickle.load(open(out_path+file_name+\".cEmb.pkl\",\"rb\"))\n",
    "            rEmb = pickle.load(open(out_path+file_name+\".rEmb.pkl\",\"rb\"))\n",
    "            masked_cEmb = np.load(open(info_path+file_name+\".masked_cEmb.npy\",\"rb\"))\n",
    "            masked_rEmb = np.load(open(info_path+file_name+\".masked_rEmb.npy\",\"rb\"))\n",
    "            true_cEmb = np.load(open(info_path+file_name+\".true_cEmb.npy\",\"rb\"))\n",
    "            true_rEmb = np.load(open(info_path+file_name+\".true_rEmb.npy\",\"rb\"))\n",
    "            c2id = pickle.load(open(info_path+file_name+\".c2id.pkl\",\"rb\"))\n",
    "            r2id = pickle.load(open(info_path+file_name+\".r2id.pkl\",\"rb\"))\n",
    "            i2id = pickle.load(open(info_path+file_name+\".i2id.pkl\",\"rb\"))\n",
    "            id2c = {i:c for c,i in c2id.items()}\n",
    "            id2r = {i:c for c,i in r2id.items()}\n",
    "            id2i = {i:c for c,i in i2id.items()}\n",
    "\n",
    "            model = DFALC({}, len(c2id), len(r2id), masked_cEmb, masked_rEmb, device,name=model_name[idx],loss_weight=0.5).to(device)\n",
    "            cqa = CQAnswering(\"input/\"+file_name+\".depth_\"+str(depth)+\".queries\", \"input/\"+file_name+\".depth_\"+str(depth)+\".answers\", c2id, r2id, i2id)\n",
    "            precision, recall = cqa.get_score(model, torch.tensor(cEmb), torch.tensor(rEmb), alpha=0.8)\n",
    "            if idx == 0:\n",
    "                our_precision.append(precision)\n",
    "                our_recall.append(recall)\n",
    "            else:\n",
    "                base_precision.append(precision)\n",
    "                base_recall.append(recall)\n",
    "            print(\"precision: \", precision)\n",
    "            print(\"recall: \", recall)\n",
    "\n",
    "\n",
    "our_data = pd.DataFrame(np.array([our_precision,our_recall,['mask {:.0f}%, DF-ALC'.format(0.2*(i//len(base_names)+1)*100) for i in range(len(base_names)*len(mask_rates))],names*len(mask_rates)]).T, columns = [\"Precision\",\"Recall\",\"mask_rate\",\"Ontology\"])\n",
    "our_data[\"Precision\"] = our_data[\"Precision\"].astype(\"float\")\n",
    "our_data[\"Recall\"] = our_data[\"Recall\"].astype(\"float\")\n",
    "base_data = pd.DataFrame(np.array([base_precision,base_recall,['mask {:.0f}%, Base'.format(0.2*(i//len(base_names)+1)*100) for i in range(len(base_names)*len(mask_rates))],names*len(mask_rates)]).T, columns = [\"Precision\",\"Recall\",\"mask_rate\",\"Ontology\"])\n",
    "base_data[\"Precision\"] = base_data[\"Precision\"].astype(\"float\")\n",
    "base_data[\"Recall\"] = base_data[\"Recall\"].astype(\"float\")\n",
    "'''\n",
    "fig, axes = plt.subplots(1,2,figsize=(65,13))\n",
    "axes1, axes2 = axes.flatten()\n",
    "ax1=sns.lineplot(data=our_data,x=\"Ontology\",y=\"Precision\",hue=\"mask_rate\",palette=[\"#00397E\"]*4,style=\"mask_rate\",ax=axes1,legend=False,linewidth=8)\n",
    "ax2=sns.lineplot(data=base_data,x=\"Ontology\",y=\"Precision\",hue=\"mask_rate\",palette=[\"#F66A2A\"]*4,style=\"mask_rate\",ax=axes1,legend=False,linewidth=8)\n",
    "ax1.set(ylim=(0,1))\n",
    "ax3=sns.lineplot(data=our_data,x=\"Ontology\",y=\"Recall\",hue=\"mask_rate\",palette=[\"#00397E\"]*4,style=\"mask_rate\",ax=axes2,linewidth=8)\n",
    "ax4=sns.lineplot(data=base_data,x=\"Ontology\",y=\"Recall\",hue=\"mask_rate\",palette=[\"#F66A2A\"]*4,style=\"mask_rate\",ax=axes2,linewidth=8)\n",
    "ax3.set(ylim=(0,1))\n",
    "leg = plt.legend(bbox_to_anchor=(1.02,1),loc=\"upper left\")\n",
    "for legobj in leg.legendHandles:\n",
    "    legobj.set_linewidth(8.0)\n",
    "plt.tight_layout(pad=0.05)\n",
    "# ax2.set(ylim=(0,1))\n",
    "# plt.show()\n",
    "fig.savefig(\"cqa.png\",dpi=400)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23c44645-6e6e-4e62-a079-c79f02646e63",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "b5d4ea6110d76bf407abdf3fc85b4f9a1bbb4f7f6454d667a509d28831b3322d"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
