from pathlib import Path
import pandas as pd
import numpy as np

from SEPAL.dataloader import DataLoader


if __name__ == "__main__":
    # Load original data
    dir_path = Path(__file__).parent
    df = pd.read_csv(dir_path / "countypres_2000-2020.csv")

    # Select only the 2020 presidential election
    df = df[df["year"] == 2020]

    # Format county names to match Yago nomenclature
    df.loc[:, "county_name"] = df["county_name"].str.title()
    df.loc[:, "state"] = df["state"].str.title()
    df.loc[:, "col_to_embed"] = df["county_name"] + "_County,_" + df["state"]
    df.loc[:, "col_to_embed"] = df["col_to_embed"].str.replace(" ", "_")
    mask = df["col_to_embed"].str.contains("Louisiana")
    df.loc[mask, "col_to_embed"] = df.loc[mask, "col_to_embed"].str.replace(
        "County", "Parish"
    )
    df.loc[:, "col_to_embed"] = df["col_to_embed"].str.replace("_City_County", "")


    ## Deal with unmatched entities
    # Get Yago4 entities
    yago_dir = Path(__file__).absolute().parents[2] / "knowledge_graphs"
    yago4_dl = DataLoader(yago_dir / "yago4_with_full_ontology")
    entity_list = list(yago4_dl.entity_to_idx.keys())

    # Deal with special cases
    mask = ~df["col_to_embed"].isin(entity_list)
    n0 = mask.sum()
    print("Initial number of unmatched entities: ", n0)

    exact_matches = {
        "District_Of_Columbia_County,_District_Of_Columbia": "Washington,_D.C.",
        "Baltimore,_Maryland": "Baltimore_County,_Maryland",
        "Bristol_County,_Rhode_Island": "Bristol,_Rhode_Island",
        "Bronx_County,_New_York": "Bronx_County_Q855974",
        "Broomfield_County,_Colorado": "Broomfield,_Colorado",
        "Carson,_Nevada": "Carson_City,_Nevada",
        "Charles,_Virginia": "Charles_City_County,_Virginia",
        "Coos_County,_New_Hampshire": "Coös_County,_New_Hampshire",
        "Denver_County,_Colorado": "Denver",
        "District_Of_Columbia_County,_District_Of_Columbia": "Washington,_D.C.",
        "James,_Virginia": "James_City_County,_Virginia",
        "Kansas,_Missouri": "Kansas_City,_Missouri",
        "Kings_County,_New_York": "Kings_County_Q11980692",
        "Nantucket_County,_Massachusetts": "Nantucket",
        "New_York_County,_New_York": "New_York_County_Q500416",
        "Orleans_Parish,_Louisiana": "New_Orleans",
        "Queens_County,_New_York": "Queens_County_Q5142559",
        "Richmond_County,_New_York": "Richmond_County_Q11997784",
        "San_Francisco_County,_California": "San_Francisco_County_Q13188841",
        "St._Louis,_Missouri": "St._Louis",
        "St._Louis_County_County,_Missouri": "St._Louis_County,_Missouri",
        "Dupage_County,_Illinois": "DuPage_County,_Illinois",
        "Dona_Ana_County,_New_Mexico": "Doña_Ana_County,_New_Mexico",
        "La_Salle_Parish,_Louisiana": "LaSalle_Parish,_Louisiana",
        "St_Mary'S_County,_Maryland": "St._Mary's_County,_Maryland",
    }

    df.loc[mask, 'col_to_embed'] = df.loc[mask, 'col_to_embed'].map(exact_matches).fillna(df.loc[mask, 'col_to_embed'])
    mask = ~df["col_to_embed"].isin(entity_list)
    n1 = mask.sum()
    print(n0 - n1, "additional matches")

    approximate_matches = {
        "District_1_County,_Alaska": "Fairbanks,_Alaska",
        "District_2_County,_Alaska": "Fort_Wainwright",
        "District_3_County,_Alaska": "North_Pole,_Alaska",
        "District_4_County,_Alaska": "Goldstream,_Alaska",
        "District_5_County,_Alaska": "Fairbanks,_Alaska",
        "District_6_County,_Alaska": "Interior_Alaska",
        "District_7_County,_Alaska": "Wasilla,_Alaska",
        "District_8_County,_Alaska": "Knik-Fairview,_Alaska",
        "District_9_County,_Alaska": "Valdez,_Alaska",
        "District_10_County,_Alaska": "Houston,_Alaska",
        "District_11_County,_Alaska": "Palmer,_Alaska",
        "District_12_County,_Alaska": "Butte,_Alaska",
        "District_13_County,_Alaska": "Chugiak,_Anchorage",
        "District_14_County,_Alaska": "Eagle_River,_Anchorage",
        "District_15_County,_Alaska": "Joint_Base_Elmendorf–Richardson",
        "District_18_County,_Alaska": "West_Anchorage_High_School",
        "District_19_County,_Alaska": "Mountain_View,_Anchorage",
        "District_20_County,_Alaska": "Downtown_Anchorage",
        "District_21_County,_Alaska": "Sand_Lake_(Anchorage)",
        "District_23_County,_Alaska": "Campbell_Creek_(Alaska)",
        "District_24_County,_Alaska": "East_Anchorage_High_School",
        "District_28_County,_Alaska": "Girdwood,_Anchorage,_Alaska",
        "District_29_County,_Alaska": "Seward,_Alaska",
        "District_30_County,_Alaska": "Kenai,_Alaska",
        "District_31_County,_Alaska": "Homer,_Alaska",
        "District_32_County,_Alaska": "Kodiak,_Alaska",
        "District_33_County,_Alaska": "Juneau,_Alaska",
        "District_34_County,_Alaska": "Mendenhall_Valley,_Juneau",
        "District_35_County,_Alaska": "Sitka,_Alaska",
        "District_36_County,_Alaska": "Ketchikan,_Alaska",
        "District_37_County,_Alaska": "Bristol_Bay_Borough,_Alaska",
        "District_38_County,_Alaska": "Bethel,_Alaska",
        "District_39_County,_Alaska": "Nome,_Alaska",
        "District_40_County,_Alaska": "Kotzebue,_Alaska",
    }
    
    df.loc[mask, 'col_to_embed'] = df.loc[mask, 'col_to_embed'].map(approximate_matches).fillna(df.loc[mask, 'col_to_embed'])
    mask = ~df["col_to_embed"].isin(entity_list)
    n0, n1 = n1, mask.sum()
    print(n0 - n1, "additional matches")


    # String-level morphological variations (--> 110 extra matches)
    df.loc[mask, "col_to_embed"] = df.loc[mask, "col_to_embed"].str.replace("_Of_", "_of_").str.replace("_The_", "_the_").str.replace("_And_", "_and_").str.replace("_Du_", "_du_").str.replace("_Qui_", "_qui_").str.replace("'S", "'s").str.replace("Saint_", "St._").str.replace("De_Witt", "DeWitt").str.replace("De_Soto", "DeSoto")
    mask = ~df["col_to_embed"].isin(entity_list)
    n0, n1 = n1, mask.sum()
    print(n0 - n1, "additional matches")


    # Capitalize letter following "Mc", "De" or "La" for non-matches (--> 282 extra matches)
    def capitalize_after_subword(s, subword):
        x = s.find(subword) + len(subword)
        if x == -1 + len(subword):
            return s
        else:
            return s[:x] + s[x].upper() + s[x+1:]
    
    subwords = ["Mc", "De", "La"]
    for subword in subwords:
        df.loc[mask, "col_to_embed"] = df.loc[mask, "col_to_embed"].apply(capitalize_after_subword, args=(subword,))
    mask = ~df["col_to_embed"].isin(entity_list)
    n0, n1 = n1, mask.sum()
    print(n0 - n1, "additional matches")

    print("Final number of unmatched entities: ", n1)
    
    
    # Restore original format for non-matches
    df.loc[mask, "col_to_embed"] = (df["county_name"] + "_County,_" + df["state"]).str.replace(" ", "_")


    # Make target variable to predict (log of candidates votes)
    df["target"] = df["candidatevotes"]
    df["raw_entities"] = df["county_name"] + " " + df["state"]
    df = df[["raw_entities", "col_to_embed", "party", "target"]]
    df.dropna(inplace=True)
    df = df.groupby(["raw_entities", "col_to_embed", "party"], as_index=False).sum()
    df.loc[:, "target"] = np.log10(df["target"] + 1)

    # Account for specificities of Yago3 and Yago4
    df["yago4_col_to_embed"] = df["col_to_embed"]
    df["yago3_col_to_embed"] = df["col_to_embed"].str.replace("Bronx_County_Q855974", "The_Bronx")
    df.loc[:, "yago3_col_to_embed"] = df["yago3_col_to_embed"].str.replace("Kings_County_Q11980692", "Brooklyn")
    df.loc[:, "yago3_col_to_embed"] = df["yago3_col_to_embed"].str.replace("New_York_County_Q500416", "Manhattan")
    df.loc[:, "yago3_col_to_embed"] = df["yago3_col_to_embed"].str.replace("Queens_County_Q5142559", "Queens")
    df.loc[:, "yago3_col_to_embed"] = df["yago3_col_to_embed"].str.replace("Richmond_County_Q11997784", "Richmond_County,_New_York")
    df.loc[:, "yago3_col_to_embed"] = df["yago3_col_to_embed"].str.replace("San_Francisco_County_Q13188841", "San_Francisco")
    df.loc[:, "yago3_col_to_embed"] = df["yago3_col_to_embed"].str.replace("St._Helena_Parish,_Louisiana", "Saint_Helena_Parish,_Louisiana")
    df.loc[:, "yago3_col_to_embed"] = df["yago3_col_to_embed"].str.replace("LaSalle_Parish,_Louisiana", "La_Salle_Parish,_Louisiana")

    # Keep only relevant columns
    df = df[["raw_entities", "yago3_col_to_embed", "yago4_col_to_embed", "party", "target"]]

    # Save dataframe
    df.to_parquet(dir_path / "target_log.parquet", index=False)
