import torch
import pickle
import pandas as pd
import numpy as np
import scipy.sparse

from SEPAL.utils import create_graph
from SEPAL import SEPAL_DIR


def make_core_degree(
    data="yago4.5_lcc",
    save_dir=SEPAL_DIR / "datasets/knowledge_graphs/core_yago4.5",
    prop=0.03,
):
    # Load graph
    graph = create_graph(data, True)

    ## Get core indices
    # Keep high degree nodes
    edges = graph.get_mapped_triples()
    node_list_tensor = torch.argsort(torch.IntTensor(graph.degrees), descending=True)[
        : int(prop * graph.num_entities) + 1
    ]
    node_list = node_list_tensor.tolist()
    mask = torch.isin(edges[:, [0, 2]], node_list_tensor).all(axis=1)
    subgraph_edges = edges[mask]
    num_nodes = len(node_list)
    reindex = {node_list[i]: i for i in range(len(node_list))}
    subgraph_edges[:, [0, 2]] = torch.tensor(
        np.vectorize(reindex.__getitem__)(subgraph_edges[:, [0, 2]])
    )

    # Keep only the largest connected component of the subgraph
    subgraph_adjacency = scipy.sparse.csr_matrix(
        (np.ones(len(subgraph_edges)), (subgraph_edges[:, 0], subgraph_edges[:, 2])),
        shape=(num_nodes, num_nodes),
    )
    n_components, labels = scipy.sparse.csgraph.connected_components(
        subgraph_adjacency, directed=False, return_labels=True
    )
    selected_nodes = np.where(labels == np.argmax(np.bincount(labels)))[0]
    reindex_inv = {v: k for k, v in reindex.items()}
    node_list = [reindex_inv[i] for i in selected_nodes]

    ## Make triples and metadata
    make_triples_from_node_list(graph, node_list, save_dir)
    return


def make_core_relation_based(
    data="yago4.5_lcc",
    save_dir=SEPAL_DIR / "datasets/knowledge_graphs/rel_core_yago4.5",
    prop=0.01,
):
    # Load graph
    graph = create_graph(data, True)

    # 1. Sampling among the edges the top k% of each relation type (based on sum of degrees of head and tail)
    edges = graph.get_mapped_triples()
    df = pd.DataFrame(edges, columns=["head", "rel", "tail"])
    df["sum_degree"] = df["head"].apply(lambda x: graph.degrees[x]) + df["tail"].apply(
        lambda x: graph.degrees[x]
    )
    df.sort_values(by="sum_degree", ascending=False, inplace=True)

    # 2. Getting the involved entities.
    node_list = []
    for rel in range(graph.num_relations):
        rel_df = df[df["rel"] == rel]
        node_list.extend(
            np.unique(rel_df[: int(prop * len(rel_df)) + 1][["head", "tail"]].values)
        )
    node_list = list(set(node_list))

    # 3. Adding all the edges linking those entities (even if they were not in the top k%).
    node_list_tensor = torch.IntTensor(node_list)
    mask = torch.isin(edges[:, [0, 2]], node_list_tensor).all(axis=1)
    subgraph_edges = edges[mask]
    num_nodes = len(node_list)
    reindex = {node_list[i]: i for i in range(len(node_list))}
    subgraph_edges[:, [0, 2]] = torch.tensor(
        np.vectorize(reindex.__getitem__)(subgraph_edges[:, [0, 2]])
    )

    # 4. Keeping only the largest connected component.
    subgraph_adjacency = scipy.sparse.csr_matrix(
        (np.ones(len(subgraph_edges)), (subgraph_edges[:, 0], subgraph_edges[:, 2])),
        shape=(num_nodes, num_nodes),
    )
    n_components, labels = scipy.sparse.csgraph.connected_components(
        subgraph_adjacency, directed=False, return_labels=True
    )
    selected_nodes = np.where(labels == np.argmax(np.bincount(labels)))[0]
    reindex_inv = {v: k for k, v in reindex.items()}
    node_list = [reindex_inv[i] for i in selected_nodes]

    ## Make triples and metadata
    make_triples_from_node_list(graph, node_list, save_dir)
    return


def make_triples_from_node_list(graph, node_list, save_dir):
    ## Make triples and metadata
    # Make triples
    edges = graph.triples_factory.mapped_triples
    triples_mask = torch.isin(edges[:, [0, 2]], torch.IntTensor(node_list)).all(axis=1)
    subgraph_edges = edges[triples_mask]
    reindex = {node_list[i]: i for i in range(len(node_list))}
    subgraph_edges[:, [0, 2]] = torch.tensor(
        np.vectorize(reindex.__getitem__)(subgraph_edges[:, [0, 2]])
    )
    triples = pd.DataFrame(subgraph_edges, columns=["head", "relation", "tail"])

    # Shuffle triples
    triples = triples.sample(frac=1)

    # Save triples
    dtypes = {"head": np.uint32, "relation": np.uint16, "tail": np.uint32}
    triples = triples.to_records(index=False, column_dtypes=dtypes)
    np.save(save_dir / "triplets.npy", triples)

    # Make metadata
    old_id_to_entity = {v: k for k, v in graph.triples_factory.entity_to_id.items()}
    d_inv = {v: k for k, v in reindex.items()}
    id_to_entity = {i: old_id_to_entity[d_inv[i]] for i in range(len(node_list))}
    entity_to_id = {v: k for k, v in id_to_entity.items()}

    # Save metadata
    metadata = {
        "entity_to_idx": entity_to_id,
        "rel_to_idx": graph.triples_factory.relation_to_id,
        "n_entities": len(node_list),
        "n_relations": len(graph.triples_factory.relation_to_id),
        "cat_attr": list(graph.triples_factory.relation_to_id.keys()),
        "n_cat_attr": len(graph.triples_factory.relation_to_id),
        "num_attr": [],
        "n_num_attr": 0,
    }
    with open(save_dir / "metadata.pkl", "wb") as f:
        pickle.dump(metadata, f)
    return


if __name__ == "__main__":
    make_core_relation_based(
        data="yago3_lcc",
        save_dir=SEPAL_DIR / "datasets/knowledge_graphs/rel_core_yago3",
        prop=0.025,
    )
    make_core_relation_based(
        data="yago4_lcc",
        save_dir=SEPAL_DIR / "datasets/knowledge_graphs/rel_core_yago4",
        prop=0.01,
    )
    make_core_relation_based(
        data="yago4.5_lcc",
        save_dir=SEPAL_DIR / "datasets/knowledge_graphs/rel_core_yago4.5",
        prop=0.01,
    )
