from pathlib import Path
import numpy as np
import pandas as pd
import json
from pykeen.datasets import Wikidata5M
from pykeen.triples import TriplesFactory
from sklearn.impute import SimpleImputer

MODELS_DIR = Path(__file__).parents[1] / "models/linked"
LINKED_DATA_DIR = Path(__file__).parents[1] / "data/tabular/tables_with_linked_entities"
PROCESSED_DIR = Path(__file__).parents[1] / "data/tabular/processed"


data_paths = {
    "wikidata500k": Path(__file__).parents[1] / "data/kg/processed/wikidata5m_deg9",
    "wikidata1m": Path(__file__).parents[1] / "data/kg/processed/wikidata5m_deg6",
    "wikidata2m": Path(__file__).parents[1] / "data/kg/processed/wikidata5m_deg4",
    "wikidata3m": Path(__file__).parents[1] / "data/kg/processed/wikidata5m_deg3",
}


def load_entity_embeddings(path: Path):
    # Load the entity embeddings
    entity_embeddings = np.load(path)

    if (
        entity_embeddings.dtype == np.complex64
        or entity_embeddings.dtype == np.complex128
    ):
        # concatenate the real and imaginary parts for complex embeddings
        entity_embeddings = np.hstack([entity_embeddings.real, entity_embeddings.imag])
    return entity_embeddings


def get_entity_to_id_mapping(data):
    # Get entity names
    if data == "wikidata5m":
        triples_factory = Wikidata5M().training
    else:
        triples_factory = TriplesFactory.from_path_binary(path=data_paths[data])
    entity_to_id = triples_factory.entity_to_id
    return entity_to_id


def main(data="wikidata5m"):
    # Get the entity to id mapping
    entity_to_id = get_entity_to_id_mapping(data)

    # Get datasets configurations
    with open(PROCESSED_DIR / "data_configs.json", "r") as f:
        configs = json.load(f)

    # Iterate over the folders of MODELS_DIR
    for folder in MODELS_DIR.iterdir():
        if folder.is_dir() and data in folder.name:
            embedding_path = folder / "entity_embeddings.npy"
            tables_output_path = folder / "embedded_tables"
            if embedding_path.exists() and not tables_output_path.is_dir():
                for task in LINKED_DATA_DIR.iterdir():
                    output_task_path = tables_output_path / task.name
                    output_task_path.mkdir(parents=True, exist_ok=True)
                    for table_file in task.iterdir():
                        print(f"Processing {table_file.stem}...")
                        table = pd.read_parquet(table_file)
                        entity_embeddings = load_entity_embeddings(embedding_path)
                        nan_array = np.empty(
                            shape=entity_embeddings[0].shape, dtype="float32"
                        )
                        nan_array[:] = np.nan
                        X_emb = np.vstack(
                            table["wikidata_id"]
                            .map(entity_to_id)
                            .apply(
                                lambda i: (
                                    entity_embeddings[int(i)] if i == i else nan_array
                                )
                            )
                            .to_numpy()
                        )

                        if not table["wikidata_id"].isin(entity_to_id.keys()).all():
                            # Impute NaN rows (unknown entity embeddings) with column means
                            imputer = SimpleImputer(
                                missing_values=np.nan,
                                strategy="mean",
                                keep_empty_features=True,
                            )
                            X_emb = imputer.fit_transform(X_emb)

                        y = table[configs[table_file.stem]["target"]].to_numpy()
                        # Add party column (only for carte_us_presidential)
                        if table_file.stem == "carte_us_presidential":
                            enc_col = pd.get_dummies(table["party"], prefix="party")
                            X_emb = np.hstack([X_emb, enc_col.to_numpy()])
                        # Save X_emb and y
                        dataset_output_path = output_task_path / table_file.stem
                        dataset_output_path.mkdir(parents=True, exist_ok=True)
                        np.save(dataset_output_path / "X_emb.npy", X_emb)
                        np.save(dataset_output_path / "y.npy", y)
                        # Print shapes
                        print(f"  X shape: {X_emb.shape}, y shape: {y.shape}")
    return


if __name__ == "__main__":
    for data in [
        "wikidata500k",
        "wikidata1m",
        "wikidata2m",
        "wikidata3m",
        "wikidata5m",
    ]:
        main(data)
