import pandas as pd


def _load_meps_dataset(name):
    df = pd.read_csv(f"datasets/{name}_reg.csv")
    response_name = "UTILIZATION_reg"
    col_names = [
        "AGE",
        "PCS42",
        "MCS42",
        "K6SUM42",
        "REGION=1",
        "REGION=2",
        "REGION=3",
        "REGION=4",
        "SEX=1",
        "SEX=2",
        "MARRY=1",
        "MARRY=2",
        "MARRY=3",
        "MARRY=4",
        "MARRY=5",
        "MARRY=6",
        "MARRY=7",
        "MARRY=8",
        "MARRY=9",
        "MARRY=10",
        "FTSTU=-1",
        "FTSTU=1",
        "FTSTU=2",
        "FTSTU=3",
        "ACTDTY=1",
        "ACTDTY=2",
        "ACTDTY=3",
        "ACTDTY=4",
        "HONRDC=1",
        "HONRDC=2",
        "HONRDC=3",
        "HONRDC=4",
        "RTHLTH=-1",
        "RTHLTH=1",
        "RTHLTH=2",
        "RTHLTH=3",
        "RTHLTH=4",
        "RTHLTH=5",
        "MNHLTH=-1",
        "MNHLTH=1",
        "MNHLTH=2",
        "MNHLTH=3",
        "MNHLTH=4",
        "MNHLTH=5",
        "HIBPDX=-1",
        "HIBPDX=1",
        "HIBPDX=2",
        "CHDDX=-1",
        "CHDDX=1",
        "CHDDX=2",
        "ANGIDX=-1",
        "ANGIDX=1",
        "ANGIDX=2",
        "MIDX=-1",
        "MIDX=1",
        "MIDX=2",
        "OHRTDX=-1",
        "OHRTDX=1",
        "OHRTDX=2",
        "STRKDX=-1",
        "STRKDX=1",
        "STRKDX=2",
        "EMPHDX=-1",
        "EMPHDX=1",
        "EMPHDX=2",
        "CHBRON=-1",
        "CHBRON=1",
        "CHBRON=2",
        "CHOLDX=-1",
        "CHOLDX=1",
        "CHOLDX=2",
        "CANCERDX=-1",
        "CANCERDX=1",
        "CANCERDX=2",
        "DIABDX=-1",
        "DIABDX=1",
        "DIABDX=2",
        "JTPAIN=-1",
        "JTPAIN=1",
        "JTPAIN=2",
        "ARTHDX=-1",
        "ARTHDX=1",
        "ARTHDX=2",
        "ARTHTYPE=-1",
        "ARTHTYPE=1",
        "ARTHTYPE=2",
        "ARTHTYPE=3",
        "ASTHDX=1",
        "ASTHDX=2",
        "ADHDADDX=-1",
        "ADHDADDX=1",
        "ADHDADDX=2",
        "PREGNT=-1",
        "PREGNT=1",
        "PREGNT=2",
        "WLKLIM=-1",
        "WLKLIM=1",
        "WLKLIM=2",
        "ACTLIM=-1",
        "ACTLIM=1",
        "ACTLIM=2",
        "SOCLIM=-1",
        "SOCLIM=1",
        "SOCLIM=2",
        "COGLIM=-1",
        "COGLIM=1",
        "COGLIM=2",
        "DFHEAR42=-1",
        "DFHEAR42=1",
        "DFHEAR42=2",
        "DFSEE42=-1",
        "DFSEE42=1",
        "DFSEE42=2",
        "ADSMOK42=-1",
        "ADSMOK42=1",
        "ADSMOK42=2",
        "PHQ242=-1",
        "PHQ242=0",
        "PHQ242=1",
        "PHQ242=2",
        "PHQ242=3",
        "PHQ242=4",
        "PHQ242=5",
        "PHQ242=6",
        "EMPST=-1",
        "EMPST=1",
        "EMPST=2",
        "EMPST=3",
        "EMPST=4",
        "POVCAT=1",
        "POVCAT=2",
        "POVCAT=3",
        "POVCAT=4",
        "POVCAT=5",
        "INSCOV=1",
        "INSCOV=2",
        "INSCOV=3",
        "RACE",
    ]

    if name in ["meps_19", "meps_20"]:
        col_names = col_names + ["PERWT15F"]
    else:
        col_names = col_names + ["PERWT16F"]

    y = df[response_name].values
    X = df[col_names].values

    return X, y


def _load_cal_housing_dataset():
    df = pd.read_csv("datasets/cal_housing.csv")
    df.dropna(inplace=True)
    y = df["median_house_value"]
    y = (y - y.mean()) / y.std()
    y = y.values
    X = df[
        [
            "longitude",
            "latitude",
            "housing_median_age",
            "total_rooms",
            "total_bedrooms",
            "population",
            "households",
            "median_income",
        ]
    ]
    X = (X - X.mean()) / X.std()
    X = X.values
    return X, y


def load_dataset(dataset_name):
    if dataset_name in ["meps_19", "meps_20", "meps_21"]:
        return _load_meps_dataset(dataset_name)
    if dataset_name == "cal_housing":
        return _load_cal_housing_dataset()
    raise NotImplementedError(f"Dataset {dataset_name} not implemented")
