import numpy as np
import torch
import pandas as pd
import scipy.sparse
from pykeen.triples import TriplesFactory

from SEPAL.knowledge_graph import KnowledgeGraph


def extract_subgraph(ctrl, graph):
    if ctrl.core_selection == "degree":
        subgraph = degree_extraction(ctrl, graph)
    elif ctrl.core_selection == "relation_based":
        subgraph = relation_extraction(ctrl, graph)
    else:
        raise ValueError(
            "core selection strategy must be in {'degree', 'relation_based'}"
        )
    return subgraph


def relation_extraction(ctrl, graph):
    """
    Select the entities to create the core subgraph by:
    1. Sampling among the edges the top k% of each relation type (based on sum of degrees of head and tail).
    2. Getting the involved entities.
    3. Adding all the edges linking those entities (even if they were not in the top k%).
    4. Keeping only the largest connected component.
    """
    ## Select the core subgraph
    # The proportion of edges to keep
    prop = ctrl.core_prop

    # 1. Sampling among the edges the top k% of each relation type (based on sum of degrees of head and tail)
    edges = graph.get_mapped_triples()
    df = pd.DataFrame(edges, columns=["head", "rel", "tail"])
    df["sum_degree"] = df["head"].apply(lambda x: graph.degrees[x]) + df["tail"].apply(
        lambda x: graph.degrees[x]
    )
    df.sort_values(by="sum_degree", ascending=False, inplace=True)

    # 2. Getting the involved entities.
    node_list = []
    for rel in range(graph.num_relations):
        rel_df = df[df["rel"] == rel]
        node_list.extend(
            np.unique(rel_df[: int(prop * len(rel_df)) + 1][["head", "tail"]].values)
        )
    node_list = list(set(node_list))

    # 3. Adding all the edges linking those entities (even if they were not in the top k%).
    node_list_tensor = torch.IntTensor(node_list)
    mask = torch.isin(edges[:, [0, 2]], node_list_tensor).all(axis=1)
    subgraph_edges = edges[mask]
    num_nodes = len(node_list)
    reindex = {node_list[i]: i for i in range(len(node_list))}
    subgraph_edges[:, [0, 2]] = torch.tensor(
        np.vectorize(reindex.__getitem__)(subgraph_edges[:, [0, 2]])
    )

    # 4. Keeping only the largest connected component.
    subgraph_adjacency = scipy.sparse.csr_matrix(
        (np.ones(len(subgraph_edges)), (subgraph_edges[:, 0], subgraph_edges[:, 2])),
        shape=(num_nodes, num_nodes),
    )
    n_components, labels = scipy.sparse.csgraph.connected_components(
        subgraph_adjacency, directed=False, return_labels=True
    )
    selected_nodes = np.where(labels == np.argmax(np.bincount(labels)))[0]
    reindex_inv = {v: k for k, v in reindex.items()}
    node_list = [reindex_inv[i] for i in selected_nodes]

    ## Extract subgraph
    subgraph = extract_subgraph_from_node_list(ctrl, graph, node_list)

    return subgraph


def degree_extraction(ctrl, graph):
    """
    Select entities of high degree to create the core subgraph.
    Keep only the resulting largest connected component.
    """
    print("Extracting core subgraph...")
    ## Select the core subgraph
    # The proportion of nodes to keep
    prop = ctrl.core_prop

    # Select high degree nodes
    print("    ... getting high degree nodes")
    edges = graph.get_mapped_triples()[:, [0, 2]]
    node_list_tensor = torch.argsort(torch.IntTensor(graph.degrees), descending=True)[
        : int(prop * graph.num_entities) + 1
    ]
    node_list = node_list_tensor.tolist()
    mask = torch.isin(edges, node_list_tensor).all(axis=1)
    subgraph_edges = edges[mask]
    num_nodes = len(node_list)
    reindex = {node_list[i]: i for i in range(len(node_list))}
    subgraph_edges = torch.tensor(np.vectorize(reindex.__getitem__)(subgraph_edges))

    # Keep only the largest connected component of the subgraph
    print("    ... keeping largest connected component")
    subgraph_adjacency = scipy.sparse.csr_matrix(
        (np.ones(len(subgraph_edges)), (subgraph_edges[:, 0], subgraph_edges[:, 1])),
        shape=(num_nodes, num_nodes),
    )
    n_components, labels = scipy.sparse.csgraph.connected_components(
        subgraph_adjacency, directed=False, return_labels=True
    )
    selected_nodes = np.where(labels == np.argmax(np.bincount(labels)))[0]
    reindex_inv = {v: k for k, v in reindex.items()}
    node_list = [reindex_inv[i] for i in selected_nodes]

    ## Extract subgraph
    subgraph = extract_subgraph_from_node_list(ctrl, graph, node_list)

    return subgraph


def extract_subgraph_from_node_list(ctrl, graph, node_list):
    ## Build triples_factory
    print("    ... building triples factory")
    # Remove entities that are outside of the subgraph
    old_tf = graph.triples_factory
    mask = torch.isin(old_tf.mapped_triples[:, [0, 2]], torch.IntTensor(node_list)).all(
        axis=1
    )
    mapped_triples = old_tf.mapped_triples[mask]

    # Reindex subgraph entities between 0 and n-1
    d = {node_list[i]: i for i in range(len(node_list))}
    mapped_triples[:, [0, 2]] = torch.tensor(
        np.vectorize(d.__getitem__)(mapped_triples[:, [0, 2]])
    )

    # Build the new entity_to_id dictionnary
    old_id_to_entity = {v: k for k, v in old_tf.entity_to_id.items()}
    d_inv = {v: k for k, v in d.items()}
    id_to_entity = {i: old_id_to_entity[d_inv[i]] for i in range(len(node_list))}
    entity_to_id = {v: k for k, v in id_to_entity.items()}

    # Create triple factory object
    triples_factory = TriplesFactory(
        mapped_triples=mapped_triples,
        entity_to_id=entity_to_id,
        relation_to_id=old_tf.relation_to_id,
        create_inverse_triples=old_tf.create_inverse_triples,
    )

    ## Build knowledge_graph instance
    print("    ... building knowledge graph object")
    subgraph = KnowledgeGraph(triples_factory)
    graph.core_subgraph_idx = node_list

    ## Store number of entities
    ctrl.core_subgraph_idx = [node_list]
    ctrl.core_size = subgraph.num_entities
    print(
        f"Core subgraph contains {subgraph.num_entities} entities ({subgraph.num_entities/graph.num_entities:.1%} of total graph)"
    )

    return subgraph
