{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "registered-vietnamese",
   "metadata": {},
   "source": [
    "From the main directory \"Ontolearn\", run the commands for NCES data mentioned [here](https://ontolearn-docs-dice-group.netlify.app/usage/02_installation#download-external-files) to download pretrained models and datasets."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ignored-brunswick",
   "metadata": {},
   "source": [
    "## Inference with NCES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "operational-boating",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/nkouagou/.conda/envs/onto/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from ontolearn.concept_learner import NCES\n",
    "from ontolearn.knowledge_base import KnowledgeBase\n",
    "from ontolearn.metrics import F1\n",
    "from owlapy.parser import DLSyntaxParser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "nonprofit-conditions",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/nkouagou/.conda/envs/onto/lib/python3.10/site-packages/torch/cuda/__init__.py:141: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 6050). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)\n",
      "  return torch._C._cuda_getDeviceCount() > 0\n"
     ]
    }
   ],
   "source": [
    "nces = NCES(knowledge_base_path=\"../NCESData/family/family.owl\", quality_func=F1(), num_predictions=100, learner_name=\"SetTransformer\",\n",
    "     path_of_embeddings=\"../NCESData/family/embeddings/ConEx_entity_embeddings.csv\", load_pretrained=True, max_length=48, proj_dim=128, rnn_n_layers=2, drop_prob=0.1, num_heads=4, num_seeds=1, num_inds=32, pretrained_model_name=\"SetTransformer\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "private-advocacy",
   "metadata": {},
   "outputs": [],
   "source": [
    "KB = KnowledgeBase(path=nces.knowledge_base_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "optional-persian",
   "metadata": {},
   "outputs": [],
   "source": [
    "dl_parser = DLSyntaxParser(nces.kb_namespace)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "worse-cursor",
   "metadata": {},
   "source": [
    "### Let's learn the concept ``Father''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "varied-laundry",
   "metadata": {},
   "outputs": [],
   "source": [
    "father = dl_parser.parse('Father')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "elder-mercury",
   "metadata": {},
   "outputs": [],
   "source": [
    "not_father = dl_parser.parse('¬Father') # For negative examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "precious-kernel",
   "metadata": {},
   "outputs": [],
   "source": [
    "pos = set([ind.str.split(\"/\")[-1] for ind in KB.individuals(father)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "oriental-brighton",
   "metadata": {},
   "outputs": [],
   "source": [
    "neg = set([ind.str.split(\"/\")[-1] for ind in KB.individuals(not_father)])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "flush-choir",
   "metadata": {},
   "source": [
    "#### Prediction with SetTransformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "interesting-sunglasses",
   "metadata": {},
   "outputs": [],
   "source": [
    "nodes = nces.fit(pos, neg).best_hypotheses(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "finished-schema",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<class 'ontolearn.search.NCESNode'> at 0x7485924\tFather\tQuality:1.0\tLength:1\t|Indv.|:60,\n",
       " <class 'ontolearn.search.NCESNode'> at 0xb329d38\tFather ⊔ Father\tQuality:1.0\tLength:3\t|Indv.|:60,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x7c09448\tGrandfather ⊔ Father\tQuality:1.0\tLength:3\t|Indv.|:60]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "radical-spectrum",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nces.sorted_examples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "norwegian-nudist",
   "metadata": {},
   "source": [
    "#### Prediction with GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "molecular-emperor",
   "metadata": {},
   "outputs": [],
   "source": [
    "nces.pretrained_model_name = 'GRU'\n",
    "nces.refresh()\n",
    "nodes = list(nces.fit(pos, neg).best_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "mechanical-japanese",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<class 'ontolearn.search.NCESNode'> at 0x7485954\tFather ⊔ Grandfather\tQuality:1.0\tLength:3\t|Indv.|:60,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x7c09948\tFather ⊔ Father\tQuality:1.0\tLength:3\t|Indv.|:60,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x7c27918\tFather ⊔ Father\tQuality:1.0\tLength:3\t|Indv.|:60,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x7c27944\tFather ⊔ Father ⊔ Grandfather\tQuality:1.0\tLength:5\t|Indv.|:60,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x7486610\tFather\tQuality:1.0\tLength:1\t|Indv.|:60]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nodes[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "certain-winning",
   "metadata": {},
   "source": [
    "#### Prediction with LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "separate-simon",
   "metadata": {},
   "outputs": [],
   "source": [
    "nces.pretrained_model_name = 'LSTM'\n",
    "nces.refresh()\n",
    "nodes = list(nces.fit(pos, neg).best_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "apart-constitutional",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<class 'ontolearn.search.NCESNode'> at 0x74859e0\tFather\tQuality:1.0\tLength:1\t|Indv.|:60,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x7c096b4\tFather ⊔ Father ⊔ Grandfather\tQuality:1.0\tLength:5\t|Indv.|:60,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x74685d4\tFather\tQuality:1.0\tLength:1\t|Indv.|:60,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x7c22228\tFather ⊔ (Person ⊓ (Grandfather ⊔ (∃ hasParent.(¬Parent))))\tQuality:1.0\tLength:10\t|Indv.|:60,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x7487460\tFather\tQuality:1.0\tLength:1\t|Indv.|:60]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nodes[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "streaming-transfer",
   "metadata": {},
   "source": [
    "#### Prediction with ensemble SetTransformer+GRU+LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "welsh-sunrise",
   "metadata": {},
   "outputs": [],
   "source": [
    "nces.pretrained_model_name = ['SetTransformer','GRU','LSTM']\n",
    "nces.refresh()\n",
    "nodes = list(nces.fit(pos, neg).best_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "dramatic-astronomy",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<class 'ontolearn.search.NCESNode'> at 0x7485980\tFather ⊔ Father\tQuality:1.0\tLength:3\t|Indv.|:60,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x7486500\tFather\tQuality:1.0\tLength:1\t|Indv.|:60,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x7472594\tFather\tQuality:1.0\tLength:1\t|Indv.|:60,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x7486460\tFather\tQuality:1.0\tLength:1\t|Indv.|:60,\n",
       " <class 'ontolearn.search.NCESNode'> at 0x7486490\tFather ⊔ Grandfather\tQuality:1.0\tLength:3\t|Indv.|:60]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nodes[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "involved-syndicate",
   "metadata": {},
   "source": [
    "### Scalability of NCES (solving multiple learning problems in a go!)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "olympic-feeling",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json, time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "uniform-shuttle",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"../NCESData/family/training_data/Data.json\") as file:\n",
    "    data = list(json.load(file).items())\n",
    "## The function below takes an iterable of tuples\n",
    "LPs = list(map(lambda x: (x[0], x[1][\"positive examples\"], x[1][\"negative examples\"]), data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "substantial-crossing",
   "metadata": {},
   "outputs": [],
   "source": [
    "## We solve 256 learning problems!\n",
    "lps = LPs[:256]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "bizarre-antibody",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Duration:  2.310948610305786 seconds!\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "concepts = nces.fit_from_iterable(lps, verbose=False) ## Now predict with nces.fit_from_iterable.\n",
    "t1 = time.time()\n",
    "print(\"Duration: \", t1-t0, \"seconds!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "increased-singapore",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OWLClass(IRI('http://www.benchmark.org/family#','Sister'))\n"
     ]
    }
   ],
   "source": [
    "print(concepts[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "naughty-chicago",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "256\n"
     ]
    }
   ],
   "source": [
    "print(len(concepts))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "inappropriate-separate",
   "metadata": {},
   "source": [
    "### Change pretrained model name, e.g., use ensemble model prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "transsexual-migration",
   "metadata": {},
   "outputs": [],
   "source": [
    "nces.pretrained_model_name = ['SetTransformer', 'GRU']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "amazing-mixer",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['SetTransformer', 'GRU']"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nces.pretrained_model_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "received-hartford",
   "metadata": {},
   "outputs": [],
   "source": [
    "nces.refresh()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "visible-arizona",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Duration:  3.00120210647583 seconds!\n"
     ]
    }
   ],
   "source": [
    "t0 = time.time()\n",
    "concepts = nces.fit_from_iterable(lps, verbose=False)\n",
    "t1 = time.time()\n",
    "print(\"Duration: \", t1-t0, \"seconds!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "close-poker",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OWLObjectAllValuesFrom(property=OWLObjectProperty(IRI('http://www.benchmark.org/family#','married')),filler=OWLObjectComplementOf(OWLClass(IRI('http://www.benchmark.org/family#','PersonWithASibling'))))"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "concepts[-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "sought-degree",
   "metadata": {},
   "source": [
    "## Training NCES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "structured-digit",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json, time\n",
    "with open(\"../NCESData/family/training_data/Data.json\") as file:\n",
    "    data = json.load(file) # Training data. Below we use the first 200 datapoints to train the synthesizer. You train on the full data, e.g. on GPU by remobing \"[:200]\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "lesser-ownership",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ontolearn.concept_learner import NCES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "golden-tulsa",
   "metadata": {},
   "outputs": [],
   "source": [
    "nces = NCES(knowledge_base_path=\"../NCESData/family/family.owl\", learner_name=\"SetTransformer\",\n",
    "     path_of_embeddings=\"../NCESData/family/embeddings/ConEx_entity_embeddings.csv\", max_length=48, proj_dim=128, rnn_n_layers=2, drop_prob=0.1, num_heads=4, num_seeds=1, num_inds=32,\n",
    "            load_pretrained=False, pretrained_model_name=\"SetTransformer\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "handed-thermal",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "********************Trainable model size********************\n",
      "Synthesizer:  515296\n",
      "********************Trainable model size********************\n",
      "\n",
      "Training on CPU, it may take long...\n",
      "\n",
      "##################################################\n",
      "\n",
      "SetTransformer starts training... \n",
      "\n",
      "################################################## \n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 0.5600, Soft Acc: 73.91%, Hard Acc: 70.67%: 100%|██████████| 20/20 [00:33<00:00,  1.66s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top performance: loss: 0.5600, soft accuracy: 74.71% ... hard accuracy: 72.45%\n",
      "\n",
      "SetTransformer saved\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "nces.train(list(data.items())[-200:], epochs=20, learning_rate=0.001, save_model=True, storage_path=f\"./NCES-{time.time()}/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "still-twist",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "mediterranean-crawford",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "onto",
   "language": "python",
   "name": "onto"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
