from pathlib import Path
import pandas as pd
import numpy as np
import pickle


def main():
    ## Load the data
    data_dir = Path(__file__).parent

    # Load the triples
    train_triples = pd.read_csv(
        data_dir / "freebase/train.del", sep="\t", header=None
    )  # 304 727 650 rows × 3 columns
    valid_triples = pd.read_csv(
        data_dir / "freebase/valid.del", sep="\t", header=None
    )  # 16 929 318 rows × 3 columns
    test_triples = pd.read_csv(
        data_dir / "freebase/test.del", sep="\t", header=None
    )  # 16 929 308 rows × 3 columns

    # Concatenate the triples
    triples = pd.concat(
        [train_triples, valid_triples, test_triples], ignore_index=True
    )  # 338 586 276 rows × 3 columns

    # Rename columns
    triples.columns = ["head", "relation", "tail"]

    # Load the entity ids (Freebase IDs) and relation names
    entity_ids = pd.read_csv(
        data_dir / "freebase/entity_ids.del", sep="\t", header=None
    )  # 86 054 151 rows × 2 columns
    relation_names = pd.read_csv(
        data_dir / "freebase/relation_ids.del", sep="\t", header=None
    )  # 14 824 rows × 2 columns

    ## Make triples
    print("Making triples...")
    # Build entity_to_idx
    print("... entities")
    n_entities = len(entity_ids)
    entity_to_idx = dict(zip(entity_ids[1], entity_ids[0]))

    # Build rel_to_idx
    print("... relations")
    relations = relation_names[1].unique()
    n_relations = len(relation_names)
    rel_to_idx = dict(zip(relation_names[1], relation_names[0]))

    # Shuffle triplets
    triples = triples.sample(frac=1)

    ## Save triples
    print("Saving triples...")

    # Save triples
    dtypes = {"head": np.uint32, "relation": np.uint16, "tail": np.uint32}
    triples = triples.to_records(index=False, column_dtypes=dtypes)
    np.save(f"{data_dir}/triplets.npy", triples)

    # Save metadata
    metadata = {
        "entity_to_idx": entity_to_idx,
        "rel_to_idx": rel_to_idx,
        "n_entities": n_entities,
        "n_relations": n_relations,
        "cat_attr": list(relations),
        "n_cat_attr": len(relations),
        "num_attr": [],
        "n_num_attr": 0,
    }
    with open(f"{data_dir}/metadata.pkl", "wb") as f:
        pickle.dump(metadata, f)

    return


if __name__ == "__main__":
    main()
