from pathlib import Path
import pandas as pd
import numpy as np

from SEPAL.dataloader import DataLoader


if __name__ == "__main__":
    ## Load original data
    dir_path = Path(__file__).parent
    use_cols = ["name", "year founded", "industry", "country", "current employee estimate"]
    df = pd.read_csv(dir_path / "companies_sorted.csv", usecols=use_cols)


    ## Filter out companies with no employee
    df = df[df["current employee estimate"] != 0]


    ## Format company names to match Yago nomenclature
    df["col_to_embed"] = df["name"].str.title().str.replace(" ", "_")


    ## Deal with unmatched entities
    # Get Yago4 entities
    yago_dir = Path(__file__).absolute().parents[2] / "knowledge_graphs"
    yago4_dl = DataLoader(yago_dir / "yago4_with_full_ontology")
    entity_list = list(yago4_dl.entity_to_idx.keys())

    yago4_types = pd.read_parquet(yago_dir / "yago4/yagoTypes.parquet")


    mask = ~df["col_to_embed"].isin(entity_list)
    n0 = mask.sum()
    print("Initial number of unmatched entities: ", n0)


    # Try uppercase format for non matches (--> 3773 extra matches)
    df.loc[mask, 'col_to_embed'] = df["name"].str.upper().str.replace(" ", "_")
    mask = ~df["col_to_embed"].isin(entity_list)
    n1 = mask.sum()
    print(n0 - n1, "additional matches")

    df[mask]