from pathlib import Path
import attr
import pandas as pd

from torchbiggraph.config import add_to_sys_path, ConfigFileLoader
from torchbiggraph.train import train
from torchbiggraph.util import (
    set_logging_verbosity,
    setup_logging,
    SubprocessInitializer,
)

from SEPAL.utils import measure_performance
from SEPAL import SEPAL_DIR, DATASETS_NAMES


DATA_DIR = Path(__file__).absolute().parents[0] / "data"


core_node_proportions = {
    "yago4_lcc": 0.02,
    "yago4.5_lcc": 0.04,
    "yago4_with_full_ontology": 0.015,
    "yago4.5_with_full_ontology": 0.04,
    "full_freebase_lcc": 0.025,
}

core_edge_proportions = {
    "yago4_lcc": 0.025,
    "yago4.5_lcc": 0.025,
    "yago4_with_full_ontology": 0.025,
    "yago4.5_with_full_ontology": 0.015,
    "full_freebase_lcc": 0.015,
}


def main():
    core = False
    for subset in ["all"]:
        for data in [
            # "mini_yago3_lcc",
            # "yago3_lcc",
            # "rel_core_yago4",
            # "rel_core_yago4.5",
            # "core_yago4",
            # "core_yago4.5",
            # "yago4_with_ontology",
            # "yago4.5_with_ontology",
            # "yago4.5_lcc",
            # "yago4.5_with_full_ontology",
            # "yago4_lcc",
            # "yago4_with_full_ontology",
            # "full_freebase_lcc",
            "wikikg90mv2_lcc",
        ]:
            for num_epochs in [None]:
                print(f"Training PBG on {subset} {data}")

                loader = ConfigFileLoader()
                config = loader.load_config(
                    Path(__file__).absolute().parent / f"configs/{data}_config.py"
                )
                set_logging_verbosity(config.verbose)
                subprocess_init = SubprocessInitializer()
                subprocess_init.register(setup_logging, config.verbose)
                subprocess_init.register(add_to_sys_path, loader.config_dir.name)

                if core:
                    core_name = f"core_{data}_mixed_{core_node_proportions[data]}_{core_edge_proportions[data]}"
                    train_config = attr.evolve(
                        config,
                        entity_path=f"baselines/PBG/data/{core_name}",
                        edge_paths=[
                            f"baselines/PBG/data/{core_name}/train_partitioned"
                        ],
                        checkpoint_path=f"baselines/PBG/model/{core_name}/train_{num_epochs}",
                        num_epochs=num_epochs,
                    )
                else:
                    output_train_path, _, _ = config.edge_paths
                    if subset == "train":
                        train_config = attr.evolve(
                            config,
                            edge_paths=[output_train_path],
                            checkpoint_path=config.checkpoint_path + "/train",
                        )
                    elif subset == "all":
                        train_config = attr.evolve(
                            config,
                            checkpoint_path=config.checkpoint_path + "/all",
                        )

                    if num_epochs is not None:
                        train_config = attr.evolve(
                            train_config,
                            num_epochs=num_epochs,
                            checkpoint_path=train_config.checkpoint_path
                            + f"_{num_epochs}",
                        )

                # Training
                _, training_time, training_mem_usage = measure_performance(train)(
                    1, train_config, subprocess_init=subprocess_init
                )
                print(f"Training Time: {training_time}")
                print(f"Training Memory Usage: {training_mem_usage}")

                # Save config and performances
                config_dict = train_config.to_dict()
                config_dict["data"] = data
                config_dict["training_time"] = training_time
                config_dict["training_mem_usage"] = training_mem_usage
                config_dict["method"] = "PBG"
                config_dict["subset"] = subset
                if core:
                    config_dict["id"] = (
                        f"PBG - {core_name} {subset} {num_epochs} epochs"
                    )
                else:
                    config_dict["id"] = (
                        f"PBG - {DATASETS_NAMES[data]} {subset} {num_epochs} epochs"
                    )

                checkpoint_path = SEPAL_DIR / "baselines/PBG/checkpoints_pbg.parquet"
                new_checkpoint_info = pd.DataFrame([config_dict], index=[0])
                if Path(checkpoint_path).is_file():
                    checkpoint_info = pd.read_parquet(checkpoint_path)
                    checkpoint_info = pd.concat(
                        [checkpoint_info, new_checkpoint_info]
                    ).reset_index(drop=True)
                else:
                    checkpoint_info = new_checkpoint_info
                checkpoint_info.to_parquet(checkpoint_path, index=False)
    return


if __name__ == "__main__":
    main()
