# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#

import logging

import networkx as nx
import numpy as np
from scipy import optimize, spatial
from sklearn.cluster import DBSCAN
from sklearn.neighbors import NearestNeighbors

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def discover_facts(
    X,
    model,
    top_n=10,
    strategy="random_uniform",
    max_candidates=100,
    target_rel=None,
    seed=0,
):
    """
    Discover new facts from an existing knowledge graph.

    You should use this function when you already have a model trained on a knowledge graph and you want to
    discover potentially true statements in that knowledge graph.

    The general procedure of this function is to generate a set of candidate statements :math:`C` according to some
    sampling strategy ``strategy``, then rank them against a set of corruptions using the
    :meth:`ampligraph.evaluation.evaluate_performance` function.
    Candidates that appear in the ``top_n`` ranked statements of this procedure are returned as likely true
    statements.

    The majority of the strategies are implemented with the same underlying principle of searching for
    candidate statements:

    - from among the less frequent entities (`'entity_frequency'`),
    - less connected entities (`'graph_degree'`, `'cluster_coefficient'`),
    - | less frequent local graph structures (`'cluster_triangles'`, `'cluster_squares'`), on the assumption that
        densely connected entities are less likely to have missing true statements.
    - | The remaining strategies (`'random_uniform'`, `'exhaustive'`) generate candidate statements by a random
        sampling of entities and relations or exhaustively, respectively.

    .. warning::
        Due to the significant amount of computation required to evaluate all triples using the 'exhaustive' strategy,
        we do not recommend its use at this time.

    The function will automatically filter entities that have not been seen by the model, and operates on
    the assumption that the model provided has been fit on the data ``X`` (determined heuristically), although ``X``
    may be a subset of the original data, in which case a warning is shown.

    The ``target_rel`` argument indicates what relation to generate candidate statements for. If this is set to ``None``
    then all target relations will be considered for sampling.

    Parameters
    ----------

    X : ndarray of shape (n, 3)
        The input knowledge graph used to train ``model``, or a subset of it.
    model : EmbeddingModel
        The trained model that will be used to score candidate facts.
    top_n : int
        The cutoff position in ranking to consider a candidate triple as true positive.
    strategy: str
        The candidates generation strategy:

        - `'random_uniform'` : generates `N` candidates (:math:`N <= max_candidates`) based on a uniform sampling of
            entities.
        - `'entity_frequency'` : generates candidates by weighted sampling of entities using entity frequency.
        - `'graph_degree'` : generates candidates by weighted sampling of entities with graph degree.
        - `'cluster_coefficient'` : generates candidates by weighted sampling entities with clustering coefficient.
        - `'cluster_triangles'` : generates candidates by weighted sampling entities with cluster triangles.
        - `'cluster_squares'` : generates candidates by weighted sampling entities with cluster squares.

    max_candidates: int or float
        The maximum numbers of candidates generated by ``strategy``.
        Can be an absolute number or a percentage [0,1] of the size of the `X` parameter.
    target_rel : str or list(str)
        Target relations to focus on. The function will discover facts only for that specific relation types.
        If `None`, the function attempts to discover new facts for all relation types in the graph.
    seed : int
        Seed to use for reproducible results.


    Returns
    -------
    X_pred : ndarray, shape (n, 3)
        A list of new facts predicted to be true.

    Example
    -------
    >>> import requests
    >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
    >>> from ampligraph.datasets import load_from_csv
    >>> from ampligraph.discovery import discover_facts
    >>> # Game of Thrones relations dataset
    >>> url = 'https://ampligraph.s3-eu-west-1.amazonaws.com/datasets/GoT.csv'
    >>> open('GoT.csv', 'wb').write(requests.get(url).content)
    >>> X = load_from_csv('.', 'GoT.csv', sep=',')
    >>> model = ScoringBasedEmbeddingModel(eta=5,
    >>>                                      k=300,
    >>>                                      scoring_type='ComplEx')
    >>> model.compile(optimizer='adam', loss='multiclass_nll')
    >>> model.fit(X,
    >>>              batch_size=100,
    >>>              epochs=10,
    >>>              validation_freq=50,
    >>>              validation_batch_size=100,
    >>>              validation_data = dataset['valid'])
    >>> discover_facts(X,
    >>>                model,
    >>>                top_n=100,
    >>>                strategy='random_uniform',
    >>>                max_candidates=100,
    >>>                target_rel='ALLIED_WITH',
    >>>                seed=0)
    Epoch 1/10
    33/33 [==============================] - 1s 27ms/step - loss: 177.7778
    Epoch 2/10
    33/33 [==============================] - 0s 6ms/step - loss: 177.4795
    Epoch 3/10
    33/33 [==============================] - 0s 6ms/step - loss: 176.9654
    Epoch 4/10
    33/33 [==============================] - 0s 6ms/step - loss: 175.8453
    Epoch 5/10
    33/33 [==============================] - 0s 6ms/step - loss: 173.4385
    Epoch 6/10
    33/33 [==============================] - 0s 6ms/step - loss: 168.8143
    Epoch 7/10
    33/33 [==============================] - 0s 6ms/step - loss: 161.2919
    Epoch 8/10
    33/33 [==============================] - 0s 6ms/step - loss: 151.3496
    Epoch 9/10
    33/33 [==============================] - 0s 6ms/step - loss: 140.4268
    Epoch 10/10
    33/33 [==============================] - 0s 5ms/step - loss: 129.8206
    3175 triples containing invalid keys skipped!
    (array([['House Nymeros Martell of Sunspear', 'ALLIED_WITH',
             'House Mallister of Seagard'],
            ['Ben', 'ALLIED_WITH', 'House Mallister of Seagard'],
            ['Selwyn Tarth', 'ALLIED_WITH', 'House Mallister of Seagard'],
            ['Clarence Charlton', 'ALLIED_WITH', 'House Woods'],
            ['Selwyn Tarth', 'ALLIED_WITH', 'House Woods'],
            ['Dacks', 'ALLIED_WITH', 'Titus Peake'],
            ['Barra', 'ALLIED_WITH', 'Titus Peake'],
            ['House Chelsted', 'ALLIED_WITH', 'Denys Darklyn'],
            ['Crow Spike Keep', 'ALLIED_WITH', 'Denys Darklyn'],
            ['Selwyn Tarth', 'ALLIED_WITH', 'Denys Darklyn'],
            ['House Chelsted', 'ALLIED_WITH', 'House Belmore of Strongsong'],
            ['Barra', 'ALLIED_WITH', 'House Belmore of Strongsong'],
            ['Walder Frey', 'ALLIED_WITH', 'House Belmore of Strongsong']],
           dtype=object),
     array([ 2. , 53. , 73. , 42. , 18. , 59.5, 86. , 76.5, 31. , 60.5, 31.5,
            32. , 24. ]))
    """
    if model.is_backward:
        model = model.model

    if not model.is_fitted:
        msg = "Model is not fitted."
        logger.error(msg)
        raise ValueError(msg)

    # if not model.is_fitted_on(X):
    #    msg = 'Model might not be fitted on this data.'
    #    logger.warning(msg)
    # raise ValueError(msg)

    if strategy not in [
        "random_uniform",
        "entity_frequency",
        "graph_degree",
        "cluster_coefficient",
        "cluster_triangles",
        "cluster_squares",
    ]:
        msg = "%s is not a valid strategy." % strategy
        logger.error(msg)
        raise ValueError(msg)

    if strategy == "exhaustive":
        msg = "Strategy is `exhaustive`, ignoring max_candidates."
        logger.info(msg)

    if isinstance(max_candidates, float):
        logger.debug(
            "Converting max_candidates float value {} to int value {}".format(
                max_candidates, int(max_candidates * len(X))
            )
        )
        max_candidates = int(max_candidates * len(X))

    if isinstance(target_rel, str):
        target_rel = [target_rel]

    if target_rel is None:
        msg = "No target relation specified. Using all relations to generate candidate statements."
        logger.info(msg)
        rel_list = [x for x in model.data_indexer.backend.get_all_relations()]
    else:
        missing_rels = []
        for rel in target_rel:
            if rel not in model.data_indexer.backend.get_all_relations():
                missing_rels.append(rel)

        if len(missing_rels) > 0:
            msg = "Target relation(s) not found in model: {}".format(
                missing_rels
            )
            logger.error(msg)
            raise ValueError(msg)

        rel_list = [target_rel]

    # Set random seed
    np.random.seed(seed)

    # Remove unseen entities
    # X_filtered = filter_unseen_entities(X, model)

    discoveries = []
    discovery_ranks = []

    # Iterate through relations
    for relation in rel_list:
        logger.info("Generating candidates for relation: %s" % relation)

        candidates = generate_candidates(
            X, strategy, relation, max_candidates, seed=seed
        )

        logger.debug("Generated %d candidate statements." % len(candidates))

        # Get ranks of candidate statements
        # ranks = evaluate_performance(candidates, model=model, filter_triples=X, use_default_protocol=True,
        #                             verbose=False)

        ranks = model.evaluate(
            candidates,
            use_filter={"test": X},
            corrupt_side="s,o",
            verbose=False,
        )

        # Select candidate statements within the top_n predicted ranks standard protocol evaluates against
        # corruptions on both sides, we just average the ranks here
        avg_ranks = np.mean(ranks, axis=1)

        preds = np.array(avg_ranks) <= top_n
        discoveries.append(candidates[preds])
        discovery_ranks.append(avg_ranks[preds])

    logger.info("Discovered %d facts" % len(discoveries))

    return np.hstack(discoveries), np.hstack(discovery_ranks)


def generate_candidates(
    X, strategy, target_rel, max_candidates, consolidate_sides=False, seed=0
):
    """Generate candidate statements from an existing knowledge graph using a defined strategy.

    Parameters
    ----------
    X: np.array, shape (n, 3)
        Triples from which to discover new facts.
    strategy: str
        The candidates generation strategy.
        - `'random_uniform'` : generates `N` candidates (:math:`N <= max_candidates`) based on a uniform random
            sampling of head and tail entities.
        - `'entity_frequency'` : generates candidates by sampling entities with low frequency.
        - `'graph_degree'` : generates candidates by sampling entities with a low graph degree.
        - `'cluster_coefficient'` : generates candidates by sampling entities with a low clustering coefficient.
        - `'cluster_triangles'` : generates candidates by sampling entities with a low number of cluster triangles.
        - `'cluster_squares'` : generates candidates by sampling entities with a low number of cluster squares.
    max_candidates: int or float
        The maximum numbers of candidates generated by ``strategy``.
        Can be an absolute number or a percentage [0,1].
        This does not guarantee the number of candidates generated.
    target_rel : str
        Target relation to focus on. The function will generate candidate
         statements only with this specific relation type.
    consolidate_sides: bool
        If `True` will generate candidate statements as a product of unique head and tail entities, otherwise will
        consider head and tail entities separately (default: `False`).
    seed : int
        Seed to use for reproducible results.

    Returns
    -------
    X_candidates : ndarray, shape (n, 3)
        A list of candidate statements.


    Example
    -------
    >>> import numpy as np
    >>> from ampligraph.discovery.discovery import generate_candidates
    >>>
    >>> X = np.array([['a', 'y', 'b'],
    >>>               ['b', 'y', 'a'],
    >>>               ['a', 'y', 'c'],
    >>>               ['c', 'y', 'a'],
    >>>               ['a', 'y', 'd'],
    >>>               ['c', 'y', 'd'],
    >>>               ['b', 'y', 'c'],
    >>>               ['f', 'y', 'e']])

    >>> X_candidates = generate_candidates(X, strategy='graph_degree', target_rel='y', max_candidates=3)
    >>> ([['a', 'y', 'e'],
    >>>  ['f', 'y', 'a'],
    >>>  ['c', 'y', 'e']])

    """
    if (
        X.shape[1] > 3
    ):  # exception needed if weights are given in input together with triples
        X = X[:, :3]
    if strategy not in [
        "random_uniform",
        "entity_frequency",
        "graph_degree",
        "cluster_coefficient",
        "cluster_triangles",
        "cluster_squares",
    ]:
        msg = "%s is not a valid candidate generation strategy." % strategy
        raise ValueError(msg)

    if target_rel not in np.unique(X[:, 1]):
        # No error as may be case where target_rel is not in X
        msg = "Target relation is not found in triples."
        logger.warning(msg)

    if not isinstance(max_candidates, (float, int)):
        msg = "Parameter max_candidates must be a float or int."
        raise ValueError(msg)

    if max_candidates <= 0:
        msg = (
            "Parameter max_candidates must be a positive integer "
            "or float in range (0,1]."
        )
        raise ValueError(msg)

    if isinstance(max_candidates, float):
        max_candidates = int(max_candidates * len(X))

    def _filter_candidates(X_candidates, X, remove_reflexive=True):
        """Inner function to filter candidate statements from X_candidates that are in X."""
        X_candidates = _setdiff2d(X_candidates, X)
        # Filter statements that are ['x', rel, 'x']
        if remove_reflexive:
            keep_idx = np.where(X_candidates[:, 0] != X_candidates[:, 2])
            X_candidates = X_candidates[keep_idx]

        return X_candidates

    # Set random seed
    np.random.seed(seed)

    # Get entities linked with this relation
    if consolidate_sides:
        e_s = np.unique(np.concatenate((X[:, 0], X[:, 2])))
        e_o = e_s
    else:
        e_s = np.unique(X[:, 0])
        e_o = np.unique(X[:, 2])

    logger.info("Generating candidates using {} strategy.".format(strategy))

    if strategy == "random_uniform":
        # Take close to sqrt of max_candidates so that: len(meshgrid result) ==
        # max_candidates
        # +10 to allow for reduction in sampled array due to filtering
        sample_size = int(np.sqrt(max_candidates) + 10)

        # Pre-allocate X_candidates array
        X_candidates = np.zeros([max_candidates, 3], dtype=object)
        num_retries, max_retries = (
            0,
            5,
        )  # Retry up to 5 times to reach max_candidates
        start_idx, end_idx = 0, 0  #

        while end_idx <= max_candidates - 1:
            sample_e_s = np.random.choice(e_s, size=sample_size, replace=False)
            sample_e_o = np.random.choice(e_o, size=sample_size, replace=False)

            gen_candidates = np.array(
                np.meshgrid(sample_e_s, target_rel, sample_e_o)
            ).T.reshape(-1, 3)
            gen_candidates = _filter_candidates(gen_candidates, X)

            # Select either all of gen_candidates or just enough to fill
            # X_candidates
            select_idx = min(
                len(gen_candidates), len(X_candidates) - start_idx
            )
            end_idx = start_idx + select_idx

            X_candidates[start_idx:end_idx, :] = gen_candidates[
                0:select_idx, :
            ]
            start_idx = end_idx

            num_retries += 1
            if num_retries == max_retries:
                break

        # end_idx will equal max_candidates in most cases, but could be less
        return X_candidates[0:end_idx, :]

    elif strategy == "entity_frequency":
        # Get entity counts and sort them in ascending order
        if consolidate_sides:
            e_s_counts = np.array(
                np.unique(X[:, [0, 2]], return_counts=True)
            ).T
            e_o_counts = e_s_counts
        else:
            e_s_counts = np.array(np.unique(X[:, 0], return_counts=True)).T
            e_o_counts = np.array(np.unique(X[:, 2], return_counts=True)).T

        e_s_weights = e_s_counts[:, 1].astype(np.float64) / np.sum(
            e_s_counts[:, 1].astype(np.float64)
        )
        e_o_weights = e_o_counts[:, 1].astype(np.float64) / np.sum(
            e_o_counts[:, 1].astype(np.float64)
        )

    elif strategy in [
        "graph_degree",
        "cluster_coefficient",
        "cluster_triangles",
        "cluster_squares",
    ]:
        # Create networkx graph
        G = nx.Graph()
        for row in X:
            G.add_nodes_from([row[0], row[2]])
            G.add_edge(row[0], row[2], name=row[1])

        # Calculate node metrics
        if strategy == "graph_degree":
            C = {i: j for i, j in G.degree()}
        elif strategy == "cluster_coefficient":
            C = nx.algorithms.cluster.clustering(G)
        elif strategy == "cluster_triangles":
            C = nx.algorithms.cluster.triangles(G)
        elif strategy == "cluster_squares":
            C = nx.algorithms.cluster.square_clustering(G)

        e_s_weights = np.array([C[x] for x in e_s], dtype=np.float64)
        e_o_weights = np.array([C[x] for x in e_o], dtype=np.float64)

        e_s_weights = e_s_weights / np.sum(e_s_weights)
        e_o_weights = e_o_weights / np.sum(e_o_weights)

    # Take close to sqrt of max_candidates so that: len(meshgrid result) ==
    # max_candidates
    # +10 to allow for reduction in sampled array due to filtering
    sample_size = int(np.sqrt(max_candidates) + 10)

    # Pre-allocate X_candidates array
    X_candidates = np.zeros([max_candidates, 3], dtype=object)
    num_retries, max_retries = (
        0,
        5,
    )  # Retry up to 5 times to reach max_candidates
    start_idx, end_idx = 0, 0

    while end_idx <= max_candidates - 1:
        sample_e_s = np.random.choice(
            e_s, size=sample_size, replace=True, p=e_s_weights
        )
        sample_e_o = np.random.choice(
            e_o, size=sample_size, replace=True, p=e_o_weights
        )

        gen_candidates = np.array(
            np.meshgrid(sample_e_s, target_rel, sample_e_o)
        ).T.reshape(-1, 3)
        gen_candidates = _filter_candidates(gen_candidates, X)

        # Select either all of gen_candidates or just enough to fill
        # X_candidates
        select_idx = min(len(gen_candidates), len(X_candidates) - start_idx)
        end_idx = start_idx + select_idx

        X_candidates[start_idx:end_idx, :] = gen_candidates[0:select_idx, :]
        start_idx = end_idx

        num_retries += 1

        if num_retries == max_retries:
            break

    # end_idx will be max_candidates in most cases, but could be less
    return X_candidates[0:end_idx, :]


def _setdiff2d(A, B):
    """Utility function equivalent to numpy.setdiff1d on 2d arrays.

    Parameters
    ----------

    A : ndarray, shape [n, m]

    B : ndarray, shape [n, m]

    Returns
    -------
    subset_A : np.array, shape [k, m]
        Rows of A that are not in B.

    """

    if len(A.shape) != 2 or len(B.shape) != 2:
        raise RuntimeError("Input arrays must be 2-dimensional.")

    tmp = np.prod(np.swapaxes(A[:, :, None], 1, 2) == B, axis=2)
    return A[~np.sum(np.cumsum(tmp, axis=0) * tmp == 1, axis=1).astype(bool)]


def find_clusters(X, model, clustering_algorithm=DBSCAN(), mode="e"):
    """
    Perform link-based cluster analysis on a knowledge graph.

    The clustering happens on the embedding space of the entities and relations.
    For example, if we cluster some entities of a model that uses :math:`k=100` (i.e. embedding space of size 100),
    we will apply the chosen clustering algorithm on the 100-dimensional space of the provided input samples.

    Clustering can be used to evaluate the quality of the knowledge embeddings, by comparing to natural clusters.
    For example, in the example below we cluster the embeddings of international football matches and end up
    finding geographical clusters very similar to the continents.
    This comparison can be subjective by inspecting a 2D projection of the embedding space or objective using a
    `clustering metric <https://scikit-learn.org/stable/modules/clustering.html#clustering-performance-evaluation>`_.

    | The choice of the clustering algorithm and its corresponding tuning will greatly impact the results.
      Please see `scikit-learn documentation <https://scikit-learn.org/stable/modules/clustering.html#clustering>`_
      for a list of algorithms, their parameters, and pros and cons.

    Clustering is exclusive (i.e., a triple is assigned to one and only one cluster).

    Parameters
    ----------

    X : ndarray, shape (n, 3) or (n)
        The input to be clustered.
        ``X`` can either be the triples of a knowledge graph, its entities, or its relations.
        The argument ``mode`` defines whether ``X`` is supposed to be an array of triples
        or an array of either entities or relations.
    model : EmbeddingModel
        The fitted model that will be used to generate the embeddings.
        This model must have been fully trained already, be it directly with
        ``fit()`` or from a helper function such as :meth:`ampligraph.evaluation.select_best_model_ranking`.
    clustering_algorithm : object
        The initialized object of the clustering algorithm.
        It should be ready to apply the :meth:`fit_predict` method.
        Please see: `scikit-learn documentation <https://scikit-learn.org/stable/modules/clustering.html#clustering>`_
        to understand the clustering API provided by scikit-learn.
        The default clustering model is
        `sklearn's DBSCAN <https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html>`_
        with its default parameters.
    mode: str
        Clustering mode.

        Choose from:

        - | `'e'` (default): the algorithm will cluster the embeddings of the provided entities.
        - | `'r'`: the algorithm will cluster the embeddings of the provided relations.
        - | `'t'` : the algorithm will cluster the concatenation
            of the embeddings of the subject, predicate and object for each triple.

    Returns
    -------
    labels : ndarray, shape [n]
        Index of the cluster each triple belongs to.

    Example
    -------
    >>> # Note seaborn, matplotlib, adjustText are not AmpliGraph dependencies.
    >>> # and must therefore be installed manually as:
    >>> #
    >>> # $ pip install seaborn matplotlib adjustText
    >>>
    >>> import requests
    >>> import pandas as pd
    >>> import numpy as np
    >>> from sklearn.decomposition import PCA
    >>> from sklearn.cluster import KMeans
    >>> import matplotlib.pyplot as plt
    >>> import seaborn as sns
    >>>
    >>> # adjustText lib: https://github.com/Phlya/adjustText
    >>> from adjustText import adjust_text
    >>>
    >>> from ampligraph.datasets import load_from_csv
    >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
    >>> from ampligraph.discovery import find_clusters
    >>>
    >>> # International football matches triples
    >>> # See tutorial here to understand how the triples are created from a tabular dataset:
    >>> # https://github.com/Accenture/AmpliGraph/blob/master/docs/tutorials/\
ClusteringAndClassificationWithEmbeddings.ipynb
    >>> url = 'https://ampligraph.s3-eu-west-1.amazonaws.com/datasets/football.csv'
    >>> open('football.csv', 'wb').write(requests.get(url).content)
    >>> X = load_from_csv('.', 'football.csv', sep=',')[:, 1:]
    >>>
    >>> model = ScoringBasedEmbeddingModel(eta=5,
    >>>                                  k=300,
    >>>                                  scoring_type='ComplEx')
    >>> model.compile(optimizer='adam', loss='multiclass_nll')
    >>> model.fit(X,
    >>>           batch_size=10000,
    >>>           epochs=10)
    >>>
    >>> df = pd.DataFrame(X, columns=["s", "p", "o"])
    >>>
    >>> teams = np.unique(np.concatenate((df.s[df.s.str.startswith("Team")],
    >>>                                   df.o[df.o.str.startswith("Team")])))
    >>> team_embeddings = model.get_embeddings(teams, embedding_type='e')
    >>>
    >>> embeddings_2d = PCA(n_components=2).fit_transform(np.array([i for i in team_embeddings]))
    >>>
    >>> # Find clusters of embeddings using KMeans
    >>>
    >>> kmeans = KMeans(n_clusters=6, n_init=100, max_iter=500)
    >>> clusters = find_clusters(teams, model, kmeans, mode='e')
    >>>
    >>> # Plot results
    >>> df = pd.DataFrame({"teams": teams, "clusters": "cluster" + pd.Series(clusters).astype(str),
    >>>                    "embedding1": embeddings_2d[:, 0], "embedding2": embeddings_2d[:, 1]})
    >>>
    >>> plt.figure(figsize=(10, 10))
    >>> plt.title("Cluster embeddings")
    >>>
    >>> ax = sns.scatterplot(data=df, x="embedding1", y="embedding2", hue="clusters")
    >>>
    >>> texts = []
    >>> for i, point in df.iterrows():
    >>>     if np.random.uniform() < 0.1:
    >>>         texts.append(plt.text(point['embedding1']+.02, point['embedding2'], str(point['teams'])))
    >>> adjust_text(texts)

    .. image:: ../img/clustering/clustered_embeddings_docstring.png
        :align: center

    """
    if model.is_backward:
        model = model.model

    if not model.is_fitted:
        msg = "Model has not been fitted."
        logger.error(msg)
        raise ValueError(msg)

    if not hasattr(clustering_algorithm, "fit_predict"):
        msg = "Clustering algorithm does not have the `fit_predict` method."
        logger.error(msg)
        raise ValueError(msg)

    modes = ("t", "e", "r")
    if mode not in modes:
        msg = "Argument `mode` must be one of the following: {}.".format(
            ", ".join(modes)
        )
        logger.error(msg)
        raise ValueError(msg)

    if mode == "t" and (len(X.shape) != 2 or X.shape[1] != 3):
        msg = "For 't' mode the input X must be a matrix with three columns."
        logger.error(msg)
        raise ValueError(msg)

    if mode in ("e", "r") and len(X.shape) != 1:
        msg = "For 'e' or 'r' mode the input X must be an array."
        raise ValueError(msg)

    if mode == "t":
        s = model.get_embeddings(X[:, 0], embedding_type="e")
        p = model.get_embeddings(X[:, 1], embedding_type="r")
        o = model.get_embeddings(X[:, 2], embedding_type="e")
        emb = np.hstack((s, p, o))
    else:
        emb = model.get_embeddings(X, embedding_type=mode)

    return clustering_algorithm.fit_predict(emb)


def find_duplicates(
    X,
    model,
    mode="e",
    metric="l2",
    tolerance="auto",
    expected_fraction_duplicates=0.1,
    verbose=False,
):
    r"""
    Find duplicate entities, relations or triples in a graph based on their embeddings.

    For example, say you have a movie dataset that was scraped off the web with possible duplicate movies.
    The movies in this case are the entities.
    Therefore, you would use the `"e"` mode to find all the movies that could de duplicates of each other.

    Duplicates are defined as points whose distance in the embedding space are smaller than
    some given threshold (called the tolerance).

    The tolerance can be defined a priori or be found via an optimisation procedure given
    an expected fraction of duplicates. The optimisation algorithm applies a root-finding routine
    to find the tolerance that gets to the closest expected fraction. The routine always converges.

    Distance is defined by the chosen metric, which by default is the Euclidean distance (L2 norm).

    As the distances are calculated on the embedding space,
    the embeddings must be meaningful for this routine to work properly.
    Therefore, it is suggested to evaluate the embeddings first using a metric such as MRR
    before considering applying this method.

    Parameters
    ----------

    X : ndarray, shape (n, 3) or (n)
        The input to be clustered.
        `X` can either be the triples of a knowledge graph, its entities, or its relations.
        The argument ``mode`` defines whether X is supposed to be an array of triples
        or an array of either entities or relations.
    model : EmbeddingModel
        The fitted model that will be used to generate the embeddings.
        This model must have been fully trained already, be it directly with ``fit()``
        or from a helper function such as :meth:`ampligraph.evaluation.select_best_model_ranking`.
    mode: str
        Specifies among which type of entities to look for duplicates.

        Choose from:

        - | `'e'` (default): the algorithm will find duplicates of the provided entities based on their embeddings.
        - | `'r'`: the algorithm will find duplicates of the provided relations based on their embeddings.
        - | `'t'` : the algorithm will find duplicates of the concatenation
            of the embeddings of the subject, predicate and object for each provided triple.

    metric: str
        A distance metric used to compare entity distance in the embedding space.
        `See options here <https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NearestNeighbors.html>`_.
    tolerance: int or str
        Minimum distance (depending on the chosen ``metric``) to define one entity as the duplicate of another.
        If `'auto'`, it will be determined automatically in a way that you get the ``expected_fraction_duplicates``.
        The `'auto'` option can be much slower than the regular one, as the finding duplicate internal procedure
        will be repeated multiple times.
    expected_fraction_duplicates: float
        Expected fraction of duplicates to be found. It is used only when ``tolerance='auto'``.
        Should be between 0 and 1 (default: 0.1).
    verbose: bool
        Whether to print evaluation messages during optimisation when ``tolerance='auto'`` (default: `False`).

    Returns
    -------
    duplicates : set of frozensets
        Each entry in the duplicates set is a frozenset containing all entities that were found to be duplicates
        according to the metric and tolerance.
        Each frozenset will contain at least two entities.

    tolerance: float
        Tolerance used to find the duplicates (useful if the automatic tolerance option is selected).

    Example
    -------
    >>> import pandas as pd
    >>> import numpy as np
    >>> import re
    >>> from ampligraph.latent_features.models import ScoringBasedEmbeddingModel
    >>> # The IMDB dataset used here is part of the Movies5 dataset found on:
    >>> # The Magellan Data Repository (https://sites.google.com/site/anhaidgroup/projects/data)
    >>> import requests
    >>> url = 'http://pages.cs.wisc.edu/~anhai/data/784_data/movies5.tar.gz'
    >>> open('movies5.tar.gz', 'wb').write(requests.get(url).content)
    >>> import tarfile
    >>> tar = tarfile.open('movies5.tar.gz', "r:gz")
    >>> tar.extractall()
    >>> tar.close()
    >>>
    >>> # Reading tabular dataset of IMDB movies and filling the missing values
    >>> imdb = pd.read_csv("movies5/csv_files/imdb.csv")
    >>> imdb["directors"] = imdb["directors"].fillna("UnknownDirector")
    >>> imdb["actors"] = imdb["actors"].fillna("UnknownActor")
    >>> imdb["genre"] = imdb["genre"].fillna("UnknownGenre")
    >>> imdb["duration"] = imdb["duration"].fillna("0")
    >>>
    >>> # Creating knowledge graph triples from tabular dataset
    >>> imdb_triples = []
    >>>
    >>> for _, row in imdb.iterrows():
    >>>     movie_id = "ID" + str(row["id"])
    >>>     directors = row["directors"].split(",")
    >>>     actors = row["actors"].split(",")
    >>>     genres = row["genre"].split(",")
    >>>     duration = "Duration" + str(int(re.sub("\D", "", row["duration"])) // 30)
    >>>
    >>>     directors_triples = [(movie_id, "hasDirector", d) for d in directors]
    >>>     actors_triples = [(movie_id, "hasActor", a) for a in actors]
    >>>     genres_triples = [(movie_id, "hasGenre", g) for g in genres]
    >>>     duration_triple = (movie_id, "hasDuration", duration)
    >>>
    >>>
    >>>     imdb_triples.extend(directors_triples)
    >>>     imdb_triples.extend(actors_triples)
    >>>     imdb_triples.extend(genres_triples)
    >>>     imdb_triples.append(duration_triple)
    >>>
    >>> # Training knowledge graph embedding with ComplEx model
    >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
    >>>
    >>> imdb_triples = np.array(imdb_triples)
    >>> model = ScoringBasedEmbeddingModel(eta=5,
    >>>                                    k=300,
    >>>                                    scoring_type='ComplEx')
    >>> model.compile(optimizer='adam', loss='multiclass_nll')
    >>> model.fit(imdb_triples,
    >>>           batch_size=10000,
    >>>           epochs=10)
    >>>
    >>> # Finding duplicates movies (entities)
    >>> from ampligraph.discovery import find_duplicates
    >>>
    >>> entities = np.unique(imdb_triples[:, 0])
    >>> dups, _ = find_duplicates(entities, model, mode='e', tolerance=0.45)
    >>> id_list = []
    >>> for data in dups:
    >>>     for i in data:
    >>>         id_list.append(int(i[2:]))
    >>> print(imdb.iloc[id_list[:6]][['movie_name', 'year']])
    Epoch 1/10
    7/7 [==============================] - 1s 122ms/step - loss: 15612.8799
    Epoch 2/10
    7/7 [==============================] - 0s 20ms/step - loss: 15610.5010
    Epoch 3/10
    7/7 [==============================] - 0s 19ms/step - loss: 15607.7412
    Epoch 4/10
    7/7 [==============================] - 0s 19ms/step - loss: 15604.0674
    Epoch 5/10
    7/7 [==============================] - 0s 20ms/step - loss: 15598.9365
    Epoch 6/10
    7/7 [==============================] - 0s 19ms/step - loss: 15591.7188
    Epoch 7/10
    7/7 [==============================] - 0s 19ms/step - loss: 15581.6055
    Epoch 8/10
    7/7 [==============================] - 0s 20ms/step - loss: 15567.6807
    Epoch 9/10
    7/7 [==============================] - 0s 20ms/step - loss: 15548.8184
    Epoch 10/10
    7/7 [==============================] - 0s 21ms/step - loss: 15523.8721
               movie_name  year
    5198    Duel to Death  1983
    5199    Duel to Death  1983
    2649   The Eliminator  2004
    2650   The Eliminator  2004
    3967  Lipstick Camera  1994
    3968  Lipstick Camera  1994
    """

    if model.is_backward:
        model = model.model

    if not model.is_fitted:
        msg = "Model has not been fitted."
        logger.error(msg)
        raise ValueError(msg)

    modes = ("t", "e", "r")
    if mode not in modes:
        msg = "Argument `mode` must be one of the following: {}.".format(
            ", ".join(modes)
        )
        logger.error(msg)
        raise ValueError(msg)

    if mode == "t" and (len(X.shape) != 2 or X.shape[1] != 3):
        msg = "For 't' mode the input X must be a matrix with three columns."
        logger.error(msg)
        raise ValueError(msg)

    if mode in ("e", "r") and len(X.shape) != 1:
        msg = "For 'e' or 'r' mode the input X must be an array."
        logger.error(msg)
        raise ValueError(msg)

    if mode == "t":
        s = model.get_embeddings(X[:, 0], embedding_type="e")
        p = model.get_embeddings(X[:, 1], embedding_type="r")
        o = model.get_embeddings(X[:, 2], embedding_type="e")
        emb = np.hstack((s, p, o))
    else:
        emb = model.get_embeddings(X, embedding_type=mode)

    def get_dups(tol):
        """
        Given tolerance, finds duplicate entities in a graph based on their embeddings.

        Parameters
        ----------
        tol: float
            Minimum distance (depending on the chosen metric) to define one entity as the duplicate of another.

        Returns
        -------
        duplicates : set of frozensets
            Each entry in the duplicates set is a frozenset containing all entities that were found to be duplicates
            according to the metric and tolerance.
            Each frozenset will contain at least two entities.

        """
        nn = NearestNeighbors(metric=metric, radius=tol)
        nn.fit(emb)
        neighbors = nn.radius_neighbors(emb)[1]
        idx_dups = (
            (i, row) for i, row in enumerate(neighbors) if len(row) > 1
        )
        if mode == "t":
            dups = {
                frozenset(tuple(X[idx]) for idx in row) for i, row in idx_dups
            }
        else:
            dups = {frozenset(X[idx] for idx in row) for i, row in idx_dups}
        return dups

    def opt(tol, info):
        """
        Auxiliary function for the optimization procedure to find the tolerance that corresponds to the expected
        number of duplicates.

        Returns the difference between actual and expected fraction of duplicates.
        """
        duplicates = get_dups(tol)
        fraction_duplicates = len(set().union(*duplicates)) / len(emb)
        if verbose:
            info["Nfeval"] += 1
            logger.info(
                "Eval {}: tol: {}, duplicate fraction: {}".format(
                    info["Nfeval"], tol, fraction_duplicates
                )
            )
        return fraction_duplicates - expected_fraction_duplicates

    if tolerance == "auto":
        max_distance = spatial.distance_matrix(emb, emb).max()
        tolerance = optimize.bisect(
            opt,
            0.0,
            max_distance,
            xtol=1e-3,
            maxiter=50,
            args=({"Nfeval": 0},),
        )

    return get_dups(tolerance), tolerance


def query_topn(
    model,
    top_n=10,
    head=None,
    relation=None,
    tail=None,
    ents_to_consider=None,
    rels_to_consider=None,
):
    """Queries the model with two elements of a triple and returns the top_n results of
    all possible completions ordered by score predicted by the model.

    For example, given a `<subject, predicate>` pair in the arguments, the model will score
    all possible triples `<subject, predicate, ?>`, filling in the missing element with known
    entities, and return the top_n triples ordered by score. If given a `<subject, object>`
    pair it will fill in the missing element with known relations.

    .. note::
        This function does not filter out true statements - triples returned can include those
        the model was trained on.

    Parameters
    ----------
    model : EmbeddingModel
        The trained model that will be used to score triple completions.
    top_n : int
        The number of completed triples to returned.
    head : str
        An entity string to query.
    relation : str
        A relation string to query.
    tail : str
        An object string to query.
    ents_to_consider: array-like
        List of entities to use for triple completions. If `None`, will generate completions using all distinct entities
        (Default: `None`).
    rels_to_consider: array-like
        List of relations to use for triple completions. If `None`, will generate completions using all distinct
        relations (default: `None`).

    Returns
    -------
    X : ndarray of shape (n, 3)
        A list of triples ordered by score.
    S : ndarray, shape (n)
       A list of scores.

    Example
    -------

    >>> import requests
    >>> from ampligraph.datasets import load_from_csv
    >>> from ampligraph.discovery import discover_facts
    >>> from ampligraph.discovery import query_topn
    >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
    >>> # Game of Thrones relations dataset
    >>> url = 'https://ampligraph.s3-eu-west-1.amazonaws.com/datasets/GoT.csv'
    >>> open('GoT.csv', 'wb').write(requests.get(url).content)
    >>> X = load_from_csv('.', 'GoT.csv', sep=',')
    >>>
    >>> model = ScoringBasedEmbeddingModel(eta=5,
    >>>                                    k=150,
    >>>                                    scoring_type='TransE')
    >>> model.compile(optimizer='adagrad', loss='pairwise')
    >>> model.fit(X,
    >>>           batch_size=100,
    >>>           epochs=20,
    >>>           verbose=False)
    >>>
    >>> query_topn(model, top_n=5,
    >>>            head='Eddard Stark', relation='ALLIED_WITH', tail=None,
    >>>            ents_to_consider=None, rels_to_consider=None)
    >>>
    (array([['Eddard Stark', 'ALLIED_WITH', 'Smithyton'],
            ['Eddard Stark', 'ALLIED_WITH', 'Eden Risley'],
            ['Eddard Stark', 'ALLIED_WITH', 'House Westbrook'],
            ['Eddard Stark', 'ALLIED_WITH', 'House Leygood'],
            ['Eddard Stark', 'ALLIED_WITH', 'House Bridges']], dtype='<U44'),
     array([9.000417 , 5.272001 , 5.1876183, 5.121145 , 5.0564814],
           dtype=float32))

    """

    if model.is_backward:
        model = model.model

    if not model.is_fitted:
        msg = "Model is not fitted."
        logger.error(msg)
        raise ValueError(msg)

    if not np.sum([head is None, relation is None, tail is None]) == 1:
        msg = "Exactly one of `head`, `relation` or `tail` arguments must be None."
        logger.error(msg)
        raise ValueError(msg)

    if head:
        if head not in list(model.data_indexer.backend.get_all_entities()):
            msg = "Head entity `{}` not seen by model".format(head)
            logger.error(msg)
            raise ValueError(msg)

    if relation:
        if relation not in list(
            model.data_indexer.backend.get_all_relations()
        ):
            msg = "Relation `{}` not seen by model".format(relation)
            logger.error(msg)
            raise ValueError(msg)

    if tail:
        if tail not in list(model.data_indexer.backend.get_all_entities()):
            msg = "Tail entity `{}` not seen by model".format(tail)
            logger.error(msg)
            raise ValueError(msg)

    if ents_to_consider is not None:
        if head and tail:
            msg = "Cannot specify `ents_to_consider` and both `subject` and `object` arguments."
            logger.error(msg)
            raise ValueError(msg)
        if not isinstance(ents_to_consider, (list, np.ndarray)):
            msg = "`ents_to_consider` must be a list or numpy array."
            logger.error(msg)
            raise ValueError(msg)
        if not all(
            x in list(model.data_indexer.backend.get_all_entities())
            for x in ents_to_consider
        ):
            msg = "Entities in `ents_to_consider` have not been seen by the model."
            logger.error(msg)
            raise ValueError(msg)
        if len(ents_to_consider) < top_n:
            msg = "`ents_to_consider` contains less than top_n values, return set will be truncated."
            logger.warning(msg)

    if rels_to_consider is not None:
        if relation:
            msg = "Cannot specify both `rels_to_consider` and `relation` arguments."
            logger.error(msg)
            raise ValueError(msg)
        if not isinstance(rels_to_consider, (list, np.ndarray)):
            msg = "`rels_to_consider` must be a list or numpy array."
            logger.error(msg)
            raise ValueError(msg)
        if not all(
            x in list(model.data_indexer.backend.get_all_relations())
            for x in rels_to_consider
        ):
            msg = "Relations in `rels_to_consider` have not been seen by the model."
            logger.error(msg)
            raise ValueError(msg)
        if len(rels_to_consider) < top_n:
            msg = "`rels_to_consider` contains less than top_n values, return set will be truncated."
            logger.warning(msg)

    # Complete triples from entity and relation dict
    if relation is None:
        rels = rels_to_consider or list(
            model.data_indexer.backend.get_all_relations()
        )
        triples = np.array([[head, x, tail] for x in rels])
    else:
        ents = ents_to_consider or list(
            model.data_indexer.backend.get_all_entities()
        )
        if head:
            triples = np.array([[head, relation, x] for x in ents])
        else:
            triples = np.array([[x, relation, tail] for x in ents])

    # Get scores for completed triples
    scores = model.predict(triples)

    # Join triples and scores, sort ascending by scores, then take top_n
    # results
    topn_idx = np.squeeze(np.argsort(scores, axis=0)[::-1][:top_n])
    scores_out = np.array(scores)[topn_idx]
    triples_out = np.copy(triples[topn_idx, :])

    return triples_out, scores_out


def find_nearest_neighbours(kge_model, entities, n_neighbors=10, entities_subset=None, metric="euclidean"):
    """ Return the nearest neighbors of entities.

    The method works in the embedding space and finds a desired number of neighboring embeddings.
    It can operate from all the entities in the graph or from a subset of interest.

    Parameters
    ----------
    kge_model: ampligraph.latent_features.EmbeddingModel
        Trained kge model
    entities: list or np.array
        List of entities whose neighbors need to be found
    n_neighbors: int
        number of neighbors to be computed
    entities_subset: list or np.array
        List of entities from which neighbors need to be computed. 
        If this list is not passed, all the entities in the graph would be used
    metric: string or callable
        distance metric to be used with NearestNeighbors algorithm
        For values that can be passed, refer sklearn NearestNeighbors
        
    Returns
    -------
    neighbors: np.array of size (len(entities), n_neighbors)
        Each row contains the n_neighbors neighbours of corresponding concepts in entities
    distance: np.array of size (len(entities), n_neighbors)
        Each row contains distances of corresponding neighbours
    
    Examples
    --------
    >>> model = DistMult(batches_count=2, seed=555, epochs=1, k=10,
    >>>                  loss='pairwise', loss_params={'margin': 5},
    >>>                  optimizer='adagrad', optimizer_params={'lr': 0.1})
    >>> X = np.array([['a', 'y', 'b'],
    >>>               ['b', 'y', 'a'],
    >>>               ['e', 'y', 'c'],
    >>>               ['c', 'z', 'a'],
    >>>               ['a', 'z', 'd'],
    >>>               ['f', 'z', 'g'],
    >>>               ['c', 'z', 'g']])
    >>> model.fit(X)
    >>> neighbors, dist = find_nearest_neighbours(model, 
    >>>                                           entities=['b'], 
    >>>                                           n_neighbors=3, 
    >>>                                           entities_subset=['a', 'c', 'd', 'e', 'f'])
    >>> print(neighbors, dist)
    [['e' 'd' 'c']] [[0.97474706 0.979108   1.2323136 ]]
    """
    assert kge_model.is_fitted, "KGE model is not fit!"
    assert isinstance(entities, (list, np.ndarray)), \
        "Invalid type for entities! Must be a list or np.array"

    if entities_subset is not None:
        assert isinstance(entities_subset, (list, np.ndarray)), \
            "Invalid type for entities_subset! Must be a list or np.array"

        all_neighbors_emb = kge_model.get_embeddings(entities_subset)
        all_neighbors = entities_subset
    else:
        all_neighbors_emb = kge_model.trained_model_params[0]
        all_neighbors = list(kge_model.ent_to_idx.keys())

    assert n_neighbors < len(all_neighbors), 'n_neighbors must be less than the number of entities being fit!'
    knn_model = NearestNeighbors(n_neighbors=n_neighbors, metric=metric).fit(all_neighbors_emb)

    test_entities_emb = kge_model.get_embeddings(entities)
    distances, indices = knn_model.kneighbors(test_entities_emb)
    out_neighbors = []
    for neighbor_idx_list in indices:
        out_neighbors.append([])
        for neighbor_idx in neighbor_idx_list:
            out_neighbors[-1].append(all_neighbors[neighbor_idx])
    
    return np.array(out_neighbors), np.array(distances)
