"""A script to sample smaller KGs from Wikidata5M."""

from pathlib import Path
import numpy as np
import torch
import scipy.sparse
from pykeen.triples import TriplesFactory
from pykeen.datasets import Wikidata5M

KG_DIR = Path(__file__).parents[1] / "data" / "kg" / "processed"


def extract_dataset(base_tf, thresholds=[3, 4, 6, 9]):
    # 1. Compute entity degrees
    num_entities = base_tf.num_entities
    mapped_triples = base_tf._add_inverse_triples_if_necessary(base_tf.mapped_triples)
    edge_index = mapped_triples[:, 0::2].t()
    adjacency = scipy.sparse.csr_matrix(
        (torch.ones(edge_index.shape[1]), (edge_index[0], edge_index[1])),
        shape=(num_entities, num_entities),
        dtype="float32",
    )
    adjacency.data = np.ones(adjacency.data.shape[0], dtype="float32")
    degrees = adjacency @ np.ones(num_entities)

    # 2. Extract subgraphs
    edges = mapped_triples[:, [0, 2]]
    for threshold in thresholds:
        # Select nodes with degree >= threshold
        node_list_tensor = torch.where(torch.tensor(degrees) >= threshold)[0]
        node_list = node_list_tensor.tolist()

        # Extract edges from node list
        mask = torch.isin(edges, node_list_tensor).all(axis=1)
        subgraph_edges = edges[mask]
        num_nodes = len(node_list)
        reindex = {node_list[i]: i for i in range(len(node_list))}
        subgraph_edges = torch.tensor(np.vectorize(reindex.__getitem__)(subgraph_edges))

        # Keep only the largest connected component of the subgraph
        subgraph_adjacency = scipy.sparse.csr_matrix(
            (
                np.ones(len(subgraph_edges)),
                (subgraph_edges[:, 0], subgraph_edges[:, 1]),
            ),
            shape=(num_nodes, num_nodes),
        )
        n_components, labels = scipy.sparse.csgraph.connected_components(
            subgraph_adjacency, directed=False, return_labels=True
        )
        selected_nodes = np.where(labels == np.argmax(np.bincount(labels)))[0]
        reindex_inv = {v: k for k, v in reindex.items()}
        node_list = [reindex_inv[i] for i in selected_nodes]

        # Remove entities that are outside of the subgraph
        mask = torch.isin(
            base_tf.mapped_triples[:, [0, 2]], torch.IntTensor(node_list)
        ).all(axis=1)
        mapped_triples = base_tf.mapped_triples[mask]

        # Reindex subgraph entities between 0 and n-1
        d = {node_list[i]: i for i in range(len(node_list))}
        mapped_triples[:, [0, 2]] = torch.tensor(
            np.vectorize(d.__getitem__)(mapped_triples[:, [0, 2]])
        )

        # Build the new entity_to_id dictionnary
        old_id_to_entity = {v: k for k, v in base_tf.entity_to_id.items()}
        d_inv = {v: k for k, v in d.items()}
        id_to_entity = {i: old_id_to_entity[d_inv[i]] for i in range(len(node_list))}
        entity_to_id = {v: k for k, v in id_to_entity.items()}

        # Create triple factory object
        triples_factory = TriplesFactory(
            mapped_triples=mapped_triples,
            entity_to_id=entity_to_id,
            relation_to_id=base_tf.relation_to_id,
            create_inverse_triples=base_tf.create_inverse_triples,
        )

        # Save triples factory
        out_dir = KG_DIR / f"wikidata5m_deg{threshold}"
        out_dir.mkdir(parents=True, exist_ok=True)
        triples_factory.to_path_binary(path=out_dir)

        print(f"Saved subgraph with degree >= {threshold} to {out_dir}")
        print(f"Number of entities: {triples_factory.num_entities}")
        print(f"Number of relations: {triples_factory.num_relations}")
        print(f"Number of triples: {triples_factory.num_triples}")

    return


if __name__ == "__main__":
    dataset = Wikidata5M()
    base_tf = dataset.training
    extract_dataset(base_tf)
