import pandas as pd
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from fairlearn.datasets import fetch_acs_income
from sklearn.preprocessing import FunctionTransformer
from sklearn.decomposition import PCA
from .data_util import save_processed_data
from .data_util import read_processed_data
# ----------------------------------------------------
#               Load COMPAS Credit dataset
# ----------------------------------------------------  

def load_compas():
    """
    Load COMPAS dataset with standard fairness preprocessing:
    - Protected attribute: race (African-American=1, Caucasian=0)
    - Label: recidivism within 2 years (1=yes, 0=no)
    """
    print("Loading COMPAS dataset...")
    
    # Load data from local file
    df = pd.read_csv('data/compas-scores-two-years.csv')

    # Standard filtering used in ProPublica analysis
    df = df[
        (df.days_b_screening_arrest <= 30) &
        (df.days_b_screening_arrest >= -30) &
        (df.is_recid != -1) &
        (df.c_charge_degree.isin(['F', 'M'])) &
        (df.race.isin(['African-American', 'Caucasian']))
    ].copy()
    
    # Create binary protected attribute
    s = (df['race'] == 'African-American').astype(int)
    
    # Create binary label
    y = df['two_year_recid'].astype(int)
    
    # Select features
    features_to_keep = [
        'age', 'c_charge_degree', 'c_charge_desc', 'priors_count',
        'juv_fel_count', 'juv_misd_count', 'juv_other_count'
    ]
    X = df[features_to_keep].copy()
    
    # Preprocess features
    numerical_features = X.select_dtypes(include=['int64', 'float64']).columns
    categorical_features = X.select_dtypes(include=['object']).columns
    
    # Create preprocessing pipeline
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', StandardScaler(), numerical_features),
            ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)
        ])
    
    # Apply preprocessing
    X_processed = preprocessor.fit_transform(X)
    
    alpha = s.mean()
    print(f"African-American ratio (alpha): {alpha:.4f}")

    # Save processed data
    save_processed_data('compas', X_processed, y, s, alpha)

    return X_processed, y, s, alpha

# ----------------------------------------------------
#               Load German Credit dataset
# ----------------------------------------------------  

def load_german_credit():
    print("Loading German Credit dataset...")
    column_names = [
        'age', 'sex', 'job', 'housing', 'saving_accounts',
        'checking_account', 'credit_amount', 'duration', 'purpose'
    ]
    dtype_dict = {
        'age': 'int64', 'sex': 'str', 'job': 'str', 'housing': 'str',
        'saving_accounts': 'str', 'checking_account': 'str',
        'credit_amount': 'int64', 'duration': 'int64', 'purpose': 'str'
    }

    # *** skip the original header line ***
    df = pd.read_csv(
        'data/german_credit.csv',
        names=column_names,
        dtype=dtype_dict,
        skiprows=1                 # ← fixes the 'Age' literal problem :contentReference[oaicite:8]{index=8}
    )

    s = (df['age'] > 25).astype(int)
    y = (df['credit_amount'] > df['credit_amount'].median()).astype(int)

    X = df.drop(['credit_amount'], axis=1)
    num_feats = ['age', 'duration']
    cat_feats = ['sex', 'job', 'housing', 'saving_accounts', 'checking_account', 'purpose']

    pre = ColumnTransformer([
        ('num', StandardScaler(), num_feats),
        ('cat', OneHotEncoder(handle_unknown='ignore'), cat_feats)
    ])
    X = pre.fit_transform(X)

    alpha = s.mean()
    print(f"Age > 25 ratio (alpha): {alpha:.4f}")

    # Save processed data
    save_processed_data('german_credit', X, y, s, alpha)

    return X, y, s, alpha

# ----------------------------------------------------
#              Load Bank Marketing dataset
# ----------------------------------------------------
    
def load_bank_marketing():
    """
    Load Bank Marketing dataset with standard fairness preprocessing:
    - Protected attribute: age ≥25 → privileged (1), <25 → unprivileged (0)
    - Label: y ('yes' → 1, 'no' → 0)
    - Numeric features: age, duration, campaign, pdays, previous,
                        emp_var_rate, cons_price_idx, cons_conf_idx,
                        euribor3m, nr_employed
    - Categorical features: job, marital, education, default,
                            housing, loan, contact, month,
                            day_of_week, poutcome
    """
    print("Loading Bank Marketing dataset...")
    
    # 1) Define column names & dtypes (matches 'bank-full.csv' from UCI)
    column_names = [
        'age', 'job', 'marital', 'education', 'default', 'housing', 'loan',
        'contact', 'month', 'day_of_week', 'duration', 'campaign', 'pdays',
        'previous', 'poutcome', 'emp_var_rate', 'cons_price_idx',
        'cons_conf_idx', 'euribor3m', 'nr_employed', 'y'
    ]
    dtype_dict = {
        'age': int, 'job': str, 'marital': str, 'education': str,
        'default': str, 'housing': str, 'loan': str, 'contact': str,
        'month': str, 'day_of_week': str, 'duration': int, 'campaign': int,
        'pdays': int, 'previous': int, 'poutcome': str,
        'emp_var_rate': float, 'cons_price_idx': float,
        'cons_conf_idx': float, 'euribor3m': float, 'nr_employed': float,
        'y': str
    }

    # 2) Load, skipping the original header row; UCI uses ';' as delimiter
    df = pd.read_csv(
        'data/bank.csv',
        names=column_names,
        dtype=dtype_dict,
        sep=';',
        skiprows=1
    )

    # 3) Extract protected attribute and label
    s = (df['age'] >= 25).astype(int)
    y = (df['y'] == 'yes').astype(int)

    # 4) Drop the raw label column from features
    X = df.drop(['y'], axis=1)

    # 5) Specify which columns are numeric vs. categorical
    numerical_features = [
        'age', 'duration', 'campaign', 'pdays', 'previous',
        'emp_var_rate', 'cons_price_idx', 'cons_conf_idx',
        'euribor3m', 'nr_employed'
    ]
    categorical_features = [
        'job', 'marital', 'education', 'default', 'housing', 'loan',
        'contact', 'month', 'day_of_week', 'poutcome'
    ]

    # 6) Build and apply preprocessing pipeline
    preprocessor = ColumnTransformer([
        ('num', StandardScaler(), numerical_features),
        ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)
    ])
    X = preprocessor.fit_transform(X)

    # 7) Compute and print alpha
    alpha = s.mean()
    print(f"Age ≥ 25 ratio (alpha): {alpha:.4f}")

    # Save processed data
    save_processed_data('bank_marketing', X, y, s, alpha)

    return X, y, s, alpha

# ----------------------------------------------------
#              Load Communities and Crime dataset
# ----------------------------------------------------
def load_communities_crime():
    """
    Load Communities & Crime dataset with standard fairness preprocessing:
    - Protected attribute: racepctblack > median → 1, else 0
    - Label: ViolentCrimesPerPop > median → 1, else 0
    - Drop ID columns: communityname, state, countyCode, communityCode, fold
    - Replace '?' → NaN, coerce all to numeric
    - Impute missing values (median) + StandardScale all features
    """
    print("Loading Communities & Crime dataset...")
    # 1) Read CSV (needs latin1 to handle the leading 'Ê')
    df = pd.read_csv('data/crimedata.csv', encoding='latin1')
    
    # 2) Rename the mis-decoded first column
    df = df.rename(columns={df.columns[0]: 'communityname'})
    
    # 3) Drop ID-like columns
    for col in ['communityname', 'state', 'countyCode', 'communityCode', 'fold']:
        if col in df.columns:
            df = df.drop(columns=col)
    
    # 4) Clean and coerce to numeric
    df = df.replace('?', np.nan)
    df = df.apply(pd.to_numeric, errors='coerce')
    
    # 5) Protected attribute and label
    s = (df['racepctblack'] > df['racepctblack'].median()).astype(int)
    y = (df['ViolentCrimesPerPop'] > df['ViolentCrimesPerPop'].median()).astype(int)
    
    # 6) Drop the raw target from features
    X = df.drop(columns=['ViolentCrimesPerPop'])
    
    # 7) Create preprocessing pipeline with ColumnTransformer
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', Pipeline([
                ('imputer', SimpleImputer(strategy='median')),
                ('scaler', StandardScaler())
            ]), X.columns)
        ])
    
    # 8) Apply preprocessing
    X = preprocessor.fit_transform(X)
    
    # 9) Print alpha and return
    alpha = s.mean()
    print(f"Protected-group ratio (α): {alpha:.4f}")

    # Save processed data
    save_processed_data('communities_crime', X, y, s, alpha)

    return X, y, s, alpha

# ----------------------------------------------------
#              Load Law School Admissions dataset
# ----------------------------------------------------

def load_communities_crime():
    """
    Load Communities & Crime dataset with standard fairness preprocessing:
    - Protected attribute: racepctblack > median → 1, else 0
    - Label: ViolentCrimesPerPop > median → 1, else 0
    - Drop ID columns: communityname, state, countyCode, communityCode, fold
    - Replace '?' → NaN, coerce all to numeric
    - Impute missing values (median) + StandardScale numeric features
    - OneHotEncode any categorical features (none by default)
    """
    print("Loading Communities & Crime dataset...")
    
    # 1) Read CSV (latin1 to handle the leading 'Ê')
    df = pd.read_csv('data/crimedata.csv', encoding='latin1')
    
    # 2) Fix the mis-decoded first column name
    df = df.rename(columns={df.columns[0]: 'communityname'})
    
    # 3) Drop ID-like columns
    to_drop = ['communityname', 'state', 'countyCode', 'communityCode', 'fold']
    df.drop(columns=[c for c in to_drop if c in df.columns], inplace=True)
    
    # 4) Clean and coerce everything to numeric
    df.replace('?', np.nan, inplace=True)
    df = df.apply(pd.to_numeric, errors='coerce')
    
    # 5) Protected attribute and label
    s = (df['racepctblack'] > df['racepctblack'].median()).astype(int)
    y = (df['ViolentCrimesPerPop'] > df['ViolentCrimesPerPop'].median()).astype(int)
    
    # 6) Drop the raw target from features
    X = df.drop(columns=['ViolentCrimesPerPop'])
    
    # 7) Identify numeric vs. any remaining categorical columns
    numeric_features     = X.select_dtypes(include=[np.number]).columns.tolist()
    categorical_features = X.select_dtypes(include=['object', 'category']).columns.tolist()
    
    # 8) Build a ColumnTransformer with both pipelines
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', Pipeline([
                ('imputer', SimpleImputer(strategy='median')),
                ('scaler', StandardScaler())
            ]), numeric_features),
            ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)
        ],
        remainder='drop'  # any other columns are dropped
    )
    
    # 9) Fit/transform
    X_processed = preprocessor.fit_transform(X)
    
    # 10) Print α and return
    alpha = s.mean()
    print(f"Protected-group ratio (α): {alpha:.4f}")

    # Save processed data
    save_processed_data('communities_crime', X_processed, y, s, alpha)

    return X_processed, y, s, alpha


# ----------------------------------------------------
#              Load Adult Census Income dataset
# ----------------------------------------------------

def load_adult_dataset():
    """
    Load Adult Census Income dataset with standard fairness preprocessing:
    - Protected attribute: gender (male=1, female=0)
    - Label: income >50K (1) vs ≤50K (0)
    """
    print("Loading Adult dataset...")
    
    # Load the Adult dataset
    data = fetch_openml(name='adult', version=2, as_frame=True)
    df = pd.concat([data.data, data.target.rename('income')], axis=1)
    
    # Convert categorical columns to string type
    categorical_features = []
    for col in df.columns:
        if df[col].dtype == 'object' or df[col].dtype.name == 'category':
            categorical_features.append(col)
            df[col] = df[col].astype(str)
    
    # Find the gender column
    gender_column = [col for col in df.columns if 'sex' in col.lower() or 'gender' in col.lower()][0]
    
    # Create binary protected attribute
    s = (df[gender_column] == 'Male').astype(int)
    
    # Create binary label
    y = (df['income'] == '>50K').astype(int)
    
    # Drop label and protected attribute from features
    X = df.drop([gender_column, 'income'], axis=1)
    
    # Preprocess features
    numerical_features = X.select_dtypes(include=['int64', 'float64']).columns
    categorical_features = X.select_dtypes(include=['object']).columns
    
    # Create preprocessing pipeline
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', StandardScaler(), numerical_features),
            ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)
        ])
    
    # Apply preprocessing
    X = preprocessor.fit_transform(X)
    
    alpha = s.mean()
    print(f"Male ratio (alpha): {alpha:.4f}")

    # Save processed data
    save_processed_data('adult', X, y, s, alpha)

    return X, y, s, alpha



# def load_adult_dataset():
#     """
#     Load Adult Census Income dataset and keep only continuous features:
#       - Protected attribute: gender (male=1, female=0)
#       - Label: income >50K (1) vs ≤50K (0)
#       - Features: only numeric columns, standardized.
#     """
#     print("Loading Adult dataset...")
    
#     # 1) Load raw data
#     data = fetch_openml(name='adult', version=2, as_frame=True)
#     df = pd.concat([data.data, data.target.rename('income')], axis=1)
    
#     # 2) Extract protected attribute and label
#     #    (find the 'sex' or 'gender' column, binarize)
#     gender_col = [c for c in df if 'sex' in c.lower() or 'gender' in c.lower()][0]
#     s = (df[gender_col] == 'Male').astype(int).to_numpy()
#     y = (df['income'] == '>50K').astype(int).to_numpy()
    
#     # 3) Drop protected + label from feature frame
#     X_df = df.drop([gender_col, 'income'], axis=1)
    
#     # 4) Keep only numeric columns
#     numeric_cols = X_df.select_dtypes(include=[np.number]).columns
#     X_num = X_df[numeric_cols].to_numpy()
    
#     # 5) Standardize
#     scaler = StandardScaler()
#     X_scaled = scaler.fit_transform(X_num)
    
#     # 6) Compute alpha
#     alpha = s.mean()
#     print(f"Male ratio (alpha): {alpha:.4f}")
    
#     # 7) Save and return
#     save_processed_data('adult', X_scaled, y, s, alpha)
#     return X_scaled, y, s, alpha


# ----------------------------------------------------
#              Load Student Performance dataset
# ----------------------------------------------------

def load_student_performance(subject='mat'):
    """
    Load UCI Student Performance (math or Portuguese) with fairness preprocessing:
    
    - Protected attribute: sex (Male=1 privileged, Female=0)
    - Label: pass (G3 >= 10 → 1, else 0)
    - Drop the period grades G1, G2 and raw G3 after binarization
    - Numeric features standardized; categorical features one-hot–encoded
    """
    print(f"Loading Student Performance (subject={subject})…")
    
    # 1) Read the CSV (semicolon-delimited) :contentReference[oaicite:1]{index=1}
    filename = f"data/student-{subject}.csv"
    df = pd.read_csv(filename, sep=';')
    
    # 2) Protected attribute (male privileged)
    s = (df['sex'] == 'M').astype(int)
    
    # 3) Binarize final grade G3 → pass/fail
    y = (df['G3'] >= 10).astype(int)
    
    # 4) Drop raw grades to avoid leakage
    df = df.drop(columns=['G1', 'G2', 'G3', 'sex'])
    
    # 5) Split numeric vs. categorical
    numeric_features     = df.select_dtypes(include=['int64', 'float64']).columns.tolist()
    categorical_features = df.select_dtypes(include=['object']).columns.tolist()
    
    # 6) Build preprocessing pipeline
    preprocessor = ColumnTransformer(transformers=[
        ('num', StandardScaler(), numeric_features),
        ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)
    ])
    
    # 7) Fit/transform
    X = preprocessor.fit_transform(df)
    
    # 8) Compute and report alpha
    alpha = s.mean()
    print(f"Privileged (male) ratio α: {alpha:.4f}")

    # Save processed data
    save_processed_data('student_performance', X, y, s, alpha)

    return X, y, s, alpha

# ----------------------------------------------------
#              Load Adult Census Income dataset
# ----------------------------------------------------

def load_heritage_health():
    """
    Load Heritage Health Prize (Year 1→Year 2) with fairness preprocessing:
    
    - Protected attribute: sex (M=1 privileged, F=0)
    - Label: high-utilizer if DaysInHospital_Y2 > median → 1, else 0
    - Features:
      • Numeric: claim_count, avg_charlson, total_los
      • Categorical: AgeAtFirstClaim, mode_PrimaryConditionGroup
    """
    print("Loading Heritage Health Prize dataset…")
    
    # Paths
    members_fp = 'data/heritage/Members_Y1.csv'
    claims_fp  = 'data/heritage/Claims_Y1.csv'
    day2_fp    = 'data/heritage/DayInHospital_Y2.csv'
    
    # Load raw tables
    members = pd.read_csv(members_fp)
    claims  = pd.read_csv(claims_fp)
    day2    = pd.read_csv(day2_fp)
    
    # Clean numeric columns
    claims['CharlsonIndex'] = pd.to_numeric(claims['CharlsonIndex'], errors='coerce')
    claims['LengthOfStay'] = claims['LengthOfStay'].astype(str).str.extract(r'(\d+)').astype(float)
    
    # Aggregate claims by MemberID
    agg = (
        claims
        .groupby('MemberID')
        .agg(
            claim_count       = ('MemberID', 'size'),
            avg_charlson      = ('CharlsonIndex', 'mean'),
            total_los        = ('LengthOfStay', 'sum'),
            mode_pc          = ('PrimaryConditionGroup', 
                              lambda x: x.mode().iloc[0] if not x.mode().empty else np.nan)
        )
        .reset_index()
    )
    
    # Merge all together
    df = (
        members
        .merge(agg,  on='MemberID', how='left')
        .merge(day2, on='MemberID', how='left')
    )
    
    # Protected attribute (privileged = male)
    s = (df['sex'] == 'M').astype(int)
    
    # Label: high-utilizer in Year 2
    df['DaysInHospital_Y2'] = df['DaysInHospital_Y2'].fillna(0)
    median_days = df['DaysInHospital_Y2'].median()
    y = (df['DaysInHospital_Y2'] > median_days).astype(int)
    
    # Features for X - REMOVE DaysInHospital_Y2 from features
    X = df[[
        'AgeAtFirstClaim',
        'mode_pc',
        'claim_count',
        'avg_charlson',
        'total_los'
    ]].copy()
    
    # Identify feature types
    numeric_features     = ['claim_count', 'avg_charlson', 'total_los']
    categorical_features = ['AgeAtFirstClaim', 'mode_pc']
    
    # Preprocessing pipeline
    preprocessor = ColumnTransformer(transformers=[
        ('num', Pipeline([
            ('imputer', SimpleImputer(strategy='median')),
            ('scaler',  StandardScaler())
        ]), numeric_features),
        ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)
    ], remainder='drop')
    
    # Fit & transform
    X_processed = preprocessor.fit_transform(X)
    
    # Print α and return
    alpha = s.mean()
    print(f"Privileged (male) ratio α: {alpha:.4f}")

    # Save processed data
    save_processed_data('heritage_health', X_processed, y, s, alpha)

    return X_processed, y, s, alpha

# ----------------------------------------------------
#              Load MEPS dataset    
# ----------------------------------------------------

def load_meps():
    """
    Load MEPS 2016 Full-Year PUF (h192.csv) with standard fairness preprocessing:
    
    - Protected attribute: SEX (1=male privileged → s=1; 2=female → s=0)
    - Label: high-expenditure if TOTEXP16 > median → y=1, else 0
    - Drops IDs and raw TOTEXP16 & SEX columns
    - Numeric features → StandardScaler
    - Categorical features → OneHotEncoder(handle_unknown='ignore')
    """
    print("Loading MEPS 2016 dataset...")
    
    regressors = [
    # demographics
    'AGE16X',      # age at end of 2016
    'RACEV2X',     # race/ethnicity (detailed)
    'REGION16',    # Census region
    'MARRY16X',    # marital status

    # socio‐economic
    'EDUCYR',      # years of education
    'POVCAT16',    # poverty category (poor/near‐poor/low/middle/high)
    'POVLEV16',    # ratio of income to poverty line
    'INSCOV16',    # has any health insurance
    'EMPST31', 'EMPST42', 'EMPST53',     # employment status
    'FTSTU16X',    # full‐time student indicator
    'WAGEP16X',    # wage and salary income
    'BUSNP16X',    # business/self‐employment income
    'UNEMP16X',    # unemployment compensation
    'WCMPP16X',    # workers’ comp income
    'INTRP16X',    # interest income
    'DIVDP16X',    # dividend income
    'IRASP16X',    # other investment income (e.g. IRA distributions)
    'VETSP16X',    # Veterans’ benefits
    'ALIMP16X',    # alimony, etc.
    'CHLDP16X',    # child support

    # health limitations & status
    'ANYLMT16',   # any activity limitation
    # (plus you could pull in round‐specific perceived health: RTHLTH42, RTHLTH53, etc.)

    # chronic conditions (ever diagnosed)
    'HIBPDX',     # high blood pressure
    'CHDDX',      # coronary heart disease
    'CANCERDX',   # cancer
    'DIABDX',     # diabetes
    'ARTHDX',     # arthritis
    'ASTHDX',     # asthma
    'ADHDADDX',   # ADHD
    'EMPHDX',     # emphysema
    'HIEUIDX',    # hearing impairment
    # …and any other “DX” flags you find clinically relevant

    # utilization counts (event‐level summaries)
    'RXTOT16',    # number of prescription events
    'PRVEV16',    # number of events paid by private insurance
    'MCREV16',    # # of Medicare‐covered events
    'MCDEV16',    # # of Medicaid‐covered events
    'TRIEV16',    # # of TriCare events
    'OPAEV16',    # # of ambulatory (office) events
    'OPBEV16',    # # of hospital‐based outpatient events
    'SEX',
    'TOTEXP16'
]

    # 1) Load the CSV
    
    df = pd.read_csv("data/h192.csv",usecols = regressors)
    
    # 2) Protected attribute (male=1 privileged)
    s = (df['SEX'] == 1).astype(int)
    
    # 3) Label: high-expenditure
    #    Coerce to numeric just in case, fill missing with 0
    df['TOTEXP16'] = pd.to_numeric(df['TOTEXP16'], errors='coerce').fillna(0)
    y = (df['TOTEXP16'] > df['TOTEXP16'].median()).astype(int)
    
    # 4) Drop identifiers + raw label + protected column
    # id_cols = ['DUID', 'PID', 'DUPERSID', 'PANEL']
    drop_cols = ['TOTEXP16', 'SEX']
    X = df.drop(columns=[c for c in drop_cols if c in df.columns])
    # X = df[regressors]
    # print(X.columns)
    # Remove highly correlated features
    numeric_features = X.select_dtypes(include=['int64', 'float64']).columns.tolist()
    # to_remove = []
    # for col in numeric_features:
    #     corr = np.abs(np.corrcoef(X[col], y)[0, 1])
    #     if corr > 0.9:
    #         print(f"Removing highly correlated feature: {col} (corr={corr:.3f})")
    #         to_remove.append(col)
    # X = X.drop(columns=to_remove)


    
    # numeric_features = [col for col in numeric_features if col not in to_remove]
    categorical_features = X.select_dtypes(include=['object', 'category']).columns.tolist()

    # 6) Build preprocessing pipeline
    preprocessor = ColumnTransformer(transformers=[
        ('num', StandardScaler(), numeric_features),
        ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)
    ], remainder='drop')
    
    # 7) Fit/transform
    X_processed = preprocessor.fit_transform(X)
    
    # pca = PCA(n_components=100, random_state=0)
    # X_processed = pca.fit_transform(X_larege)

    # 8) Report α and return
    alpha = s.mean()
    print(f"Male (privileged) ratio (α): {alpha:.4f}")

    # Save processed data
    save_processed_data('meps', X_processed, y, s, alpha)

    return X_processed, y, s, alpha


# ----------------------------------------------------
#              Load American Community Survey dataset
# ----------------------------------------------------

def load_acs_income(threshold=None, states=None):
    """
    Load ACS Income dataset with standard fairness preprocessing:
    - Protected attribute: SEX (male=1 privileged, female=0)
    - Label: annual personal income PINCP > threshold → 1, else 0
      (threshold defaults to median of PINCP)
    - Numeric features: standardized
    - Categorical features: one-hot encoded
    """
    print("Loading ACS Income dataset…")

    # 1) Fetch from Fairlearn (as_frame=True, return_X_y gives (X_df, PINCP_series))
    X_df, pincome = fetch_acs_income(
        as_frame=True,
        return_X_y=True,
        states=states
    )  # :contentReference[oaicite:0]{index=0}

    # 2) Protected attribute: SEX
    #    Some versions encode as 1=Male,2=Female; others as 'Male'/'Female'.
    if pd.api.types.is_numeric_dtype(X_df['SEX']):
        s = (X_df['SEX'] == 1).astype(int)
    else:
        s = (X_df['SEX'] == 'Male').astype(int)

    # 3) Binary label based on PINCP threshold (default = median)
    thr = threshold if threshold is not None else pincome.median()
    y = (pincome > thr).astype(int)

    # 4) Drop only the protected attribute from features
    X = X_df.drop(columns=['SEX'])
    
    # 5) Split numeric vs. categorical
    numeric_features     = X.select_dtypes(include=[np.number]).columns.tolist()
    categorical_features = X.select_dtypes(include=['object', 'category']).columns.tolist()

    # 6) Build preprocessing pipeline
    preprocessor = ColumnTransformer([
        ('num', Pipeline([
            ('scale', StandardScaler())
        ]), numeric_features),
        ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)
    ], remainder='drop')

    # 7) Fit & transform
    X_processed = preprocessor.fit_transform(X)

    # 8) Compute and report α
    alpha = s.mean()
    print(f"Male (privileged) ratio α: {alpha:.4f}")

    # Save processed data
    save_processed_data('acs_income', X_processed, y, s, alpha)

    return X_processed, y, s, alpha

# ----------------------------------------------------
#              Load FICO HELOC dataset
# ----------------------------------------------------

def load_heloc():
    """
    Load FICO HELOC dataset with standard fairness preprocessing:
    - Label: RiskPerformance ('Good'→0, 'Bad'→1)
    - Protected attribute: age (above median → 1, below → 0)
    - Features: all numeric features binarized (above median → 1, else 0)
    - Special missing-value codes (-9, -8, -7) → NaN
    - Numeric features: impute (median) + binarize
    - Categorical features: impute (mode) + OneHotEncoder
    """
    print("Loading HELOC dataset...")
    
    # 1) Read the CSV
    df = pd.read_csv('data/heloc_dataset_v1.csv')
    
    # 2) Create binary label
    y = (df['RiskPerformance'] == 'Bad').astype(int)
    
    # 3) Create binary sensitive attribute based on age
    # First clean any missing values in age
    age_median = df['ExternalRiskEstimate'].median()
    s = (df['ExternalRiskEstimate'] > age_median).astype(int)
    
    # 4) Drop label and sensitive attribute from features
    X = df.drop(columns=['RiskPerformance', 'ExternalRiskEstimate'])
    
    # 5) Clean special missing codes
    X = X.replace({-9: np.nan, -8: np.nan, -7: np.nan})
    
    # 6) Identify numeric vs. categorical columns
    numeric_features     = X.select_dtypes(include=[np.number]).columns.tolist()
    categorical_features = X.select_dtypes(include=['object', 'category']).columns.tolist()
    
    # 7) Build preprocessing pipelines
    numeric_transformer = Pipeline([
        ('imputer', SimpleImputer(strategy='median')),
        ('binarizer', FunctionTransformer(lambda x: (x > np.median(x, axis=0)).astype(int)))
    ])
    categorical_transformer = Pipeline([
        ('imputer', SimpleImputer(strategy='most_frequent')),
        ('onehot',  OneHotEncoder(handle_unknown='ignore'))
    ])
    
    preprocessor = ColumnTransformer(transformers=[
        ('num', numeric_transformer, numeric_features),
        ('cat', categorical_transformer, categorical_features)
    ], remainder='drop')
    
    # 8) Fit & transform
    X_processed = preprocessor.fit_transform(X)
    
    # 9) Report α and return
    alpha = s.mean()
    print(f"Privileged-group ratio (alpha): {alpha:.4f}")
    
    # Save processed data
    save_processed_data('heloc', X_processed, y, s, alpha)

    return X_processed, y, s, alpha


# ----------------------------------------------------
#              Load CelebA dataset
# ----------------------------------------------------

def load_celeba():
    """
    Load CelebA dataset with standard fairness preprocessing:
    - Protected attribute: Male (1) vs Female (0)
    - Label: Smiling (1) vs Not Smiling (0)
    """
    print("Loading CelebA dataset...")
    
    # Load data from local file
    df = pd.read_csv('data/list_attr_celeba.csv')
    
    # Create binary protected attribute (Male=1, Female=0)
    s = (df['Male'] == 1).astype(int)
    
    # Create binary label (Smiling=1, Not Smiling=0)
    y = (df['Smiling'] == 1).astype(int)
    
    # Select features (drop protected attribute and label)
    X = df.drop(['Male', 'Smiling'], axis=1)
    
    # Preprocess features
    numerical_features = X.select_dtypes(include=['int64', 'float64']).columns
    categorical_features = X.select_dtypes(include=['object']).columns
    
    # Create preprocessing pipeline
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', StandardScaler(), numerical_features),
            ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)
        ])
    
    # Apply preprocessing
    X_processed = preprocessor.fit_transform(X)
    
    alpha = s.mean()
    print(f"Male ratio (alpha): {alpha:.4f}")

    # Save processed data
    save_processed_data('celeba', X_processed, y, s, alpha)

    return X_processed, y, s, alpha


# ----------------------------------------------------
#              Load Law School dataset
# ----------------------------------------------------

def load_law():
    """
    Load Law School dataset with standard fairness preprocessing:
      - Protected attribute: race (white=1, non-white=0)
      - Label: pass_bar exam (1=passed, 0=failed)
    """
    print("Loading Law dataset...")

    # 1) Load CSV
    df = pd.read_csv('data/law_dataset.csv')

    # 2) Identify protected attribute (race)
    #    racetxt codes white=1, non-white=0
    race_col = [c for c in df.columns if c.lower().startswith('race')][0]
    s = df[race_col].astype(int)

    # 3) Create binary label
    y = df['pass_bar'].astype(int)

    # 4) Drop protected + label from features
    X = df.drop([race_col, 'pass_bar'], axis=1)

    # 5) Split into numeric vs. (possible) categorical
    num_feats = X.select_dtypes(include=['int64', 'float64']).columns.tolist()
    cat_feats = X.select_dtypes(include=['object', 'category']).columns.tolist()

    # 6) Build preprocessing pipeline
    preprocessor = ColumnTransformer([
        ('num', StandardScaler(), num_feats),
        ('cat', OneHotEncoder(handle_unknown='ignore'), cat_feats)
    ])

    # 7) Fit & transform
    X_proc = preprocessor.fit_transform(X)

    # 8) Compute privileged‐group base rate
    alpha = float(s.mean())
    print(f"White (privileged) ratio (α): {alpha:.4f}")

    # 9) Persist for downstream use (optional helper)
    save_processed_data('law', X_proc, y, s, alpha)

    return X_proc, y, s, alpha

# ----------------------------------------------------
#              Load dataset by name
# ----------------------------------------------------

def load_dataset(dataset_name, force_reload=False):
    """
    Load a dataset by name, either from cache or by calling the appropriate loader.
    
    Parameters:
    -----------
    dataset_name : str
        Name of the dataset to load
    force_reload : bool, default=False
        If True, force reloading the dataset even if cached data exists
        
    Returns:
    --------
    X : numpy.ndarray
        Features
    y : numpy.ndarray
        Labels
    s : numpy.ndarray
        Protected attributes
    alpha : float
        Protected group ratio
    """
    # Check if processed data exists and force_reload is False
    if not force_reload:
        try:
            return read_processed_data(dataset_name)
        except FileNotFoundError:
            pass
    
    # Map dataset names to loader functions
    loaders = {
        'compas': load_compas,
        'german_credit': load_german_credit,
        'bank_marketing': load_bank_marketing,
        'communities_crime': load_communities_crime,
        'adult': load_adult_dataset,
        'student': load_student_performance,
        'heritage_health': load_heritage_health,
        'meps': load_meps,
        'acs_income': load_acs_income,
        'heloc': load_heloc,
        'celeba': load_celeba,
        'law': load_law
    }
    
    if dataset_name not in loaders:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    return loaders[dataset_name]()
