import pickle
import torch
import pandas as pd
from pykeen.triples import TriplesFactory

from SEPAL.dataloader import DataLoader
from SEPAL.stratify_datasets import reconnect_training_set
from SEPAL import SEPAL_DIR


def main():
    # Load triples factory
    triples_dir = SEPAL_DIR / f"datasets/knowledge_graphs/full_freebase_lcc"
    dl = DataLoader(triples_dir)
    tf = dl.get_triples_factory(create_inverse_triples=False)

    # Get train/test/valid splits
    data_dir = SEPAL_DIR / f"datasets/knowledge_graphs/full_freebase"
    train_triples = pd.read_csv(
        data_dir / "freebase/train.del", sep="\t", header=None
    )  # 304 727 650 rows × 3 columns
    valid_triples = pd.read_csv(
        data_dir / "freebase/valid.del", sep="\t", header=None
    )  # 16 929 318 rows × 3 columns
    test_triples = pd.read_csv(
        data_dir / "freebase/test.del", sep="\t", header=None
    )  # 16 929 308 rows × 3 columns

    train_triples.columns = ["head", "relation", "tail"]
    valid_triples.columns = ["head", "relation", "tail"]
    test_triples.columns = ["head", "relation", "tail"]

    # Get entity mapping
    mapping = pd.read_csv(
        data_dir / "freebase/entity_ids.del", sep="\t", header=None
    )  # 86 054 151 rows × 2 columns
    mapping.columns = ["fb_id", "entity_name"]
    lcc_entity_to_id = tf.entity_to_id
    mapping["lcc_id"] = mapping["entity_name"].map(lcc_entity_to_id).astype("Int32")

    fb_to_lcc = mapping.set_index("fb_id")["lcc_id"].dropna().astype(int).to_dict()

    # Get relation mapping
    mapping = pd.read_csv(
        data_dir / "freebase/relation_ids.del", sep="\t", header=None
    )  # 14 824 rows × 2 columns
    mapping.columns = ["fb_id", "relation_name"]
    lcc_rel_to_id = tf.relation_to_id
    mapping["lcc_id"] = mapping["relation_name"].map(lcc_rel_to_id).astype("Int16")

    fb_to_lcc_rel = mapping.set_index("fb_id")["lcc_id"].dropna().astype(int).to_dict()

    # Map triples
    def map_triples(triples):
        triples["head"] = triples["head"].map(fb_to_lcc).astype("Int32")
        triples["tail"] = triples["tail"].map(fb_to_lcc).astype("Int32")
        triples["relation"] = triples["relation"].map(fb_to_lcc_rel).astype("Int16")
        return triples.dropna()

    train_triples = map_triples(train_triples)
    valid_triples = map_triples(valid_triples)
    test_triples = map_triples(test_triples)

    # Build triples factories
    training = TriplesFactory(
        mapped_triples=torch.tensor(train_triples.values.astype("int64")),
        entity_to_id=tf.entity_to_id,
        relation_to_id=tf.relation_to_id,
        create_inverse_triples=False,
    )
    testing = TriplesFactory(
        mapped_triples=torch.tensor(test_triples.values.astype("int64")),
        entity_to_id=tf.entity_to_id,
        relation_to_id=tf.relation_to_id,
        create_inverse_triples=False,
    )
    validation = TriplesFactory(
        mapped_triples=torch.tensor(valid_triples.values.astype("int64")),
        entity_to_id=tf.entity_to_id,
        relation_to_id=tf.relation_to_id,
        create_inverse_triples=False,
    )

    # Reconnect training set
    training, testing, validation = reconnect_training_set(
        tf, training, testing, validation
    )

    ## Save constructed triple factories
    with open(triples_dir / "training_tf.pkl", "wb") as f:
        pickle.dump(training, f)

    with open(triples_dir / "testing_tf.pkl", "wb") as f:
        pickle.dump(testing, f)

    with open(triples_dir / "validation_tf.pkl", "wb") as f:
        pickle.dump(validation, f)

    return


if __name__ == "__main__":
    main()
