from pathlib import Path
import numpy as np
import pandas as pd

from SEPAL.utils import create_graph

downstream_dir = Path(__file__).parent

downstream_directories = {
    "US elections": downstream_dir / "us_elections/target_log.parquet",
    "Housing prices": downstream_dir / "housing_prices/target_log.parquet",
    "US accidents": downstream_dir / "us_accidents/target_log.parquet",
    "Movie revenues": downstream_dir / "movie_revenues/target_log.parquet",
}



def compute_core_coverage(graph, table, p, id_to_entity, feature):
    entities = np.argsort(graph.degrees)[::-1][:int(p*graph.num_entities)+1]
    entity_names = [id_to_entity[entity] for entity in entities]
    coverage = sum(table[feature].isin(entity_names)) / len(table)
    return coverage



if __name__ == "__main__":
    new_results = pd.DataFrame(columns=["data","downstream_table","proportion","coverage"])

    # Loop over the datasets
    for data in ["mini_yago3_lcc", "yago3_lcc", "yago4_lcc", "yago4_with_full_ontology"]:
        print(f"--------- {data} ---------")
        # Load knowledge graph
        graph = create_graph(data)
        id_to_entity = {v:k for k,v in graph.triples_factory.entity_to_id.items()}
        feature = "yago4_col_to_embed" if "yago4" in data else "yago3_col_to_embed"
        for downstream in downstream_directories.keys():
            print(f"    {downstream}")
            # Load downstream table
            table = pd.read_parquet(downstream_directories[downstream])
            for p in [0]:
                core_coverage = compute_core_coverage(graph, table, p, id_to_entity, feature)
                sample_result = {
                    "data": data,
                    "downstream_table": downstream,
                    "proportion": p,
                    "coverage": core_coverage,
                }
                new_results = pd.concat([new_results, pd.DataFrame(sample_result, index=[0])]).reset_index(drop=True)

    # Load or create results file
    results_path = downstream_dir / "core_coverage.parquet"
    if Path(results_path).is_file():
        results = pd.read_parquet(results_path)
        results = pd.concat([results, new_results]).reset_index(drop=True)
    else:
        results = new_results
    
    # Save results
    results.to_parquet(results_path, index=False)