from pathlib import Path

from torchbiggraph.config import ConfigFileLoader
from torchbiggraph.converters.importers import (
    convert_input_data,
    TSVEdgelistReader,
    PyArrowEdgelistReader,
)
from torchbiggraph.util import (
    set_logging_verbosity,
)


DATA_DIR = Path(__file__).absolute().parents[0] / "data"


def convert_data(config, filenames, source="tsv"):
    input_edge_paths = [DATA_DIR / name for name in filenames]

    if source == "tsv":
        convert_input_data(
            config.entities,
            config.relations,
            config.entity_path,
            config.edge_paths,
            input_edge_paths,
            TSVEdgelistReader(lhs_col=0, rhs_col=2, rel_col=1),
            dynamic_relations=config.dynamic_relations,
        )
    elif source == "parquet":
        convert_input_data(
            config.entities,
            config.relations,
            config.entity_path,
            config.edge_paths,
            input_edge_paths,
            PyArrowEdgelistReader("head", "tail", "relation", weight_col=None),
            entity_min_count=1,
            relation_type_min_count=1,
            dynamic_relations=config.dynamic_relations,
        )
    return


if __name__ == "__main__":
    for data in [
        # "mini_yago3_lcc",
        # "yago3_lcc",
        # "rel_core_yago4",
        # "rel_core_yago4.5",
        # "core_yago4",
        # "core_yago4.5",
        # "yago4_lcc",
        # "yago4.5_lcc",
        # "yago4_with_ontology",
        "yago4_with_full_ontology",
        # "yago4.5_with_ontology",
        # "yago4.5_with_full_ontology",
    ]:
        print(f"Converting {data}")
        # Training
        filenames = [
            f"{data}/{data}_{subset}.parquet"
            for subset in ["training", "validation", "testing"]
        ]

        loader = ConfigFileLoader()
        config = loader.load_config(
            Path(__file__).absolute().parent / f"configs/{data}_config.py"
        )
        set_logging_verbosity(config.verbose)

        convert_data(config, filenames, source="parquet")
