from pathlib import Path
from ogb.lsc import WikiKG90Mv2Dataset
import pandas as pd
import numpy as np
import pickle

data_dir = Path(__file__).resolve().parent

if __name__ == "__main__":
    ## Get the dataset
    dataset = WikiKG90Mv2Dataset(root=data_dir)

    ## Get all the triples
    train_triples = dataset.train_hrt
    valid_task = dataset.valid_dict["h,r->t"]
    valid_triples = np.hstack((valid_task["hr"], valid_task["t"].reshape(-1, 1)))
    del valid_task

    ## Put everything in a dataframe
    triples = pd.DataFrame(
        np.vstack((train_triples, valid_triples)),
        columns=["head", "relation", "tail"],
    ) # 601 077 811 rows × 3 columns

    # Load the entity and relation names
    entity_names = pd.read_csv(
        data_dir / "wikikg90mv2_mapping/entity.csv", sep=",", header=0
    )  # 91 230 610 rows × 4 columns
    relation_names = pd.read_csv(
        data_dir / "wikikg90mv2_mapping/relation.csv", sep=",", header=0
    )  # 1 387 rows × 4 columns

    ## Make triples
    print("Making triples...")
    # Build entity_to_idx
    print("... entities")
    n_entities = len(entity_names)
    entity_to_idx = dict(zip(entity_names["entity"], entity_names["idx"]))

    # Build rel_to_idx
    print("... relations")
    relations = relation_names["relation"].unique()
    n_relations = len(relation_names)
    rel_to_idx = dict(zip(relation_names["relation"], relation_names["idx"]))

    # Shuffle triplets
    triples = triples.sample(frac=1)

    ## Save triples
    print("Saving triples...")

    # Save triples
    dtypes = {"head": np.uint32, "relation": np.uint16, "tail": np.uint32}
    triples = triples.to_records(index=False, column_dtypes=dtypes)
    np.save(f"{data_dir}/triplets.npy", triples)

    # Save metadata
    metadata = {
        "entity_to_idx": entity_to_idx,
        "rel_to_idx": rel_to_idx,
        "n_entities": n_entities,
        "n_relations": n_relations,
        "cat_attr": list(relations),
        "n_cat_attr": len(relations),
        "num_attr": [],
        "n_num_attr": 0,
    }
    with open(f"{data_dir}/metadata.pkl", "wb") as f:
        pickle.dump(metadata, f)

