# Create new versions of the datasets reduced to only their largest connected component

from pathlib import Path
import scipy.sparse
import numpy as np
import torch
import pickle
import pandas as pd

data_dir = Path(__file__).parent


def main(datasets=["mini_yago3", "yago3", "yago4", "yago4.5", "full_freebase", "wikikg90mv2"]):
    for data in datasets:
        print(f"----- {data} -----")
        triples = np.load(data_dir / f"{data}/triplets.npy")
        df = pd.DataFrame(triples).astype(np.int32)
        with open(data_dir / f"{data}/metadata.pkl", "rb") as f:
            metadata = pickle.load(f)

        adjacency = scipy.sparse.csr_matrix(
            (torch.ones(len(df)), (df["head"].values, df["tail"].values)),
            shape=(metadata["n_entities"], metadata["n_entities"]),
            dtype="float32",
        )
        adjacency = adjacency + adjacency.T  # make the graph undirected

        # Get largest connected component
        n_components, labels = scipy.sparse.csgraph.connected_components(
            adjacency, directed=False, return_labels=True
        )
        node_list = np.where(labels == np.argmax(np.bincount(labels)))[0]

        # Extract triples
        mask = df["head"].isin(node_list) & df["tail"].isin(node_list)
        new_df = df[mask]

        # Reindex entities
        d = {node_list[i]: i for i in range(len(node_list))}
        new_df["head"] = new_df["head"].map(d)
        new_df["tail"] = new_df["tail"].map(d)

        # Reindex relations
        rel_list = np.sort(new_df["relation"].unique())
        rel_d = {rel_list[i]: i for i in range(len(rel_list))}
        new_df["relation"] = new_df["relation"].map(rel_d)

        # Save triples
        dtypes = {"head": np.uint32, "relation": np.uint16, "tail": np.uint32}
        new_triples = new_df.to_records(index=False, column_dtypes=dtypes)
        np.save(data_dir / f"{data}_lcc/triplets.npy", new_triples)

        # Make new entity_to_idx dict
        old_id_to_entity = {v: k for k, v in metadata["entity_to_idx"].items()}
        d_inv = {v: k for k, v in d.items()}
        new_id_to_entity = {
            i: old_id_to_entity[d_inv[i]] for i in range(len(node_list))
        }
        new_entity_to_idx = {v: k for k, v in new_id_to_entity.items()}

        # Make new relation_to_idx dict
        old_id_to_rel = {v: k for k, v in metadata["rel_to_idx"].items()}
        rel_d_inv = {v: k for k, v in rel_d.items()}
        new_id_to_rel = {i: old_id_to_rel[rel_d_inv[i]] for i in range(len(rel_list))}
        new_rel_to_idx = {v: k for k, v in new_id_to_rel.items()}

        # Save metadata
        new_metadata = {
            "entity_to_idx": new_entity_to_idx,
            "rel_to_idx": new_rel_to_idx,
            "n_entities": node_list.size,
            "n_relations": rel_list.size,
            "cat_attr": list(new_rel_to_idx.keys()),
            "n_cat_attr": len(new_rel_to_idx),
            "num_attr": [],
            "n_num_attr": 0,
        }
        with open(data_dir / f"{data}_lcc/metadata.pkl", "wb") as f:
            pickle.dump(new_metadata, f)
    return


if __name__ == "__main__":
    main(datasets=["wikikg90mv2"])
