import pandas as pd
import numpy as np
import attr

from torchbiggraph.config import ConfigFileLoader
from torchbiggraph.util import (
    set_logging_verbosity,
)
from convert_PBG_data import convert_data

from SEPAL.core_extraction import hybrid_extraction, extract_subgraph_from_node_list
from SEPAL.utils import create_train_graph
from SEPAL import SEPAL_DIR


core_node_proportions = {
    "yago4_lcc": 0.02,
    "yago4.5_lcc": 0.04,
    "yago4_with_full_ontology": 0.015,
    "yago4.5_with_full_ontology": 0.04,
    "full_freebase_lcc": 0.025,
}

core_edge_proportions = {
    "yago4_lcc": 0.025,
    "yago4.5_lcc": 0.025,
    "yago4_with_full_ontology": 0.025,
    "yago4.5_with_full_ontology": 0.015,
    "full_freebase_lcc": 0.015,
}


def create_core_datasets(
    datasets=[
        "yago4.5_lcc",
        "yago4.5_with_full_ontology",
        "yago4_lcc",
        "yago4_with_full_ontology",
        "full_freebase_lcc",
    ],
):
    for data in datasets:
        print(f"Processing {data}")
        # Load data
        graph = create_train_graph(data, True)

        ## Extract core subgraph
        node_prop = core_node_proportions[data]
        edge_prop = core_edge_proportions[data]
        node_list = hybrid_extraction(node_prop, edge_prop, graph)
        subgraph = extract_subgraph_from_node_list(graph, node_list)

        ## Save parquet file
        # Define the output directory
        core_name = f"core_{data}_hybrid_{node_prop}_{edge_prop}"
        output_dir = SEPAL_DIR / f"baselines/PBG/data/{core_name}"
        # If output directory does not exist, create it
        if not output_dir.exists():
            output_dir.mkdir(parents=True)
        # Build dataframe
        tf = subgraph.triples_factory
        df = pd.DataFrame(
            tf.mapped_triples, columns=["head", "relation", "tail"]
        ).astype(np.int32)
        id_to_entity = {v: k for k, v in tf.entity_to_id.items()}
        df["head"] = df["head"].map(id_to_entity)
        df["relation"] = df["relation"].map(
            {v: k for k, v in tf.relation_to_id.items()}
        )
        df["tail"] = df["tail"].map(id_to_entity)

        # Save dataframe in parquet format
        df.to_parquet(output_dir / f"training.parquet")

        ## Convert data
        filenames = [f"{core_name}/training.parquet"]

        loader = ConfigFileLoader()
        config = loader.load_config(
            SEPAL_DIR / f"baselines/PBG/configs/{data}_config.py"
        )
        config = attr.evolve(
            config,
            entity_path=f"baselines/PBG/data/{core_name}",
            edge_paths=[f"baselines/PBG/data/{core_name}/train_partitioned"],
            checkpoint_path=f"baselines/PBG/model/{core_name}/train",
        )

        set_logging_verbosity(config.verbose)

        convert_data(config, filenames, source="parquet")
    return


if __name__ == "__main__":
    create_core_datasets()
