# Clustering and Classification using Knowledge Graph Embeddings

> **_NOTE:_**  **An interactive version of this tutorial is [available on Colab](https://colab.research.google.com/drive/1QUphvcFvNsWyRZM_J5ahsLhEHJY4SjyS).**

In this tutorial we will explore how to use the knowledge graph embeddings generated by a graph of international football matches (since the 19th century) in clustering and classification tasks. Knowledge graph embeddings are typically used for missing link prediction and knowledge discovery, but they can also be used for entity clustering, entity disambiguation, and other downstream tasks. The embeddings are a form of representation learning that allow linear algebra and machine learning to be applied to knowledge graphs, which otherwise would be difficult to do.


We will cover in this tutorial:

1. Creating the knowledge graph (i.e. triples) from a tabular dataset of football matches
2. Training the ComplEx embedding model on those triples
3. Evaluating the quality of the embeddings on a validation set
4. Clustering the embeddings, comparing to the natural clusters formed by the geographical continents
5. Applying the embeddings as features in classification task, to predict match results
6. Evaluating the predictive model on a out-of-time test set, comparing to a simple baseline

We will show that knowledge embedding clusters manage to capture implicit geographical information from the graph and that they can be a useful feature source for a downstream machine learning classification task, significantly increasing accuracy from the baseline.



## Requirements

A Python environment with the AmpliGraph library installed. Please follow the [install guide](http://docs.ampligraph.org/en/latest/install.html).

Some sanity check:


```python
import numpy as np
import pandas as pd
import ampligraph

ampligraph.__version__
```




    '2.0-dev'



## Dataset

We will use the [International football results from 1872 to 2019](https://www.kaggle.com/martj42/international-football-results-from-1872-to-2017) available on Kaggle (public domain). It contains over 40 thousand international football matches. Each row contains the following information:
1. Match date
2. Home team name
3. Away team name
4. Home score (goals including extra time)
5. Away score (goals including extra time)
6. Tournament (whether it was a friendly match or part of a tournament)
7. City where match took place
8. Country where match took place
9. Whether match was on neutral grounds

This dataset comes in a tabular format, therefore we will need to construct the knowledge graph ourselves.


```python
import requests
url = 'https://ampligraph.s3-eu-west-1.amazonaws.com/datasets/football_graph.csv'
open('football_results.csv', 'wb').write(requests.get(url).content)
```




    3033782




```python
df = pd.read_csv("football_results.csv").sort_values("date")
df.isna().sum()
```




    date          0
    home_team     0
    away_team     0
    home_score    2
    away_score    2
    tournament    0
    city          0
    country       0
    neutral       0
    dtype: int64



Dropping matches with unknown score:


```python
df = df.dropna()
```

The training set will be from 1872 to 2014, while the test set will be from 2014 to present date. Note that a temporal test set makes any machine learning task harder compared to a random shuffle.


```python
df["train"] = df.date < "2014-01-01"
df.train.value_counts()
```




    True     35714
    False     5057
    Name: train, dtype: int64



## Knowledge graph creation
We are going to create a knowledge graph from scratch based on the match information. The idea is that each match is an entity that will be connected to its participating teams, geography, characteristics, and results. 

The objective is to generate a new representation of the dataset where each data point is an triple in the form:

    <subject, predicate, object>
    
First we need to create the entities (subjects and objects) that will form the graph. We make sure teams and geographical information result in different entities (e.g. the Brazilian team and the corresponding country will be different).


```python
# Entities naming
df["match_id"] = df.index.values.astype(str)
df["match_id"] =  "Match" + df.match_id
df["city_id"] = "City" + df.city.str.title().str.replace(" ", "")
df["country_id"] = "Country" + df.country.str.title().str.replace(" ", "")
df["home_team_id"] = "Team" + df.home_team.str.title().str.replace(" ", "")
df["away_team_id"] = "Team" + df.away_team.str.title().str.replace(" ", "")
df["tournament_id"] = "Tournament" + df.tournament.str.title().str.replace(" ", "")
df["neutral"] = df.neutral.astype(str)
```

Then, we create the actual triples based on the relationship between the entities. We do it only for the triples in the training set (before 2014).


```python
triples = []
for _, row in df[df["train"]].iterrows():
    # Home and away information
    home_team = (row["home_team_id"], "isHomeTeamIn", row["match_id"])
    away_team = (row["away_team_id"], "isAwayTeamIn", row["match_id"])
    
    # Match results
    if row["home_score"] > row["away_score"]:
        score_home = (row["home_team_id"], "winnerOf", row["match_id"])
        score_away = (row["away_team_id"], "loserOf", row["match_id"])
    elif row["home_score"] < row["away_score"]:
        score_away = (row["away_team_id"], "winnerOf", row["match_id"])
        score_home = (row["home_team_id"], "loserOf", row["match_id"])
    else:
        score_home = (row["home_team_id"], "draws", row["match_id"])
        score_away = (row["away_team_id"], "draws", row["match_id"])
    home_score = (row["match_id"], "homeScores", np.clip(int(row["home_score"]), 0, 5))
    away_score = (row["match_id"], "awayScores", np.clip(int(row["away_score"]), 0, 5))
    
    # Match characteristics
    tournament = (row["match_id"], "inTournament", row["tournament_id"])
    city = (row["match_id"], "inCity", row["city_id"])
    country = (row["match_id"], "inCountry", row["country_id"])
    neutral = (row["match_id"], "isNeutral", row["neutral"])
    year = (row["match_id"], "atYear", row["date"][:4])
    
    triples.extend((home_team, away_team, score_home, score_away, 
                    tournament, city, country, neutral, year, home_score, away_score))
```

Note that we treat some literals (year, neutral match, home score, away score) as discrete entities and they will be part of the final knowledge graph used to generate the embeddings. We limit the number of score entities by clipping the score to be at most 5.

Below we provide a visualization of a subset of the graph related to the infamous [Maracanazo](https://en.wikipedia.org/wiki/Uruguay_v_Brazil_(1950_FIFA_World_Cup)):

![Football graph](img/FootballGraph.png)

The whole graph related to this match can be summarised by the triples below:


```python
triples_df = pd.DataFrame(triples, columns=["subject", "predicate", "object"])
triples_df[(triples_df.subject=="Match3129") | (triples_df.object=="Match3129")]
```




<div>
<style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>subject</th>
      <th>predicate</th>
      <th>object</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>34419</th>
      <td>TeamBrazil</td>
      <td>isHomeTeamIn</td>
      <td>Match3129</td>
    </tr>
    <tr>
      <th>34420</th>
      <td>TeamUruguay</td>
      <td>isAwayTeamIn</td>
      <td>Match3129</td>
    </tr>
    <tr>
      <th>34421</th>
      <td>TeamBrazil</td>
      <td>loserOf</td>
      <td>Match3129</td>
    </tr>
    <tr>
      <th>34422</th>
      <td>TeamUruguay</td>
      <td>winnerOf</td>
      <td>Match3129</td>
    </tr>
    <tr>
      <th>34423</th>
      <td>Match3129</td>
      <td>inTournament</td>
      <td>TournamentFifaWorldCup</td>
    </tr>
    <tr>
      <th>34424</th>
      <td>Match3129</td>
      <td>inCity</td>
      <td>CityRioDeJaneiro</td>
    </tr>
    <tr>
      <th>34425</th>
      <td>Match3129</td>
      <td>inCountry</td>
      <td>CountryBrazil</td>
    </tr>
    <tr>
      <th>34426</th>
      <td>Match3129</td>
      <td>isNeutral</td>
      <td>False</td>
    </tr>
    <tr>
      <th>34427</th>
      <td>Match3129</td>
      <td>atYear</td>
      <td>1950</td>
    </tr>
    <tr>
      <th>34428</th>
      <td>Match3129</td>
      <td>homeScores</td>
      <td>1</td>
    </tr>
    <tr>
      <th>34429</th>
      <td>Match3129</td>
      <td>awayScores</td>
      <td>2</td>
    </tr>
  </tbody>
</table>
</div>



## Training knowledge graph embeddings

We split our training dataset further into training and validation, where the new training set will be used to the knowledge embedding training and the validation set will be used in its evaluation. The test set will be used to evaluate the performance of the classification algorithm built on top of the embeddings.

What differs from the standard method of randomly sampling N points to make up our validation set is that our data points are two entities linked by some relationship, and we need to take care to ensure that all entities are represented in train and validation sets by at least one triple.

To accomplish this, AmpliGraph provides the [`train_test_split_no_unseen`](https://docs.ampligraph.org/en/latest/generated/ampligraph.evaluation.train_test_split_no_unseen.html#train-test-split-no-unseen) function.


```python
from ampligraph.evaluation import train_test_split_no_unseen 

X_train, X_test = train_test_split_no_unseen(np.array(triples), test_size=10000)
```


```python
print('Train set size: ', X_train.shape)
print('Test set size: ', X_test.shape)
```

    Train set size:  (382854, 3)
    Test set size:  (10000, 3)


AmpliGraph 2 has a unique class for defining [several](https://docs.ampligraph.org/en/latest/ampligraph.latent_features.html#knowledge-graph-embedding-models) Knoweldge Graph Embedding models (TransE, ComplEx, DistMult, HolE), it sufficies to specify the different scoring type. Together with that, at initialization time, we also need to define some parameters:
- **`k`** : the dimensionality of the embedding space;
- **`eta`** ($\eta$) : the number of negatives (i.e., false triples) that must be generated at training runtime for each positive (i.e., true triple).

We are going to use the [ComplEx](https://docs.ampligraph.org/en/latest/generated/ampligraph.latent_features.ComplEx.html#ampligraph.latent_features.ComplEx) model and use as hyperparameters those that gave the [best results](https://docs.ampligraph.org/en/latest/experiments.html) on some benchmark datasets.



```python
from ampligraph.latent_features import ScoringBasedEmbeddingModel
model = ScoringBasedEmbeddingModel(k=100,
                                   eta=20,
                                   scoring_type='ComplEx',
                                   seed=0)

```

Right after defining the model, it is time to compile the model, specifying:

- **`optimizer`** : we will use the Adam optimizer, with a learning rate of 1e-3, but AmpliGraph 2 supports any _tf.keras.optimizers_;
- **`loss`** : we will consider the pairwise loss, with a margin of 0.5 set via the *loss_params* kwarg. However, many other loss functions are supported, and custom losses can be defined by the user;
- **`regularizer`** : we will use the $L_p$ regularization with $p=2$, i.e. L2 regularization. The regularization parameter $\lambda$ = 1e-5 is set via the *regularizer_params* kwarg. Also in this case, _tf.keras.regularizers_ are supported.
- **`initializer`** : we will use the Glorot Uniform initialization, but the _tf.keras.initializers_ are supported.


```python
from ampligraph.latent_features.loss_functions import get as get_loss
from ampligraph.latent_features.regularizers import get as get_regularizer
from tensorflow.keras.optimizers import Adam

optimizer = Adam(learning_rate=1e-4)
loss = get_loss('multiclass_nll')
regularizer = get_regularizer('LP', {'p': 3, 'lambda': 1e-5})

model.compile(loss=loss,
              optimizer=optimizer,
              entity_relation_regularizer=regularizer,
              entity_relation_initializer='glorot_uniform')
```

Training should take around 10 minutes on a modern GPU:


```python
model.fit(X_train,
          batch_size=10000,
          epochs=200,
          verbose=True)
```

    Epoch 1/200


    2023-02-09 16:02:03.697289: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


    40/40 [==============================] - 11s 280ms/step - loss: 29887.3809
    Epoch 2/200
    40/40 [==============================] - 10s 241ms/step - loss: 29887.1445
    Epoch 3/200
    40/40 [==============================] - 10s 240ms/step - loss: 29886.8750
    Epoch 4/200
    40/40 [==============================] - 9s 236ms/step - loss: 29886.4453
    Epoch 5/200
    40/40 [==============================] - 10s 240ms/step - loss: 29885.7754
    Epoch 6/200
    40/40 [==============================] - 9s 237ms/step - loss: 29884.7070
    Epoch 7/200
    40/40 [==============================] - 10s 239ms/step - loss: 29883.0762
    Epoch 8/200
    40/40 [==============================] - 10s 238ms/step - loss: 29880.6836
    Epoch 9/200
    40/40 [==============================] - 9s 237ms/step - loss: 29877.3359
    Epoch 10/200
    40/40 [==============================] - 10s 241ms/step - loss: 29872.8359
    Epoch 11/200
    40/40 [==============================] - 10s 239ms/step - loss: 29866.9805
    Epoch 12/200
    40/40 [==============================] - 9s 237ms/step - loss: 29859.5527
    Epoch 13/200
    40/40 [==============================] - 10s 247ms/step - loss: 29850.3535
    Epoch 14/200
    40/40 [==============================] - 10s 240ms/step - loss: 29839.2012
    Epoch 15/200
    40/40 [==============================] - 9s 236ms/step - loss: 29825.9219
    Epoch 16/200
    40/40 [==============================] - 9s 236ms/step - loss: 29810.3203
    Epoch 17/200
    40/40 [==============================] - 9s 236ms/step - loss: 29792.2578
    Epoch 18/200
    40/40 [==============================] - 9s 237ms/step - loss: 29771.5781
    Epoch 19/200
    40/40 [==============================] - 9s 237ms/step - loss: 29748.1426
    Epoch 20/200
    40/40 [==============================] - 9s 236ms/step - loss: 29721.8145
    Epoch 21/200
    40/40 [==============================] - 9s 235ms/step - loss: 29692.4609
    Epoch 22/200
    40/40 [==============================] - 9s 237ms/step - loss: 29659.9688
    Epoch 23/200
    40/40 [==============================] - 9s 237ms/step - loss: 29624.2266
    Epoch 24/200
    40/40 [==============================] - 9s 236ms/step - loss: 29585.1367
    Epoch 25/200
    40/40 [==============================] - 9s 235ms/step - loss: 29542.5996
    Epoch 26/200
    40/40 [==============================] - 10s 238ms/step - loss: 29496.5371
    Epoch 27/200
    40/40 [==============================] - 9s 236ms/step - loss: 29446.8633
    Epoch 28/200
    40/40 [==============================] - 9s 236ms/step - loss: 29393.5000
    Epoch 29/200
    40/40 [==============================] - 10s 239ms/step - loss: 29336.4062
    Epoch 30/200
    40/40 [==============================] - 9s 236ms/step - loss: 29275.5078
    Epoch 31/200
    40/40 [==============================] - 9s 237ms/step - loss: 29210.7656
    Epoch 32/200
    40/40 [==============================] - 9s 236ms/step - loss: 29142.1250
    Epoch 33/200
    40/40 [==============================] - 9s 237ms/step - loss: 29069.5645
    Epoch 34/200
    40/40 [==============================] - 10s 240ms/step - loss: 28993.0352
    Epoch 35/200
    40/40 [==============================] - 10s 240ms/step - loss: 28912.5566
    Epoch 36/200
    40/40 [==============================] - 10s 239ms/step - loss: 28828.0762
    Epoch 37/200
    40/40 [==============================] - 9s 236ms/step - loss: 28739.5664
    Epoch 38/200
    40/40 [==============================] - 9s 236ms/step - loss: 28647.1191
    Epoch 39/200
    40/40 [==============================] - 9s 236ms/step - loss: 28550.6992
    Epoch 40/200
    40/40 [==============================] - 10s 239ms/step - loss: 28450.3242
    Epoch 41/200
    40/40 [==============================] - 10s 239ms/step - loss: 28346.0430
    Epoch 42/200
    40/40 [==============================] - 9s 236ms/step - loss: 28237.8711
    Epoch 43/200
    40/40 [==============================] - 9s 236ms/step - loss: 28125.9043
    Epoch 44/200
    40/40 [==============================] - 9s 236ms/step - loss: 28010.1875
    Epoch 45/200
    40/40 [==============================] - 10s 240ms/step - loss: 27890.7910
    Epoch 46/200
    40/40 [==============================] - 10s 240ms/step - loss: 27767.8320
    Epoch 47/200
    40/40 [==============================] - 10s 240ms/step - loss: 27641.3457
    Epoch 48/200
    40/40 [==============================] - 9s 236ms/step - loss: 27511.5410
    Epoch 49/200
    40/40 [==============================] - 10s 241ms/step - loss: 27378.5137
    Epoch 50/200
    40/40 [==============================] - 10s 238ms/step - loss: 27242.3848
    Epoch 51/200
    40/40 [==============================] - 10s 238ms/step - loss: 27103.3203
    Epoch 52/200
    40/40 [==============================] - 9s 237ms/step - loss: 26961.4766
    Epoch 53/200
    40/40 [==============================] - 9s 235ms/step - loss: 26817.0410
    Epoch 54/200
    40/40 [==============================] - 9s 235ms/step - loss: 26670.1855
    Epoch 55/200
    40/40 [==============================] - 9s 236ms/step - loss: 26521.0332
    Epoch 56/200
    40/40 [==============================] - 9s 236ms/step - loss: 26369.7598
    Epoch 57/200
    40/40 [==============================] - 9s 235ms/step - loss: 26216.5410
    Epoch 58/200
    40/40 [==============================] - 9s 237ms/step - loss: 26061.5840
    Epoch 59/200
    40/40 [==============================] - 10s 238ms/step - loss: 25905.0820
    Epoch 60/200
    40/40 [==============================] - 9s 237ms/step - loss: 25747.1875
    Epoch 61/200
    40/40 [==============================] - 10s 239ms/step - loss: 25588.0508
    Epoch 62/200
    40/40 [==============================] - 9s 237ms/step - loss: 25427.8145
    Epoch 63/200
    40/40 [==============================] - 9s 237ms/step - loss: 25266.7012
    Epoch 64/200
    40/40 [==============================] - 10s 239ms/step - loss: 25104.8340
    Epoch 65/200
    40/40 [==============================] - 9s 237ms/step - loss: 24942.4004
    Epoch 66/200
    40/40 [==============================] - 10s 238ms/step - loss: 24779.5020
    Epoch 67/200
    40/40 [==============================] - 10s 238ms/step - loss: 24616.2949
    Epoch 68/200
    40/40 [==============================] - 10s 240ms/step - loss: 24452.9141
    Epoch 69/200
    40/40 [==============================] - 9s 236ms/step - loss: 24289.4941
    Epoch 70/200
    40/40 [==============================] - 10s 242ms/step - loss: 24126.1426
    Epoch 71/200
    40/40 [==============================] - 10s 239ms/step - loss: 23962.9141
    Epoch 72/200
    40/40 [==============================] - 10s 238ms/step - loss: 23799.9766
    Epoch 73/200
    40/40 [==============================] - 9s 236ms/step - loss: 23637.4238
    Epoch 74/200
    40/40 [==============================] - 10s 240ms/step - loss: 23475.3418
    Epoch 75/200
    40/40 [==============================] - 10s 239ms/step - loss: 23313.8379
    Epoch 76/200
    40/40 [==============================] - 10s 238ms/step - loss: 23153.0195
    Epoch 77/200
    40/40 [==============================] - 9s 236ms/step - loss: 22992.8984
    Epoch 78/200
    40/40 [==============================] - 9s 236ms/step - loss: 22833.6055
    Epoch 79/200
    40/40 [==============================] - 9s 237ms/step - loss: 22675.1562
    Epoch 80/200
    40/40 [==============================] - 9s 237ms/step - loss: 22517.7363
    Epoch 81/200
    40/40 [==============================] - 9s 236ms/step - loss: 22361.2402
    Epoch 82/200
    40/40 [==============================] - 9s 234ms/step - loss: 22205.8477
    Epoch 83/200
    40/40 [==============================] - 9s 235ms/step - loss: 22051.6113
    Epoch 84/200
    40/40 [==============================] - 9s 234ms/step - loss: 21898.5215
    Epoch 85/200
    40/40 [==============================] - 9s 236ms/step - loss: 21746.6406
    Epoch 86/200
    40/40 [==============================] - 9s 234ms/step - loss: 21595.9570
    Epoch 87/200
    40/40 [==============================] - 9s 234ms/step - loss: 21446.5742
    Epoch 88/200
    40/40 [==============================] - 9s 234ms/step - loss: 21298.5117
    Epoch 89/200
    40/40 [==============================] - 9s 234ms/step - loss: 21151.7422
    Epoch 90/200
    40/40 [==============================] - 9s 235ms/step - loss: 21006.3789
    Epoch 91/200
    40/40 [==============================] - 10s 239ms/step - loss: 20862.4570
    Epoch 92/200
    40/40 [==============================] - 10s 239ms/step - loss: 20719.9062
    Epoch 93/200
    40/40 [==============================] - 9s 235ms/step - loss: 20578.7930
    Epoch 94/200
    40/40 [==============================] - 9s 234ms/step - loss: 20439.0879
    Epoch 95/200
    40/40 [==============================] - 9s 234ms/step - loss: 20300.8652
    Epoch 96/200
    40/40 [==============================] - 9s 234ms/step - loss: 20164.1035
    Epoch 97/200
    40/40 [==============================] - 9s 236ms/step - loss: 20028.8086
    Epoch 98/200
    40/40 [==============================] - 9s 234ms/step - loss: 19894.9043
    Epoch 99/200
    40/40 [==============================] - 9s 234ms/step - loss: 19762.5332
    Epoch 100/200
    40/40 [==============================] - 9s 234ms/step - loss: 19631.5977
    Epoch 101/200
    40/40 [==============================] - 9s 234ms/step - loss: 19502.1523
    Epoch 102/200
    40/40 [==============================] - 9s 234ms/step - loss: 19374.1582
    Epoch 103/200
    40/40 [==============================] - 9s 234ms/step - loss: 19247.5918
    Epoch 104/200
    40/40 [==============================] - 9s 234ms/step - loss: 19122.4473
    Epoch 105/200
    40/40 [==============================] - 9s 234ms/step - loss: 18998.7988
    Epoch 106/200
    40/40 [==============================] - 9s 234ms/step - loss: 18876.5859
    Epoch 107/200
    40/40 [==============================] - 9s 234ms/step - loss: 18755.7129
    Epoch 108/200
    40/40 [==============================] - 9s 234ms/step - loss: 18636.3965
    Epoch 109/200
    40/40 [==============================] - 9s 236ms/step - loss: 18518.4102
    Epoch 110/200
    40/40 [==============================] - 9s 236ms/step - loss: 18401.7832
    Epoch 111/200
    40/40 [==============================] - 9s 234ms/step - loss: 18286.5488
    Epoch 112/200
    40/40 [==============================] - 9s 234ms/step - loss: 18172.6641
    Epoch 113/200
    40/40 [==============================] - 9s 234ms/step - loss: 18060.1387
    Epoch 114/200
    40/40 [==============================] - 9s 234ms/step - loss: 17948.9551
    Epoch 115/200
    40/40 [==============================] - 9s 234ms/step - loss: 17839.0566
    Epoch 116/200
    40/40 [==============================] - 9s 234ms/step - loss: 17730.4883
    Epoch 117/200
    40/40 [==============================] - 9s 234ms/step - loss: 17623.2070
    Epoch 118/200
    40/40 [==============================] - 9s 234ms/step - loss: 17517.1914
    Epoch 119/200
    40/40 [==============================] - 9s 234ms/step - loss: 17412.4375
    Epoch 120/200
    40/40 [==============================] - 9s 235ms/step - loss: 17308.9844
    Epoch 121/200
    40/40 [==============================] - 9s 237ms/step - loss: 17206.6895
    Epoch 122/200
    40/40 [==============================] - 9s 235ms/step - loss: 17105.5879
    Epoch 123/200
    40/40 [==============================] - 9s 234ms/step - loss: 17005.6973
    Epoch 124/200
    40/40 [==============================] - 9s 235ms/step - loss: 16906.9902
    Epoch 125/200
    40/40 [==============================] - 9s 234ms/step - loss: 16809.4336
    Epoch 126/200
    40/40 [==============================] - 9s 236ms/step - loss: 16713.0078
    Epoch 127/200
    40/40 [==============================] - 10s 239ms/step - loss: 16617.7383
    Epoch 128/200
    40/40 [==============================] - 9s 237ms/step - loss: 16523.5508
    Epoch 129/200
    40/40 [==============================] - 9s 236ms/step - loss: 16430.4766
    Epoch 130/200
    40/40 [==============================] - 9s 237ms/step - loss: 16338.4980
    Epoch 131/200
    40/40 [==============================] - 9s 235ms/step - loss: 16247.6152
    Epoch 132/200
    40/40 [==============================] - 9s 236ms/step - loss: 16157.6943
    Epoch 133/200
    40/40 [==============================] - 9s 235ms/step - loss: 16068.8613
    Epoch 134/200
    40/40 [==============================] - 9s 236ms/step - loss: 15981.0537
    Epoch 135/200
    40/40 [==============================] - 9s 236ms/step - loss: 15894.2197
    Epoch 136/200
    40/40 [==============================] - 9s 235ms/step - loss: 15808.4102
    Epoch 137/200
    40/40 [==============================] - 9s 234ms/step - loss: 15723.6289
    Epoch 138/200
    40/40 [==============================] - 9s 235ms/step - loss: 15639.7656
    Epoch 139/200
    40/40 [==============================] - 9s 235ms/step - loss: 15556.8086
    Epoch 140/200
    40/40 [==============================] - 9s 235ms/step - loss: 15474.7988
    Epoch 141/200
    40/40 [==============================] - 9s 235ms/step - loss: 15393.7002
    Epoch 142/200
    40/40 [==============================] - 9s 236ms/step - loss: 15313.5596
    Epoch 143/200
    40/40 [==============================] - 9s 234ms/step - loss: 15234.2344
    Epoch 144/200
    40/40 [==============================] - 9s 235ms/step - loss: 15155.8135
    Epoch 145/200
    40/40 [==============================] - 9s 235ms/step - loss: 15078.2734
    Epoch 146/200
    40/40 [==============================] - 9s 235ms/step - loss: 15001.5713
    Epoch 147/200
    40/40 [==============================] - 9s 235ms/step - loss: 14925.6904
    Epoch 148/200
    40/40 [==============================] - 9s 236ms/step - loss: 14850.6758
    Epoch 149/200
    40/40 [==============================] - 9s 234ms/step - loss: 14776.5000
    Epoch 150/200
    40/40 [==============================] - 9s 235ms/step - loss: 14703.1162
    Epoch 151/200
    40/40 [==============================] - 9s 235ms/step - loss: 14630.4854
    Epoch 152/200
    40/40 [==============================] - 9s 235ms/step - loss: 14558.6670
    Epoch 153/200
    40/40 [==============================] - 9s 235ms/step - loss: 14487.5830
    Epoch 154/200
    40/40 [==============================] - 9s 234ms/step - loss: 14417.2432
    Epoch 155/200
    40/40 [==============================] - 9s 235ms/step - loss: 14347.6885
    Epoch 156/200
    40/40 [==============================] - 9s 236ms/step - loss: 14278.8184
    Epoch 157/200
    40/40 [==============================] - 9s 235ms/step - loss: 14210.7178
    Epoch 158/200
    40/40 [==============================] - 9s 234ms/step - loss: 14143.3330
    Epoch 159/200
    40/40 [==============================] - 9s 234ms/step - loss: 14076.6445
    Epoch 160/200
    40/40 [==============================] - 9s 234ms/step - loss: 14010.5996
    Epoch 161/200
    40/40 [==============================] - 9s 235ms/step - loss: 13945.2568
    Epoch 162/200
    40/40 [==============================] - 9s 236ms/step - loss: 13880.5986
    Epoch 163/200
    40/40 [==============================] - 9s 236ms/step - loss: 13816.6064
    Epoch 164/200
    40/40 [==============================] - 9s 234ms/step - loss: 13753.2754
    Epoch 165/200
    40/40 [==============================] - 9s 235ms/step - loss: 13690.5752
    Epoch 166/200
    40/40 [==============================] - 9s 234ms/step - loss: 13628.5088
    Epoch 167/200
    40/40 [==============================] - 9s 235ms/step - loss: 13567.0547
    Epoch 168/200
    40/40 [==============================] - 9s 234ms/step - loss: 13506.2412
    Epoch 169/200
    40/40 [==============================] - 9s 235ms/step - loss: 13446.0605
    Epoch 170/200
    40/40 [==============================] - 9s 233ms/step - loss: 13386.4375
    Epoch 171/200
    40/40 [==============================] - 9s 234ms/step - loss: 13327.3701
    Epoch 172/200
    40/40 [==============================] - 9s 233ms/step - loss: 13268.9375
    Epoch 173/200
    40/40 [==============================] - 9s 233ms/step - loss: 13211.0713
    Epoch 174/200
    40/40 [==============================] - 9s 234ms/step - loss: 13153.7578
    Epoch 175/200
    40/40 [==============================] - 9s 234ms/step - loss: 13096.9854
    Epoch 176/200
    40/40 [==============================] - 9s 234ms/step - loss: 13040.7803
    Epoch 177/200
    40/40 [==============================] - 9s 235ms/step - loss: 12985.1426
    Epoch 178/200
    40/40 [==============================] - 9s 233ms/step - loss: 12930.0234
    Epoch 179/200
    40/40 [==============================] - 9s 233ms/step - loss: 12875.4482
    Epoch 180/200
    40/40 [==============================] - 9s 233ms/step - loss: 12821.3965
    Epoch 181/200
    40/40 [==============================] - 9s 234ms/step - loss: 12767.8574
    Epoch 182/200
    40/40 [==============================] - 9s 234ms/step - loss: 12714.8047
    Epoch 183/200
    40/40 [==============================] - 9s 234ms/step - loss: 12662.2578
    Epoch 184/200
    40/40 [==============================] - 9s 233ms/step - loss: 12610.2285
    Epoch 185/200
    40/40 [==============================] - 9s 234ms/step - loss: 12558.6826
    Epoch 186/200
    40/40 [==============================] - 9s 232ms/step - loss: 12507.6172
    Epoch 187/200
    40/40 [==============================] - 9s 234ms/step - loss: 12456.9961
    Epoch 188/200
    40/40 [==============================] - 9s 233ms/step - loss: 12406.8438
    Epoch 189/200
    40/40 [==============================] - 9s 234ms/step - loss: 12357.1992
    Epoch 190/200
    40/40 [==============================] - 9s 233ms/step - loss: 12307.9795
    Epoch 191/200
    40/40 [==============================] - 9s 234ms/step - loss: 12259.2031
    Epoch 192/200
    40/40 [==============================] - 9s 233ms/step - loss: 12210.8799
    Epoch 193/200
    40/40 [==============================] - 9s 231ms/step - loss: 12162.9922
    Epoch 194/200
    40/40 [==============================] - 9s 232ms/step - loss: 12115.5459
    Epoch 195/200
    40/40 [==============================] - 9s 234ms/step - loss: 12068.5449
    Epoch 196/200
    40/40 [==============================] - 9s 233ms/step - loss: 12021.9502
    Epoch 197/200
    40/40 [==============================] - 9s 233ms/step - loss: 11975.7871
    Epoch 198/200
    40/40 [==============================] - 9s 231ms/step - loss: 11929.9854
    Epoch 199/200
    40/40 [==============================] - 9s 231ms/step - loss: 11884.6123
    Epoch 200/200
    40/40 [==============================] - 9s 231ms/step - loss: 11839.6475





    <tensorflow.python.keras.callbacks.History at 0x302a382e0>



## Evaluating knowledge embeddings

AmpliGraph follows the _tensorflow.keras_ style APIs, allowing, after compiling the model, to perform the main operations of the model with the **`fit`**, **`predict`**, and **`evaluate`** methods. 

An additional step that we need to take when evaluating KGEs is defining the filter that will be used to ensure that no negative statements generated by the corruption procedure are actually positives. This is simply done by concatenating our train and test sets and excluding these triples from the corruptions (this was unnecessary at training time because training triples are automatically filtered out for the fit method).


```python
filter = {'test': np.concatenate([X_train, X_test])}
```

Now it is time to evaluate our model on the test set to see how well it's performing. 

For this we are going to use the `evaluate` method, which takes as arguments:

- **`X_test`** : the data to evaluate on. We're going to use our test set to evaluate.
- **`use_filter`** : whether to filter out the false negatives generated by the corruption strategy. If a dictionary is passed, the values of it are used as elements to filter.
- **`corrupt_side`** : specifies whether to corrupt subj and obj separately or to corrupt both during evaluation.


```python
ranks = model.evaluate(X_test,
                       use_filter=filter,
                       corrupt_side='s,o',
                       verbose=True)
```

    2023-02-09 16:37:56.450736: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
    2023-02-09 16:37:57.445608: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


    314/314 [==============================] - 104s 331ms/step


We're going to use the *mrr_score* (mean reciprocal rank) and *hits_at_n_score* functions to evaluate the quality of our predictions.

- **mrr_score**: The function computes the mean of the reciprocal of elements of a vector of rankings ranks.
- **hits_at_n_score**: The function computes how many elements of a vector of rankings ranks make it to the top n positions.


```python
from ampligraph.evaluation import mr_score, mrr_score, hits_at_n_score

mr = mr_score(ranks)
mrr = mrr_score(ranks)

print("MRR: %.2f" % (mrr))
print("MR: %.2f" % (mr))

hits_10 = hits_at_n_score(ranks, n=10)
print("Hits@10: %.2f" % (hits_10))
hits_3 = hits_at_n_score(ranks, n=3)
print("Hits@3: %.2f" % (hits_3))
hits_1 = hits_at_n_score(ranks, n=1)
print("Hits@1: %.2f" % (hits_1))
```

    MRR: 0.30
    MR: 2951.93
    Hits@10: 0.43
    Hits@3: 0.35
    Hits@1: 0.23


We can interpret these results by stating that the model will rank the correct entity within the top-3 possibilities 35% of the time. 

By themselves, these metrics are not enough to conclude the usefulness of the embeddings in a downstream task, but they suggest that the embeddings have learned a reasonable representation enough to consider using them in more tasks.

## Clustering and embedding visualization

To evaluate the subjective quality of the embeddings, we can cluster the embeddings on the original space and/or visualise the embeddings in a 2D space. We can compare the clustered embeddings with natural clusters (in this case the continent where the team is from) so that we have a ground truth to evaluate the clustering quality both qualitatively and quantitatively.

Requirements:

* seaborn
* adjustText
* incf.countryutils

For seaborn and adjustText, simply install them with `pip install seaborn adjustText`.

For incf.countryutils, do the following steps:
```bash
git clone https://github.com/wyldebeast-wunderliebe/incf.countryutils.git
cd incf.countryutils
pip install .
```


```python
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
from adjustText import adjust_text
from incf.countryutils import transformations
%matplotlib inline
```

We create a map from the team ID (e.g. "TeamBrazil") to the team name (e.g. "Brazil") for visualization purposes.


```python
id_to_name_map = {**dict(zip(df.home_team_id, df.home_team)), **dict(zip(df.away_team_id, df.away_team))}
```

We now create a dictionary with the embeddings of all teams:


```python
teams = pd.concat((df.home_team_id[df["train"]], df.away_team_id[df["train"]])).unique()
team_embeddings = dict(zip(teams, model.get_embeddings(teams)))
```

We use PCA to project the embeddings from the 200 space into 2D space:


```python
embeddings_2d = PCA(n_components=2).fit_transform(np.array([i for i in team_embeddings.values()]))
```

We will cluster the teams embeddings on its original 200-dimensional space using the `find_clusters` in our discovery API:


```python
from ampligraph.discovery import find_clusters
from sklearn.cluster import KMeans

clustering_algorithm = KMeans(n_clusters=6, n_init=50, max_iter=500, random_state=0)
clusters = find_clusters(teams, model, clustering_algorithm, mode='e')
```

This helper function uses the `incf.countryutils` library to translate country names to their corresponding continents.


```python
def cn_to_ctn(country):
    try:
        return transformations.cn_to_ctn(id_to_name_map[country])
    except KeyError:
        return "unk"
```

This dataframe contains for each team their projected embeddings to 2D space via PCA, their continent and the KMeans cluster. This will be used alongisde Seaborn to make the visualizations. 


```python
plot_df = pd.DataFrame({"teams": teams, 
                        "embedding1": embeddings_2d[:, 0], 
                        "embedding2": embeddings_2d[:, 1],
                        "continent": pd.Series(teams).apply(cn_to_ctn),
                        "cluster": "cluster" + pd.Series(clusters).astype(str)})
```

We plot the results on a 2D scatter plot, coloring the teams by the continent or cluster and also displaying some individual team names. 

We always display the names of the top 20 teams (according to [FIFA rankings](https://en.wikipedia.org/wiki/FIFA_World_Rankings)) and a random subset of the rest.


```python
top20teams = ["TeamBelgium", "TeamFrance", "TeamBrazil", "TeamEngland", "TeamPortugal", "TeamCroatia", "TeamSpain", 
              "TeamUruguay", "TeamSwitzerland", "TeamDenmark", "TeamArgentina", "TeamGermany", "TeamColombia",
              "TeamItaly", "TeamNetherlands", "TeamChile", "TeamSweden", "TeamMexico", "TeamPoland", "TeamIran"]

def plot_clusters(hue):
    np.random.seed(0)
    plt.figure(figsize=(12, 12))
    plt.title("{} embeddings".format(hue).capitalize())
    ax = sns.scatterplot(data=plot_df[plot_df.continent!="unk"], x="embedding1", y="embedding2", hue=hue)
    texts = []
    for i, point in plot_df.iterrows():
        if point["teams"] in top20teams or np.random.random() < 0.1:
            texts.append(plt.text(point['embedding1']+0.02, point['embedding2']+0.01, str(point["teams"])))
    adjust_text(texts)
```

The first visualisation of the 2D embeddings shows the natural geographical clusters (continents), which can be seen as a form of the ground truth:


```python
plot_clusters("continent")
```


    
![png](ClusteringAndClassificationWithEmbeddings_files/ClusteringAndClassificationWithEmbeddings_52_0.png)
    


We can see above that the embeddings learned geographical similarities even though this information was not explicit on the original dataset.

Now we plot the same 2D embeddings but with the clusters found by K-Means:


```python
plot_clusters("cluster")
```


    
![png](ClusteringAndClassificationWithEmbeddings_files/ClusteringAndClassificationWithEmbeddings_54_0.png)
    


We can see that K-Means found very similar cluster to the natural geographical clusters by the continents. This shows that on the 200-dimensional embedding space, similar teams appear close together, which can be captured by a clustering algorithm.

Our evaluation of the clusters can be more objective by using a metric such as the [adjusted Rand score](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html), which varies from -1 to 1, where 0 is random labelling and 1 is a perfect match:


```python
from sklearn import metrics
metrics.adjusted_rand_score(plot_df.continent, plot_df.cluster)
```




    0.43941491221434137



## Classification

We will now use the knowledge embeddings we obtained through our model as the inputs for a machine learning model to predict future matches. We will frame it as a classification problem, in which we have three classes: home team wins, home team loses, draw.

The embeddings are used directly as features to a XGBoost classifier.

Let's first determine the target:


```python
df["results"] = (df.home_score > df.away_score).astype(int) + \
                (df.home_score == df.away_score).astype(int)*2 + \
                (df.home_score < df.away_score).astype(int)*3 - 1
```


```python
df.results.value_counts(normalize=True)
```




    0    0.486473
    2    0.282456
    1    0.231071
    Name: results, dtype: float64



Now we create a function that extracts the features (knowledge embeddings for home and away teams) and the target for a particular subset of the dataset:


```python
def get_features_target(mask):
    
    def get_embeddings(team):
        return team_embeddings.get(team, np.full(200, np.nan))
    
    X = np.hstack((np.vstack(df[mask].home_team_id.apply(get_embeddings).values),
                   np.vstack(df[mask].away_team_id.apply(get_embeddings).values)))
    y = df[mask].results.values
    return X, y
```


```python
clf_X_train, y_train = get_features_target((df["train"]))
clf_X_test, y_test = get_features_target((~df["train"]))
```


```python
clf_X_train.shape, clf_X_test.shape
```




    ((35714, 400), (5057, 400))



Note that we have 200 features by team because the ComplEx model uses imaginary and real number for its embeddings, so we have twice as many parameters as defined by `k=100` in its model definition.

We also have some missing information from the embeddings of the entities (i.e. teams) that only appear in the test set, which are unlikely to be correctly classified:


```python
np.isnan(clf_X_test).sum()/clf_X_test.shape[1]
```




    105.0



First install xgboost with `pip install xgboost`.


```python
from xgboost import XGBClassifier
```

Create a multiclass model with 500 estimators:


```python
clf_model = XGBClassifier(n_estimators=500, max_depth=5, objective="multi:softmax")
```

Fit the model using all of the training samples:


```python
clf_model.fit(clf_X_train, y_train)
```




<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: "▸";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: "▾";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: "";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: "";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: "";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id="sk-container-id-1" class="sk-top-container"><div class="sk-text-repr-fallback"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,
              colsample_bylevel=None, colsample_bynode=None,
              colsample_bytree=None, early_stopping_rounds=None,
              enable_categorical=False, eval_metric=None, feature_types=None,
              gamma=None, gpu_id=None, grow_policy=None, importance_type=None,
              interaction_constraints=None, learning_rate=None, max_bin=None,
              max_cat_threshold=None, max_cat_to_onehot=None,
              max_delta_step=None, max_depth=5, max_leaves=None,
              min_child_weight=None, missing=nan, monotone_constraints=None,
              n_estimators=500, n_jobs=None, num_parallel_tree=None,
              objective=&#x27;multi:softmax&#x27;, predictor=None, ...)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-1" type="checkbox" checked><label for="sk-estimator-id-1" class="sk-toggleable__label sk-toggleable__label-arrow">XGBClassifier</label><div class="sk-toggleable__content"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,
              colsample_bylevel=None, colsample_bynode=None,
              colsample_bytree=None, early_stopping_rounds=None,
              enable_categorical=False, eval_metric=None, feature_types=None,
              gamma=None, gpu_id=None, grow_policy=None, importance_type=None,
              interaction_constraints=None, learning_rate=None, max_bin=None,
              max_cat_threshold=None, max_cat_to_onehot=None,
              max_delta_step=None, max_depth=5, max_leaves=None,
              min_child_weight=None, missing=nan, monotone_constraints=None,
              n_estimators=500, n_jobs=None, num_parallel_tree=None,
              objective=&#x27;multi:softmax&#x27;, predictor=None, ...)</pre></div></div></div></div></div>



The baseline accuracy for this problem is 47%, as that is the frequency of the most frequent class (home team wins):


```python
df[~df["train"]].results.value_counts(normalize=True)
```




    0    0.471030
    2    0.287325
    1    0.241645
    Name: results, dtype: float64




```python
metrics.accuracy_score(y_test, clf_model.predict(clf_X_test))
```




    0.5378683013644453



In conclusion, while the baseline for this classification problem was 47%, with just the knowledge embeddings alone we were able to build a classifier that achieves **54%** accuracy.

As future work, we could add more features to the model (not embeddings related) and tune the model hyper-parameters.
