{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "convertible-ethics",
   "metadata": {},
   "source": [
    "From the main directory \"Ontolearn\", run the commands for NCES data mentioned [here](https://ontolearn-docs-dice-group.netlify.app/usage/02_installation#download-external-files) to download pretrained models and datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "intended-bullet",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/nkouagou/.conda/envs/onto/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from ontolearn.concept_learner import NCES\n",
    "from ontolearn.knowledge_base import KnowledgeBase\n",
    "from owlapy.parser import DLSyntaxParser\n",
    "from owlapy.render import DLSyntaxObjectRenderer\n",
    "import sys\n",
    "sys.path.append(\"examples/\")\n",
    "from ontolearn.metrics import F1\n",
    "from quality_functions import quality\n",
    "import time\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "serial-might",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/nkouagou/.conda/envs/onto/lib/python3.10/site-packages/torch/cuda/__init__.py:141: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 6050). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)\n",
      "  return torch._C._cuda_getDeviceCount() > 0\n"
     ]
    }
   ],
   "source": [
    "nces = NCES(knowledge_base_path=\"../NCESData/family/family.owl\", quality_func=F1(), num_predictions=100, learner_name=\"SetTransformer\",\n",
    "     path_of_embeddings=\"../NCESData/family/embeddings/ConEx_entity_embeddings.csv\", load_pretrained=True, max_length=48, proj_dim=128, rnn_n_layers=2, drop_prob=0.1, num_heads=4, num_seeds=1, num_inds=32, pretrained_model_name=\"SetTransformer\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "precise-tobacco",
   "metadata": {},
   "outputs": [],
   "source": [
    "KB = KnowledgeBase(path=nces.knowledge_base_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "helpful-bonus",
   "metadata": {},
   "outputs": [],
   "source": [
    "dl_syntax_renderer = DLSyntaxObjectRenderer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "pleased-circular",
   "metadata": {},
   "outputs": [],
   "source": [
    "atomic_classes = [dl_syntax_renderer.render(a) for a in KB.ontology.classes_in_signature()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "published-supplement",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Brother',\n",
       " 'Male',\n",
       " 'PersonWithASibling',\n",
       " 'Child',\n",
       " 'Person',\n",
       " 'Daughter',\n",
       " 'Female',\n",
       " 'Father',\n",
       " 'Parent',\n",
       " 'Grandchild',\n",
       " 'Granddaughter',\n",
       " 'Grandfather',\n",
       " 'Grandparent',\n",
       " 'Grandmother',\n",
       " 'Grandson',\n",
       " 'Mother',\n",
       " 'Sister',\n",
       " 'Son']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "atomic_classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "greek-transition",
   "metadata": {},
   "outputs": [],
   "source": [
    "dl_parser = DLSyntaxParser(nces.kb_namespace)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "administrative-disorder",
   "metadata": {},
   "outputs": [],
   "source": [
    "brother = dl_parser.parse('Brother')\n",
    "daughter = dl_parser.parse('Daughter')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "instructional-syndrome",
   "metadata": {},
   "source": [
    "#### Input examples can be sets or lists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "coated-pressing",
   "metadata": {},
   "outputs": [],
   "source": [
    "pos = set(KB.individuals(brother)).union(set(KB.individuals(daughter)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "current-floor",
   "metadata": {},
   "outputs": [],
   "source": [
    "neg = set(KB.individuals())-set(pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "continental-march",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "120"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(neg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "sound-fitness",
   "metadata": {},
   "source": [
    "#### Prediction with SetTransformer (default model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "solar-desire",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Duration:  0.5292303562164307  seconds\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "node = list(nces.fit(pos, neg).best_predictions)[0]\n",
    "t1 = time.time()\n",
    "print(\"\\nDuration: \", t1-t0, \" seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "adult-valuation",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "82"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(set(KB.individuals(node.concept)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "collect-gothic",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<class 'ontolearn.search.NCESNode'> at 0x9fe2fe8\tBrother ⊔ Sister ⊔ Daughter\tQuality:1.0\tLength:5\t|Indv.|:82"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "chronic-horse",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 100.0%\n",
      "Precision: 100.0%\n",
      "Recall: 100.0%\n",
      "F1: 100.0%\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(100.0, 100.0, 100.0, 100.0)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "quality(KB, node.concept, pos, neg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "noble-psychiatry",
   "metadata": {},
   "source": [
    "### Ensemble prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "desirable-auction",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Duration:  0.6503381729125977  seconds\n"
     ]
    }
   ],
   "source": [
    "nces.pretrained_model_name = ['SetTransformer','GRU','LSTM']\n",
    "nces.refresh()\n",
    "t0 = time.time()\n",
    "node = list(nces.fit(pos, neg).best_predictions)[0]\n",
    "t1 = time.time()\n",
    "print(\"\\nDuration: \", t1-t0, \" seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "victorian-amateur",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 100.0%\n",
      "Precision: 100.0%\n",
      "Recall: 100.0%\n",
      "F1: 100.0%\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(100.0, 100.0, 100.0, 100.0)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "quality(KB, node.concept, pos, neg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "domestic-breakfast",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<class 'ontolearn.search.NCESNode'> at 0x9fe2c38\tBrother ⊔ Daughter\tQuality:1.0\tLength:3\t|Indv.|:82"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "careful-works",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<class 'ontolearn.search.NCESNode'> at 0x9fe2c38\tBrother ⊔ Daughter\tQuality:1.0\tLength:3\t|Indv.|:82,\n",
       " <class 'ontolearn.search.NCESNode'> at 0xa05c220\tBrother ⊔ Daughter\tQuality:1.0\tLength:3\t|Indv.|:82,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x9fe34b0\tBrother ⊔ Daughter\tQuality:1.0\tLength:3\t|Indv.|:82,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x9fe3ca8\tPersonWithASibling ⊔ Daughter ⊔ Sister\tQuality:1.0\tLength:5\t|Indv.|:82,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x9fe3e08\tBrother ⊔ Daughter\tQuality:1.0\tLength:3\t|Indv.|:82]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nces.best_predictions[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "capable-quality",
   "metadata": {},
   "source": [
    "### Complex learning problems, potentially without an exact solution"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "honest-empire",
   "metadata": {},
   "source": [
    "#### First learning problem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "novel-protest",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_individuals = set(KB.individuals())\n",
    "pos = set(random.sample(list(all_individuals), 150))\n",
    "remaining = all_individuals-pos\n",
    "neg = set(random.sample(list(remaining), min(100, len(remaining))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "amended-found",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['SetTransformer', 'GRU', 'LSTM']"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nces.pretrained_model_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "elder-fever",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Duration:  0.6927070617675781  seconds\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "node = list(nces.fit(pos, neg).best_predictions)[0]\n",
    "t1 = time.time()\n",
    "print(\"\\nDuration: \", t1-t0, \" seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "chinese-avatar",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<class 'ontolearn.search.NCESNode'> at 0xcb9e860\tPerson ⊔ Son\tQuality:0.85227\tLength:3\t|Indv.|:202"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "accessory-excess",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 74.25699999999999%\n",
      "Precision: 74.25699999999999%\n",
      "Recall: 100.0%\n",
      "F1: 85.227%\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(74.25699999999999, 74.25699999999999, 100.0, 85.227)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "quality(KB, node.concept, pos, neg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "varied-danger",
   "metadata": {},
   "source": [
    "#### Second learning problem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "cardiac-webmaster",
   "metadata": {},
   "outputs": [],
   "source": [
    "pos = set(random.sample(list(all_individuals), 80))\n",
    "remaining = all_individuals-pos\n",
    "neg = set(random.sample(list(remaining), min(150, len(remaining))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "fantastic-piece",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Duration:  0.7965126037597656  seconds\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "node = list(nces.fit(pos, neg).best_predictions)[0]\n",
    "t1 = time.time()\n",
    "print(\"\\nDuration: \", t1-t0, \" seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "pregnant-prague",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 57.921%\n",
      "Precision: 47.934%\n",
      "Recall: 72.5%\n",
      "F1: 57.711%\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(57.921, 47.934, 72.5, 57.711)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "quality(KB, node.concept, pos, neg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "alternative-lloyd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "latin-spanking",
   "metadata": {},
   "source": [
    "## Important note\n",
    "\n",
    "- Each of the synthesized expressions are not present in the knowledge base.\n",
    "- NCES synthesizes solutions by leveraging its experience on the training data.\n",
    "- The inputs (positive/negative examples) need not be balanced\n",
    "- NCES can solve multiple learning problems at the same time (through broadcasting on matrix operations in its neural network component), see nces_notebook1.ipynb\n",
    "- Since LSTM and GRU are not permutation-equivariant, we can get different but closely related solutions by shuflling the input examples for these architectures. For this, one needs to instantiate the NCES class with the attribute \"sorted_examples=False\" which is the case by default."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "anonymous-principle",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "onto",
   "language": "python",
   "name": "onto"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
