# Description: This script prepares the data for the PyTorch-BigGraph (PBG) model.

import pandas as pd
import numpy as np
import pickle

from SEPAL import SEPAL_DIR


input_data = [
    # "mini_yago3_lcc",
    # "yago3_lcc",
    # "rel_core_yago4",
    # "rel_core_yago4.5",
    # "core_yago4",
    # "core_yago4.5",
    # "yago4_lcc",
    # "yago4.5_lcc",
    # "yago4_with_ontology",
    # "yago4_with_full_ontology",
    # "yago4.5_with_ontology",
    # "yago4.5_with_full_ontology",
    # "full_freebase_lcc",
    "wikikg90mv2_lcc",
]

def main():
    for data in input_data:
        print(f"Processing {data}")
        # Define the output directory
        output_dir = SEPAL_DIR / f"baselines/PBG/data/{data}"
        # If output directory does not exist, create it
        if not output_dir.exists():
            output_dir.mkdir(parents=True)

        for subset in ["training", "validation", "testing"]:
            # Load data
            tf_path = SEPAL_DIR / f"datasets/knowledge_graphs/{data}/{subset}_tf.pkl"
            with open(tf_path, "rb") as f:
                tf = pickle.load(f)

            # Build dataframe
            df = pd.DataFrame(
                tf.mapped_triples, columns=["head", "relation", "tail"]
            ).astype(np.int32)
            id_to_entity = {v: k for k, v in tf.entity_to_id.items()}
            df["head"] = df["head"].map(id_to_entity)
            df["relation"] = df["relation"].map(
                {v: k for k, v in tf.relation_to_id.items()}
            )
            df["tail"] = df["tail"].map(id_to_entity)

            # Save dataframe in parquet format
            df.to_parquet(output_dir / f"{data}_{subset}.parquet")
    return

if __name__ == "__main__":
    main()
