import os
import sys
import time
import json
import math
import numpy as np
import pandas as pd
from collections import OrderedDict
from typing import Tuple, Sequence, Dict, List, Optional, Mapping, NamedTuple, Any
from sklearn.preprocessing import StandardScaler, MinMaxScaler, LabelEncoder, OneHotEncoder, RobustScaler
from numba import jit, prange


def gen_or_load_shuffle_idx_and_split(
        df: pd.DataFrame,
        idx_path: str,
        split_p: float
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    idx_path = os.path.join(sys.path[1], idx_path)
    if os.path.exists(idx_path):
        with open(idx_path) as f:
            shuffle_idx = json.load(f)
    else:
        shuffle_idx = np.random.permutation(df.index)
        with open(idx_path, "w") as f:
            json.dump(shuffle_idx.tolist(), f)

    df = df.reindex(shuffle_idx)
    N = len(df)
    train_df = df[:round(split_p * N)].copy()
    test_df = df[round(split_p * N):].copy()

    return train_df, test_df


class FeatIndex(NamedTuple):
    """ Feature index after one-hot encoding """
    cat_feat: Sequence[str]
    num_feat: Sequence[str]
    sen_feat: Sequence[str]
    feat2idx: Dict[str, Sequence[int]]
    cat_idx: Sequence[int]
    num_idx: Sequence[int]
    sen_idx: Sequence[int]


class CompSampleIterator():
    """ One type of comparable samples """

    def __init__(
            self,
            X: np.ndarray,
            y: np.ndarray,
            batch_size: int,
            comp_idx: Tuple[np.ndarray, np.ndarray],
            name: str,
    ):
        self.X = X
        self.y = y
        self._batch_size = batch_size
        self._idx1 = comp_idx[0].tolist()
        self._idx2 = comp_idx[1].tolist()
        self._name = name

        # indexes of samples conditional on target variable
        self._true_idx = tuple([i for i, (idx1, idx2) in enumerate(zip(self._idx1, self._idx2)) if
                                self.y[idx1] == self.y[idx2] and y[idx1] == 1])
        self._false_idx = tuple([i for i, (idx1, idx2) in enumerate(zip(self._idx1, self._idx2)) if
                                 self.y[idx1] == self.y[idx2] and y[idx1] == 0])

    @property
    def cond_idx(self) -> Tuple[Tuple[int], Tuple[int]]:
        return self._true_idx, self._false_idx

    @property
    def name(self) -> str:
        return self._name

    def __len__(self):
        return math.ceil(len(self._idx1) / self._batch_size)

    def __iter__(self) -> Tuple[np.ndarray, np.ndarray]:
        for i in range(len(self)):
            if i != len(self) - 1:
                idx1 = self._idx1[i * self._batch_size:(i + 1) * self._batch_size]
                idx2 = self._idx2[i * self._batch_size:(i + 1) * self._batch_size]
            else:
                idx1 = self._idx1[i * self._batch_size:]
                idx2 = self._idx2[i * self._batch_size:]

            yield self.X.take(idx1, axis=0), self.X.take(idx2, axis=0)


class CompData():
    """ Data loaders yielding comparable samples with various sensitive conditions """

    def __init__(
            self,
            X: np.ndarray,
            y: np.ndarray,
            comp_mat_dict: Dict[str, np.ndarray],
            batch_size: int,
    ):
        self._idx_sen_and = np.where(comp_mat_dict["comp_sen_and"] == 1)
        self._idx_sen_or = np.where(comp_mat_dict["comp_sen_or"] == 1)
        self._idx_sen_not = np.where(comp_mat_dict["comp_sen_not"] == 1)

        self._and_loader = CompSampleIterator(X, y, batch_size, self._idx_sen_and, "and")
        self._or_loader = CompSampleIterator(X, y, batch_size, self._idx_sen_or, "or")
        self._not_loader = CompSampleIterator(X, y, batch_size, self._idx_sen_not, "not")

    @property
    def loaders(self) -> Sequence[CompSampleIterator]:
        return (self._and_loader, self._or_loader, self._not_loader)

    @property
    def idx_sen_and(self):
        return np.copy(self._idx_sen_and)

    @property
    def idx_sen_or(self):
        return np.copy(self._idx_sen_or)

    @property
    def idx_sen_not(self):
        return np.copy(self._idx_sen_not)


class Dataset():
    """ Base class for dataset with flexible comparable requirements for individual fairness """

    def __init__(
            self,
            train_df: pd.DataFrame,
            test_df: pd.DataFrame,
            label_name: str,
            label_mapping: Mapping[str, Mapping[Any, int]],
            categorical_feat: Sequence[str],
            numerical_feat: Sequence[str],
            sensitive_feat: Sequence[str],
            categorical_thr: int,
            numerical_thr: float,
            dataset_name: str,
            drop_feat: Optional[Sequence[str]] = None,
            remove_dup_in_test: bool = False,
    ):
        """
        For two comparable samples
        categorical features differ at most 'categorical_thr'
        the sum of normalized numerical features differ at most 'numerical_thr'
        and no constraint on 'sensitive_feat'

        Currently only support categorical sensitive feature
        """

        assert isinstance(categorical_thr, int) and categorical_thr > 0, "Invalid categorical_thr for comparability"
        assert numerical_thr >= 0. and numerical_thr <= 1., "Invalid numerical_thr for comparability"
        assert sensitive_feat is not None, "At least one sensitive feature"
        assert all(train_df.columns == test_df.columns), "Column name misalignment between train and test"
        assert not (self.categorical_feat is None and self.numerical_feat is None), "Empty in both cat. and num."
        for s in sensitive_feat:
            assert s in train_df.columns, "Sensitive feat. %s not found in training data frame" % s
            assert s in test_df.columns, "Sensitive feat. %s not found in testing data frame" % s

        self._train_df = train_df
        self._test_df = test_df
        self._label_name = label_name
        self._label_mapping = label_mapping
        self._categorical_feat = self._sort_feat_name(categorical_feat, self._train_df.columns)
        self._numerical_feat = self._sort_feat_name(numerical_feat, self._train_df.columns)
        self._sensitive_feat = self._sort_feat_name(sensitive_feat, self._train_df.columns)
        self._categorical_thr = categorical_thr
        self._numerical_thr = numerical_thr
        self._dataset_name = dataset_name
        self._drop_feat = drop_feat

        self._train_df.dropna(inplace=True)
        self._test_df.dropna(inplace=True)
        self._train_df.reset_index(inplace=True, drop=True)
        self._test_df.reset_index(inplace=True, drop=True)
        if self._drop_feat is not None:
            self._train_df.drop(columns=self._drop_feat, inplace=True)
            self._test_df.drop(columns=self._drop_feat, inplace=True)

        all_feat = list(self._categorical_feat) + list(self._numerical_feat)
        all_feat.append(label_name)

        assert sorted(all_feat) == sorted(list(self._train_df.columns)), "Feature misalignment in training dataframe"
        assert sorted(all_feat) == sorted(list(self._test_df.columns)), "Feature misalignment in testing dataframe"

        self._cat_feat_wo_sensitive = list(self._categorical_feat)
        for f in self._sensitive_feat:
            self._cat_feat_wo_sensitive.remove(f)

        self.cat_scaler, self.num_scaler = StandardScaler(), MinMaxScaler()
        self.encoder = OneHotEncoder(sparse=False, handle_unknown="ignore")

        if remove_dup_in_test:
            temp = self._test_df.copy()
            temp.drop(columns=self._label_name, axis=1, inplace=True)
            duplicated = temp.duplicated().values
            print("=> %d out of %d are duplicated samples in test data" % (np.sum(duplicated), len(self._test_df)))
            keep_idx = np.where((1 - duplicated) == 1)[0]
            self._test_df = self._test_df.iloc[keep_idx, :]
            print("=> Duplications removed, leaving %d samples in test DataFrame" % len(self._test_df))
            self._test_df.reset_index(inplace=True, drop=True)

    def _sort_feat_name(self, sub_feat: Sequence[str], all_feat: Sequence[str]) -> Tuple[str]:
        """ Keep the order of feature subsets follow the features in DataFrame """
        sub_feat = [feat for feat in all_feat if feat in sub_feat]
        return tuple(sub_feat)

    def _get_X_and_y(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, np.ndarray]:
        """ Split features and target variable from dataframe """

        y = df[self._label_name].to_numpy().copy()
        for i in range(len(y)):
            y[i] = self._label_mapping[self._label_name][y[i]]

        X = df.copy()
        X.drop(columns=self._label_name, axis=1, inplace=True)

        return X, y

    def _comparability(self, df: pd.DataFrame) -> Dict[str, np.ndarray]:
        """ Determine the comparability of dyadic samples with various sensitive conditions """

        print("=> Computing comparability matrix...")

        # normalize numerical features
        if self.numerical_feat:
            min_max_scaler = MinMaxScaler()
            numerical_scaled = min_max_scaler.fit_transform(df[list(self._numerical_feat)].values)
            numerical_scaled = numerical_scaled.astype(np.float64)
        else:
            numerical_scaled = np.zeros((len(df), 1)).astype(np.float64)

        # encode categorical features
        categorical = df[self._cat_feat_wo_sensitive].values
        sensitive = df[list(self._sensitive_feat)].values
        le = LabelEncoder()
        for col in range(categorical.shape[1]):
            categorical[:, col] = le.fit_transform(categorical[:, col])
        for col in range(sensitive.shape[1]):
            sensitive[:, col] = le.fit_transform(sensitive[:, col])
        categorical = categorical.astype(np.int8)
        sensitive = sensitive.astype(np.int8)

        start = time.time()
        comp_mat, sen_mats = self._comp_func(
            len(df),
            categorical,
            numerical_scaled,
            sensitive,
            self._categorical_thr,
            self._numerical_thr,
        )
        sen_mats = np.asarray(sen_mats, dtype=np.bool_)

        sen_and_mat = np.prod(sen_mats, axis=0, dtype=np.bool_)
        sen_or_mat = np.sum(sen_mats, axis=0, dtype=np.bool_)
        sen_not_mat = np.invert(sen_or_mat)

        comp_sen_and = np.multiply(comp_mat, sen_and_mat)
        comp_sen_or = np.multiply(comp_mat, sen_or_mat)
        comp_sen_not = np.multiply(comp_mat, sen_not_mat)

        num_dyadic = comp_mat.shape[0] * (comp_mat.shape[1] - 1)
        num_comp = np.sum(comp_mat)
        num_comp_sen_and = np.sum(comp_sen_and)
        num_comp_sen_or = np.sum(comp_sen_or)
        num_comp_sen_not = np.sum(comp_sen_not)

        print("Computing time: {0:.2f}s; "
              "Dyadic samples: {1}; "
              "Comp. samples: {2}, {3:.5f}%; \n"
              "Comp. samples with all different sen.: {4}, {5:.5f}%; \n"
              "Comp. samples with at least one different sen.: {6}, {7:.5f}%; \n"
              "Comp. samples with same sen.: {8}, {9:.5f}%; "
              .format(time.time() - start,
                      num_dyadic,
                      num_comp, num_comp * 100 / num_dyadic,
                      num_comp_sen_and, num_comp_sen_and * 100 / num_dyadic,
                      num_comp_sen_or, num_comp_sen_or * 100 / num_dyadic,
                      num_comp_sen_not, num_comp_sen_not * 100 / num_dyadic
                      ))

        return {"comp_sen_and": comp_sen_and, "comp_sen_or": comp_sen_or, "comp_sen_not": comp_sen_not}

    @staticmethod
    @jit(nopython=True, parallel=False)
    def _comp_func(
            N: int,
            categorical: np.ndarray,
            numerical_scaled: np.ndarray,
            sensitive: np.ndarray,
            categorical_thr: int,
            numerical_thr: float,
    ) -> Tuple[np.ndarray, List[np.ndarray]]:
        """ Efficiently compute comparability and sensitive matrix """

        comp_mat = np.zeros((N, N), np.bool_)

        N_sen = sensitive.shape[1]
        sen_mats = [np.zeros((N, N), np.bool_) for _ in range(N_sen)]

        for i in prange(N - 1):
            for j in prange(i + 1, N):
                cat_diff = np.sum(categorical[i] != categorical[j])
                num_diff = np.max(np.abs(numerical_scaled[i] - numerical_scaled[j]))
                comp_mat[i, j] = (cat_diff <= categorical_thr) & (num_diff <= numerical_thr)
                for k in range(N_sen):
                    sen_mats[k][i, j] = sensitive[i, k] != sensitive[j, k]

        return comp_mat, sen_mats

    def is_comparable(self, X: np.ndarray, Y: np.ndarray, return_ratio: bool = False) -> List[bool]:
        """ X and Y is original data without normalization """

        assert X.shape[1] == Y.shape[1] == len(self.feat_idx.cat_idx) + len(self.feat_idx.num_idx)
        N = X.shape[0]

        # X, Y = self.scale(X, method="num"), self.scale(Y, method="num")

        cat_idx_non_sen = tuple(set(self.feat_idx.cat_idx) - set(self.feat_idx.sen_idx))

        cat_diff = np.logical_not(np.equal(X[:, cat_idx_non_sen], Y[:, cat_idx_non_sen])).astype(np.int64)
        cat_diff = np.sum(cat_diff, axis=1) <= (self._categorical_thr * 2)

        if self.numerical_feat:
            num_diff = np.amax(np.abs(X[:, self.feat_idx.num_idx] - Y[:, self.feat_idx.num_idx]), axis=1)
            num_diff = num_diff <= self._numerical_thr
        else:
            num_diff = [np.nan]

        sen_diff = np.logical_not(np.equal(X[:, self.feat_idx.sen_idx], Y[:, self.feat_idx.sen_idx])).astype(np.int64)
        sen_diff = np.sum(sen_diff, axis=1) > 0

        print("=> Checking Comparability. Cat: %.3f; Num: %.3f; Sen: %.3f; All: %.3f" % (
            sum(cat_diff) / N, sum(num_diff) / N, sum(sen_diff) / N,
            sum(np.logical_and(np.logical_and(cat_diff, num_diff), sen_diff)) / N))

        if return_ratio:
            return sum(cat_diff) / N, sum(num_diff) / N, sum(sen_diff) / N
        else:
            return np.logical_and(np.logical_and(cat_diff, num_diff), sen_diff)

    def _one_hot_encoding(self, df: pd.DataFrame, fit=False) -> pd.DataFrame:
        """
        One hot encoding with or without a default label mapping
        Data after one-hot encoding: numerical first then categorical attributes
        """

        df_cat = df[list(self._categorical_feat)]
        df_num = df[list(self._numerical_feat)]

        if fit:
            df_cat_values = self.encoder.fit_transform(df_cat.values)

            self._cat_col_name = ["%s.%s" % (feat_name, feat_val) for feat_name, all_val in
                                  zip(self._categorical_feat, self.encoder.categories_) for feat_val in all_val]

            self._feat2idx = OrderedDict({feat: [i] for i, feat in enumerate(self._numerical_feat)})
            curr_start_idx = len(self._numerical_feat)
            for i, feat in enumerate(self._categorical_feat):
                feat_idx = [j + curr_start_idx for j in range(len(self.encoder.categories_[i]))]
                self._feat2idx.update({feat: feat_idx})
                curr_start_idx += len(feat_idx)

            self._cat_idx, self._num_idx, self._sen_idx = [], [], []
            for feat in self._feat2idx.keys():
                if feat in self._categorical_feat:
                    self._cat_idx.extend(self._feat2idx[feat])
                if feat in self._numerical_feat:
                    self._num_idx.extend(self._feat2idx[feat])
                if feat in self._sensitive_feat:
                    self._sen_idx.extend(self._feat2idx[feat])

            self._cat_idx = tuple(self._cat_idx)
            self._num_idx = tuple(self._num_idx)
            self._sen_idx = tuple(self._sen_idx)
            self._feat2idx = {key: tuple(list_) for key, list_ in self._feat2idx.items()}
            self._feat2idx = OrderedDict(sorted(self._feat2idx.items(), key=lambda x: x[1][0]))  # order by index
        else:
            assert hasattr(self.encoder, "categories_"), "Call train_data to do one-hot encoding first"
            df_cat_values = self.encoder.transform(df_cat.values)

        df_cat = pd.DataFrame(df_cat_values, columns=self._cat_col_name)
        encoded_df = pd.concat([df_num, df_cat], axis=1)

        return encoded_df

    def scale(self, X: np.ndarray, method: str):
        """ Optionally normalize categorical and/or numerical features """
        # TODO: add support for all numerical feature

        assert method in ["all", "cat", "num", "none"], "Unknown scale argument"
        X = np.copy(X)

        if method == "all":
            X[:, self.feat_idx.cat_idx] = self.cat_scaler.transform(X[:, self.feat_idx.cat_idx])
            if self.numerical_feat:
                X[:, self.feat_idx.num_idx] = self.num_scaler.transform(X[:, self.feat_idx.num_idx])
        elif method == "cat":
            X[:, self.feat_idx.cat_idx] = self.cat_scaler.transform(X[:, self.feat_idx.cat_idx])
        elif method == "num" and self.numerical_feat:
            X[:, self.feat_idx.num_idx] = self.num_scaler.transform(X[:, self.feat_idx.num_idx])
        else:
            pass

        return X

    def train_data(self, scale: str = "none") -> Tuple[np.ndarray, np.ndarray]:
        X, y = self._get_X_and_y(self._train_df)
        X = self._one_hot_encoding(X, fit=True)
        self._encoded_feature_names = tuple(X.columns)
        X = X.to_numpy()

        X, y = self.validate(X, y)

        if self.categorical_feat:
            self.cat_scaler.fit(X[:, self.feat_idx.cat_idx])
        if self.numerical_feat:
            self.num_scaler.fit(X[:, self.feat_idx.num_idx])
        X = self.scale(X, scale)

        self._feat_dim = X.shape[1]

        return X, y

    def test_data(self, scale: str = "none") -> Tuple[np.ndarray, np.ndarray]:
        X, y = self._get_X_and_y(self._test_df)
        X = self._one_hot_encoding(X, fit=False)
        X = X.to_numpy()

        X, y = self.validate(X, y)
        X = self.scale(X, scale)

        return X, y

    def comp_data(self, batch_size, train=False, scale="none") -> CompData:
        """ Load of compute comparable matrix, and return comparable samples """

        comp_mat_name = "train" if train else "test"
        for sen in self._sensitive_feat:
            comp_mat_name += "_%s" % sen
        comp_mat_name += "_%d_%.3f.npy" % (self._categorical_thr, self._numerical_thr)
        comp_mat_path = os.path.join("./save/%s" % self._dataset_name, comp_mat_name)

        if os.path.exists(comp_mat_path):
            print("=> Loading comparability matrix from %s..." % (comp_mat_path))
            comp_mat_dict = np.load(comp_mat_path, allow_pickle=True).item()
        else:
            df = self.train_df if train else self.test_df
            comp_mat_dict = self._comparability(df)
            np.save(comp_mat_path, comp_mat_dict)

        X, y = self.train_data(scale) if train else self.test_data(scale)

        return CompData(X, y, comp_mat_dict, batch_size)

    @property
    def name(self):
        return self._dataset_name

    @property
    def train_df(self):
        return self._train_df.copy()

    @property
    def test_df(self):
        return self._test_df.copy()

    @property
    def categorical_feat(self) -> Tuple[str]:
        return self._categorical_feat

    @property
    def numerical_feat(self) -> Tuple[str]:
        return self._numerical_feat

    @property
    def sensitive_feat(self) -> Tuple[str]:
        return self._sensitive_feat

    @property
    def encoded_feature_names(self) -> Tuple[str]:
        if not hasattr(self, "_encoded_feature_names"):
            raise AttributeError("Call train_data to do one-hot encoding first")
        return self._encoded_feature_names

    @property
    def feat_dim(self) -> int:
        if not hasattr(self, "_feat_dim"):
            raise AttributeError("Call train_data to do one-hot encoding first")
        return self._feat_dim

    @property
    def feat_idx(self) -> FeatIndex:
        """ Index of features after one-hot encoding """
        if not hasattr(self, "_feat2idx"):
            raise AttributeError("Call train_data to do one-hot encoding first")
        return FeatIndex(self._categorical_feat, self._numerical_feat, self._sensitive_feat,
                         self._feat2idx, self._cat_idx, self._num_idx, self._sen_idx)

    @staticmethod
    def validate(X, y) -> Tuple[np.ndarray, np.ndarray]:
        X = X.astype(np.float64)
        y = y.astype(np.int64)

        assert len(np.unique(y)) > 1, "Only one class exists in the dataset"
        assert np.any(y) >= 0, "Class label should be an integer larger or equal to zero"

        return X, y


class Adult(Dataset):
    """ https://archive.ics.uci.edu/ml/datasets/adult """

    column_names = (
        'age',
        'workclass',
        'fnlwgt',
        'education',
        'education-num',
        'marital-status',
        'occupation',
        'relationship',
        'race',
        'sex',
        'capital-gain',
        'capital-loss',
        'hours-per-week',
        'native-country',
        'income-per-year',
    )

    categorical_feat = [
        'workclass',
        'education',
        'marital-status',
        'occupation',
        'relationship',
        'race',
        'sex',
        'native-country',
    ]

    numerical_feat = [
        'age',
        'education-num',
        'capital-gain',
        'capital-loss',
        'hours-per-week',
    ]

    drop_feat = ["fnlwgt"]
    label_name = "income-per-year"
    label_mapping = {label_name: {">50K": 1, "<=50K": 0}}
    na_values = ("?")

    def __init__(
            self,
            sensitive_feat: Sequence[str] = ("marital-status",),
            categorical_thr: int = 1,
            numerical_thr: float = 0.025,
    ):
        """
        For two comparable samples
        categorical features differ at most 'categorical_thr'
        normalized numerical features differ at most 'numerical_thr'
        and no constraint on 'sensitive_feat'
        """

        train_df = pd.read_csv(
            os.path.join(sys.path[1], "./data/adult/adult.data"),
            header=None,
            names=self.column_names,
            skipinitialspace=True,
            na_values=self.na_values,
        )

        test_df = pd.read_csv(
            os.path.join(sys.path[1], "./data/adult/adult.test"),
            header=None,
            names=self.column_names,
            skipinitialspace=True,
            na_values=self.na_values,
        )
        for i, row in test_df.iterrows():
            if isinstance(row["income-per-year"], str):
                test_df.at[i, "income-per-year"] = row["income-per-year"][:-1]

        super(Adult, self).__init__(
            train_df=train_df,
            test_df=test_df,
            label_name=self.label_name,
            label_mapping=self.label_mapping,
            categorical_feat=self.categorical_feat,
            numerical_feat=self.numerical_feat,
            sensitive_feat=sensitive_feat,
            categorical_thr=categorical_thr,
            numerical_thr=numerical_thr,
            dataset_name="adult",
            drop_feat=self.drop_feat,
        )


class German(Dataset):
    """ https://archive.ics.uci.edu/ml/datasets/Statlog+%28German+Credit+Data%29 """

    column_names = (
        "status",
        "month",
        "credit_history",
        "purpose",
        "credit_amount",
        "savings",
        "employment",
        "investment_as_income_percentage",
        "personal_status",
        "other_debtors",
        "residence_since",
        "property",
        "age",
        "installment_plans",
        "housing",
        "number_of_credits",
        "skill_level",
        "people_liable_for",
        "telephone",
        "foreign_worker",
        "credit",
    )

    categorical_feat = [
        "status",
        "credit_history",
        "purpose",
        "savings",
        "employment",
        "other_debtors",
        "property",
        "installment_plans",
        "housing",
        "skill_level",
        "telephone",
        "foreign_worker",
        "sex",
        "marital-status",
    ]

    numerical_feat = [
        "month",
        "credit_amount",
        "investment_as_income_percentage",
        "residence_since",
        "age",
        "number_of_credits",
        "people_liable_for",
    ]

    drop_feat = ["personal_status"]
    label_name = "credit"
    label_mapping = {label_name: {1: 1, 2: 0}}
    na_values = ()

    def __init__(
            self,
            sensitive_feat: Sequence[str] = ("sex", "marital-status", "foreign_worker"),
            categorical_thr: int = 3,
            numerical_thr: float = 0.2,
            idx_path: str = "./save/german/idx.json",
            split_p: float = 0.75,
    ):
        df = pd.read_csv(
            os.path.join(sys.path[1], "./data/german/german.data"),
            sep=" ",
            names=self.column_names,
        )
        df = self.default_preprocessing(df)
        train_df, test_df = gen_or_load_shuffle_idx_and_split(df, idx_path, split_p)

        super(German, self).__init__(
            train_df=train_df,
            test_df=test_df,
            label_name=self.label_name,
            label_mapping=self.label_mapping,
            categorical_feat=self.categorical_feat,
            numerical_feat=self.numerical_feat,
            sensitive_feat=sensitive_feat,
            categorical_thr=categorical_thr,
            numerical_thr=numerical_thr,
            dataset_name="german",
            drop_feat=self.drop_feat,
        )

    @staticmethod
    def default_preprocessing(df):
        """
        Adds a derived sex attribute based on personal_status.
        https://github.com/Trusted-AI/AIF360/blob/master/aif360/datasets/german_dataset.py
        """

        sex_map = {'A91': 'male', 'A93': 'male', 'A94': 'male', 'A92': 'female', 'A95': 'female'}
        marital_status_map = {'A91': 'married', 'A92': 'married', 'A93': 'single', 'A94': 'married', 'A95': 'single'}
        df['sex'] = df['personal_status'].replace(sex_map)
        df["marital-status"] = df["personal_status"].replace(marital_status_map)

        return df


class Compas(Dataset):
    """ https://github.com/propublica/compas-analysis """

    features_to_keep = (
        "sex",
        "age",
        "age_cat",
        "race",
        "juv_fel_count",
        "juv_misd_count",
        "juv_other_count",
        "priors_count",
        "c_charge_degree",
        "c_charge_desc",
        "two_year_recid",
    )

    categorical_feat = [
        "sex",
        "age_cat",
        "race",
        "c_charge_degree",
        "c_charge_desc",
    ]

    numerical_feat = [
        "age",
        "juv_fel_count",
        "juv_misd_count",
        "juv_other_count",
        "priors_count",
    ]

    drop_feat = []
    label_name = "two_year_recid"
    label_mapping = {label_name: {"No recid.": 1, "Did recid.": 0}}
    na_values = ()

    def __init__(
            self,
            sensitive_feat: Sequence[str] = ("race", "sex",),
            categorical_thr: int = 1,
            numerical_thr: float = 0.025,
            idx_path: str = "./save/compas/idx.json",
            split_p: float = 0.75,
    ):
        df = pd.read_csv(os.path.join(sys.path[1], "./data/compas/compas-scores-two-years.csv"), index_col='id')
        df = self.default_preprocessing(df)
        df = df[list(self.features_to_keep)]
        train_df, test_df = gen_or_load_shuffle_idx_and_split(df, idx_path, split_p)

        super(Compas, self).__init__(
            train_df=train_df,
            test_df=test_df,
            label_name=self.label_name,
            label_mapping=self.label_mapping,
            categorical_feat=self.categorical_feat,
            numerical_feat=self.numerical_feat,
            sensitive_feat=sensitive_feat,
            categorical_thr=categorical_thr,
            numerical_thr=numerical_thr,
            dataset_name="compas",
            drop_feat=self.drop_feat,
        )

    @staticmethod
    def default_preprocessing(df):
        """
        Perform the same preprocessing as the original analysis:
        https://github.com/propublica/compas-analysis/blob/master/Compas%20Analysis.ipynb
        """

        def race(row):
            return 'Caucasian' if row['race'] == "Caucasian" else 'Not Caucasian'

        def two_year_recid(row):
            return 'Did recid.' if row['two_year_recid'] == 1 else 'No recid.'

        # df['race'] = df.apply(lambda row: race(row), axis=1)
        df['two_year_recid'] = df.apply(lambda row: two_year_recid(row), axis=1)

        return df[(df.days_b_screening_arrest <= 30)
                  & (df.days_b_screening_arrest >= -30)
                  & (df.is_recid != -1)
                  & (df.c_charge_degree != 'O')
                  & (df.score_text != 'N/A')]


class Bank(Dataset):
    """ https://archive.ics.uci.edu/ml/datasets/bank+marketing """

    categorical_feat = [
        "age",
        "job",
        "marital",
        "education",
        "default",
        "housing",
        "loan",
        "contact",
        "month",
        "day_of_week",
        "poutcome",
    ]

    numerical_feat = [
        "duration",
        "campaign",
        "pdays",
        "previous",
        "emp.var.rate",
        "cons.price.idx",
        "cons.conf.idx",
        "euribor3m",
        "nr.employed"
    ]

    drop_feat = []
    label_name = "y"
    label_mapping = {label_name: {"yes": 1, "no": 0}}
    na_values = ("unknown")

    def __init__(
            self,
            sensitive_feat: Sequence[str] = ("age",),
            categorical_thr: int = 1,
            numerical_thr: float = 0.025,
            idx_path: str = "./save/bank/idx.json",
            split_p: float = 0.75,
    ):
        df = pd.read_csv(
            os.path.join(sys.path[1], "./data/bank-additional/bank-additional-full.csv"),
            sep=";",
            na_values=self.na_values,
        )

        """ Convert age into a binary feature """
        df["age"].where(df["age"] >= 30, -1, inplace=True)
        df["age"].where(df["age"] < 30, 1, inplace=True)

        train_df, test_df = gen_or_load_shuffle_idx_and_split(df, idx_path, split_p)

        super(Bank, self).__init__(
            train_df=train_df,
            test_df=test_df,
            label_name=self.label_name,
            label_mapping=self.label_mapping,
            categorical_feat=self.categorical_feat,
            numerical_feat=self.numerical_feat,
            sensitive_feat=sensitive_feat,
            categorical_thr=categorical_thr,
            numerical_thr=numerical_thr,
            dataset_name="bank",
            drop_feat=self.drop_feat,
        )


class MEPS(Dataset):
    """ https://github.com/Trusted-AI/AIF360/tree/master/aif360/datasets """

    features_to_keep = [
        'REGION', 'AGE', 'SEX', 'RACE', 'MARRY', 'FTSTU', 'ACTDTY', 'HONRDC', 'RTHLTH', 'MNHLTH', 'HIBPDX', 'CHDDX',
        'ANGIDX', 'MIDX', 'OHRTDX', 'STRKDX', 'EMPHDX', 'CHBRON', 'CHOLDX', 'CANCERDX', 'DIABDX', 'JTPAIN', 'ARTHDX',
        'ARTHTYPE', 'ASTHDX', 'ADHDADDX', 'PREGNT', 'WLKLIM', 'ACTLIM', 'SOCLIM', 'COGLIM', 'DFHEAR42', 'DFSEE42',
        'ADSMOK42', 'PCS42', 'MCS42', 'K6SUM42', 'PHQ242', 'EMPST', 'POVCAT', 'INSCOV', 'UTILIZATION',
    ]

    categorical_feat = [
        'REGION', 'SEX', 'RACE', 'MARRY', 'FTSTU', 'ACTDTY', 'HONRDC', 'RTHLTH', 'MNHLTH', 'HIBPDX', 'CHDDX',
        'ANGIDX', 'MIDX', 'OHRTDX', 'STRKDX', 'EMPHDX', 'CHBRON', 'CHOLDX', 'CANCERDX', 'DIABDX', 'JTPAIN',
        'ARTHDX', 'ARTHTYPE', 'ASTHDX', 'ADHDADDX', 'PREGNT', 'WLKLIM', 'ACTLIM', 'SOCLIM', 'COGLIM', 'DFHEAR42',
        'DFSEE42', 'ADSMOK42', 'PHQ242', 'EMPST', 'POVCAT', 'INSCOV',
    ]

    numerical_feat = [
        'AGE', 'PCS42', 'MCS42', 'K6SUM42'
    ]

    drop_feat = []
    label_name = "UTILIZATION"
    label_mapping = {label_name: {1: 1, 0: 0}}

    def __init__(
            self,
            panel: int,
            fy: str,
            df: pd.DataFrame,
            sensitive_feat: Sequence[str],
            categorical_thr: int,
            numerical_thr: float,
            idx_path: str,
            split_p: float = 0.75,
    ):
        assert panel in (19, 20, 21)
        assert fy in ("15", "16")
        self.panel = panel
        self.fy = fy

        self.features_to_keep.append('PERWT' + self.fy + 'F')
        self.numerical_feat.append('PERWT' + self.fy + 'F')

        df = self.default_preprocessing(df)
        df = df[self.features_to_keep]

        idx_path = os.path.join(sys.path[1], idx_path)
        train_df, test_df = gen_or_load_shuffle_idx_and_split(df, idx_path, split_p)

        super(MEPS, self).__init__(
            train_df=train_df,
            test_df=test_df,
            label_name=self.label_name,
            label_mapping=self.label_mapping,
            categorical_feat=self.categorical_feat,
            numerical_feat=self.numerical_feat,
            sensitive_feat=sensitive_feat,
            categorical_thr=categorical_thr,
            numerical_thr=numerical_thr,
            dataset_name="meps" + str(panel),
            drop_feat=self.drop_feat,
        )

    def default_preprocessing(self, df):
        def race(row):
            """ non-Hispanic Whites are marked as WHITE, all others as NON-WHITE """
            if ((row['HISPANX'] == 2) and (row['RACEV2X'] == 1)):
                return 'White'
            return 'Non-White'

        df['RACEV2X'] = df.apply(lambda row: race(row), axis=1)
        df = df.rename(columns={'RACEV2X': 'RACE'})

        df = df[df['PANEL'] == self.panel]

        df = df.rename(columns={'FTSTU53X': 'FTSTU', 'ACTDTY53': 'ACTDTY', 'HONRDC53': 'HONRDC', 'RTHLTH53': 'RTHLTH',
                                'MNHLTH53': 'MNHLTH', 'CHBRON53': 'CHBRON', 'JTPAIN53': 'JTPAIN', 'PREGNT53': 'PREGNT',
                                'WLKLIM53': 'WLKLIM', 'ACTLIM53': 'ACTLIM', 'SOCLIM53': 'SOCLIM', 'COGLIM53': 'COGLIM',
                                'EMPST53': 'EMPST', 'REGION53': 'REGION', 'MARRY53X': 'MARRY', 'AGE53X': 'AGE',
                                'POVCAT' + self.fy: 'POVCAT', 'INSCOV' + self.fy: 'INSCOV'})

        df = df[df['REGION'] >= 0]  # remove values -1
        df = df[df['AGE'] >= 0]  # remove values -1
        df = df[df['MARRY'] >= 0]  # remove values -1, -7, -8, -9
        df = df[df['ASTHDX'] >= 0]  # remove values -1, -7, -8, -9

        # for all other categorical features, remove values < -1
        df = df[(df[['FTSTU', 'ACTDTY', 'HONRDC', 'RTHLTH', 'MNHLTH', 'HIBPDX', 'CHDDX', 'ANGIDX', 'EDUCYR', 'HIDEG',
                     'MIDX', 'OHRTDX', 'STRKDX', 'EMPHDX', 'CHBRON', 'CHOLDX', 'CANCERDX', 'DIABDX',
                     'JTPAIN', 'ARTHDX', 'ARTHTYPE', 'ASTHDX', 'ADHDADDX', 'PREGNT', 'WLKLIM',
                     'ACTLIM', 'SOCLIM', 'COGLIM', 'DFHEAR42', 'DFSEE42', 'ADSMOK42',
                     'PHQ242', 'EMPST', 'POVCAT', 'INSCOV']] >= -1).all(1)]

        def utilization(row):
            return row['OBTOTV' + self.fy] + row['OPTOTV' + self.fy] + row['ERTOT' + self.fy] \
                   + row['IPNGTD' + self.fy] + row['HHTOTD' + self.fy]

        df['TOTEXP' + self.fy] = df.apply(lambda row: utilization(row), axis=1)
        lessE = df['TOTEXP' + self.fy] < 10.0
        df.loc[lessE, 'TOTEXP' + self.fy] = 0.0
        moreE = df['TOTEXP' + self.fy] >= 10.0
        df.loc[moreE, 'TOTEXP' + self.fy] = 1.0

        df = df.rename(columns={'TOTEXP' + self.fy: 'UTILIZATION'})

        return df


class MEPS19(MEPS):
    """ panel 19 fy 2015 """

    def __init__(
            self,
            sensitive_feat: Sequence[str] = ("RACE",),
            categorical_thr: int = 3,
            numerical_thr: float = 0.05,
            idx_path: str = "./data/meps/meps19_idx.json",
    ):
        df = pd.read_csv(os.path.join(sys.path[1], "./data/meps/h181.csv"), sep=",")

        super(MEPS19, self).__init__(
            panel=19,
            fy="15",
            df=df,
            sensitive_feat=sensitive_feat,
            categorical_thr=categorical_thr,
            numerical_thr=numerical_thr,
            idx_path=idx_path,
        )


class MEPS20(MEPS):
    """ panel 19 fy 2015 """

    def __init__(
            self,
            sensitive_feat: Sequence[str] = ("RACE",),
            categorical_thr: int = 3,
            numerical_thr: float = 0.05,
            idx_path: str = "./data/meps/meps20_idx.json",
    ):
        df = pd.read_csv(os.path.join(sys.path[1], "./data/meps/h181.csv"), sep=",")

        super(MEPS20, self).__init__(
            panel=20,
            fy="15",
            df=df,
            sensitive_feat=sensitive_feat,
            categorical_thr=categorical_thr,
            numerical_thr=numerical_thr,
            idx_path=idx_path,
        )


class MEPS21(MEPS):
    """ panel 19 fy 2015 """

    def __init__(
            self,
            sensitive_feat: Sequence[str] = ("RACE", "SEX",),
            categorical_thr: int = 5,
            numerical_thr: float = 0.05,
            idx_path: str = "./data/meps/meps21_idx.json",
    ):
        df = pd.read_csv(os.path.join(sys.path[1], "./data/meps/h192.csv"), sep=",")

        super(MEPS21, self).__init__(
            panel=21,
            fy="16",
            df=df,
            sensitive_feat=sensitive_feat,
            categorical_thr=categorical_thr,
            numerical_thr=numerical_thr,
            idx_path=idx_path,
        )


class Credit(Dataset):
    """ https://archive.ics.uci.edu/ml/datasets/default+of+credit+card+clients """

    categorical_feat = [
        "SEX",
        "EDUCATION",
        "MARRIAGE",
        "PAY_0",
        "PAY_2",
        "PAY_3",
        "PAY_4",
        "PAY_5",
        "PAY_6",
    ]

    numerical_feat = [
        "LIMIT_BAL",
        "AGE",
        "BILL_AMT1",
        "BILL_AMT2",
        "BILL_AMT3",
        "BILL_AMT4",
        "BILL_AMT5",
        "BILL_AMT6",
        "PAY_AMT1",
        "PAY_AMT2",
        "PAY_AMT3",
        "PAY_AMT4",
        "PAY_AMT5",
        "PAY_AMT6",
    ]

    drop_feat = []
    label_name = "default payment next month"
    label_mapping = {label_name: {1: 1, 0: 0}}

    def __init__(
            self,
            sensitive_feat: Sequence[str] = ("SEX",),
            categorical_thr: int = 1,
            numerical_thr: float = 0.025,
            idx_path: str = "./save/credit/idx.json",
            split_p: float = 0.75,
    ):
        df = pd.read_excel("./data/credit/default of credit card clients.xls", header=1, index_col=0)
        idx_path = os.path.join(sys.path[1], idx_path)
        train_df, test_df = gen_or_load_shuffle_idx_and_split(df, idx_path, split_p)

        super(Credit, self).__init__(
            train_df=train_df,
            test_df=test_df,
            label_name=self.label_name,
            label_mapping=self.label_mapping,
            categorical_feat=self.categorical_feat,
            numerical_feat=self.numerical_feat,
            sensitive_feat=sensitive_feat,
            categorical_thr=categorical_thr,
            numerical_thr=numerical_thr,
            dataset_name="credit",
            drop_feat=self.drop_feat,
        )


class Dutch(Dataset):
    """ https://github.com/tailequy/fairness_dataset/blob/main/experiments/data/dutch.csv """

    categorical_feat = [
        "sex",
        "age",
        "household_position",
        "household_size",
        "prev_residence_place",
        "citizenship",
        "country_birth",
        "edu_level",
        "economic_status",
        "cur_eco_activity",
        "marital_status",
    ]

    numerical_feat = [
    ]

    drop_feat = []
    label_name = "occupation"
    label_mapping = {label_name: {1: 1, 0: 0}}

    def __init__(
            self,
            sensitive_feat: Sequence[str] = ("sex",),
            categorical_thr: int = 1,
            numerical_thr: float = 0.025,
            idx_path: str = "./save/dutch/idx.json",
            split_p: float = 0.75,
    ):
        df = pd.read_csv("./data/dutch/dutch.csv")
        idx_path = os.path.join(sys.path[1], idx_path)
        train_df, test_df = gen_or_load_shuffle_idx_and_split(df, idx_path, split_p)

        super(Dutch, self).__init__(
            train_df=train_df,
            test_df=test_df,
            label_name=self.label_name,
            label_mapping=self.label_mapping,
            categorical_feat=self.categorical_feat,
            numerical_feat=self.numerical_feat,
            sensitive_feat=sensitive_feat,
            categorical_thr=categorical_thr,
            numerical_thr=numerical_thr,
            dataset_name="dutch",
            drop_feat=self.drop_feat,
        )


class LawSchool(Dataset):
    """ https://eric.ed.gov/?id=ED469370 """

    categorical_feat = [
        "fulltime",
        "fam_inc",
        "male",
        "tier",
        "race",
    ]

    numerical_feat = [
        "decile1b",
        "decile3",
        "lsat",
        "ugpa",
        "zfygpa",
        "zgpa",
    ]

    drop_feat = []
    label_name = "pass_bar"
    label_mapping = {label_name: {1: 1, 0: 0}}

    def __init__(
            self,
            sensitive_feat: Sequence[str] = ("race",),
            categorical_thr: int = 1,
            numerical_thr: float = 0.1,
            idx_path: str = "./save/law_school/idx.json",
            split_p: float = 0.75,
    ):
        df = pd.read_csv("./data/law_school/law_school.csv", sep=",")
        idx_path = os.path.join(sys.path[1], idx_path)
        train_df, test_df = gen_or_load_shuffle_idx_and_split(df, idx_path, split_p)

        super(LawSchool, self).__init__(
            train_df=train_df,
            test_df=test_df,
            label_name=self.label_name,
            label_mapping=self.label_mapping,
            categorical_feat=self.categorical_feat,
            numerical_feat=self.numerical_feat,
            sensitive_feat=sensitive_feat,
            categorical_thr=categorical_thr,
            numerical_thr=numerical_thr,
            dataset_name="law_school",
            drop_feat=self.drop_feat,
        )


class Diabetes(Dataset):
    """ https://archive.ics.uci.edu/ml/datasets/diabetes+130-us+hospitals+for+years+1999-2008 """

    categorical_feat = [
        'race',
        'gender',
        'age',
        'admission_type_id',
        'discharge_disposition_id',
        'admission_source_id',
        'diag_1',
        'diag_2',
        'diag_3',
        'max_glu_serum',
        'A1Cresult',
        'metformin',
        'repaglinide',
        'nateglinide',
        'chlorpropamide',
        'glimepiride',
        'acetohexamide',
        'glipizide',
        'glyburide',
        'tolbutamide',
        'pioglitazone',
        'rosiglitazone',
        'acarbose',
        'miglitol',
        'troglitazone',
        'tolazamide',
        'examide',
        'citoglipton',
        'insulin',
        'glyburide-metformin',
        'glipizide-metformin',
        'glimepiride-pioglitazone',
        'metformin-rosiglitazone',
        'metformin-pioglitazone',
        'change',
        'diabetesMed',
    ]

    numerical_feat = [
        'time_in_hospital',
        'num_lab_procedures',
        'num_procedures',
        'num_medications',
        'number_outpatient',
        'number_emergency',
        'number_inpatient',
        'number_diagnoses'
    ]

    drop_feat = ["encounter_id", "patient_nbr", "weight", "payer_code", "medical_specialty"]
    label_name = "readmitted"
    label_mapping = {label_name: {"<30": 1, ">30": 0}}
    na_values = ("?")

    def __init__(
            self,
            sensitive_feat: Sequence[str] = ("age",),
            categorical_thr: int = 6,
            numerical_thr: float = 0.075,
            idx_path: str = "./save/diabetes/idx.json",
            split_p: float = 0.75,
    ):
        """ Delete drop_feat first, then remove rows with nan values """

        df = pd.read_csv("./data/diabetes/diabetic_data.csv", sep=",")
        df = df.loc[:, ~df.columns.isin(self.drop_feat)]
        df = df.loc[df['readmitted'] != 'NO']
        df = df[~df.eq(self.na_values).any(1)]
        df.reset_index(inplace=True, drop=True)

        idx_path = os.path.join(sys.path[1], idx_path)
        train_df, test_df = gen_or_load_shuffle_idx_and_split(df, idx_path, split_p)

        super(Diabetes, self).__init__(
            train_df=train_df,
            test_df=test_df,
            label_name=self.label_name,
            label_mapping=self.label_mapping,
            categorical_feat=self.categorical_feat,
            numerical_feat=self.numerical_feat,
            sensitive_feat=sensitive_feat,
            categorical_thr=categorical_thr,
            numerical_thr=numerical_thr,
            dataset_name="diabetes",
            drop_feat=[],
        )


class Oulad(Dataset):
    """ https://analyse.kmi.open.ac.uk/open_dataset """

    categorical_feat = [
        "code_module",
        "code_presentation",
        "gender",
        "region",
        "highest_education",
        "imd_band",
        "age_band",
        "disability",
    ]

    numerical_feat = [
        "num_of_prev_attempts",
        "studied_credits",
    ]

    drop_feat = ["id_student"]
    label_name = "final_result"
    label_mapping = {label_name: {"Pass": 1, "Fail": 0}}
    na_values = (" ")

    def __init__(
            self,
            sensitive_feat: Sequence[str] = ("age_band",),
            categorical_thr: int = 1,
            numerical_thr: float = 0.025,
            idx_path: str = "./save/oulad/idx.json",
            split_p: float = 0.75,
    ):
        df = pd.read_csv("./data/oulad/anonymisedData/studentInfo.csv")
        df = self.default_preprocessing(df)
        train_df, test_df = gen_or_load_shuffle_idx_and_split(df, idx_path, split_p)

        super(Oulad, self).__init__(
            train_df=train_df,
            test_df=test_df,
            label_name=self.label_name,
            label_mapping=self.label_mapping,
            categorical_feat=self.categorical_feat,
            numerical_feat=self.numerical_feat,
            sensitive_feat=sensitive_feat,
            categorical_thr=categorical_thr,
            numerical_thr=numerical_thr,
            dataset_name="oulad",
            drop_feat=self.drop_feat,
        )

    @staticmethod
    def default_preprocessing(df):
        def final_result(row):
            return 'Pass' if row['final_result'] == "Distinction" else row['final_result']

        df['final_result'] = df.apply(lambda row: final_result(row), axis=1)
        df = df[(df["final_result"] == "Pass") | (df["final_result"] == "Fail")]
        df.reset_index(inplace=True, drop=True)

        return df


def fetch_dataset(name: str):
    if name == "adult":
        dataset = Adult()
    elif name == "compas":
        dataset = Compas()
    elif name == "german":
        dataset = German()
    elif name == "bank":
        dataset = Bank()
    elif name == "meps19":
        dataset = MEPS19()
    elif name == "meps20":
        dataset = MEPS20()
    elif name == "meps21":
        dataset = MEPS21()
    elif name == "credit":
        dataset = Credit()
    elif name == "dutch":
        dataset = Dutch()
    elif name == "law_school":
        dataset = LawSchool()
    elif name == "diabetes":
        dataset = Diabetes()
    elif name == "oulad":
        dataset = Oulad()
    else:
        raise ValueError

    return dataset


if __name__ == "__main__":
    dataset = Dutch()
    train_X, train_y = dataset.train_data(scale="all")
    test_X, test_y = dataset.test_data(scale="all")

    feat_idx = dataset.feat_idx

    train_comp_data = dataset.comp_data(batch_size=1024, train=True, scale="cat")
    test_comp_data = dataset.comp_data(batch_size=1024, train=False, scale="cat")

    print(train_X.shape, test_X.shape)

    print("Comparable samples with positive labels in training", len(train_comp_data.loaders[1].cond_idx[0]))
    print("Comparable samples with negative labels in training", len(train_comp_data.loaders[1].cond_idx[1]))
    print("Comparable samples with positive labels in testing", len(test_comp_data.loaders[1].cond_idx[0]))
    print("Comparable samples with negative labels in testing", len(test_comp_data.loaders[1].cond_idx[1]))
