"""This script discards the linked tables with too small coverage of entities in Wikidata5M."""

from pathlib import Path
import pandas as pd
import json
from pykeen.datasets import Wikidata5M
from sklearn.preprocessing import LabelEncoder

LINKED_DATA_DIR = Path(__file__).parents[2] / "data/tabular/tables_with_linked_entities"
PROCESSED_DIR = Path(__file__).parents[2] / "data/tabular/processed"


def get_entity_to_id_mapping():
    # Get entity names in Wikidata5M
    dataset = Wikidata5M()
    entity_to_id = dataset.training.entity_to_id
    return entity_to_id


def main():
    # Get the entity to id mapping
    entity_to_id = get_entity_to_id_mapping()

    # Get datasets configurations
    with open(PROCESSED_DIR / "data_configs.json", "r") as f:
        configs = json.load(f)

    # Iterate over the linked tables
    for task in LINKED_DATA_DIR.iterdir():
        if task.is_dir():
            for table_file in task.iterdir():
                print(f"Processing {table_file.stem}...")
                table = pd.read_parquet(table_file)
                table = table[
                    table["wikidata_id"].isin(entity_to_id.keys())
                ].reset_index(drop=True)
                if len(table) < 1050:
                    print(f"Discarding {table_file.stem} with {len(table)} rows.")
                    table_file.unlink()
                    continue

                # For classification tasks, ensure all classes have enough samples
                if task.stem == "classification":
                    target_col = configs[table_file.stem]["target"]
                    y = table[target_col]
                    class_counts = y.value_counts()
                    classes_to_remove = class_counts[class_counts < 105].index
                    table = table[
                        ~table[target_col].isin(classes_to_remove)
                    ].reset_index(drop=True)
                    # Re-encode classification targets
                    le = LabelEncoder()
                    table[target_col] = le.fit_transform(table[target_col])
                    if len(table) < 1050:
                        print(f"Discarding {table_file.stem} with {len(table)} rows.")
                        table_file.unlink()
                        continue

                table.to_parquet(table_file, index=False)
                print(f"Keeping {table_file.stem} with {len(table)} rows.")

    return


if __name__ == "__main__":
    main()
