"""
Functions to evaluate the coverage of a given KG wrt a given target table, in terms of entities.
"""

from pathlib import Path
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from SEPAL.dataloader import DataLoader

data_dir = Path(__file__).absolute().parents[1]


triples_dir_names = {
    # "knowledge_graphs/mini_yago3_lcc": "Mini YAGO3",
    # "knowledge_graphs/yago3_lcc": "YAGO3",
    # "knowledge_graphs/yago4_lcc": "YAGO4",
    # "knowledge_graphs/yago4.5_lcc": "YAGO4.5",
    # "knowledge_graphs/yago4_with_full_ontology": "YAGO4 + taxonomy",
    # "knowledge_graphs/yago4.5_with_full_ontology": "YAGO4.5 + taxonomy",
    # "knowledge_graphs/full_freebase_lcc": "Freebase",
    "knowledge_graphs/wikikg90mv2_lcc": "WikiKG90Mv2",
}

dataset_names = {
    "evaluation/us_elections/target_log.parquet": "US elections",
    "evaluation/housing_prices/target_log.parquet": "Housing prices",
    "evaluation/us_accidents/target_log.parquet": "US accidents",
    "evaluation/movie_revenues/target_log.parquet": "Movie revenues",
    # "evaluation/company_employees/target_log.parquet": "Company employees",
}


def compute_coverage(target_df, entity_list, KG):
    if KG in ["Mini YAGO3", "YAGO3"]:
        feature = "yago3_col_to_embed"
    elif KG in ["YAGO4.5", "YAGO4.5 + taxonomy"]:
        feature = "yago4.5_col_to_embed"
    elif KG == "Freebase":
        feature = "freebase_col_to_embed"
    elif KG == "WikiKG90Mv2":
        feature = "wikidata_col_to_embed"
    else:
        feature = "yago4_col_to_embed"
    return sum(target_df[feature].isin(entity_list)) * 100 / len(target_df)


def compute_coverage_for_all_targets():
    entity_dict = {
        triples_dir: list(DataLoader(data_dir / triples_dir).entity_to_idx.keys())
        for triples_dir in triples_dir_names.keys()
    }
    target_data_dict = {
        target_file: pd.read_parquet(data_dir / target_file)
        for target_file in dataset_names.keys()
    }

    res = np.array(
        [
            [
                triples_dir_names[triples_dir],
                dataset_names[target_file],
                compute_coverage(
                    target_data_dict[target_file],
                    entity_dict[triples_dir],
                    triples_dir_names[triples_dir],
                ),
            ]
            for triples_dir in triples_dir_names.keys()
            for target_file in dataset_names.keys()
        ]
    )
    df = pd.DataFrame(res, columns=["source", "target", "coverage"])
    df["coverage"] = pd.to_numeric(df["coverage"])

    # Save results
    results_path = data_dir / "evaluation/entity_coverage.parquet"
    if results_path.exists():
        # Append to existing file
        old_df = pd.read_parquet(results_path)
        df = pd.concat([old_df, df], ignore_index=True)
        df.to_parquet(results_path, index=False)
    else:
        df.to_parquet(results_path, index=False)

    return


def plot_coverage():
    """
    Plots the percentage of rows of the target_file whose corresponding entity is in the KG stored in triples_dir.
    """
    df = pd.read_parquet(data_dir / "evaluation/entity_coverage.parquet")
    fig, ax = plt.subplots(figsize=(6, 5))
    sns.barplot(
        ax=ax, data=df, x="coverage", y="target", hue="source", edgecolor="white"
    )
    for p in ax.patches:
        width = p.get_width()
        plt.text(
            7 + width,
            p.get_y() + 0.55 * p.get_height(),
            "{:.1f}%".format(width),
            ha="center",
            va="center",
        )
    ax.set(xlabel="Entity coverage (%)", ylabel="")
    plt.legend(bbox_to_anchor=(1, 0.6), loc="upper left")
    plt.xlim([0, 115])
    plt.savefig(data_dir / "evaluation/coverage.pdf", bbox_inches="tight")
    return


if __name__ == "__main__":
    compute_coverage_for_all_targets()
