import pandas as pd

from SEPAL import SEPAL_DIR
from SEPAL.dataloader import DataLoader
from SEPAL.utils import get_downstream_wikidb_files


def compute_coverages(
    datasets=[
        "mini_yago3_lcc",
        "yago3_lcc",
        "yago4_lcc",
        "yago4.5_lcc",
        "yago4_with_full_ontology",
        "yago4.5_with_full_ontology",
        "full_freebase_lcc",
        "wikikg90mv2_lcc",
    ]
):
    coverages = {
        "dataset": [],
        "table": [],
        "coverage": [],
    }
    for dataset in datasets:
        print(f"Computing coverages for {dataset}...")
        # Load knowledge graph entities
        dl = DataLoader(SEPAL_DIR / "datasets/knowledge_graphs" / dataset)
        entity_list = list(dl.entity_to_idx.keys())
        # Get the list of target files
        target_files = get_downstream_wikidb_files()
        for target_file in target_files:
            # Load the downstream table
            table = pd.read_parquet(target_file)
            # Get the list of entities in the downstream table
            if "yago4.5" in dataset:
                col = "yago4.5_col_to_embed"
            elif "yago4" in dataset:
                col = "yago4_col_to_embed"
            elif "freebase" in dataset:
                col = "freebase_col_to_embed"
            elif "wikikg" in dataset:
                col = "wikidata_col_to_embed"
            else:
                col = "yago3_col_to_embed"
            # Compute the coverage
            coverage = table[col].isin(entity_list).sum() / len(table)
            table_name = str(target_file).split("/")[-1].removesuffix(".parquet")
            print(f"    {table_name}: {coverage:.4f}")
            coverages["dataset"].append(dataset)
            coverages["table"].append(table_name)
            coverages["coverage"].append(coverage)
    # Save the results
    coverages_df = pd.DataFrame(coverages)
    coverages_df.to_parquet(
        SEPAL_DIR / "datasets/wikidb/entity_coverage.parquet", index=False
    )
    return


if __name__ == "__main__":
    compute_coverages()
