from pathlib import Path
import pandas as pd
import json

SOURCE_DIR = Path(__file__).parents[2] / "data/tabular/processed"
DATA_DIR = Path(__file__).parents[2] / "data/tabular/tables_with_linked_entities"
MATCHINGS_DIR = Path(__file__).parents[2] / "data/tabular/carte_entity_linking"


def link_tables():
    link_movies()
    link_us_accidents_counts()
    link_us_presidential()
    return


def load_data(table_name):
    df = pd.read_parquet(SOURCE_DIR / f"regression/{table_name}.parquet")
    with open(MATCHINGS_DIR / f"{table_name}_entity_to_id.json", "r") as f:
        matchings = json.load(f)
    return df, matchings


def link_movies():
    table_name = "carte_movies"
    print(f"Linking {table_name}...")
    df, matchings = load_data(table_name)
    df["movie"] = df["title"] + ", " + df["release_date"]
    df["wikidata_id"] = df["movie"].map(matchings)
    df = df[["movie", "revenue", "wikidata_id"]]
    df.drop_duplicates(subset=["wikidata_id"], inplace=True)
    df.to_parquet(DATA_DIR / "regression" / f"{table_name}.parquet", index=False)
    return


def link_us_accidents_counts():
    table_name = "carte_us_accidents_counts"
    print(f"Linking {table_name}...")
    df, matchings = load_data(table_name)
    df["US_city"] = df["City"] + ", " + df["Code"]
    df["wikidata_id"] = df["US_city"].map(matchings)
    df = df[["US_city", "Counts", "wikidata_id"]]
    df.drop_duplicates(subset=["wikidata_id"], inplace=True)
    df.to_parquet(DATA_DIR / "regression" / f"{table_name}.parquet", index=False)
    return


def link_us_presidential():
    table_name = "carte_us_presidential"
    print(f"Linking {table_name}...")
    df, matchings = load_data(table_name)
    df["county"] = df["county_name"] + " " + df["state"]
    df["wikidata_id"] = df["county"].map(matchings)
    df = df[["county", "party", "target", "wikidata_id"]]
    # Keep the `party` column
    df.drop_duplicates(subset=["wikidata_id", "party"], inplace=True)
    df.to_parquet(DATA_DIR / "regression" / f"{table_name}.parquet", index=False)
    return


if __name__ == "__main__":
    link_tables()
