{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "878c94b1",
   "metadata": {},
   "source": [
    "# Backward Compatibility with AmpliGraph 1\n",
    "\n",
    "The main difference in the API of AmpliGraph 2 is how you import the models and evaluate performance.\n",
    "We still provide backward compatibility with the APIs of AmpliGraph 1 through the module ampligraph.compat."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "66ed1ede",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "import tensorflow as tf\n",
    "tf.get_logger().setLevel('ERROR')\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "815c51e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ampligraph\n",
    "# load the dataset\n",
    "from ampligraph.datasets import load_wn18rr\n",
    "X = load_wn18rr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3ea220fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metal device set to: Apple M1 Pro\n",
      "\n",
      "systemMemory: 32.00 GB\n",
      "maxCacheSize: 10.67 GB\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Import the models from ampligraph.compat\n",
    "# AmpliGraph 2 APIs support TransE, DistMult, ComplEx, HolE\n",
    "\n",
    "from ampligraph.compat import DistMult\n",
    "\n",
    "model = DistMult(batches_count=10, seed=0, epochs=500, k=350, eta=10,\n",
    "                    # Use adam optimizer with learning rate 1e-3\n",
    "                    optimizer='adam', optimizer_params={'lr':1e-3},\n",
    "                    # Use multiclass_nll loss \n",
    "                    loss='multiclass_nll', loss_params={},\n",
    "                    # Use L3 regularizer with regularizer weight 1e-3\n",
    "                    regularizer='LP', regularizer_params={'p':3, 'lambda':1e-3}, \n",
    "                    # Enable stdout messages (set to false if you don't want to display)\n",
    "                    verbose=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3dd122b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the filter\n",
    "filter = np.concatenate((X['train'], X['valid'][::10], X['test']))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "49ad1810",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/500\n",
      "11/11 [==============================] - 4s 324ms/step - loss: 10411.0664\n",
      "Epoch 2/500\n",
      "11/11 [==============================] - 2s 200ms/step - loss: 10409.5049\n",
      "Epoch 3/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 10407.1797\n",
      "Epoch 4/500\n",
      "11/11 [==============================] - 2s 193ms/step - loss: 10403.2314\n",
      "Epoch 5/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 10396.5605\n",
      "Epoch 6/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 10385.7422\n",
      "Epoch 7/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 10369.0137\n",
      "Epoch 8/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 10344.3242\n",
      "Epoch 9/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 10309.3848\n",
      "Epoch 10/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 10261.7119\n",
      "Epoch 11/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 10198.7959\n",
      "Epoch 12/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 10118.0898\n",
      "Epoch 13/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 10017.2520\n",
      "Epoch 14/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 9894.2715\n",
      "Epoch 15/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 9747.7959\n",
      "Epoch 16/500\n",
      "11/11 [==============================] - 2s 201ms/step - loss: 9577.3604\n",
      "Epoch 17/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 9383.7783\n",
      "Epoch 18/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 9169.2197\n",
      "Epoch 19/500\n",
      "11/11 [==============================] - 2s 199ms/step - loss: 8937.3066\n",
      "Epoch 20/500\n",
      "11/11 [==============================] - ETA: 0s - loss: 8692.6836\n",
      "73 triples containing invalid keys skipped!\n",
      "\n",
      "749 triples containing invalid keys skipped!\n",
      "4/4 [==============================] - 3s 745ms/step\n",
      "11/11 [==============================] - 5s 469ms/step - loss: 8692.6836 - val_mrr: 0.2393 - val_mr: 10607.1667 - val_hits@1: 0.1976 - val_hits@10: 0.3048 - val_hits@100: 0.3405\n",
      "Epoch 21/500\n",
      "11/11 [==============================] - 2s 200ms/step - loss: 8440.2305\n",
      "Epoch 22/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 8184.7773\n",
      "Epoch 23/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 7930.2988\n",
      "Epoch 24/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 7680.1445\n",
      "Epoch 25/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 7436.6362\n",
      "Epoch 26/500\n",
      "11/11 [==============================] - 2s 199ms/step - loss: 7201.5059\n",
      "Epoch 27/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 6975.7954\n",
      "Epoch 28/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 6760.0210\n",
      "Epoch 29/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 6554.4170\n",
      "Epoch 30/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 6358.8291\n",
      "Epoch 31/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 6172.9624\n",
      "Epoch 32/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 5996.5449\n",
      "Epoch 33/500\n",
      "11/11 [==============================] - 2s 199ms/step - loss: 5829.0117\n",
      "Epoch 34/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 5669.9492\n",
      "Epoch 35/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 5518.8384\n",
      "Epoch 36/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 5375.2651\n",
      "Epoch 37/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 5238.7280\n",
      "Epoch 38/500\n",
      "11/11 [==============================] - 2s 201ms/step - loss: 5108.7544\n",
      "Epoch 39/500\n",
      "11/11 [==============================] - 2s 201ms/step - loss: 4984.9268\n",
      "Epoch 40/500\n",
      "11/11 [==============================] - ETA: 0s - loss: 4866.8066\n",
      "73 triples containing invalid keys skipped!\n",
      "\n",
      "749 triples containing invalid keys skipped!\n",
      "4/4 [==============================] - 2s 601ms/step\n",
      "11/11 [==============================] - 5s 425ms/step - loss: 4866.8066 - val_mrr: 0.2444 - val_mr: 10529.9667 - val_hits@1: 0.2024 - val_hits@10: 0.2976 - val_hits@100: 0.3476\n",
      "Epoch 41/500\n",
      "11/11 [==============================] - 2s 203ms/step - loss: 4754.1548\n",
      "Epoch 42/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 4646.5786\n",
      "Epoch 43/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 4543.6880\n",
      "Epoch 44/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 4445.2539\n",
      "Epoch 45/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 4351.0322\n",
      "Epoch 46/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 4260.7495\n",
      "Epoch 47/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 4174.1572\n",
      "Epoch 48/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 4091.0640\n",
      "Epoch 49/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 4011.1919\n",
      "Epoch 50/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 3934.4226\n",
      "Epoch 51/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 3860.6382\n",
      "Epoch 52/500\n",
      "11/11 [==============================] - 2s 190ms/step - loss: 3789.5923\n",
      "Epoch 53/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 3721.1589\n",
      "Epoch 54/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 3655.2007\n",
      "Epoch 55/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 3591.5776\n",
      "Epoch 56/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 3530.1736\n",
      "Epoch 57/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 3470.8843\n",
      "Epoch 58/500\n",
      "11/11 [==============================] - 2s 201ms/step - loss: 3413.6216\n",
      "Epoch 59/500\n",
      "11/11 [==============================] - 2s 206ms/step - loss: 3358.2769\n",
      "Epoch 60/500\n",
      "11/11 [==============================] - ETA: 0s - loss: 3304.7117\n",
      "73 triples containing invalid keys skipped!\n",
      "\n",
      "749 triples containing invalid keys skipped!\n",
      "4/4 [==============================] - 2s 624ms/step\n",
      "11/11 [==============================] - 5s 428ms/step - loss: 3304.7117 - val_mrr: 0.2445 - val_mr: 10503.5452 - val_hits@1: 0.2024 - val_hits@10: 0.3024 - val_hits@100: 0.3524\n",
      "Epoch 61/500\n",
      "11/11 [==============================] - 2s 206ms/step - loss: 3252.8743\n",
      "Epoch 62/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 3202.7122\n",
      "Epoch 63/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 3154.0876\n",
      "Epoch 64/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 3106.9473\n",
      "Epoch 65/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 3061.2515\n",
      "Epoch 66/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 3016.9248\n",
      "Epoch 67/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 2973.8972\n",
      "Epoch 68/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 2932.1084\n",
      "Epoch 69/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 2891.5144\n",
      "Epoch 70/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 2852.1052\n",
      "Epoch 71/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 2813.7700\n",
      "Epoch 72/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 2776.4924\n",
      "Epoch 73/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 2740.1943\n",
      "Epoch 74/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 2704.9050\n",
      "Epoch 75/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 2670.5303\n",
      "Epoch 76/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 2637.0491\n",
      "Epoch 77/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 2604.4229\n",
      "Epoch 78/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 2572.6294\n",
      "Epoch 79/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 2541.6455\n",
      "Epoch 80/500\n",
      "11/11 [==============================] - ETA: 0s - loss: 2511.4175\n",
      "73 triples containing invalid keys skipped!\n",
      "\n",
      "749 triples containing invalid keys skipped!\n",
      "4/4 [==============================] - 2s 559ms/step\n",
      "11/11 [==============================] - 4s 405ms/step - loss: 2511.4175 - val_mrr: 0.2442 - val_mr: 10496.0857 - val_hits@1: 0.2024 - val_hits@10: 0.3024 - val_hits@100: 0.3548\n",
      "Epoch 81/500\n",
      "11/11 [==============================] - 2s 202ms/step - loss: 2481.9272\n",
      "Epoch 82/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 2453.1702\n",
      "Epoch 83/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 2425.1064\n",
      "Epoch 84/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 2397.6875\n",
      "Epoch 85/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 2370.9175\n",
      "Epoch 86/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 2344.7327\n",
      "Epoch 87/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 2319.1604\n",
      "Epoch 88/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 2294.1565\n",
      "Epoch 89/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 2269.7158\n",
      "Epoch 90/500\n",
      "11/11 [==============================] - 2s 204ms/step - loss: 2245.8376\n",
      "Epoch 91/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 2222.4568\n",
      "Epoch 92/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 2199.5872\n",
      "Epoch 93/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 2177.2122\n",
      "Epoch 94/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 2155.3091\n",
      "Epoch 95/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 2133.8701\n",
      "Epoch 96/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 2112.8484\n",
      "Epoch 97/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 2092.2544\n",
      "Epoch 98/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 2072.0894\n",
      "Epoch 99/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 2052.3276\n",
      "Epoch 100/500\n",
      "11/11 [==============================] - ETA: 0s - loss: 2032.9592\n",
      "73 triples containing invalid keys skipped!\n",
      "\n",
      "749 triples containing invalid keys skipped!\n",
      "4/4 [==============================] - 2s 599ms/step\n",
      "11/11 [==============================] - 5s 415ms/step - loss: 2032.9592 - val_mrr: 0.2443 - val_mr: 10477.2429 - val_hits@1: 0.2024 - val_hits@10: 0.3071 - val_hits@100: 0.3524\n",
      "Epoch 101/500\n",
      "11/11 [==============================] - 2s 200ms/step - loss: 2013.9766\n",
      "Epoch 102/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1995.3853\n",
      "Epoch 103/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1977.1498\n",
      "Epoch 104/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1959.2532\n",
      "Epoch 105/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1941.6775\n",
      "Epoch 106/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1924.4440\n",
      "Epoch 107/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 1907.5493\n",
      "Epoch 108/500\n",
      "11/11 [==============================] - 2s 200ms/step - loss: 1890.9412\n",
      "Epoch 109/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 1874.6334\n",
      "Epoch 110/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1858.6289\n",
      "Epoch 111/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1842.9141\n",
      "Epoch 112/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1827.4719\n",
      "Epoch 113/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1812.2948\n",
      "Epoch 114/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1797.3940\n",
      "Epoch 115/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 1782.7501\n",
      "Epoch 116/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 1768.3480\n",
      "Epoch 117/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1754.1951\n",
      "Epoch 118/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1740.2925\n",
      "Epoch 119/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1726.6057\n",
      "Epoch 120/500\n",
      "11/11 [==============================] - ETA: 0s - loss: 1713.1615\n",
      "73 triples containing invalid keys skipped!\n",
      "\n",
      "749 triples containing invalid keys skipped!\n",
      "4/4 [==============================] - 3s 633ms/step\n",
      "11/11 [==============================] - 5s 430ms/step - loss: 1713.1615 - val_mrr: 0.2472 - val_mr: 10462.5929 - val_hits@1: 0.2071 - val_hits@10: 0.3048 - val_hits@100: 0.3500\n",
      "Epoch 121/500\n",
      "11/11 [==============================] - 2s 200ms/step - loss: 1699.9281\n",
      "Epoch 122/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1686.9122\n",
      "Epoch 123/500\n",
      "11/11 [==============================] - 2s 201ms/step - loss: 1674.1143\n",
      "Epoch 124/500\n",
      "11/11 [==============================] - 2s 181ms/step - loss: 1661.5205\n",
      "Epoch 125/500\n",
      "11/11 [==============================] - 2s 198ms/step - loss: 1649.1191\n",
      "Epoch 126/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1636.9050\n",
      "Epoch 127/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1624.8959\n",
      "Epoch 128/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 1613.0625\n",
      "Epoch 129/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1601.4297\n",
      "Epoch 130/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 1589.9502\n",
      "Epoch 131/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1578.6691\n",
      "Epoch 132/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1567.5460\n",
      "Epoch 133/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1556.6003\n",
      "Epoch 134/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1545.7859\n",
      "Epoch 135/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1535.1605\n",
      "Epoch 136/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 1524.6825\n",
      "Epoch 137/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1514.3524\n",
      "Epoch 138/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1504.1636\n",
      "Epoch 139/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1494.1503\n",
      "Epoch 140/500\n",
      "11/11 [==============================] - ETA: 0s - loss: 1484.2710\n",
      "73 triples containing invalid keys skipped!\n",
      "\n",
      "749 triples containing invalid keys skipped!\n",
      "4/4 [==============================] - 2s 554ms/step\n",
      "11/11 [==============================] - 4s 400ms/step - loss: 1484.2710 - val_mrr: 0.2473 - val_mr: 10447.3381 - val_hits@1: 0.2071 - val_hits@10: 0.3048 - val_hits@100: 0.3548\n",
      "Epoch 141/500\n",
      "11/11 [==============================] - 2s 199ms/step - loss: 1474.5320\n",
      "Epoch 142/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1464.9152\n",
      "Epoch 143/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1455.4531\n",
      "Epoch 144/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1446.0973\n",
      "Epoch 145/500\n",
      "11/11 [==============================] - 2s 193ms/step - loss: 1436.8644\n",
      "Epoch 146/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1427.7543\n",
      "Epoch 147/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1418.7764\n",
      "Epoch 148/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1409.9091\n",
      "Epoch 149/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1401.1660\n",
      "Epoch 150/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1392.5348\n",
      "Epoch 151/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1384.0427\n",
      "Epoch 152/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1375.6371\n",
      "Epoch 153/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1367.3436\n",
      "Epoch 154/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1359.1583\n",
      "Epoch 155/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1351.0839\n",
      "Epoch 156/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1343.1071\n",
      "Epoch 157/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1335.2385\n",
      "Epoch 158/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1327.4606\n",
      "Epoch 159/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1319.7870\n",
      "Epoch 160/500\n",
      "11/11 [==============================] - ETA: 0s - loss: 1312.1954\n",
      "73 triples containing invalid keys skipped!\n",
      "\n",
      "749 triples containing invalid keys skipped!\n",
      "4/4 [==============================] - 2s 560ms/step\n",
      "11/11 [==============================] - 4s 402ms/step - loss: 1312.1954 - val_mrr: 0.2479 - val_mr: 10416.3095 - val_hits@1: 0.2071 - val_hits@10: 0.3071 - val_hits@100: 0.3548\n",
      "Epoch 161/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 1304.7108\n",
      "Epoch 162/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 1297.3075\n",
      "Epoch 163/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 1290.0109\n",
      "Epoch 164/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1282.7915\n",
      "Epoch 165/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1275.6677\n",
      "Epoch 166/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1268.6223\n",
      "Epoch 167/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1261.6553\n",
      "Epoch 168/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1254.7828\n",
      "Epoch 169/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1247.9882\n",
      "Epoch 170/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1241.2743\n",
      "Epoch 171/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1234.6219\n",
      "Epoch 172/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1228.0609\n",
      "Epoch 173/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1221.5646\n",
      "Epoch 174/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1215.1548\n",
      "Epoch 175/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1208.8032\n",
      "Epoch 176/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1202.5280\n",
      "Epoch 177/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1196.3253\n",
      "Epoch 178/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1190.1824\n",
      "Epoch 179/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1184.1156\n",
      "Epoch 180/500\n",
      "11/11 [==============================] - ETA: 0s - loss: 1178.1238\n",
      "73 triples containing invalid keys skipped!\n",
      "\n",
      "749 triples containing invalid keys skipped!\n",
      "4/4 [==============================] - 2s 564ms/step\n",
      "11/11 [==============================] - 4s 403ms/step - loss: 1178.1238 - val_mrr: 0.2545 - val_mr: 10405.9452 - val_hits@1: 0.2190 - val_hits@10: 0.3071 - val_hits@100: 0.3548\n",
      "Epoch 181/500\n",
      "11/11 [==============================] - 2s 199ms/step - loss: 1172.1846\n",
      "Epoch 182/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 1166.3132\n",
      "Epoch 183/500\n",
      "11/11 [==============================] - 2s 197ms/step - loss: 1160.5056\n",
      "Epoch 184/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1154.7552\n",
      "Epoch 185/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1149.0632\n",
      "Epoch 186/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1143.4471\n",
      "Epoch 187/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1137.8889\n",
      "Epoch 188/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1132.3876\n",
      "Epoch 189/500\n",
      "11/11 [==============================] - 2s 194ms/step - loss: 1126.9469\n",
      "Epoch 190/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1121.5671\n",
      "Epoch 191/500\n",
      "11/11 [==============================] - 2s 193ms/step - loss: 1116.2313\n",
      "Epoch 192/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1110.9479\n",
      "Epoch 193/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1105.7261\n",
      "Epoch 194/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1100.5529\n",
      "Epoch 195/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1095.4424\n",
      "Epoch 196/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1090.3757\n",
      "Epoch 197/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1085.3503\n",
      "Epoch 198/500\n",
      "11/11 [==============================] - 2s 195ms/step - loss: 1080.3785\n",
      "Epoch 199/500\n",
      "11/11 [==============================] - 2s 196ms/step - loss: 1075.4651\n",
      "Epoch 200/500\n",
      "11/11 [==============================] - ETA: 0s - loss: 1070.6000\n",
      "73 triples containing invalid keys skipped!\n",
      "\n",
      "749 triples containing invalid keys skipped!\n",
      "4/4 [==============================] - 2s 569ms/step\n",
      "11/11 [==============================] - 4s 407ms/step - loss: 1070.6000 - val_mrr: 0.2550 - val_mr: 10358.2190 - val_hits@1: 0.2190 - val_hits@10: 0.3071 - val_hits@100: 0.3571\n",
      "Restoring model weights from the end of the best epoch: 100.\n",
      "Epoch 200: early stopping\n"
     ]
    }
   ],
   "source": [
    "# Fit the model on training and validation set\n",
    "model.fit(X['train'][::2], \n",
    "          early_stopping = True,\n",
    "          early_stopping_params = \\\n",
    "                  {\n",
    "                      'x_valid': X['valid'][::10],  # validation set\n",
    "                      'criteria':'hits@10',         # Uses hits10 criteria for early stopping\n",
    "                      'burn_in': 20,                # early stopping kicks in after 100 epochs\n",
    "                      'check_interval':20,          # validates every 20th epoch\n",
    "                      'stop_interval':5,            # stops if 5 successive validation checks are bad.\n",
    "                      'x_filter': filter,           # Use filter for filtering out positives \n",
    "                      'corruption_entities':'all',  # corrupt using all entities\n",
    "                      'corrupt_side':'s'            # corrupt only subject\n",
    "                  }\n",
    "          )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0da6a76f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([['06845599', '_member_of_domain_usage', '03754979'],\n",
       "       ['00789448', '_verb_group', '01062739']], dtype=object)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_test = X['test']\n",
    "X_test[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "38786d62",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "1 triples containing invalid keys skipped!\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([-0.3750222], dtype=float32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Score assigned to unseen triples\n",
    "model.predict(X_test[:2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "39ec7559",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding size:  350\n",
      "\n",
      " Embedding vectors: \n",
      "[[ 0.11967714  0.14262393 -0.15412526  0.14064218 -0.03678143  0.17855963\n",
      "   0.13154885  0.1871646   0.15934017  0.13739511  0.19862601  0.13786295\n",
      "  -0.19559897 -0.19016445  0.16443959 -0.19838691  0.17837389  0.02906989\n",
      "  -0.19489937 -0.14237712 -0.16674249  0.17720278  0.06629281  0.10147594\n",
      "  -0.07981575 -0.15890078 -0.1647761  -0.11785773 -0.20817536  0.15641469\n",
      "   0.20843565  0.15639313  0.16039309 -0.15879703 -0.15178938 -0.16976918\n",
      "   0.17478164 -0.19067667 -0.12253983 -0.19105424 -0.14591898 -0.20406112\n",
      "  -0.17460953 -0.19648121 -0.09293007 -0.19525793  0.18996416  0.17143488\n",
      "   0.1854977   0.20067944 -0.18529464  0.1355615   0.06794422 -0.19307975\n",
      "  -0.18535511  0.1253285  -0.18829922  0.1228343   0.13849634 -0.18207617\n",
      "   0.19859134  0.16638105  0.20536643  0.18736392 -0.18714447  0.15228283\n",
      "   0.1753144   0.17433217 -0.17006505 -0.1720568   0.19801918  0.18579687\n",
      "  -0.18681724 -0.1791706   0.1113568   0.14846447 -0.20054623  0.20372821\n",
      "   0.18833119 -0.18584427 -0.1918742  -0.13139698  0.18179244  0.1737768\n",
      "   0.11331942 -0.20072089  0.18395062 -0.20536603  0.02118571  0.05126335\n",
      "  -0.19640496 -0.17096068 -0.164725   -0.17271374  0.17344357 -0.19628163\n",
      "  -0.1898071   0.1279031   0.19385721 -0.07791467 -0.20854901 -0.13680887\n",
      "   0.18527906 -0.152901    0.14481007  0.1523258   0.2111358  -0.18797067\n",
      "   0.18849637 -0.2073762   0.05238935 -0.2102062   0.04517871 -0.16955946\n",
      "  -0.17572236 -0.18688051  0.16608055 -0.05527462 -0.09782106  0.2001707\n",
      "   0.19277512 -0.19790159  0.15724504  0.1092148   0.18558525 -0.19788602\n",
      "  -0.14528981 -0.1950381   0.17656523 -0.19231224 -0.19559628  0.19494689\n",
      "   0.07714475  0.20289116  0.14949007  0.17374896  0.10056299 -0.17525047\n",
      "  -0.17265144  0.19744131 -0.19403248  0.20430258 -0.00727208 -0.2076957\n",
      "   0.18211046 -0.07284966  0.19741394  0.02125066  0.21229678  0.20274714\n",
      "   0.18729033 -0.12883253 -0.18245655  0.20008726 -0.16744806 -0.00822476\n",
      "  -0.1519935  -0.20154041 -0.15378582  0.1004666  -0.19832684 -0.00152403\n",
      "   0.18024774  0.09356381  0.1386365  -0.10284465 -0.20688115 -0.16602436\n",
      "  -0.19199525  0.18236728  0.18384357 -0.20078585  0.20912364  0.17682433\n",
      "   0.04122413  0.16256247  0.17337331 -0.2031825   0.17895767  0.1787043\n",
      "   0.08947121 -0.18562213  0.19931571 -0.1250084  -0.1486245   0.18204062\n",
      "  -0.19667365 -0.03249188 -0.07440095  0.1837779  -0.09474514 -0.18782851\n",
      "   0.14209504 -0.18933278  0.1042712   0.15894704  0.21559557  0.17742933\n",
      "  -0.08005217  0.14980444 -0.19101751 -0.20679839 -0.18883981  0.16262226\n",
      "  -0.21743062  0.14735585  0.19151884 -0.19164796  0.1560455   0.16957785\n",
      "   0.18727481  0.17819144  0.16179329 -0.19080769  0.16973712 -0.1925212\n",
      "   0.16374578 -0.1808627   0.16363224  0.18052916 -0.14257996  0.1618393\n",
      "   0.20270735  0.21307628  0.08400014  0.17235395 -0.17586367  0.20447017\n",
      "  -0.18412977  0.14006896 -0.19693156  0.03113016 -0.14300886  0.19441696\n",
      "   0.16641963 -0.08047516 -0.16613105  0.14991985  0.17333452  0.18397705\n",
      "  -0.18745625 -0.17148913 -0.17563117  0.05985169  0.20005651 -0.09213397\n",
      "   0.20170435 -0.09108246 -0.09319798  0.18700038 -0.16210344  0.16957417\n",
      "   0.18647636  0.12074988 -0.14370346  0.18065935 -0.1705291  -0.1358455\n",
      "   0.1341864  -0.15401235  0.11808699 -0.19629124  0.1530521   0.11844989\n",
      "   0.17775625 -0.09101854 -0.16421929 -0.10792024 -0.15569463 -0.1919804\n",
      "  -0.14689288 -0.19177453  0.16614464 -0.180337    0.15706272  0.16164859\n",
      "  -0.18557972  0.14301576 -0.19341838 -0.17645848  0.16774076 -0.14670373\n",
      "   0.08545575 -0.18902545  0.19527805 -0.17480174  0.20061219  0.20175354\n",
      "  -0.18648772 -0.13474452  0.18921259 -0.203916    0.16157116 -0.17338535\n",
      "   0.19284092  0.1854795   0.19398591 -0.19422098  0.16450556  0.18990421\n",
      "  -0.03154118  0.18817769 -0.18382524  0.03499205 -0.18768372  0.1996769\n",
      "  -0.19130495  0.02770057  0.19007386  0.14654544  0.18743734  0.20540932\n",
      "   0.16295843 -0.14572024  0.16793911 -0.20129976 -0.1366053   0.17850834\n",
      "   0.22011986  0.09955541 -0.10942172 -0.19518769  0.18671449  0.19170643\n",
      "  -0.04972033 -0.19841279  0.19350344  0.12070791  0.16050468 -0.18220672\n",
      "  -0.17314675  0.18021546  0.05014515  0.19465469 -0.10073806  0.18060473\n",
      "   0.02544012 -0.16447988 -0.19230923  0.20384873  0.09585697 -0.15488586\n",
      "   0.19120377 -0.17508115 -0.2021174   0.15325852 -0.20674117  0.15579496\n",
      "   0.15116681 -0.1127306 ]\n",
      " [ 0.26359528  0.1249386   0.28858092 -0.17797872  0.20497012  0.22269176\n",
      "   0.27832016  0.2757501   0.27046713  0.25969136  0.08387282 -0.19584669\n",
      "  -0.21931227  0.32106528  0.31492916  0.25374964 -0.22897394  0.3181942\n",
      "   0.25454542 -0.20019929 -0.23475999  0.06679947  0.2183339  -0.2894066\n",
      "   0.26247907 -0.11179897 -0.32501188  0.27682978 -0.32035124 -0.11812925\n",
      "  -0.3331316  -0.28532103 -0.31380764 -0.21089476 -0.24743614  0.3035787\n",
      "   0.30109754  0.21315086 -0.3102582   0.3206451   0.24750993 -0.23013237\n",
      "   0.13498528  0.2604506   0.2891886   0.06165736 -0.3223656   0.32921845\n",
      "  -0.2314669  -0.27268854 -0.14173894  0.24163432  0.24833779 -0.077958\n",
      "  -0.28186816  0.31856927 -0.24632494  0.29161862 -0.27775103 -0.27559993\n",
      "  -0.13994707  0.14485045 -0.31444263  0.22851521 -0.31769055  0.1438567\n",
      "  -0.23589028 -0.19769956 -0.26819453  0.29963103 -0.2857223   0.08439551\n",
      "  -0.27716926 -0.11211722  0.00091646 -0.12490457 -0.19723907  0.24744073\n",
      "   0.26289397  0.3004941   0.20986204 -0.29589865 -0.226554   -0.2014509\n",
      "   0.23785566  0.09087156 -0.14181933  0.08421248 -0.25986207  0.31696242\n",
      "  -0.3071655   0.22914991 -0.2737352   0.21047549  0.33670494  0.00460528\n",
      "  -0.22103903 -0.28319368  0.26737532 -0.30088496  0.26658863 -0.0166248\n",
      "   0.31870374  0.25784487 -0.23977266 -0.30406773 -0.19809067 -0.31597644\n",
      "  -0.24394882  0.27527156 -0.23096207  0.11502818  0.23679689 -0.3029175\n",
      "  -0.11502473 -0.15216485  0.28205407  0.26932096 -0.22406279 -0.21522948\n",
      "   0.23483907  0.31125176  0.25894076  0.19535172 -0.28060195  0.31772178\n",
      "  -0.23603833  0.27163854  0.3017834   0.19314533  0.14095597 -0.3237321\n",
      "   0.23730579  0.2491708  -0.31911972  0.04240919  0.3348101   0.21677844\n",
      "  -0.3220805   0.33414608 -0.2790371   0.27658615 -0.20900933 -0.3314625\n",
      "   0.26359507 -0.21445355  0.28395426  0.03873984 -0.03569955  0.18599904\n",
      "  -0.27797407  0.13169739 -0.12203855  0.29744348 -0.26150778 -0.14516914\n",
      "  -0.23998928  0.20108832  0.19799729 -0.30347157  0.04786517 -0.15913786\n",
      "  -0.24268943 -0.29918385  0.26591703 -0.17801045  0.187522   -0.2562872\n",
      "   0.32818586  0.30207413 -0.24459507  0.2326856   0.26462683 -0.07968622\n",
      "   0.23250486 -0.0713632  -0.32773966  0.29006663  0.2967593  -0.27702066\n",
      "   0.20602098 -0.27834156  0.02536987  0.24047309 -0.05963988  0.19400993\n",
      "  -0.26759937  0.20425174  0.26406324 -0.27043748 -0.2725764   0.32584906\n",
      "  -0.27253738  0.25177482 -0.14214556  0.31422865 -0.15071005  0.22032034\n",
      "  -0.27366078  0.31184897 -0.16175425 -0.26006597  0.29910457 -0.19137892\n",
      "   0.19285683  0.3061134  -0.308348   -0.31757486  0.23900005  0.19998355\n",
      "   0.03313541  0.08632842  0.26077834 -0.16638811  0.29526487 -0.2576779\n",
      "  -0.13505988  0.3116214  -0.33281446 -0.27673033 -0.3084193  -0.1544944\n",
      "   0.26886797 -0.29679534  0.2288939  -0.02863639 -0.10803284  0.24452958\n",
      "  -0.30263546 -0.29816934  0.23169205 -0.25762266 -0.26693684  0.33462974\n",
      "  -0.3045454   0.1804682   0.22978568  0.28253326  0.09517216 -0.1881876\n",
      "  -0.23260184 -0.27417582 -0.03281635  0.11797003  0.3031759   0.13671435\n",
      "  -0.1873772   0.17103979 -0.2287404  -0.13884926 -0.31018433  0.25720453\n",
      "   0.25405997  0.02159105 -0.2791893  -0.2575044  -0.2637025   0.27467427\n",
      "   0.20162524 -0.3148877  -0.2593904   0.04984489  0.2575763   0.28519237\n",
      "  -0.1660749   0.29068708  0.3063611  -0.22861944 -0.31932285 -0.17554168\n",
      "   0.22421476 -0.280646   -0.17823899  0.07885154  0.0124142  -0.02381873\n",
      "  -0.17261492 -0.01282863 -0.29386663  0.2191935  -0.24402273  0.27857628\n",
      "  -0.16756125 -0.24458432  0.32096735 -0.2837933  -0.31599298 -0.3198233\n",
      "   0.2723685  -0.0761544   0.31904408 -0.29810312 -0.10409454 -0.30488965\n",
      "  -0.13548215 -0.2572005   0.19196187  0.0912411  -0.16105731  0.3027988\n",
      "   0.13277508  0.29381382 -0.33378816 -0.23794146  0.28933227  0.31829256\n",
      "   0.27744126  0.25253007 -0.3187252  -0.3180473  -0.25687635 -0.01028104\n",
      "   0.11331479  0.19953014  0.31369498 -0.2545048   0.23001711 -0.30784523\n",
      "  -0.03809048  0.27521244  0.3078119   0.2932224   0.23608378  0.30375472\n",
      "  -0.22032784  0.17577317  0.3090387   0.25136158  0.10072201 -0.03726945\n",
      "  -0.19307207 -0.12359352 -0.00647417 -0.2640685  -0.27448252 -0.24473403\n",
      "   0.29297298  0.23404413 -0.29046378 -0.27104375  0.23150001  0.2638384\n",
      "   0.30425945  0.3098355   0.29045057 -0.25269663 -0.17774479 -0.2900341\n",
      "   0.20337999 -0.1987345 ]]\n"
     ]
    }
   ],
   "source": [
    "# Get embedding of entities\n",
    "embed = model.get_embeddings(['11647131','02518161'], embedding_type='entity')\n",
    "print('Embedding size: ', embed.shape[1])\n",
    "# Notice that the embedding size for ComplEx is double\n",
    "# compared to the k specified when initializing the model,\n",
    "# since ComplEx embeddings live in the space of complex numbers.\n",
    "print('\\n Embedding vectors: ')\n",
    "print(embed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "729d6de5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(33117, 11)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# get the entity and relation mappings to emb matrix\n",
    "ent_to_idx, rel_to_idx = model.get_hyperparameter_dict()\n",
    "len(ent_to_idx), len(rel_to_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "305beebf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "676 triples containing invalid keys skipped!\n",
      "\n",
      "749 triples containing invalid keys skipped!\n",
      "2249/2249 [==============================] - 38s 17ms/step\n",
      "MR: 9791.32028469751\n",
      "MRR: 0.266744577802399\n",
      "hits@1: 0.2257562277580071\n",
      "hits@10: 0.33451957295373663\n"
     ]
    }
   ],
   "source": [
    "# import the evaluate_performance API from compat module\n",
    "from ampligraph.compat import evaluate_performance\n",
    "ranks = evaluate_performance(X_test, model, filter_triples=filter, corrupt_side='s,o', verbose=True)\n",
    "\n",
    "# import the evaluation metrics\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "\n",
    "print('MR:', mr_score(ranks))\n",
    "print('MRR:', mrr_score(ranks))\n",
    "print('hits@1:', hits_at_n_score(ranks, 1))\n",
    "print('hits@10:', hits_at_n_score(ranks, 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1b67466c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING - Found untraced functions such as _get_ranks while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
     ]
    }
   ],
   "source": [
    "from ampligraph.utils import save_model\n",
    "# save the model\n",
    "save_model(model, 'backward_model')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f78f15cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved model does not include a db file. Skipping.\n"
     ]
    }
   ],
   "source": [
    "from ampligraph.utils import restore_model\n",
    "\n",
    "# restore saved models or checkpoints\n",
    "res_model = restore_model('backward_model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7f0f553b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "676 triples containing invalid keys skipped!\n",
      "\n",
      "749 triples containing invalid keys skipped!\n",
      "2249/2249 [==============================] - 40s 18ms/step\n",
      "MR: 9791.32028469751\n",
      "MRR: 0.266744577802399\n",
      "hits@1: 0.2257562277580071\n",
      "hits@10: 0.33451957295373663\n"
     ]
    }
   ],
   "source": [
    "# import the evaluate_performance API from compat module\n",
    "from ampligraph.compat import evaluate_performance\n",
    "ranks = evaluate_performance(X_test, res_model, filter_triples=filter, corrupt_side='s,o', verbose=True)\n",
    "\n",
    "# import the evaluation metrics\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "\n",
    "print('MR:', mr_score(ranks))\n",
    "print('MRR:', mrr_score(ranks))\n",
    "print('hits@1:', hits_at_n_score(ranks, 1))\n",
    "print('hits@10:', hits_at_n_score(ranks, 10))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fea2d0d",
   "metadata": {},
   "source": [
    "# Discovery\n",
    "The APIs for knowledge discovery can be imported from the ampligraph.discovery modules.\n",
    "They are designed to be backward compatible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6a24648d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([['/m/06w99h3', '/location/country/form_of_government', '/m/09nqf'],\n",
       "        ['/m/0fvf9q', '/location/country/form_of_government',\n",
       "         '/m/05b4l5x']], dtype=object),\n",
       " array([27.5, 47.5]))"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from ampligraph.discovery import discover_facts\n",
    "\n",
    "discover_facts(X['train'][:100], \n",
    "               res_model, \n",
    "               top_n=100, \n",
    "               strategy='entity_frequency', \n",
    "               max_candidates=100, \n",
    "               target_rel='/location/country/form_of_government', \n",
    "               seed=0)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "2e69f3670cdad0193847aaa0b77be56c05c951fcbdd384ff882dde0464f4de76"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
