import copy
import torch
import random
import json
import numpy as np
import pandas as pd
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
import datasets as hf_datasets
from transformers import AutoTokenizer, AutoModel, DataCollatorWithPadding
from sklearn.preprocessing import MinMaxScaler
import pickle
import requests
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import transforms
from pathlib import Path
from PIL import Image
from pathlib import Path
import zipfile
import multiprocessing
import threading
import folktables
import urllib.request
import sys
import re
# from fedlearn.utils.sampling import *
# from data.celeba.metadata_to_json import celeba_generate


class FairDataset(Dataset):
    def __init__(self, X, Y, A, n_class=None, n_group=None, weight=None):
        self.X = X
        self.Y = Y
        self.A = A
        self.weight = weight
        self.data_info = self.get_data_info( Y, A )
        # self.data_num = self.X.shape[0]

    def __getitem__(self, index):

        X = torch.tensor(self.X[index], dtype=torch.float)
        Y = torch.tensor(self.Y[index], dtype=torch.long)
        A = torch.tensor(self.A[index], dtype=torch.long)

        if self.weight is not None:
            assert len(self.weight) == self.X.shape[0]
            weight = self.weight[index]
            return (X, Y, A, weight)
        return (X, Y, A)

    def __len__(self):
        if hasattr(self.X, "shape"):
            return self.X.shape[0]
        return None
    
    def dim(self):
        if hasattr(self.X, "shape"):
            return self.X.shape[1:]
        return None
    
    def get_data_info(self, Y, A):
        unique_A = np.unique(A)
        unique_Y = np.unique(Y)
        self.n_group = len(unique_A)
        self.n_class = len(unique_Y)

        df_YA = pd.DataFrame({'Y': Y.ravel(), 'A': A.ravel()})

        # print(Y.shape, A.shape)

        info = pd.DataFrame(index=unique_A.astype(int), columns=unique_Y.astype(int))
        for a in unique_A:
            for y in unique_Y:
                info.at[a, y] = df_YA[(df_YA['A'] == a) & (df_YA['Y'] == y)].shape[0]

        return info

class FairTextDataset(Dataset):
    def __init__(self, raw_dataset, tokenizer, max_length=256, weight=None,
                 text_col="bio", label_col="title", group_col="gender"):
        self.ds = raw_dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.weight = weight

        self.text_col = text_col
        self.label_col = label_col
        self.group_col = group_col

        self.Y = np.asarray(self.ds[self.label_col], dtype=np.int64)
        self.A = np.asarray(self.ds[self.group_col], dtype=np.int64)
        self.data_info = self.get_data_info(self.Y, self.A)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        text = self.ds[self.text_col][idx]
        enc = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding=False,
            return_tensors=None
        )

        x = {
            "input_ids": torch.tensor(enc["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(enc["attention_mask"], dtype=torch.long),
        }
        if "token_type_ids" in enc:
            x["token_type_ids"] = torch.tensor(enc["token_type_ids"], dtype=torch.long)

        y = torch.tensor(self.Y[idx], dtype=torch.long)
        a = torch.tensor(self.A[idx], dtype=torch.long)

        if self.weight is not None:
            w = torch.tensor(self.weight[idx], dtype=torch.float)
            return (x, y, a, w)
        return (x, y, a)

    def get_data_info(self, Y, A):
        unique_A = np.unique(A)
        unique_Y = np.unique(Y)
        self.n_group = len(unique_A)
        self.n_class = len(unique_Y)

        df_YA = pd.DataFrame({'Y': Y.ravel(), 'A': A.ravel()})
        info = pd.DataFrame(index=unique_A.astype(int), columns=unique_Y.astype(int))
        for a in unique_A:
            for y in unique_Y:
                info.at[a, y] = df_YA[(df_YA['A'] == a) & (df_YA['Y'] == y)].shape[0]
        return info



def mkdir(*args: str) -> tuple:
    for path in args:
        if not os.path.exists(path):
            os.makedirs(path, exist_ok=True)
    return args

def get_data(options):
    """ 
    Returns train and test datasets:
    """

    data_name = options['data'].lower()
    data_settings = options['data_setting']
    options.update(data_settings)

    if data_name == 'adult':
        raw_path = "data/adult/raw_data/"
        train_path = raw_path + "train.csv"
        test_path = raw_path + "test.csv"
        processed_path = "data/adult/processed_data/"
        processed_data_path = processed_path + "processed_adult.npy"
        if os.path.exists(train_path) and os.path.exists(test_path) and not data_settings.get('generate',False):
            options['data_exist'] = 1
            data = np.load(processed_data_path, allow_pickle=True).item()
            X, Y, A = data['X'], data['Y'], data['A']
        else:
            mkdir(processed_path)
            if os.path.exists(train_path) and os.path.exists(test_path):
                pass
            else:
                adult_process()
            df = pd.concat([pd.read_csv(train_path),pd.read_csv(test_path)], axis=0)
            X, Y = df.drop('salary', axis=1).to_numpy().astype(np.float32),  df['salary'].to_numpy().astype(np.float32)
            colname = df.drop('salary', axis=1).columns.tolist()
            if data_settings['sensitive_attr'] == 'gender-race':
                sensitive_attr = 'sex-race'
                X, A, Y = adult_get_sensitive_feature(X, colname, sensitive_attr, Y)
            elif data_settings['sensitive_attr'] == 'gender':
                sensitive_attr = 'sex'
                X, A = adult_get_sensitive_feature(X, colname, sensitive_attr)
            elif data_settings['sensitive_attr'] == 'race':
                sensitive_attr = 'race'
                X, A = adult_get_sensitive_feature(X, colname, sensitive_attr)
            np.save(processed_data_path,{"X" : X, "Y" : Y, "A" : A})
        train_data, val_data, test_data, n_group, n_class = split(X, Y, A)

    elif data_name == 'celeba':
        processed_path = f"data/celeba/processed_data/sensitive_attr_{data_settings['sensitive_attr']}_multiclass/"
        processed_data_path = processed_path + "processed_celeba.npy"
        options['data_save_path'] = processed_path
        if os.path.exists(processed_data_path) and not data_settings.get('generate',False):
            data = np.load(processed_data_path, allow_pickle=True).item()
            X, Y, A = data['X'], data['Y'], data['A']
            options['data_exist'] = 1
        else:
            mkdir(processed_path)
            if data_settings['sensitive_attr'] == 'sex':
                sensitive_attr = 'Male'
            elif data_settings['sensitive_attr'] == 'age':
                sensitive_attr = 'Young'
            elif data_settings['sensitive_attr'] == 'sex-race':
                sensitive_attr = ['Male', 'Pale_Skin']
            elif data_settings['sensitive_attr'] == 'race':
                sensitive_attr = 'Pale_Skin'

            X, Y, A = celeba_data_processing(sensitive_attr, multiclass=True)
            np.save(processed_data_path,{"X" : X, "Y" : Y, "A" : A})

            # if options['fairness_type'] == 'groupwise':
            #     X, Y, A = celeba_data_processing(sensitive_attr, sample_num, multiclass=False)
            # elif options['fairness_type'] == 'subgroup':
            #     X, Y, A = celeba_data_processing(sensitive_attr, sample_num, multiclass=True)

        train_data, val_data, test_data, n_group, n_class = split(X, Y, A)

        print('celeba data processed.')
        
    
    elif data_name == 'enem':
        print("use enem.")
        X, A, Y = enem_process(data_settings['sensitive_attr'], n_classes=5)
        train_data, val_data, test_data, n_group, n_class = split(X, Y, A)

    elif data_name == 'diabetes':
        print("use diabetes.")
        data_dir = "data/diabetes/raw_data/"
        X, A, Y = load_diabetes(data_dir=data_dir, group_col=data_settings['sensitive_attr'])
        train_data, val_data, test_data, n_group, n_class = split(X, Y, A)

    elif data_name == 'bios':
        print("use biasbios.")
        data_dir = "data/biasbios/raw_data"
        out_dir = "data/biasbios/emb_full_bert_meanpool"
        emb_path  = os.path.join(out_dir, "X_emb.dat")
        y_path    = os.path.join(out_dir, "Y.npy")
        a_path    = os.path.join(out_dir, "A.npy")
        meta_path = os.path.join(out_dir, "meta.json")
        if os.path.exists(emb_path) and os.path.exists(y_path) and os.path.exists(a_path) and not data_settings.get('generate',False):
            meta = json.load(open(meta_path, "r", encoding="utf-8"))
            N, H = int(meta["N"]), int(meta["H"])
            dt = np.float16 if meta.get("dtype", "float16") == "float16" else np.float32

            X_emb = np.memmap(emb_path, mode="r", dtype=dt, shape=(N, H))
            Y = np.load(y_path)
            A = np.load(a_path)
        else:
            X_emb, Y, A, meta = build_full_embeddings_meanpool(
                                data_dir=data_dir,
                                out_dir=out_dir,
                                model_name="bert-base-uncased",
                                max_length=256,
                                batch_size=64,
                                dtype="float16",
                                device="cuda",
                                num_workers=4,
                                add_sensitive_attribute=True,
                                )

        train_data, val_data, test_data, n_group, n_class = split(X_emb, Y, A)

    elif data_name == 'acs':

        X, Y, label_names, A, sensitive_attr_names = acsincome_process(2, data_settings['sensitive_attr'])
        train_data, val_data, test_data, n_group, n_class = split(X, Y, A)
        print(f'y class:{Y.shape}, a class:{A.shape}')
    
    elif data_name == 'drug':

        X, Y, A = drug_process(data_settings['sensitive_attr'])
        train_data, val_data, test_data, n_group, n_class = split(X, Y, A)
    
    else:
        raise ValueError('Not support dataset {}!'.format(data_name))
    
    print_statistics_info(train_data, val_data, test_data)
    return train_data, val_data, test_data, n_group, n_class

def enem_process(sensitive_attr, n_classes=5):
    enem_path = 'data/enem/raw_data/microdados_enem_2020/DADOS/' #changed to 2020
    enem_file = 'MICRODADOS_ENEM_2020.csv' #changed for 2020
    label = ['NU_NOTA_CH'] ## Labels could be: NU_NOTA_CH=human science, NU_NOTA_LC=languages&codes, NU_NOTA_MT=math, NU_NOTA_CN=natural science
    group_attribute = ['TP_COR_RACA','TP_SEXO']
    question_vars = ['Q00'+str(x) if x<10 else 'Q0' + str(x) for x in range(1,25)] #changed for 2020
    domestic_vars = ['SG_UF_PROVA', 'TP_FAIXA_ETARIA'] #changed for 2020
    all_vars = label+group_attribute+question_vars+domestic_vars

    n_sample = 1400000

    if n_classes == 2:
        n_groups = 2
        multigroup = False
    elif n_classes == 5:
        n_groups = 5
        multigroup = True
    
    fname = 'data/enem/processed_data/enem-'+str(n_classes) + '-g' + str(n_groups) + '-' + str(n_sample) + '-20.pkl'

    if os.path.isfile(fname):
        df = pd.read_pickle(fname)
    else:
        # df = load_enem(enem_path, enem_file, all_vars, label, n_sample)
        df = load_enem(enem_path, enem_file, all_vars, label, n_sample, n_classes, multigroup=multigroup)
        df.to_pickle(fname)

    df['gradebin'] = df['gradebin'].astype(int)

    # start_time = time.localtime()
    # start_time_str = strftime("%Y-%m-%d-%H.%M.%S", start_time)
    # filename = 'enem-'+ str(df.shape[0]) +'-mp-' + start_time_str
    # f = open(filename+'-log.txt','w')

    # repetition = 10
    # use_protected = True
    # use_sample_weight = True
    # tune_threshold = False
    # # tolerance = [0.000, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0]
    # tolerance = [0.000, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0]

    label_name = 'gradebin'
    if sensitive_attr == 'race':
        protected_attrs = 'racebin'
    else:
        raise ValueError('Not support sensitive attribute {}!'.format(sensitive_attr))

    X = df.drop(columns=[label_name,protected_attrs]).to_numpy()
    A = df[protected_attrs].to_numpy()
    Y = df[label_name].to_numpy()

    return X,A,Y


def get_idx_wo_protected(feature_names, protected_attrs):
    idx_wo_protected = set(range(len(feature_names)))
    protected_attr_idx = [feature_names.index(x) for x in protected_attrs]
    idx_wo_protected = list(idx_wo_protected - set(protected_attr_idx))
    return idx_wo_protected

def get_idx_w_protected(feature_names):
    return list(set(range(len(feature_names))))

def get_idx_protected(feature_names, protected_attrs):
    protected_attr_idx = [feature_names.index(x) for x in protected_attrs]
    idx_protected = list(set(protected_attr_idx))
    return idx_protected


# from pathlib import Path
# import zipfile
# import requests
# import numpy as np
# import pandas as pd

# from sklearn.compose import ColumnTransformer
# from sklearn.pipeline import Pipeline
# from sklearn.impute import SimpleImputer
# from sklearn.preprocessing import OneHotEncoder, StandardScaler


def load_diabetes(
    data_dir: str | Path,
    group_col: str = "race",  # A: "race" / "gender" / "age"
    drop_missing_group: bool = True,
    drop_cols: tuple = ("encounter_id", "patient_nbr", "weight", "payer_code", "examide", "citoglipton"),
    row_missing_thres: float = 0.97,  # 只保留“行缺失率 < 0.97”的样本
    med_top_n: int = 10,
    diag_top_n: int = 50,
):
    """
    Return:
      X: np.ndarray float32 (one-hot + impute + normalize)
      A: np.ndarray int64   (0..G-1)
      Y: np.ndarray int64   (0/1/2) for readmitted {NO, >30, <30}
      info: dict (a_map, feature_dim, etc.)
    """


    # ===================== 0) ensure dataset exists =====================
    UCI_ZIP_URL = (
        "https://archive.ics.uci.edu/static/public/296/"
        "diabetes%2B130-us%2Bhospitals%2Bfor%2Byears%2B1999-2008.zip"
    )
    CSV_NAME = "diabetic_data.csv"

    data_dir = Path(data_dir)
    data_dir.mkdir(parents=True, exist_ok=True)
    csv_path = data_dir / CSV_NAME

    downloaded = False
    if not csv_path.exists():
        downloaded = True
        zip_path = data_dir / "diabetes130.zip"
        with requests.get(UCI_ZIP_URL, stream=True, timeout=60) as r:
            r.raise_for_status()
            with open(zip_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=1024 * 1024):
                    if chunk:
                        f.write(chunk)
        with zipfile.ZipFile(zip_path, "r") as zf:
            zf.extractall(data_dir)

        if not csv_path.exists():
            candidates = list(data_dir.rglob(CSV_NAME))
            if not candidates:
                raise FileNotFoundError(f"Downloaded zip but cannot find {CSV_NAME} under {data_dir}")
            candidates[0].replace(csv_path)

        try:
            zip_path.unlink(missing_ok=True)
        except Exception:
            pass

    df = (pd.read_csv(csv_path)
        .rename(columns={"diag_1": "primary_diagnosis"}))
    # Create Outcome variables
    # df.loc[:, "readmit_30_days"] = (df["readmitted"] == "<30")
    # df.loc[:, "readmit_binary"] = (df["readmitted"] != "NO")
    # Replace missing values and re-code categories
    df.loc[:,"age"] = df.age.replace({"?": ""})
    df.loc[:,"payer_code"] = df["payer_code"].replace({"?", "Unknown"})
    df.loc[:,"medical_specialty"] = df["medical_specialty"].replace({"?": "Missing"})
    df.loc[:, "race"] = df["race"].replace({"?": "Unknown"})

    df.loc[:, "admission_source_id"] = df["admission_source_id"].replace({1: "Referral", 2: "Referral", 3: "Referral", 7: "Emergency"})
    df.loc[:, "age"] = df["age"].replace( ["[0-10)", "[10-20)", "[20-30)"], "30 years or younger")
    df.loc[:, "age"] = df["age"].replace(["[30-40)", "[40-50)", "[50-60)"], "30-60 years")
    df.loc[:, "age"] = df["age"].replace(["[60-70)", "[70-80)", "[80-90)", "[90-100)"], "Over 60 years")

    # Clean various medical codes
    df.loc[:, "discharge_disposition_id"] = (df.discharge_disposition_id
                                            .apply(lambda x: "Discharged to Home" if x==1 else "Other"))

    df.loc[:, "admission_source_id"] = df["admission_source_id"].apply(lambda x: x if x in ["Emergency", "Referral"] else "Other")
    # Re-code Medical Specialties and Primary Diagnosis
    specialties = [
        "Missing",
        "InternalMedicine",
        "Emergency/Trauma",
        "Family/GeneralPractice",
        "Cardiology",
        "Surgery"
    ]
    df.loc[:, "medical_specialty"] = df["medical_specialty"].apply(lambda x: x if x in specialties else "Other")
    #
    df.loc[:, "primary_diagnosis"] = df["primary_diagnosis"].replace(
        regex={
            "[7][1-3][0-9]": "Musculoskeltal Issues",
            "250.*": "Diabetes",
            "[4][6-9][0-9]|[5][0-1][0-9]|786": "Respitory Issues",
            "[5][8-9][0-9]|[6][0-2][0-9]|788": "Genitourinary Issues"
        }
    )
    diagnoses = ["Respitory Issues", "Diabetes", "Genitourinary Issues", "Musculoskeltal Issues"]
    df.loc[:, "primary_diagnosis"] = df["primary_diagnosis"].apply(lambda x: x if x in diagnoses else "Other")

    #Binarize and bin features
    df.loc[:, "medicare"] = (df.payer_code == "MC")
    df.loc[:, "medicaid"] = (df.payer_code == "MD")

    df.loc[:, "had_emergency"] = (df["number_emergency"] > 0)
    df.loc[:, "had_inpatient_days"] = (df["number_inpatient"] > 0)
    df.loc[:, "had_outpatient_days"] = (df["number_outpatient"] > 0)

    # Save DataFrame
    # cols_to_keep = ["race","gender","age","discharge_disposition_id","admission_source_id","time_in_hospital",
    #     "medical_specialty","num_lab_procedures","num_procedures","num_medications","primary_diagnosis","number_diagnoses","max_glu_serum","A1Cresult","insulin","change",
    #     "diabetesMed", "medicare", "medicaid", "had_emergency", "had_inpatient_days", "had_outpatient_days", "readmitted","readmit_binary","readmit_30_days"]
    cols_to_keep = ["race","gender","age","discharge_disposition_id","admission_source_id","time_in_hospital",
    "medical_specialty","num_lab_procedures","num_procedures","num_medications","primary_diagnosis","number_diagnoses","max_glu_serum","A1Cresult","insulin","change",
    "diabetesMed", "medicare", "medicaid", "had_emergency", "had_inpatient_days", "had_outpatient_days", "readmitted"]

    final_df = df.loc[:, cols_to_keep]

    from sklearn.compose import ColumnTransformer
    from sklearn.pipeline import Pipeline
    from sklearn.impute import SimpleImputer
    from sklearn.preprocessing import OneHotEncoder, StandardScaler

    # =========================
    # 0) 选择 A / Y 列
    # =========================
    A_col = group_col         # 可改: "gender" / "age"
    Y_col = "readmitted"    # 三分类标签

    # =========================
    # 1) 去除缺失行（只要求 Y 和 A 不缺失；X 的缺失用 imputer 处理）
    # =========================
    df2 = final_df.copy()
    df2['race'] = df2['race'].replace({"Unknown": np.nan, "unknow": np.nan, "UNKNOWN": np.nan})
    df2['gender'] = df2['gender'].replace({"Unknown": np.nan, "unknow": np.nan, "UNKNOWN": np.nan,"Unknown/Invalid": np.nan})

    # 如果你想把字符串 "Missing"/"Unknown" 也当缺失，可以开启下面两行（按需）
    df2[A_col] = df2[A_col].replace({"Unknown": np.nan, "Missing": np.nan})
    df2["medical_specialty"] = df2["medical_specialty"].replace({"Missing": np.nan})

    df2 = df2.dropna(subset=[A_col, Y_col]).copy()

    # 如果你坚持“X 里只要有缺失就删该行”，用这行（不推荐，会丢很多数据）
    # df2 = df2.dropna(axis=0)
    row_missing_ratio = df2.isna().mean(axis=1)
    keep_row = row_missing_ratio < row_missing_thres
    df2 = df2.loc[keep_row].copy()

    # =========================
    # 2) 提取 Y（保持三分类；可选做数值编码）
    # =========================
    # readmitted 原始通常是 {"NO", ">30", "<30"}
    y_map = {"NO": 0, ">30": 1, "<30": 2}
    Y = df2[Y_col].map(y_map)
    df2 = df2[Y.notna()].copy()
    Y = Y.loc[df2.index].astype(np.int64).to_numpy()

    # =========================
    # 3) 提取 A 并做数值编码 + 映射表
    # =========================
    A_raw = df2[A_col].astype(str)
    # A_raw = A_raw.replace({"Asian": "Asian and Other", "Other": "Asian and Other"})
    A_cat = pd.Categorical(A_raw)
    A = A_cat.codes.astype(np.int64)           # 0..G-1
    a_map = list(A_cat.categories)             # id -> 原始字符串

    # =========================
    # 4) 提取 X（去掉 A、Y）
    # =========================
    X_df = df2.drop(columns=[A_col, Y_col]).copy()

    # =========================
    # 5) One-hot + 缺失填充 + 标准化（只标准化数值列）
    # =========================
    cat_cols = X_df.select_dtypes(include=["object", "category", "bool"]).columns.tolist()
    num_cols = [c for c in X_df.columns if c not in cat_cols]

    # bool 转成 0/1（可选，但很常见）
    for c in cat_cols:
        if X_df[c].dtype == bool:
            X_df[c] = X_df[c].astype(int)
            # 这类列就会进入 num_cols 更合理
    # 重新划分一次（因为 bool 可能变成 int）
    cat_cols = X_df.select_dtypes(include=["object", "category"]).columns.tolist()
    num_cols = [c for c in X_df.columns if c not in cat_cols]

    # 兼容不同 sklearn 版本
    try:
        ohe = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    except TypeError:
        ohe = OneHotEncoder(handle_unknown="ignore", sparse=False)

    numeric_pipe = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="median")),
        ("scaler", StandardScaler()),
    ])

    categorical_pipe = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="most_frequent")),  # 或 constant("missing")
        ("onehot", ohe),
    ])

    preprocess = ColumnTransformer(
        transformers=[
            ("num", numeric_pipe, num_cols),
            ("cat", categorical_pipe, cat_cols),
        ],
        remainder="drop",
        verbose_feature_names_out=False
    )

    X = preprocess.fit_transform(X_df).astype(np.float32)

    return X, A, Y





    # ===================== 1) load raw =====================
    df = pd.read_csv(csv_path)
    df = df.replace("?", np.nan)

    # ===================== 2) keep latest encounter per patient =====================
    if {"patient_nbr", "encounter_id"}.issubset(df.columns):
        df = df.sort_values(["patient_nbr", "encounter_id"])
        df = df.drop_duplicates(subset="patient_nbr", keep="last")

    # ===================== 3) build Y (3-class) =====================
    if "readmitted" not in df.columns:
        raise ValueError("Column 'readmitted' not found in df.")
    y_map = {"NO": 0, ">30": 1, "<30": 2}
    Y = df["readmitted"].map(y_map)  
    df = df[Y.notna()].copy()
    Y = Y.loc[df.index].astype(np.int64)

    # ===================== 4) build A (and drop rows with invalid A) =====================
    if group_col not in df.columns:
        raise ValueError(f"group_col='{group_col}' not in columns.")
    A_raw = df[group_col].astype(str)

    if drop_missing_group:
        bad = {"?", "Unknown/Invalid", "nan", "None"}
        A_raw = A_raw.where(~A_raw.isin(bad), np.nan)

    keep = A_raw.notna()
    df = df[keep].copy()
    Y = Y.loc[df.index].copy()
    A_raw = A_raw.loc[df.index].copy()

    # 数值编码 A
    A_cat = pd.Categorical(A_raw)
    a_map = list(A_cat.categories)          # id -> string
    A = A_cat.codes.astype(np.int64)        # 0..G-1

    # ===================== 5) reduce high-cardinality categories =====================
    def reduce_categories_top_n(_df, col, top_n, other_label="Other"):
        vc = _df[col].value_counts(dropna=True)
        top_values = set(vc.head(top_n).index)
        _df[col] = _df[col].where(_df[col].isin(top_values) | _df[col].isna(), other_label)

    if "medical_specialty" in df.columns:
        reduce_categories_top_n(df, "medical_specialty", top_n=med_top_n, other_label="Other")

    for col in [c for c in ["diag_1", "diag_2", "diag_3"] if c in df.columns]:
        reduce_categories_top_n(df, col, top_n=diag_top_n, other_label="Other")

    # ===================== 6) feature engineering =====================
    def simplify_med_status(x):
        if pd.isna(x):
            return x
        if x in ["No", "Steady"]:
            return x
        if x in ["Up", "Down"]:
            return "Change"
        return x

    med_status_cols = [
        "metformin", "repaglinide", "nateglinide", "chlorpropamide",
        "glimepiride", "acetohexamide", "glipizide", "glyburide",
        "tolbutamide", "pioglitazone", "rosiglitazone", "acarbose",
        "miglitol", "troglitazone", "tolazamide", "glyburide-metformin",
        "glipizide-metformin", "glimepiride-pioglitazone",
        "metformin-rosiglitazone", "metformin-pioglitazone",
        "insulin",
    ]
    med_cols_present = [c for c in med_status_cols if c in df.columns]
    for col in med_cols_present:
        df[col] = df[col].apply(simplify_med_status)

    # aggregate counts
    for c in ["number_outpatient", "number_emergency", "number_inpatient",
              "num_lab_procedures", "num_procedures", "num_medications", "time_in_hospital"]:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")

    if {"number_outpatient", "number_emergency", "number_inpatient"}.issubset(df.columns):
        df["total_visits"] = (
            df["number_outpatient"].fillna(0)
            + df["number_emergency"].fillna(0)
            + df["number_inpatient"].fillna(0)
        )

    if {"num_lab_procedures", "num_procedures"}.issubset(df.columns):
        df["total_procedures"] = df["num_lab_procedures"].fillna(0) + df["num_procedures"].fillna(0)

    if {"num_medications", "time_in_hospital"}.issubset(df.columns):
        df["medication_complexity"] = df["num_medications"] / (df["time_in_hospital"].fillna(0) + 1)

    if med_cols_present:
        df["diabetes_med_count"] = (df[med_cols_present] != "No").sum(axis=1)
        df["total_changes"] = (df[med_cols_present] == "Change").sum(axis=1)

    # age features
    if "age" in df.columns:
        age_map = {
            "[0-10)": 5, "[10-20)": 15, "[20-30)": 25, "[30-40)": 35, "[40-50)": 45,
            "[50-60)": 55, "[60-70)": 65, "[70-80)": 75, "[80-90)": 85, "[90-100)": 95,
        }
        df["age_numeric"] = df["age"].map(age_map)
        df["age_group"] = pd.cut(
            df["age_numeric"],
            bins=[0, 18, 35, 50, 65, 120],
            labels=["child", "young", "middle", "senior", "elderly"],
            right=False
        )

    # A1C / glucose
    def map_a1c_level(x):
        if pd.isna(x) or x == "None":
            return 0
        if x == "Norm":
            return 1
        if x == ">7":
            return 2
        if x == ">8":
            return 3
        return 0

    def map_glu_level(x):
        if pd.isna(x) or x == "None":
            return 0
        if x == "Norm":
            return 1
        if x == ">200":
            return 2
        if x == ">300":
            return 3
        return 0

    if "A1Cresult" in df.columns:
        df["a1c_level"] = df["A1Cresult"].apply(map_a1c_level)
        df["a1c_high"] = df["A1Cresult"].isin([">7", ">8"]).astype(int)
        df["a1c_measured"] = (df["A1Cresult"] != "None").astype(int)

    if "max_glu_serum" in df.columns:
        df["glu_level"] = df["max_glu_serum"].apply(map_glu_level)
        df["glu_high"] = df["max_glu_serum"].isin([">200", ">300"]).astype(int)
        df["glu_measured"] = (df["max_glu_serum"] != "None").astype(int)

    # ===================== 7) build X (drop Y/A and other columns) =====================
    drop_set = set(drop_cols) | {"readmitted", group_col} | {"age_numeric", "age", "A1Cresult", "max_glu_serum"}
    X_df = df.drop(columns=[c for c in drop_set if c in df.columns], errors="ignore").copy()
    X_df = X_df.replace("?", np.nan)

    # ===================== 8) drop rows by row missing ratio (sample-wise) =====================
    row_missing_ratio = X_df.isna().mean(axis=1)
    keep_row = row_missing_ratio < row_missing_thres
    X_df = X_df.loc[keep_row].copy()
    # 同步裁剪 A/Y
    Y = Y.loc[X_df.index].to_numpy().astype(np.int64)
    A = A_cat.codes[keep_row.to_numpy()].astype(np.int64)

    # 重新计算 A 的映射（不变），保证 A 与 X_df 一致即可
    # (a_map 不需要变)

    # ===================== 9) one-hot + impute + normalize (numeric only) =====================
    cat_cols = X_df.select_dtypes(include=["object", "category"]).columns.tolist()
    num_cols = [c for c in X_df.columns if c not in cat_cols]

    # 兼容 sklearn 版本
    try:
        ohe = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    except TypeError:
        ohe = OneHotEncoder(handle_unknown="ignore", sparse=False)

    numeric_pipe = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="median")),
        ("scaler", StandardScaler()),   # 正规化数值特征
    ])
    categorical_pipe = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="most_frequent")),
        ("onehot", ohe),
    ])

    preprocessor = ColumnTransformer(
        transformers=[
            ("num", numeric_pipe, num_cols),
            ("cat", categorical_pipe, cat_cols),
        ],
        remainder="drop",
        verbose_feature_names_out=False
    )

    X = preprocessor.fit_transform(X_df).astype(np.float32)

    info = {
        "data_dir": str(data_dir),
        "csv_path": str(csv_path),
        "downloaded": downloaded,
        "group_col": group_col,
        "num_classes": 3,
        "y_map": {"NO": 0, ">30": 1, "<30": 2},
        "a_map": a_map,  # A=0/1/... 
        "n_samples": int(X.shape[0]),
        "n_features": int(X.shape[1]),
        "cat_cols": cat_cols,
        "num_cols": num_cols,
    }
    return X, A, Y, info



def load_enem(file_path, filename, features, grade_attribute, n_sample, n_classes, multigroup=False):
    ## load csv
    df = pd.read_csv(file_path+filename, encoding='cp860', sep=';')
    # print('Original Dataset Shape:', df.shape)

    ## Remove all entries that were absent or were eliminated in at least one exam
    ix = ~df[['TP_PRESENCA_CN', 'TP_PRESENCA_CH', 'TP_PRESENCA_LC', 'TP_PRESENCA_MT']].applymap(lambda x: False if x == 1.0 else True).any(axis=1)
    df = df.loc[ix, :]

    ## Remove "treineiros" -- these are individuals that marked that they are taking the exam "only to test their knowledge". It is not uncommon for students to take the ENEM in the middle of high school as a dry run
    df = df.loc[df['IN_TREINEIRO'] == 0, :]

    ## drop eliminated features
    df.drop(['TP_PRESENCA_CN', 'TP_PRESENCA_CH', 'TP_PRESENCA_LC', 'TP_PRESENCA_MT', 'IN_TREINEIRO'], axis=1, inplace=True)

    ## subsitute race by names
    # race_names = ['N/A', 'Branca', 'Preta', 'Parda', 'Amarela', 'Indigena']
    race_names = [np.nan, 'Branca', 'Preta', 'Parda', 'Amarela', 'Indigena']
    df['TP_COR_RACA'] = df.loc[:, ['TP_COR_RACA']].applymap(lambda x: race_names[x]).copy()

    ## remove repeated exam takers
    ## This pre-processing step significantly reduces the dataset.
    df = df.loc[df.TP_ST_CONCLUSAO.isin([1])]

    ## select features
    df = df[features]

    ## Dropping all rows or columns with missing values
    df = df.dropna()

    ## Creating racebin & gradebin & sexbin variable
    df['gradebin'] = construct_grade(df, grade_attribute, n_classes)
    if multigroup:
        df['racebin'] = construct_race(df, 'TP_COR_RACA')
    else:
        df['racebin'] =np.logical_or((df['TP_COR_RACA'] == 'Branca').values, (df['TP_COR_RACA'] == 'Amarela').values).astype(int)
    df['sexbin'] = (df['TP_SEXO'] == 'M').astype(int)

    df.drop([grade_attribute[0], 'TP_COR_RACA', 'TP_SEXO'], axis=1, inplace=True)

    ## encode answers to questionaires
    ## Q005 is 'Including yourself, how many people currently live in your household?'
    question_vars = ['Q00' + str(x) if x < 10 else 'Q0' + str(x) for x in range(1, 25)]
    for q in question_vars:
        if q != 'Q005':
            df_q = pd.get_dummies(df[q], prefix=q)
            df.drop([q], axis=1, inplace=True)
            df = pd.concat([df, df_q.iloc[:, :-1]], axis=1)
            
    ## check if age range ('TP_FAIXA_ETARIA') is within attributes
    if 'TP_FAIXA_ETARIA' in features:
        q = 'TP_FAIXA_ETARIA'
        df_q = pd.get_dummies(df[q], prefix=q)
        df.drop([q], axis=1, inplace=True)
        df = pd.concat([df, df_q.iloc[:, :-1]], axis=1)

    ## encode SG_UF_PROVA (state where exam was taken)
    df_res = pd.get_dummies(df['SG_UF_PROVA'], prefix='SG_UF_PROVA')
    df.drop(['SG_UF_PROVA'], axis=1, inplace=True)
    df = pd.concat([df, df_res], axis=1)

    df = df.dropna()
    ## Scaling ##
    scaler = MinMaxScaler()
    scale_columns = list(set(df.columns.values) - set(['gradebin', 'racebin']))
    df[scale_columns] = pd.DataFrame(scaler.fit_transform(df[scale_columns]), columns=scale_columns, index=df.index)
    # print('Preprocessed Dataset Shape:', df.shape)

    df = df.sample(n=min(n_sample, df.shape[0]), axis=0, replace=False)
    return df

def drug_process(path, protected_attribute='gender'):
    # Download data file if it does not exist
    path = "data/drug/raw_data/drug_consumption.data"
    url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00373/drug_consumption.data"

    if not os.path.exists(path):
        print("Data set does not exist in current folder --- have to download it")

        # Ensure parent directory exists
        os.makedirs(os.path.dirname(path), exist_ok=True)

        try:
            with requests.get(url, stream=True, timeout=30) as r:
                r.raise_for_status()
                with open(path, "wb") as f:
                    for chunk in r.iter_content(chunk_size=1024 * 1024):
                        if chunk:
                            f.write(chunk)
            print("Download successful\n")
        except requests.RequestException as e:
            print(f"Could not download the data set --- {e}\nPlease download it manually")
            sys.exit(1)

    # -----------------------------
    # Config
    # ----------------------------- 
    if protected_attribute == 'gender':
        pr_attr_name = "Gender"      # A
    else:
        ValueError('Not support protected attribute {}!'.format(protected_attribute))

    name_of_Y = "Cannabis"       # Y
    features_for_predicting_Y = ["Age","Education","Nscore","Escore","Oscore","Ascore","Cscore","Impulsive","SS"]

    # -----------------------------
    # Load raw data
    # -----------------------------
    dataORIGINAL = pd.read_csv(
        path,
        names=["ID","Age","Gender","Education","Country","Ethnicity","Nscore",
            "Escore","Oscore","Ascore","Cscore","Impulsive","SS","Alcohol","Amphet",
            "Amyl","Benzos","Caff","Cannabis","Choc","Coke","Crack","Ecstasy","Heroin",
            "Ketamine","Legalh","LSD","Meth","Mushroom","Nicotine","Semer","VSA"]
    )

    # categorical
    dataORIGINAL["Education"] = dataORIGINAL["Education"].astype("category")
    dataORIGINAL["Country"] = dataORIGINAL["Country"].astype("category")
    dataORIGINAL["Ethnicity"] = dataORIGINAL["Ethnicity"].astype("category")

    # -----------------------------
    # A: Gender -> {0,1}
    # -----------------------------
    A = dataORIGINAL[pr_attr_name].copy()
    A[A >= 0] = 1
    A[A < 0] = 0
    A = A.values.astype("int64")

    # -----------------------------
    # Y: label
    # Cannabis: CL0-CL6 to 1-4
    # -----------------------------
    Y_raw = dataORIGINAL[name_of_Y].astype(str).values  # e.g., "CL0"..."CL6"

    def map_cl_to_4class(s: str) -> int:
        m = re.match(r"CL(\d+)$", s.strip())
        if m is None:
            raise ValueError(f"Unexpected label: {s}")
        k = int(m.group(1))
        if k == 0:
            return 0
        elif 1 <= k <= 2:
            return 1
        elif k == 3:
            return 2
        else:
            return 3
    Y = np.array([map_cl_to_4class(s) for s in Y_raw], dtype=np.int64)

    # -----------------------------
    # X: features (one-hot + float64)
    # -----------------------------
    X_df = dataORIGINAL.loc[:, features_for_predicting_Y].copy()
    X_df = pd.get_dummies(X_df, drop_first=True)

    X = X_df.values.astype("float64")

    # -----------------------------
    # Standardize X 
    # -----------------------------
    mean_vec = X.mean(axis=0)
    std_vec = X.std(axis=0)
    std_vec[std_vec < 1e-4] = 1.0
    X = (X - mean_vec) / std_vec

    print("X shape:", X.shape)
    print("Y shape:", Y.shape, "unique:", np.unique(Y))
    print("A shape:", A.shape, "unique:", np.unique(A))

    return  X,Y,A

def construct_race(df, protected_attribute):
    race_dict = {'Branca': 0, 'Preta': 1, 'Parda': 2, 'Amarela': 3, 'Indigena': 4} 
    # race_dict = {'Branca': 0, 'Preta': 1, 'Parda': 2, 'Amarela': 3}
    # changed to match ENEM 2020 numbering
    return df[protected_attribute].map(race_dict)

def construct_grade(df, grade_attribute, n):
    v = df[grade_attribute[0]].values
    quantiles = np.nanquantile(v, np.linspace(0.0, 1.0, n+1))
    return pd.cut(v, quantiles, labels=np.arange(n))

def _download_biasbios_if_needed(data_dir):
    os.makedirs(data_dir, exist_ok=True)
    train_path = os.path.join(data_dir, "train.pickle")
    test_path  = os.path.join(data_dir, "test.pickle")
    dev_path   = os.path.join(data_dir, "dev.pickle")
    if any(not os.path.exists(p) for p in [train_path, test_path, dev_path]):
        urllib.request.urlretrieve("https://storage.googleapis.com/ai2i/nullspace/biasbios/train.pickle", train_path)
        urllib.request.urlretrieve("https://storage.googleapis.com/ai2i/nullspace/biasbios/test.pickle",  test_path)
        urllib.request.urlretrieve("https://storage.googleapis.com/ai2i/nullspace/biasbios/dev.pickle",   dev_path)
    return train_path, test_path, dev_path

def load_biasbios_full_dataset(
    data_dir: str,
    add_sensitive_attribute: bool = True,
):
    """
    return: train_ds, val_ds, test_ds, tokenizer, label_names, group_names

    - train_ds / val_ds / test_ds: FairTextDataset
    - tokenizer: AutoTokenizer
    - label_names / group_names: list
    """

    # 1. label & group
    label_names = [
        "accountant", "architect", "attorney", "chiropractor", "comedian",
        "composer", "dentist", "dietitian", "dj", "filmmaker",
        "interior_designer", "journalist", "model", "nurse", "painter",
        "paralegal", "pastor", "personal_trainer", "photographer", "physician",
        "poet", "professor", "psychologist", "rapper", "software_engineer",
        "surgeon", "teacher", "yoga_teacher"
    ]
    group_names = ["female", "male"]

    features = hf_datasets.Features({
        "bio": hf_datasets.Value("string"),
        "title": hf_datasets.ClassLabel(names=label_names),
        "gender": hf_datasets.ClassLabel(names=group_names),
    })

    # load dataset
    train_path, test_path, dev_path = _download_biasbios_if_needed(data_dir)
    def _load_split(path):
        with open(path, "rb") as f:
            data = pickle.load(f)
        rows = {"bio": [], "title": [], "gender": []}
        for row in data:
            gender = "female" if row["g"] == "f" else "male"
            title  = row["p"]
            bio_text = row["hard_text_untokenized"]
            if add_sensitive_attribute:
                bio_text = gender.capitalize() + ". " + bio_text
            rows["bio"].append(bio_text)
            rows["title"].append(title)
            rows["gender"].append(gender)
        return rows

    r1 = _load_split(train_path)
    r2 = _load_split(test_path)
    r3 = _load_split(dev_path)

    rows_full = {
        "bio":    r1["bio"]    + r2["bio"]    + r3["bio"],
        "title":  r1["title"]  + r2["title"]  + r3["title"],
        "gender": r1["gender"] + r2["gender"] + r3["gender"],
    }
    full_ds = hf_datasets.Dataset.from_dict(rows_full, features=features)
    return full_ds, label_names, group_names


    # make dataset
    def _load_split(path):
        rows = {"bio": [], "title": [], "gender": []}
        with open(path, "rb") as f:
            data = pickle.load(f)

        for row in data:
            # row["g"] = 'f'/'m'
            gender = "female" if row["g"] == "f" else "male"

            # row["p"] = profession string
            title = row["p"]

            bio_text = row["hard_text_untokenized"]
            if add_sensitive_attribute:
                bio_text = gender.capitalize() + ". " + bio_text

            rows["gender"].append(gender)
            rows["title"].append(title)
            rows["bio"].append(bio_text)

        return hf_datasets.Dataset.from_dict(rows, features=features)

    raw_train = _load_split(train_path)
    raw_test  = _load_split(test_path)
    raw_val   = _load_split(dev_path)   # dev -> val

    # create tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # FairTextDataset
    train_ds = FairTextDataset(raw_train, tokenizer=tokenizer, max_length=max_length, weight=weight)
    val_ds   = FairTextDataset(raw_val,   tokenizer=tokenizer, max_length=max_length, weight=weight)
    test_ds  = FairTextDataset(raw_test,  tokenizer=tokenizer, max_length=max_length, weight=weight)

    train_ds.tokenizer = tokenizer
    val_ds.tokenizer   = tokenizer
    test_ds.tokenizer  = tokenizer

    return train_ds, val_ds, test_ds, tokenizer, label_names, group_names

def mean_pool(last_hidden, attention_mask):
    # last_hidden: [B,L,H], attention_mask: [B,L]
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden)     # [B,L,1]
    summed = (last_hidden * mask).sum(dim=1)                     # [B,H]
    denom = mask.sum(dim=1).clamp(min=1e-6)                      # [B,1]
    return summed / denom

def build_full_embeddings_meanpool(
    data_dir="data/biasbios/raw_data",
    out_dir="data/biasbios/emb_full_bert_meanpool",
    model_name="bert-base-uncased",
    max_length=256,
    batch_size=64,
    dtype="float16",          # "float16" or "float32"
    device="cuda",
    num_workers=4,
    add_sensitive_attribute=True,
):
    os.makedirs(out_dir, exist_ok=True)

    # (a) load full dataset
    full_ds, label_names, group_names = load_biasbios_full_dataset(
        data_dir=data_dir,
        add_sensitive_attribute=add_sensitive_attribute
    )

    # (b) tokenizer + tokenize (batched map faster than on-the-fly)
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

    def tok_fn(examples):
        return tokenizer(
            examples["bio"],
            truncation=True,
            max_length=max_length,
            padding=False,
        )

    tokenized = full_ds.map(tok_fn, batched=True, remove_columns=["bio"], desc="Tokenizing(full)")
    # labels / groups
    Y = np.asarray(tokenized["title"], dtype=np.int64)   # N
    A = np.asarray(tokenized["gender"], dtype=np.int64)  # N

    # (c) model frozen
    model = AutoModel.from_pretrained(model_name)
    model.to(device)
    model.eval()
    for p in model.parameters():
        p.requires_grad_(False)

    collator = DataCollatorWithPadding(tokenizer=tokenizer)
    dl = DataLoader(
        tokenized,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collator,
        pin_memory=True
    )

    # (d) get hidden size H
    H = model.config.hidden_size
    N = len(tokenized)

    # (e) prepare memmap N×H
    emb_dtype = np.float16 if dtype == "float16" else np.float32
    emb_path = os.path.join(out_dir, "X_emb.dat")
    X_emb = np.memmap(emb_path, mode="w+", dtype=emb_dtype, shape=(N, H))

    # (f) run forward & mean pool
    torch_dtype = torch.float16 if dtype == "float16" else torch.float32
    write_ptr = 0

    with torch.no_grad():
        for batch in dl:
            batch = {k: v.to(device) for k, v in batch.items()
                     if k in ["input_ids", "attention_mask", "token_type_ids"]}

            out = model(**batch)
            last = out.last_hidden_state.to(torch_dtype)                 # [B,L,H]
            emb = mean_pool(last, batch["attention_mask"])               # [B,H]

            bsz = emb.size(0)
            X_emb[write_ptr:write_ptr+bsz] = emb.detach().cpu().numpy().astype(emb_dtype, copy=False)
            write_ptr += bsz

    X_emb.flush()

    # (g) save meta + Y/A
    np.save(os.path.join(out_dir, "Y.npy"), Y)
    np.save(os.path.join(out_dir, "A.npy"), A)
    meta = {
        "model_name": model_name,
        "max_length": max_length,
        "pooling": "mean",
        "dtype": dtype,
        "N": int(N),
        "H": int(H),
        "n_class": int(len(label_names)),  # 28
        "n_group": int(len(group_names)),  # 2
        "add_sensitive_attribute": bool(add_sensitive_attribute),
    }
    with open(os.path.join(out_dir, "meta.json"), "w", encoding="utf-8") as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)

    return X_emb, Y, A, meta


def adult_process():
    # Adult
    sensitive_attributes = ['sex']
    categorical_attributes = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'native-country']
    continuous_attributes = ["age", "fnlwgt", "education-num", "capital-gain", "capital-loss", "hours-per-week"]
    features_to_keep = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
                'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss','hours-per-week', 
                'native-country', 'salary']
    label_name = 'salary'

    path_adult = 'data/adult/raw_data/adult.data'
    path_test = 'data/adult/raw_data/adult.test'

    adult = process_adult_csv(path_adult, label_name, ' >50K', sensitive_attributes, [' Female'], categorical_attributes, continuous_attributes, features_to_keep, na_values = [], header = None, columns = features_to_keep)
    test = process_adult_csv(path_test, label_name, ' >50K.', sensitive_attributes, [' Female'], categorical_attributes, continuous_attributes, features_to_keep, na_values = [], header = None, columns = features_to_keep) # the distribution is very different from training distribution
    test['native-country_ Holand-Netherlands'] = 0
    test = test[adult.columns]

    adult_num_features = len(adult.columns)-1

    adult.to_csv('data/adult/raw_data/train.csv', index=None)
    test.to_csv('data/adult/raw_data/test.csv', index=None)
    
def process_adult_csv(filename, label_name, favorable_class, sensitive_attributes, privileged_classes, categorical_attributes, continuous_attributes, features_to_keep, na_values = [], header = 'infer', columns = None):
    """
    from https://github.com/yzeng58/Improving-Fairness-via-Federated-Learning/blob/main/FedFB/DP_load_dataset.py
    process the adult file: scale, one-hot encode
    only support binary sensitive attributes -> [gender, race] -> 4 sensitive groups 
    """
    skiprows = 1 if filename.endswith('test') else 0
    df = pd.read_csv(os.path.join(filename), delimiter = ',', header = header, na_values = na_values, skiprows=skiprows)
    if header == None: df.columns = columns
    df = df[features_to_keep]

    # apply one-hot encoding to convert the categorical attributes into vectors
    df = pd.get_dummies(df, columns = categorical_attributes)

    # normalize numerical attributes to the range within [0, 1]
    def scale(vec):
        minimum = min(vec)
        maximum = max(vec)
        return (vec-minimum)/(maximum-minimum)
    
    df[continuous_attributes] = df[continuous_attributes].apply(scale, axis = 0)
    df.loc[df[label_name] != favorable_class, label_name] = 0
    df.loc[df[label_name] == favorable_class, label_name] = 1
    df[label_name] = df[label_name].astype('category').cat.codes
    df['sex'] = df['sex'].map({' Male':0, ' Female':1}).astype('category')
    return df


def compas_1_data_processing(sensitive='sex-race'):
    #@title Load COMPAS dataset

    LABEL_COLUMN = 'two_year_recid'
    if sensitive == 'sex-race':
        sensitive_attributes = ['sex_Female', 'race_African-American']
    elif sensitive == 'race':
        sensitive_attributes = ['race_African-American']


    def get_data():
        data_path = "data/compas/raw_data/compas-scores-two-years.csv"
        df = pd.read_csv(data_path)
        FEATURES = [
            'age', 'c_charge_degree', 'race', 'age_cat', 'score_text', 'sex',
            'priors_count', 'days_b_screening_arrest', 'decile_score', 'is_recid',
            'two_year_recid'
        ]
        df = df[FEATURES]
        df = df[df.days_b_screening_arrest <= 30]
        df = df[df.days_b_screening_arrest >= -30]
        df = df[df.is_recid != -1]
        df = df[df.c_charge_degree != 'O']
        df = df[df.score_text != 'N/A']
        continuous_features = [
            'priors_count', 'days_b_screening_arrest', 'is_recid', 'two_year_recid'
        ]
        continuous_to_categorical_features = ['age', 'decile_score', 'priors_count']
        categorical_features = ['c_charge_degree', 'race', 'score_text', 'sex']
        # continuous_to_categorical_features = [ 'priors_count']
        # categorical_features = ['c_charge_degree', 'race', 'sex']

        # Functions for preprocessing categorical and continuous columns.
        def binarize_categorical_columns(input_df, categorical_columns=[]):
            # Binarize categorical columns.
            binarized_df = pd.get_dummies(input_df, columns=categorical_columns)
            return binarized_df

        def bucketize_continuous_column(input_df, continuous_column_name, bins=None):
            input_df[continuous_column_name] = pd.cut(
                input_df[continuous_column_name], bins, labels=False)

        for c in continuous_to_categorical_features:
            b = [0] + list(np.percentile(df[c], [20, 40, 60, 80, 90, 100]))
            if c == 'priors_count':
                b = list(np.percentile(df[c], [0, 50, 70, 80, 90, 100]))
            bucketize_continuous_column(df, c, bins=b)

        # df = binarize_categorical_columns(
        #     df,
        #     categorical_columns=categorical_features)

        df = binarize_categorical_columns(
            df,
            categorical_columns=categorical_features +
            continuous_to_categorical_features)

        to_fill = [
            u'decile_score_0', u'decile_score_1', u'decile_score_2',
            u'decile_score_3', u'decile_score_4', u'decile_score_5'
        ]
        for i in range(len(to_fill) - 1):
            df[to_fill[i]] = df[to_fill[i:]].max(axis=1)
            
        to_fill = [
            u'priors_count_0.0', u'priors_count_1.0', u'priors_count_2.0',
            u'priors_count_3.0', u'priors_count_4.0'
        ]
        for i in range(len(to_fill) - 1):
            df[to_fill[i]] = df[to_fill[i:]].max(axis=1)

        print(df.columns)
        features = [
            u'days_b_screening_arrest', u'c_charge_degree_F', u'c_charge_degree_M',
            u'race_African-American', u'race_Asian', u'race_Caucasian',
            u'race_Hispanic', u'race_Native American', u'race_Other',
            u'score_text_High', u'score_text_Low', u'score_text_Medium',
            u'sex_Female', u'sex_Male', u'age_0', u'age_1', u'age_2', u'age_3',
            u'age_4', u'age_5', u'decile_score_0', u'decile_score_1',
            u'decile_score_2', u'decile_score_3', u'decile_score_4',
            u'decile_score_5', u'priors_count_0.0', u'priors_count_1.0',
            u'priors_count_2.0', u'priors_count_3.0', u'priors_count_4.0'
        ]

        # # new
        # features = [
        #     u'days_b_screening_arrest', u'c_charge_degree_F', u'c_charge_degree_M',
        #     u'race_African-American', u'race_Asian', u'race_Caucasian',
        #     u'race_Hispanic', u'race_Native American', u'race_Other',
        #     u'sex_Female', u'sex_Male', u'age', u'priors_count_0.0', u'priors_count_1.0',
        #     u'priors_count_2.0', u'priors_count_3.0', u'priors_count_4.0'
        # ]
        # print(len(features))

        label = ['two_year_recid']

        df = df[features + label]
        return df, features, label

    df, feature_names, label_column = get_data()

    # if sensitive == 'race':
    #     df_w = df[df['race_Caucasian'] == 1]
    #     df_b = df[df['race_African-American'] == 1]
    #     df = pd.concat([df_w, df_b])

    from sklearn.utils import shuffle
    df = shuffle(df)
    N = len(df)
    # train_df = df[:int(N * 0.66)]
    # test_df = df[int(N * 0.66):]

    X_compas = np.array(df[feature_names])
    y_compas = np.array(df[label_column]).flatten()
    # X_test_compas = np.array(test_df[feature_names])
    # y_test_compas = np.array(test_df[label_column]).flatten()

    if sensitive == 'sex-race':

        # 0: male non-black, 1: female non-black, 2: male black, 3: female black
        A_compas = np.array(df[sensitive_attributes[0]] + df[sensitive_attributes[1]] * 2).flatten()
        # A_test_compas = np.array(test_df[sensitive_attributes[0]] + test_df[sensitive_attributes[1]] * 2).flatten()

        sex_race_idx = [i for i, value in enumerate(feature_names) if (value.startswith('race') or value.startswith('sex')) ==True]
        X_compas = np.delete(X_compas, sex_race_idx, axis=1)
        # X_test_compas = np.delete(X_test_compas, sex_race_idx, axis=1)

        print(X_compas.shape)
    
    elif sensitive == 'race':
        # 0: non-black, 1: black
        A_compas = np.array(df[sensitive_attributes]).flatten()
        # A_test_compas = np.array(test_df[sensitive_attributes]).flatten()

        sen_idx = [i for i, value in enumerate(feature_names) if value.startswith('race')==True]
        X_compas = np.delete(X_compas, sen_idx, axis=1)
        # X_test_compas = np.delete(X_test_compas, sen_idx, axis=1)

    print("compas process end.")

    return X_compas, y_compas,  A_compas

class CelebAMMapDataset(Dataset):
    def __init__(self, image_paths, image_dict, transform, multiclass=False):
        self.image_paths = image_paths
        self.image_dict = image_dict
        self.transform = transform
        self.multiclass = multiclass

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        p = self.image_paths[idx]
        img_id = os.path.basename(p)  # '000001.jpg'
        y, a, other = self.image_dict[img_id]

        img = Image.open(p).convert("RGB")
        x = self.transform(img)  # torch.float32, (3,128,128)

        if self.multiclass:
            y = y * 2 + other

        # 返回 tensor + int label（DataLoader 会自动 batch/stack）
        return x, int(y), int(a)


def celeba_data_processing(sensitive_attr, batch_size=256, mmap_file="processed_data.mmap", multiclass=False, num_workers=4):
    path = os.path.join('data', 'celeba', 'processed_data')
    os.makedirs(path, exist_ok=True)

    # ---------- read attributes ----------
    attr_file = os.path.join('data', 'celeba', 'raw_data', 'list_attr_celeba.txt')
    with open(attr_file, 'r', encoding='utf-8') as f:
        attributes = f.read().splitlines()

    tar = 'Smiling'
    other_tar = 'Big_Nose'
    header = attributes[1].split()
    target_idx = header.index(tar)
    other_idx = header.index(other_tar)

    if isinstance(sensitive_attr, list):
        assert len(sensitive_attr) == 2
        sen_idx = [header.index(sen) for sen in sensitive_attr]
    else:
        sen_idx = header.index(sensitive_attr)

    image = {}
    for line in attributes[2:]:
        info = line.split()
        if not info:
            continue
        image_id = info[0]
        vals = info[1:]

        tar_img = (int(vals[target_idx]) + 1) // 2
        other_img_val = (int(vals[other_idx]) + 1) // 2

        if isinstance(sensitive_attr, list):
            sen_img1 = (int(vals[sen_idx[0]]) + 1) // 2
            sen_img2 = (int(vals[sen_idx[1]]) + 1) // 2
            sen_img = sen_img1 + 2 * sen_img2
        else:
            sen_img = (int(vals[sen_idx]) + 1) // 2

        image[image_id] = (tar_img, sen_img, other_img_val)

    # ---------- list images ----------
    images_path = Path(os.path.join('data', 'celeba', 'raw_data', 'img_align_celeba'))
    images_list = sorted(images_path.glob('*.jpg'))
    assert len(images_list) > 0, f"[ERROR] No jpg found in: {images_path.resolve()}"

    images_ids = [str(x) for x in images_list]
    N = len(images_ids)

    # 额外 sanity check：随便抽一个看看能不能匹配 label
    test_id = os.path.basename(images_ids[0])
    assert test_id in image, f"[ERROR] image id {test_id} not found in attribute dict. Check list_attr_celeba.txt"

    # ---------- transform ----------
    transform = transforms.Compose([
        transforms.CenterCrop((178, 178)),
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # ---------- DataLoader for robust multiprocessing ----------
    ds = CelebAMMapDataset(images_ids, image, transform, multiclass=multiclass)
    dl = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=(num_workers > 0),
        drop_last=False,
    )

    # ---------- mmap ----------
    shape = (N, 3, 128, 128)
    X = np.memmap(mmap_file, dtype=np.float16, mode='w+', shape=shape)
    Y = np.empty((N,), dtype=np.int64)
    A = np.empty((N,), dtype=np.int64)

    print(f"start. N={N}, batch_size={batch_size}, num_workers={num_workers}")

    idx = 0
    for xb, yb, ab in dl:
        bs = xb.size(0)
        # xb: torch float32 -> numpy float16
        X[idx:idx+bs] = xb.numpy().astype(np.float16)
        Y[idx:idx+bs] = np.asarray(yb, dtype=np.int64)
        A[idx:idx+bs] = np.asarray(ab, dtype=np.int64)
        idx += bs

    assert idx == N, f"[ERROR] filled {idx} != {N}"

    # flush
    X.flush()

    Y = Y.reshape(-1, 1)
    A = A.reshape(-1, 1)

    print("end.")
    print("X shape:", X.shape, "dtype:", X.dtype)
    print("Y shape:", Y.shape, "A shape:", A.shape)

    unique, counts = np.unique(Y, return_counts=True)
    print("Value counts:", dict(zip(unique.tolist(), counts.tolist())))

    return X, Y, A


# class CelebAProcessedDataset(Dataset):
#     def __init__(self, image_paths, image_dict, transform, multiclass=False):
#         self.image_paths = image_paths
#         self.image_dict = image_dict
#         self.transform = transform
#         self.multiclass = multiclass

#     def __len__(self):
#         return len(self.image_paths)

#     def __getitem__(self, idx):
#         p = self.image_paths[idx]
#         img_id = os.path.basename(p)  # e.g. '000001.jpg'

#         y, a, other = self.image_dict[img_id]

#         img = Image.open(p).convert("RGB")
#         x = self.transform(img)  # torch.float32, (3,128,128)

#         if self.multiclass:
#             y = y * 2 + other

#         return x, int(y), int(a)


# def celeba_data_processing(sensitive_attr, batch_size=32, mmap_file="processed_data.mmap", multiclass=False):

#     path = os.path.join('data','celeba', 'processed_data')
#     # file_name = os.path.join(path,f'num={sample_num}_multiclass={multiclass}_celeba.npy')
#     # if os.path.exists(file_name):
#     #     loaded_data = np.load(file_name, allow_pickle=True).item()
#     #     X = loaded_data['X']
#     #     Y = loaded_data['Y']
#     #     A = loaded_data['A']
#     # else:
#     if not os.path.exists(path):
#         mkdir(path)

#     f_identities = open(os.path.join('data', 'celeba', 'raw_data', 'identity_CelebA.txt'), 'r')
#     identities = f_identities.read().split('\n')

#     f_attributes = open(os.path.join('data', 'celeba', 'raw_data', 'list_attr_celeba.txt'), 'r')
#     attributes = f_attributes.read().split('\n')

#     tar = 'Smiling'
#     other_tar = 'Big_Nose'
#     sen_attr = sensitive_attr

#     target_idx = attributes[1].split().index(tar)
#     other_idx  = attributes[1].split().index(other_tar)
#     if isinstance(sen_attr, list):
#         assert len(sen_attr) == 2
#         sen_idx = [attributes[1].split().index(sen) for sen in sen_attr]
#     elif isinstance(sen_attr, str):
#         sen_idx = attributes[1].split().index(sen_attr)

#     image = {}

#     for line in attributes[2:]:
#         info = line.split()
#         if not info:
#             continue
#         image_id = info[0]
#         tar_img = (int(info[1:][target_idx]) + 1) / 2
#         other_img_val = (int(info[1:][other_idx])  + 1) / 2
#         if isinstance(sen_attr, list):
#             sen_img1 = (int(info[1:][sen_idx[0]]) + 1) / 2
#             sen_img2 = (int(info[1:][sen_idx[1]]) + 1) / 2
#             sen_img = sen_img1 + 2 * sen_img2
#         elif isinstance(sen_attr, str):
#             sen_img = (int(info[1:][sen_idx]) + 1) / 2

#         # image[image_id] = tar_img, sen_img
#         image[image_id] = (tar_img, sen_img, other_img_val)

#     images_path = Path(os.path.join('data', 'celeba', 'raw_data', 'img_align_celeba'))
#     images_list = list(images_path.glob('*.jpg'))
#     images_list_str = [str(x) for x in images_list]
#     images_ids = images_list_str
#     # images_ids = random.sample(images_list_str, sample_num)

#     sample_target = []
#     sample_sensitive = []
#     sample_other_target = []
#     for path in images_ids:
#         sample_target.append(image[path[-10:]][0])
#         sample_sensitive.append(image[path[-10:]][1])
#         sample_other_target.append(image[path[-10:]][2])

#     transform = transforms.Compose([
#         transforms.CenterCrop((178, 178)), 
#         transforms.Resize((128, 128)), 
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
#     ])

#     print('start.')

#     shape = (len(images_ids), 3, 128, 128)
#     X = np.memmap(mmap_file, dtype=np.float16, mode='w+', shape=shape)

#     for i in range(0, len(images_ids), batch_size):
#         batch_ids = images_ids[i:i+batch_size]
#         sample_target_batch = [image[path[-10:]][0] for path in batch_ids]
#         sample_sensitive_batch = [image[path[-10:]][1] for path in batch_ids]
#         sample_other_target_batch = [image[path[-10:]][2] for path in batch_ids]

#         mp_img_loader = multiprocess_img_load(batch_ids, transform)
#         batch_X = mp_img_loader.get_imgs().astype(np.float32)

#         X[i:i+len(batch_X)] = batch_X
#         sample_target[i:i+len(batch_X)] = sample_target_batch
#         sample_sensitive[i:i+len(batch_X)] = sample_sensitive_batch
#         sample_other_target[i:i+len(batch_X)] = sample_other_target_batch

#     print(X.shape)
#     print(type(X))
#     print(np.max(X),np.min(X))

#     Y, A = np.array(sample_target,dtype=np.float16), np.array(sample_sensitive, dtype=np.float16)

#     if multiclass:
#         other_target = np.array(sample_other_target,dtype=np.float16)
#         Y = Y * 2 + other_target
#     print('end.')

#     # data_dict = {'X': X, 'Y': Y, 'A': A}
#     # np.save(file_name, data_dict)
#     # X.flush()
#     # del X
#     # time.sleep(1)
#     # # os.remove(mmap_file)  # delete memmap 

#     # loaded_data = np.load(file_name, allow_pickle=True).item()
#     # X = loaded_data['X']
#     # Y = loaded_data['Y']
#     # A = loaded_data['A']

#     Y, A = Y.reshape(-1,1), A.reshape(-1,1)
#     # print shape
#     print("Y shape:", Y.shape)

#     # statistics
#     unique, counts = np.unique(Y, return_counts=True)
#     print("Value counts:", dict(zip(unique, counts)))

    
#     return X, Y, A


class multiprocess_img_load(object):
    def __init__(self, img_paths:list, transform, img_size=(3,128,128), n_thread=None) -> None:
        self.image_paths = img_paths
        self.img_size = img_size
        self.num_img = len(img_paths)
        self._mutex_put = threading.Lock()
        self.n_thread = n_thread if (n_thread is not None) else max(1, multiprocessing.cpu_count() - 2)
        self.transform = transform
    
    def get_imgs(self):
        self._buffer = np.zeros([self.num_img]+list(self.img_size))
        batch_size = round(self.num_img / self.n_thread)
        batch_idx = []
        for i in range(self.n_thread):
            idx = list(range(i * batch_size, (i+1) * batch_size)) if (i+1) * batch_size <= self.num_img else list(range(i * batch_size, self.num_img))
            batch_idx.append(idx)
        t_list = []
        for tid in range(self.n_thread):
            img_ids = list(range(tid * batch_size, (tid+1) * batch_size)) if (tid+1) * batch_size <= self.num_img else range(tid * batch_size, self.num_img)
            img_target = [self.image_paths[i] for i in img_ids]
            t = threading.Thread(target=self.load_image, args=(img_target, img_ids))
            t_list.append(t)
            t.start()

        for t in t_list:
            t.join()

        del t_list

        return self._buffer

    def load_image(self, img_names, img_ids):
        batch_images = np.vstack([np.expand_dims(self.transform(Image.open(img)).numpy(), axis=0) for img in img_names])
        self._mutex_put.acquire()
        self._buffer[img_ids] = batch_images
        self._mutex_put.release()

def celeba_split(data_indices, X ,Y ,A):
    split_data = {'users': [], 'user_data':{}, 'num_samples':[]}
    for i in range(len(data_indices)):
        split_data['users'].append(str(i))
        split_data['user_data'][str(i)] = {'x':X[data_indices[i],:],
                                      'y':Y[data_indices[i]],
                                      'A':A[data_indices[i]]}
        split_data['num_samples'].append(len(data_indices[i]))
    return split_data

def get_unsaved_data(data_split):
    for client in data_split['user_data']:
        X = np.array(data_split['user_data'][client]["x"]).astype(np.float32)
        Y = np.array(data_split['user_data'][client]["y"]).astype(np.float32).reshape(-1,1)
        A = np.array(data_split['user_data'][client]["A"]).astype(np.float32).reshape(-1,1)
        dataset = FairDataset(X, Y, A)
        data_split['user_data'][client] = dataset
    return data_split

def bank_get_sensitive_feature(X, colname, sensitive_attr):
    if sensitive_attr == 'age':
        attr_idx = colname.index(sensitive_attr)
        A = X[:,attr_idx]
        X = np.delete(X, attr_idx, axis = 1)
    return X,A

def compas_get_sensitive_feature(X, colname, sensitive_attr):
    sex_attr = []
    race_attr = []
    for col in colname:
        if col.startswith('race'):
            race_attr.append(col)
        elif col.startswith('sex'):
            sex_attr.append(col)
    
    if sensitive_attr == 'sex':
        attr_idx = [colname.index(attr) for attr in sex_attr]
        A = np.argmax(X[:,attr_idx], axis =1 )  # [1: Male, 0: Female]
        X = np.delete(X, attr_idx, axis = 1)
    elif sensitive_attr == 'race':
        attr_idx = [colname.index(attr) for attr in race_attr]
        A = np.argmax(X[:,attr_idx], axis = 1) # ['African-American': 0,'Caucasian': 1,'Asian':2,'Hispanic':3]
        A[A>=1] = 1
        X = np.delete(X, attr_idx, axis = 1)
    elif sensitive_attr == 'non-sex':
        attr_idx = [colname.index(attr) for attr in sex_attr]
        A = np.argmax(X[:,attr_idx], axi = 1) 
    elif sensitive_attr == 'non-race':
        attr_idx = [colname.index(attr) for attr in race_attr] 
        A = np.argmax(X[:,attr_idx], axis = 1)
    return X, A

def split_celeba_data(ids: list):
    path = 'data/celeba/raw_data/img_align_celeba/'
    imgs = np.concatenate([np.expand_dims(np.array(Image.open(path + id)).transpose(2,0,1), axis=0) for id in ids], axis=0)
    
    return imgs

def partition_test_data(separation, targets):
    label_num = len(set(targets))
    targets_numpy = np.array(targets, dtype=np.int32)
    data_indices = [[] for _ in range(len(separation[0]))]
    data_idx_for_each_label = [
        np.where(targets_numpy == i)[0] for i in range(label_num)
    ]
    for k in range(label_num):
        distrib_cumsum = (np.cumsum(separation[k]) * len(data_idx_for_each_label[k])).astype(int)[:-1]
        data_indices = [
            np.concatenate((idx_j, idx.tolist())).astype(np.int64)
            for idx_j, idx in zip(
                data_indices, np.split(data_idx_for_each_label[k], distrib_cumsum)
            )
        ]
    
    return data_indices

def split(X ,Y ,A, prop=None):
    n = X.shape[0]
    Y = Y.reshape(-1,1)
    A = A.reshape(-1,1)
    n_group = len(np.unique(A))
    n_class = len(np.unique(Y))
    is_A_valid = np.array_equal(np.unique(A), np.arange(n_group))
    is_Y_valid = np.array_equal(np.unique(Y), np.arange(n_class))
    assert is_A_valid == True and is_Y_valid ==True

    indices = np.random.permutation(n)
    train_index, val_index, test_index = indices[:int(n*0.4)], indices[int(n*0.4):int(n*0.6)], indices[int(n*0.6):int(n*1)]
    train_data = FairDataset(X[train_index,:], Y[train_index,:], A[train_index,:])
    val_data = FairDataset(X[val_index,:], Y[val_index,:], A[val_index,:])
    test_data = FairDataset(X[test_index,:], Y[test_index,:], A[test_index,:])
    return train_data, val_data, test_data, n_group, n_class


def adult_get_sensitive_feature(X, colname, sensitive, Y=None):
    sex_attr = 'sex'
    race_attr = []
    for col in colname:
        if col.startswith('race'):
            race_attr.append(col)
    if sensitive == "race":
        attr = 'race_ White'
        attr_idx = colname.index(attr)
        A = np.array(X[:,attr_idx])
        # print(np.unique(A))
        del_idx = [colname.index(attr) for attr in race_attr]
        X = np.delete(X, del_idx, axis = 1)
    elif sensitive == "sex":
        attr_idx = colname.index(sex_attr)
        A = X[:, attr_idx] # [1: female, 0: male]
        X = np.delete(X, attr_idx, axis = 1)
    elif sensitive == "none-race":
        attr_idx = [colname.index(attr) for attr in race_attr]
        A = np.argmax(X[:,attr_idx], axis =1 ) 
    elif sensitive == "none-sex":
        attr_idx = colname.index(sex_attr)
        A = X[:, attr_idx] # [1: female, 0: male]
    elif sensitive == "sex-race":
        race_idx = [colname.index(attr) for attr in race_attr] 
        race_unused = [colname.index(attr) for attr in ['race_ Amer-Indian-Eskimo', 'race_ Asian-Pac-Islander', 'race_ Other']] 
        Y = Y[np.sum(X[:,race_unused],axis=1) == 0]
        X = X[np.sum(X[:,race_unused],axis=1) == 0,:]
        sex_idx = colname.index(sex_attr)
        A = (np.argmax(X[:,race_idx], axis =1) + X[:,sex_idx]) - 2
        X = np.delete(X, race_idx + [sex_idx], axis = 1)
        return X,A,Y


    else:
        print("error sensitive attr")
        exit()
    
    return X, A

def read_data(path, name=None, sensitive_process=None):
    split_train = {'users': [], 'user_data':{}, 'num_samples':{}}
    split_val = copy.deepcopy(split_train)
    split_test = copy.deepcopy(split_train)

    if name == 'celeba':
        data_split = np.load(path, allow_pickle=True).item()
    elif name == 'enem':
        with open(path, 'rb') as f:
            data_split = pickle.load(f)
    else:
        with open(path, 'rb') as file:
            data_split = json.load(file)

    for client in data_split['users']:
        split_train['users'].append(client)
        split_val['users'].append(client)
        split_test['users'].append(client)

        X = np.array(data_split['user_data'][client]["x"]).astype(np.float32)

        Y = np.array(data_split['user_data'][client]["y"]).astype(np.float32).reshape(-1,1)

        A = np.array(data_split['user_data'][client]["A"]).astype(np.float32).reshape(-1,1)

        n = np.arange(X.shape[0])
        indices = np.random.permutation(n)
        train_index, val_index, test_index = indices[:int(len(n)*0.6)], indices[:int(len(n)*0.6)], indices[int(len(n)*0.6):int(len(n)*1)]
        split_train['user_data'][client] = FairDataset(X[train_index,:], Y[train_index,:], A[train_index,:])
        split_val['user_data'][client] = FairDataset(X[val_index,:], Y[val_index,:], A[val_index,:])
        split_test['user_data'][client] = FairDataset(X[test_index,:], Y[test_index,:], A[test_index,:])

        split_train['num_samples'][client] = len(train_index)
        split_val['num_samples'][client] = len(val_index)
        split_test['num_samples'][client] = len(test_index)
        
    return split_train,split_val,split_test
    
def celeba_read_data(data_split, name=None, sensitive_process=None):
    split_train = {'users': [], 'user_data':{}, 'num_samples':{}}
    split_val = copy.deepcopy(split_train)
    split_test = copy.deepcopy(split_train)

    for client in data_split['users']:
        split_train['users'].append(client)
        split_val['users'].append(client)
        split_test['users'].append(client)

        X = np.array(data_split['user_data'][client]["x"]).astype(np.float32)

        Y = np.array(data_split['user_data'][client]["y"]).astype(np.float32).reshape(-1,1)

        A = np.array(data_split['user_data'][client]["A"]).astype(np.float32).reshape(-1,1)

        n = np.arange(X.shape[0])
        indices = np.random.permutation(n)
        train_index, val_index, test_index = indices[:int(len(n)*0.6)], indices[:int(len(n)*0.6)], indices[int(len(n)*0.6):]
        split_train['user_data'][client] = FairDataset(X[train_index,:], Y[train_index,:], A[train_index,:])
        split_val['user_data'][client] = FairDataset(X[val_index,:], Y[val_index,:], A[val_index,:])
        split_test['user_data'][client] = FairDataset(X[test_index,:], Y[test_index,:], A[test_index,:])

        split_train['num_samples'][client] = len(train_index)
        split_val['num_samples'][client] = len(val_index)
        split_test['num_samples'][client] = len(test_index)
    
    return split_train,split_val,split_test


def acsincome_process(n_classes=2, sensitive_attr='sex', remove_sensitive_attr=True):

    if sensitive_attr == 'sex':
        sensitive_attr = 'SEX' 
    elif sensitive_attr == 'race':
        sensitive_attr = 'RAC1P' 

    from fairlearn.datasets import fetch_acs_income
    target = 'PINCP'
    features = [
        'AGEP', 'COW', 'SCHL', 'MAR', 'OCCP', 'POBP', 'RELP', 'WKHP', 'SEX',
        'RAC1P'
    ]
    categories = {
        "COW": {
            1.0: ("Employee of a private for-profit company or"
                "business, or of an individual, for wages,"
                "salary, or commissions"),
            2.0: ("Employee of a private not-for-profit, tax-exempt,"
                "or charitable organization"),
            3.0:
                "Local government employee (city, county, etc.)",
            4.0:
                "State government employee",
            5.0:
                "Federal government employee",
            6.0: ("Self-employed in own not incorporated business,"
                "professional practice, or farm"),
            7.0: ("Self-employed in own incorporated business,"
                "professional practice or farm"),
            8.0:
                "Working without pay in family business or farm",
            9.0:
                "Unemployed and last worked 5 years ago or earlier or never worked",
        },
        "SCHL": {
            1.0: "No schooling completed",
            2.0: "Nursery school, preschool",
            3.0: "Kindergarten",
            4.0: "Grade 1",
            5.0: "Grade 2",
            6.0: "Grade 3",
            7.0: "Grade 4",
            8.0: "Grade 5",
            9.0: "Grade 6",
            10.0: "Grade 7",
            11.0: "Grade 8",
            12.0: "Grade 9",
            13.0: "Grade 10",
            14.0: "Grade 11",
            15.0: "12th grade - no diploma",
            16.0: "Regular high school diploma",
            17.0: "GED or alternative credential",
            18.0: "Some college, but less than 1 year",
            19.0: "1 or more years of college credit, no degree",
            20.0: "Associate's degree",
            21.0: "Bachelor's degree",
            22.0: "Master's degree",
            23.0: "Professional degree beyond a bachelor's degree",
            24.0: "Doctorate degree",
        },
        "MAR": {
            1.0: "Married",
            2.0: "Widowed",
            3.0: "Divorced",
            4.0: "Separated",
            5.0: "Never married or under 15 years old",
        },
        "SEX": {
            1.0: "Male",
            2.0: "Female"
        },
        "RAC1P": {
            1.0: "White alone",
            2.0: "Black or African American alone",
            3.0: "American Indian alone",
            4.0: "Alaska Native alone",
            5.0: ("American Indian and Alaska Native tribes specified;"
                "or American Indian or Alaska Native,"
                "not specified and no other"),
            6.0: "Asian alone",
            7.0: "Native Hawaiian and Other Pacific Islander alone",
            8.0: "Some Other Race alone",
            9.0: "Two or More Races",
        },
    }

    pkl_path = 'data/acsincome/acsincome5.pkl'
    if os.path.exists(pkl_path):
        print(f"Found existing file: {pkl_path}. Loading from disk...")
        with open(pkl_path, "rb") as f:
            data, labels, label_names, groups, group_names = pickle.load(f)
            data = data.to_numpy(dtype='float32')
        return data, labels, label_names, groups, group_names
    print(f"processing continues.")

    # Download or load the dataset
    csv_path = 'data/acsincome/acs_income.csv'
    if os.path.exists(csv_path):
        print(f"Found existing file: {csv_path}. Loading from disk...")
        df = pd.read_csv(csv_path)
    else:
        print(f"{csv_path} not found. Downloading ACSIncome dataset...")
        # return pandas DataFrame
        X, y = fetch_acs_income(as_frame=True, return_X_y=True)
        df = X.copy()
        df["PINCP"] = y
        df.to_csv(csv_path, index=False)
        print(f"Downloaded and saved to {csv_path}.")
    print(f"Dataset shape: {df.shape} (rows, columns)")

    if n_classes == 2:
        label_names = ["<=50K", ">50K"]
        target_transform = lambda x: (x > 50000).astype(int)

    else:
        # Compute empirical CDF of PINCP
        x = np.sort(df[target])
        y = np.arange(len(x)) / float(len(x))

        # Partition into bins containing roughly the same number of samples
        partitions = np.array([
            x[np.argmax(y >= q)] for q in np.arange(1 / n_classes, 1, 1 / n_classes)
        ] + [np.inf])

        label_names = [f'[0, {partitions[0]})'] + [
            f'[{partitions[i]}, {partitions[i+1]})'
            for i in range(len(partitions) - 1)
        ]
        target_transform = lambda x: np.argmax(
            np.array(x)[:, None] < partitions[None, :], axis=1)

    if sensitive_attr == 'RAC1P':
        # Combine RAC1P categories 3, 4, 5, and 6, 7, and 8, 9 into new categories
        # 10, 11, and 12 respectively, due to small sample size in some groups.
        # This is also consistent with the UCI Adult dataset.
        categories['RAC1P'][10.0] = "American Indian or Alaska Native alone"
        categories['RAC1P'][
            11.0] = "Asian, Native Hawaiian or Other Pacific Islander alone"
        categories['RAC1P'][12.0] = "Other"
        df['RAC1P'] = df['RAC1P'].replace([3.0, 4.0, 5.0], 10.0)
        df['RAC1P'] = df['RAC1P'].replace([6.0, 7.0], 11.0)
        df['RAC1P'] = df['RAC1P'].replace([8.0, 9.0], 12.0)


    data, labels, groups = folktables.BasicProblem(
      features=features,
      target=target,
      target_transform=target_transform,
      group=sensitive_attr,
      postprocess=lambda x: np.nan_to_num(x, -1),
    ).df_to_pandas(df, categories=categories, dummies=True)

    labels = labels.values.squeeze()
    groups = groups.values.squeeze()

    group_names, groups = np.unique(groups, return_inverse=True)
    group_names = [categories[sensitive_attr][n] for n in group_names]

    if remove_sensitive_attr:
        data.drop(columns=list(data.filter(regex=f'^{sensitive_attr}')),
                inplace=True)
        
    data = df.values

    return data, labels, label_names, groups, group_names

    # # === Inline BasicProblem.df_to_pandas functionality ===
    # # Select features and replace categorical codes with descriptions
    # vars_df = df[features].replace(categories)
    # # One-hot encode all categorical columns
    # vars_df = pd.get_dummies(vars_df)
    # # Convert to numpy array, replacing NaNs with -1
    # data = pd.DataFrame(
    #     np.nan_to_num(vars_df.values, nan=-1),
    #     columns=vars_df.columns
    # )

    # # Apply target transformation
    # labels = target_transform(df[target])
    # labels = np.asarray(labels).squeeze()

    # # Extract and encode sensitive groups
    # groups_raw = df[sensitive_attr].values
    # unique_vals, groups = np.unique(groups_raw, return_inverse=True)
    # group_names = [categories[sensitive_attr][val] for val in unique_vals]

    # if remove_sensitive_attr:
    #     data.drop(columns=list(data.filter(regex=f'^{sensitive_attr}')),
    #             inplace=True)
    # return data.to_numpy(), labels, label_names, groups, group_names

def print_statistics_info(train_data, val_data, test_data):
    # Print statistics info
    print("=== Train Data ===")
    print("Number of samples:", len(train_data))
    print("Info table:")
    print(train_data.data_info)

    print("\n=== Validation Data ===")
    print("Number of samples:", len(val_data))
    print("Info table:")
    print(val_data.data_info)

    print("\n=== Test Data ===")
    print("Number of samples:", len(test_data))
    print("Info table:")
    print(test_data.data_info)

    # return {'train_info'}