from pathlib import Path
import attr
import pandas as pd

from torchbiggraph.config import add_to_sys_path, ConfigFileLoader
from torchbiggraph.train import train
from torchbiggraph.util import (
    set_logging_verbosity,
    setup_logging,
    SubprocessInitializer,
)

from SEPAL.utils import measure_performance
from SEPAL import SEPAL_DIR
from SEPAL.downstream_evaluation import DATASETS_NAMES


DATA_DIR = Path(__file__).absolute().parents[0] / "data"


def main():
    for subset in ["train", "all"]:
        for data in [
            # "mini_yago3_lcc",
            # "yago3_lcc",
            # "rel_core_yago4",
            # "rel_core_yago4.5",
            # "core_yago4",
            # "core_yago4.5",
            # "yago4_lcc",
            # "yago4.5_lcc",
            # "yago4_with_ontology",
            "yago4_with_full_ontology",
            # "yago4.5_with_ontology",
            # "yago4.5_with_full_ontology",
        ]:
            print(f"Training PBG on {subset} {data}")

            loader = ConfigFileLoader()
            config = loader.load_config(
                Path(__file__).absolute().parent / f"configs/{data}_config.py"
            )
            set_logging_verbosity(config.verbose)
            subprocess_init = SubprocessInitializer()
            subprocess_init.register(setup_logging, config.verbose)
            subprocess_init.register(add_to_sys_path, loader.config_dir.name)

            output_train_path, _, _ = config.edge_paths
            if subset == "train":
                train_config = attr.evolve(
                    config,
                    edge_paths=[output_train_path],
                    checkpoint_path=config.checkpoint_path + "/train",
                )
            elif subset == "all":
                train_config = attr.evolve(
                    config,
                    checkpoint_path=config.checkpoint_path + "/all",
                )

            # Training
            _, training_time, training_mem_usage = measure_performance(train)(
                1, train_config, subprocess_init=subprocess_init
            )
            print(f"Training Time: {training_time}")
            print(f"Training Memory Usage: {training_mem_usage}")

            # Save config and performances
            config_dict = train_config.to_dict()
            config_dict["data"] = data
            config_dict["training_time"] = training_time
            config_dict["training_mem_usage"] = training_mem_usage
            config_dict["method"] = "PBG"
            config_dict["subset"] = subset
            config_dict["id"] = f"PBG - {DATASETS_NAMES[data]} {subset}"

            checkpoint_path = SEPAL_DIR / "baselines/PBG/checkpoints_pbg.parquet"
            new_checkpoint_info = pd.DataFrame([config_dict], index=[0])
            if Path(checkpoint_path).is_file():
                checkpoint_info = pd.read_parquet(checkpoint_path)
                checkpoint_info = pd.concat(
                    [checkpoint_info, new_checkpoint_info]
                ).reset_index(drop=True)
            else:
                checkpoint_info = new_checkpoint_info
            checkpoint_info.to_parquet(checkpoint_path, index=False)
    return


if __name__ == "__main__":
    main()
