import pandas as pd
import numpy as np
from sklearn.preprocessing import OrdinalEncoder
import json


def load_dataset(name: str, config: dict):

    sep = ","
    if name == "bank":
        sep = ";"

    df = pd.read_csv(config["path"], sep=sep)

    if name == "beijing":
        df = df.drop(columns=["No"])

    elif name == "covertype":
        df[config["target"]] = np.where(df[config["target"]] == 2, 1, 0)

    elif name == "default":
        df = df.drop(columns=["ID"])

    elif name == "diabetes":
        df = df.drop(
            columns=[
                "encounter_id",
                "patient_nbr",
                "weight",
                "medical_specialty",
                "payer_code",
                "acetohexamide",
                "troglitazone",
                "examide",
                "citoglipton",
                "glipizide-metformin",
                "glimepiride-pioglitazone",
                "metformin-rosiglitazone",
                "metformin-pioglitazone",
            ]
        )
        df[config["target"]] = np.where(df[config["target"]] == "NO", 0, 1)
        df["age"] = df["age"].str[-4:-1]
        df["age"] = df["age"].str.replace("-", "")
        df["age"] = df["age"].astype(int)
    elif name == "lending":
        df = df.drop(
            columns=[
                "annual_income_joint",
                "verification_income_joint",
                "debt_to_income_joint",
                "months_since_last_delinq",
                "months_since_90d_late",
                "current_accounts_delinq",
                "months_since_last_credit_inquiry",
                "num_accounts_120d_past_due",
                "num_accounts_30d_past_due",
                "issue_month",
                "paid_late_fees",
            ]
        )

        df["sub_grade"] = OrdinalEncoder().fit_transform(df[["sub_grade"]]).squeeze()
        df["grade"] = OrdinalEncoder().fit_transform(df[["grade"]]).squeeze()
    elif name == "nmes":
        df = df.drop(columns=["Unnamed: 0"])
        cats = [["poor", "average", "excellent"]]
        df["health"] = (
            OrdinalEncoder(categories=cats).fit_transform(df[["health"]]).squeeze()
        )
    elif name == "news":
        df = df.drop(columns=["url"])

    # cast data types to ensure consistent behavior
    df[config["cat_features"]] = df[config["cat_features"]].astype(str)
    numerical_features = [x for x in df.columns if x not in config["cat_features"]]
    df[numerical_features] = df[numerical_features].astype(float)

    # align categories which are basically the same
    for col in config["cat_features"]:
        df[col] = (
            df[col]
            .str.lower()
            .str.strip()
            .str.replace(r"[ \-,_'\.\:;\|\\/]", "", regex=True)
        )
        df.loc[
            df[col].isin(
                ["nan", "null", "undefined", "unknown", "", "?", "NA", "missing"]
            ),
            col,
        ] = "empty"

    return df


def get_datasets_info(datasets: list[str], config_path: str):

    df = pd.DataFrame(
        columns=[
            "dataset",
            "task",
            "n",
            "n-categoricals",
            "n-numericals",
            "min-n-categories",
            "max-n-categories",
        ]
    )
    for ds in datasets:
        print(f"Processing {ds}...")
        data_config = json.load(open(config_path))[ds]
        X = load_dataset(ds, data_config)
        num_features = [x for x in X.columns if x not in data_config["cat_features"]]
        max_n_categories = 0
        min_n_categories = float("inf")
        for col in data_config["cat_features"]:
            max_n_categories = max(max_n_categories, len(X[col].unique()))
            min_n_categories = min(min_n_categories, len(X[col].unique()))
        data = {
            "dataset": ds,
            "task": (
                "classification"
                if data_config["target"] in data_config["cat_features"]
                else "regression"
            ),
            "n": X.dropna(subset=num_features).shape[0],
            "n-categoricals": len(data_config["cat_features"]),
            "n-numericals": X.shape[1] - len(data_config["cat_features"]),
            "max-n-categories": max_n_categories,
            "min-n-categories": min_n_categories,
        }
        df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)

    df = df.astype(str)
    df = df.replace("inf", "0")
    df = df.replace("_", " ", regex=True)

    df = df.sort_values(by="dataset")
    return df
