import os
import csv
import random
import numpy as np
import pandas as pd
import math
import torch

from sklearn.preprocessing import OneHotEncoder
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize
from torch.utils.data import Dataset


class CustomizedDataset(Dataset):
    def __init__(self, attribute_dict):
        # self.raw_X = attribute_dict['raw_X']
        # self.raw_X_mask_s1 = attribute_dict['raw_X_mask_s1']
        # self.raw_X_mask_s2 = attribute_dict['raw_X_mask_s2']
        # self.raw_X_mask_s1_s2 = attribute_dict['raw_X_mask_s1_s2']

        self.s1 = np.array(attribute_dict['s1'])
        self.s2 = np.array(attribute_dict['s2'])
        # self.X_mask_s1_s2 = np.array(attribute_dict['X_mask_s1_s2'])
        # self.X_mask_s1 = np.array(attribute_dict['X_mask_s1'])
        # self.X_mask_s2 = np.array(attribute_dict['X_mask_s2'])

        # self.X = torch.from_numpy(attribute_dict['X'], requires_grad=True)
        # self.X.requires_grad = True
        self.X = torch.tensor(attribute_dict['X'], requires_grad=False)
        # self.X = np.array(attribute_dict['X'])
        self.y = np.array(attribute_dict['y'])

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return {"s1": self.s1[idx],
                "s2": self.s2[idx],
                # "X_mask_s1_s2": self.X_mask_s1_s2[idx],
                # "X_mask_s1": self.X_mask_s1[idx],
                # "X_mask_s2": self.X_mask_s2[idx],
                "X": self.X[idx],
                "y": self.y[idx]
                }


# Some codes are borrow from https://github.com/optimization-for-data-driven-science/Renyi-Fair-Inference
def get_ADULT_dataset(data_path, mask_s1_flag, mask_s2_flag, mask_s1_s2_flag, use_csv_file=True, do_pca=False,
                      pca_dimension=64):
    # Using the csv files provided by Renyi-Fair-Inference (https://github.com/optimization-for-data-driven-science/Renyi-Fair-Inference)
    if use_csv_file:
        # Loading the label and sensitive attribute in training set
        with open(os.path.join(data_path, 'adult.data')) as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=',')
            y = []
            s1 = []
            s2 = []

            i = 0
            for row in csv_reader:
                if i == 0:
                    i += 1
                    continue

                if (row[9] == "Male") or ("Male" in row[9]):
                    s1.append(1)
                else:
                    s1.append(0)

                if (row[8] == "White") or ("White" in row[8]):
                    s2.append(1)
                else:
                    s2.append(0)

                if (row[14] == '>50K') or ('>50K' in row[14]):
                    y.append(1)
                else:
                    y.append(0)

        # Loading the label and sensitive attribute in test set
        with open(os.path.join(data_path, 'adult.test')) as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=',')
            testY = []
            testS1 = []
            testS2 = []
            i = 0
            for row in csv_reader:
                if i == 0:
                    i += 1
                    continue

                if (row[9] == "Male") or ("Male" in row[9]):
                    testS1.append(1)
                else:
                    testS1.append(0)

                if (row[8] == "White") or ("White" in row[8]):
                    testS2.append(1)
                else:
                    testS2.append(0)

                if (row[14] == '>50K') or ('>50K' in row[14]):
                    testY.append(1)
                else:
                    testY.append(0)

        with open(os.path.join(data_path, 'AdultTrain.csv')) as csv_file:
            csv_reader = csv.reader(csv_file)
            X = []
            i = 0
            for row in csv_reader:
                if i == 0:
                    i += 1
                    continue

                new_row = []
                for item in row:
                    new_row.append(float(item))

                new_row.append(1)  # intercept
                X.append(new_row)

        with open(os.path.join(data_path, 'AdultTest.csv')) as csv_file:
            csv_reader = csv.reader(csv_file)

            testX = []
            i = 0
            for row in csv_reader:
                if i == 0:
                    i += 1
                    continue

                new_row = []
                for item in row:
                    new_row.append(float(item))

                new_row.append(1)  # intercept

                testX.append(new_row)

        X = normalize(X, axis=0)
        testX = normalize(testX, axis=0)

        # Constructing the training dataset
        training_attribute_dict = {
            'X': X, 's1': s1, 's2': s2, 'y': y
        }
        training_dataset = CustomizedDataset(attribute_dict=training_attribute_dict)

        # Constructing the positive and negative training dataset
        positive_X, negative_X, positive_y, negative_y, positive_s2, negative_s2 = [], [], [], [], [], []

        # positive data point index in s1
        positive_array = (np.array(s1) == 1)
        for index, item in enumerate(positive_array):
            if item:
                positive_X.append(training_attribute_dict["X"][index])
                positive_s2.append(training_attribute_dict["s2"][index])
                positive_y.append(training_attribute_dict["y"][index])
            else:
                negative_X.append(training_attribute_dict["X"][index])
                negative_s2.append(training_attribute_dict["s2"][index])
                negative_y.append(training_attribute_dict["y"][index])

        positive_training_attribute_dict = {
            "X": np.array(positive_X), 's1': [1 for i in range(len(positive_X))], 's2': positive_s2, 'y': positive_y
        }
        negative_training_attribute_dict = {
            "X": np.array(negative_X), 's1': [0 for i in range(len(negative_X))], 's2': negative_s2, 'y': negative_y
        }
        positive_training_dataset = CustomizedDataset(attribute_dict=positive_training_attribute_dict)
        negative_training_dataset = CustomizedDataset(attribute_dict=negative_training_attribute_dict)

        # Constructing the testing dataset
        testing_attribute_dict = {
            'X': testX, 's1': testS1, 's2': testS2, 'y': testY
        }
        testing_dataset = CustomizedDataset(attribute_dict=testing_attribute_dict)

    # Preprocessing from raw data
    else:
        enc = OneHotEncoder()
        # Added the function of dimensionality reduction using PCA
        if do_pca:
            pca = PCA(n_components=pca_dimension)

        # Preprocess (training dataset)
        raw_X, raw_X_mask_s1, raw_X_mask_s2, raw_X_mask_s1_s2 = [], [], [], []
        y = []  # (Training set)Income over 50K (T:1, F:0)
        s1 = []  # (Training set)Sensitive feature (Male:1, Femal:0)
        s2 = []  # (Training set)Sensitive feature (White:1, non-White:0)

        with open(os.path.join(data_path, 'adult.data')) as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=',')
            for i, row in enumerate(csv_reader):
                if i == 0:  # Skipping the row of feature name
                    continue
                if (row[9] == "Male") or ("Male" in row[9]):
                    s1.append(1)
                else:
                    s1.append(0)

                if (row[8] == "White") or ("White" in row[8]):
                    s2.append(1)
                else:
                    s2.append(0)

                if '>50K' in row[14]:
                    y.append(1)
                else:
                    y.append(0)

                row_copy = row[:14]
                row_mask_s1_copy = row[:9] + row[10:14]
                row_mask_s2_copy = row[:8] + row[9:14]
                row_mask_s1_s2_copy = row[:8] + row[10:14]

                raw_X.append(row_copy)
                raw_X_mask_s1.append(row_mask_s1_copy)
                raw_X_mask_s2.append(row_mask_s2_copy)
                raw_X_mask_s1_s2.append(row_mask_s1_s2_copy)

        # Preprocess (testing dataset)
        raw_testX, raw_testX_mask_s1, raw_testX_mask_s2, raw_testX_mask_s1_s2 = [], [], [], []
        testY = []  # (Testing set)Income over 50K (T:1, F:0)
        testS1 = []  # (Testing set)Sensitive feature (Male:1, Female:0)
        testS2 = []  # (Testing set)Sensitive feature (White:1, non-White:0)

        with open(os.path.join(data_path, 'adult.test')) as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=',')
            for i, row in enumerate(csv_reader):
                if i == 0:  # Skipping the row of feature name
                    continue

                if (row[9] == "Male") or ("Male" in row[9]):
                    testS1.append(1)
                else:
                    testS1.append(0)

                if (row[8] == "White") or ("White" in row[8]):
                    testS2.append(1)
                else:
                    testS2.append(0)

                if '>50K' in row[14]:
                    testY.append(1)
                else:
                    testY.append(0)

                row_copy = row[:14]
                row_mask_s1_copy = row[:9] + row[10:14]
                row_mask_s2_copy = row[:8] + row[9:14]
                row_mask_s1_s2_copy = row[:8] + row[10:14]

                raw_testX.append(row_copy)
                raw_testX_mask_s1.append(row_mask_s1_copy)
                raw_testX_mask_s2.append(row_mask_s2_copy)
                raw_testX_mask_s1_s2.append(row_mask_s1_s2_copy)

        for data_index in range(len(raw_X)):
            for inner_index in range(len(raw_X[data_index])):
                if inner_index in [0, 2, 4, 10, 11, 12]:
                    raw_X[data_index][inner_index] = float(raw_X[data_index][inner_index])

        for data_index in range(len(raw_testX)):
            for inner_index in range(len(raw_X[data_index])):
                if inner_index in [0, 2, 4, 10, 11, 12]:
                    raw_testX[data_index][inner_index] = float(raw_testX[data_index][inner_index])

        # One-hot Encoding & PCA dimensionality reduction (training_dataset)
        enc.fit(raw_X + raw_testX)
        if mask_s1_flag:
            X_mask_s1_s2 = np.float32(enc.transform(raw_X_mask_s1_s2).toarray())
            X_mask_s1 = np.float32(np.append(X_mask_s1_s2, np.array([s2]).transpose(), axis=1))
        elif mask_s2_flag:
            X_mask_s1_s2 = np.float32(enc.transform(raw_X_mask_s1_s2).toarray())
            X_mask_s2 = np.float32(np.append(X_mask_s1_s2, np.array([s1]).transpose(), axis=1))
        elif mask_s1_s2_flag:
            X_mask_s1_s2 = np.float32(enc.transform(raw_X_mask_s1_s2).toarray())
        else:
            X = np.float32(enc.transform(raw_X).toarray())

        # One-hot Encoding (testing)
        if mask_s1_flag:
            testX_mask_s1_s2 = np.float32(enc.transform(raw_testX_mask_s1_s2).toarray())
            testX_mask_s1 = np.float32(np.append(testX_mask_s1_s2, np.array([testS2]).transpose(), axis=1))
        elif mask_s2_flag:
            testX_mask_s1_s2 = np.float32(enc.transform(raw_testX_mask_s1_s2).toarray())
            testX_mask_s2 = np.float32(np.append(testX_mask_s1_s2, np.array([testS1]).transpose(), axis=1))
        elif mask_s1_s2_flag:
            testX_mask_s1_s2 = np.float32(enc.transform(raw_testX_mask_s1_s2).toarray())
        else:
            testX = np.float32(enc.transform(raw_testX).toarray())
            # testX_mask_s2 = np.float32(np.append(testX_mask_s1_s2, np.array([testS1]).transpose(), axis=1))

        # Constructing the training dataset
        training_attribute_dict = {
            # 'raw_X': np.array(raw_X), 'raw_X_mask_s1': np.array(raw_X_mask_s1),
            # 'raw_X_mask_s2': np.array(raw_X_mask_s2), 'raw_X_mask_s1_s2': np.array(raw_X_mask_s1_s2),
            's1': s1, 's2': s2, 'y': y
        }
        if mask_s1_flag:
            if do_pca:
                pca.fit(X_mask_s1)
                training_attribute_dict['X'] = pca.transform(X_mask_s1)
            else:
                training_attribute_dict['X'] = X_mask_s1
        elif mask_s2_flag:
            if do_pca:
                pca.fit(X_mask_s2)
                training_attribute_dict['X'] = pca.transform(X_mask_s2)
            else:
                training_attribute_dict['X'] = X_mask_s2
        elif mask_s1_s2_flag:
            if do_pca:
                pca.fit(X_mask_s1_s2)
                training_attribute_dict['X'] = pca.transform(X_mask_s1_s2)
            else:
                training_attribute_dict['X'] = X_mask_s1_s2
        else:
            if do_pca:
                pca.fit(X)
                training_attribute_dict['X'] = pca.transform(X)
            else:
                training_attribute_dict['X'] = X

        training_dataset = CustomizedDataset(attribute_dict=training_attribute_dict)

        # Constructing the positive and negative training dataset
        positive_X, negative_X, positive_y, negative_y, positive_s2, negative_s2 = [], [], [], [], [], []
        # positive data point index of ndarry s1
        positive_array = (np.array(s1) == 1)
        for index, item in enumerate(positive_array):
            if item:
                positive_X.append(training_attribute_dict["X"][index])
                positive_s2.append(training_attribute_dict["s2"][index])
                positive_y.append(training_attribute_dict["y"][index])
            else:
                negative_X.append(training_attribute_dict["X"][index])
                negative_s2.append(training_attribute_dict["s2"][index])
                negative_y.append(training_attribute_dict["y"][index])

        positive_training_attribute_dict = {
            "X": np.array(positive_X), 's1': [1 for i in range(len(positive_X))], 's2': positive_s2, 'y': positive_y
        }
        negative_training_attribute_dict = {
            "X": np.array(negative_X), 's1': [0 for i in range(len(negative_X))], 's2': negative_s2, 'y': negative_y
        }
        positive_training_dataset = CustomizedDataset(attribute_dict=positive_training_attribute_dict)
        negative_training_dataset = CustomizedDataset(attribute_dict=negative_training_attribute_dict)

        # Constructing the testing dataset
        testing_attribute_dict = {
            # 'raw_X': np.array(raw_testX), 'raw_X_mask_s1': np.array(raw_testX_mask_s1),
            # 'raw_X_mask_s2': np.array(raw_testX_mask_s2), 'raw_X_mask_s1_s2': np.array(raw_testX_mask_s1_s2),
            's1': testS1, 's2': testS2, 'y': testY
        }
        if mask_s1_flag:
            if do_pca:
                pca.fit(testX_mask_s1)
                testing_attribute_dict['X'] = pca.transform(testX_mask_s1)
            else:
                testing_attribute_dict['X'] = testX_mask_s1
        elif mask_s2_flag:
            if do_pca:
                pca.fit(testX_mask_s2)
                testing_attribute_dict['X'] = pca.transform(testX_mask_s2)
            else:
                testing_attribute_dict['X'] = testX_mask_s2
        elif mask_s1_s2_flag:
            if do_pca:
                pca.fit(testX_mask_s1_s2)
                testing_attribute_dict['X'] = pca.transform(testX_mask_s1_s2)
            else:
                testing_attribute_dict['X'] = testX_mask_s1_s2
        else:
            if do_pca:
                pca.fit(testX)
                testing_attribute_dict['X'] = pca.transform(testX)
            else:
                testing_attribute_dict['X'] = testX

        testing_dataset = CustomizedDataset(attribute_dict=testing_attribute_dict)

    return training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset


def get_ARRHYTHMIA_dataset(data_path, mask_s1_flag, mask_s2_flag, mask_s1_s2_flag):
    # Preprocess
    full_X = []
    full_y = []  # Distinguish between the presence and absence of cardiac arrhythmia ('1': 1, '2'-'16':0)
    full_s = []  # Sensitive feature (Male:0, Female:1)

    with open(os.path.join(data_path, 'arrhythmia.data')) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        for row in csv_reader:
            temp = row[:13] + row[14:-1]
            try:
                full_X.append([float(item) for item in temp])
            except Exception:
                continue

            if ("1" in row[1]) or (int(row[1]) - 1 == 0):
                full_s.append(float(1))
            else:
                full_s.append(float(0))

            if int(row[-1]) == 1:
                full_y.append(float(1))
            else:
                full_y.append(float(0))

    training_size = int(len(full_X) * 0.8)
    training_indexes = random.sample(range(0, len(full_X)), training_size)
    X, y, s = [], [], []
    testX, testY, testS = [], [], []
    for i, item in enumerate(full_X):
        if i in training_indexes:
            X.append(item)
            y.append(full_y[i])
            s.append(full_s[i])
        else:
            testX.append(item)
            testY.append(full_y[i])
            testS.append(full_s[i])
    X, testX = np.array(X), np.array(testX)
    # Constructing the training dataset
    training_attribute_dict = {'X': X, 's1': s, 's2': s, 'y': y}
    training_dataset = CustomizedDataset(attribute_dict=training_attribute_dict)

    # Constructing the positive and negative training dataset
    positive_X, negative_X, positive_y, negative_y, positive_s2, negative_s2 = [], [], [], [], [], []
    # positive data point index of ndarry s1
    positive_array = (np.array(s) == 1)
    for index, item in enumerate(positive_array):
        if item:
            positive_X.append(training_attribute_dict["X"][index])
            positive_s2.append(training_attribute_dict["s2"][index])
            positive_y.append(training_attribute_dict["y"][index])
        else:
            negative_X.append(training_attribute_dict["X"][index])
            negative_s2.append(training_attribute_dict["s2"][index])
            negative_y.append(training_attribute_dict["y"][index])

    positive_training_attribute_dict = {
        "X": np.array(positive_X), 's1': [1 for i in range(len(positive_X))], 's2': positive_s2, 'y': positive_y
    }
    negative_training_attribute_dict = {
        "X": np.array(negative_X), 's1': [0 for i in range(len(negative_X))], 's2': negative_s2, 'y': negative_y
    }
    positive_training_dataset = CustomizedDataset(attribute_dict=positive_training_attribute_dict)
    negative_training_dataset = CustomizedDataset(attribute_dict=negative_training_attribute_dict)

    # Constructing the testing dataset
    testing_attribute_dict = {'X': testX, 's1': testS, 's2': testS, 'y': testY}
    testing_dataset = CustomizedDataset(attribute_dict=testing_attribute_dict)

    return training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset


def get_BANK_dataset(data_path, mask_s1_flag, mask_s2_flag, mask_s1_s2_flag):
    # Some codes are borrow from https://github.com/optimization-for-data-driven-science/Renyi-Fair-Inference

    # Preprocess
    full_X = []
    full_y = []  # Client will subscribe a term deposit (Yes: 1, No:0)
    full_s = []  # Sensitive feature (Married:1, Other:0)

    with open(os.path.join(data_path, 'bank-full.csv')) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=';')
        i = 0
        for row in csv_reader:
            if i == 0:
                i += 1
                continue

            if (row[2] == "married") or ("married" in row[2]):
                full_s.append(1)
            else:
                full_s.append(0)

            if (row[16] == 'yes') or ('yes' in row[16]):
                full_y.append(float(1))
            else:
                full_y.append(float(0))

    with open(os.path.join(data_path, 'Bank_data.csv')) as csv_file:
        csv_reader = csv.reader(csv_file)
        for _, row in enumerate(csv_reader):
            if _ == 0:
                continue
            new_row = []
            for item in row:
                new_row.append(float(item))
            full_X.append(new_row)
    # Copy from the description of R´E NYI FAIR INFERENCE
    training_size = 32000
    training_indexes = random.sample(range(0, len(full_X)), training_size)
    X, y, s = [], [], []
    testX, testY, testS = [], [], []
    for i, item in enumerate(full_X):
        if i in training_indexes:
            X.append(item)
            y.append(full_y[i])
            s.append(full_s[i])
        else:
            testX.append(item)
            testY.append(full_y[i])
            testS.append(full_s[i])
    X, testX = np.array(X), np.array(testX)
    # Constructing the training dataset
    training_attribute_dict = {'X': X, 's1': s, 's2': s, 'y': y}
    training_dataset = CustomizedDataset(attribute_dict=training_attribute_dict)

    # Constructing the positive and negative training dataset
    positive_X, negative_X, positive_y, negative_y, positive_s2, negative_s2 = [], [], [], [], [], []
    # positive data point index of ndarry s1
    positive_array = (np.array(s) == 1)
    for index, item in enumerate(positive_array):
        if item:
            positive_X.append(training_attribute_dict["X"][index])
            positive_s2.append(training_attribute_dict["s2"][index])
            positive_y.append(training_attribute_dict["y"][index])
        else:
            negative_X.append(training_attribute_dict["X"][index])
            negative_s2.append(training_attribute_dict["s2"][index])
            negative_y.append(training_attribute_dict["y"][index])

    positive_training_attribute_dict = {
        "X": np.array(positive_X), 's1': [1 for i in range(len(positive_X))], 's2': positive_s2, 'y': positive_y
    }
    negative_training_attribute_dict = {
        "X": np.array(negative_X), 's1': [0 for i in range(len(negative_X))], 's2': negative_s2, 'y': negative_y
    }
    positive_training_dataset = CustomizedDataset(attribute_dict=positive_training_attribute_dict)
    negative_training_dataset = CustomizedDataset(attribute_dict=negative_training_attribute_dict)

    # Constructing the testing dataset
    testing_attribute_dict = {'X': testX, 's1': testS, 's2': testS, 'y': testY}
    testing_dataset = CustomizedDataset(attribute_dict=testing_attribute_dict)

    return training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset


def get_COMPAS_dataset(data_path, mask_s1_flag=False, mask_s2_flag=False, mask_s1_s2_flag=False):
    # Some codes are borrow from https://github.com/propublica/compas-analysis/blob/master/Compas%20Analysis.ipynb
    enc = OneHotEncoder()
    pca = PCA(n_components=64)

    with open(os.path.join(data_path, 'compas-scores-two-years.csv')) as csv_file:

        csv_reader = csv.reader(csv_file)
        raw_data = []

        # Filtering
        for i, row in enumerate(csv_reader):
            if i == 0:  # Skipping the row of feature name
                continue

            if row[15] != '' and row[24] != '' and row[22] != '' and row[40] != '':
                if 30 >= int(row[15]) >= -30:  # Filtering by `days_b_screening_arrest`
                    if int(row[24]) != -1:  # Filtering by `is_recid`
                        if row[22] != "0":  # Filtering by `c_charge_degree`
                            if row[40] != 'N/A':  # Filtering by `score_text`
                                if row[9] == "African-American" or row[9] == "Caucasian":  # Filtering by `race`
                                    raw_data.append(row)

        # Splitting
        random.shuffle(raw_data)
        training_set = raw_data[:4800]
        testing_set = raw_data[4800:]

        # Training set
        raw_X, raw_X_mask_s1, raw_X_mask_s2, raw_X_mask_s1_s2 = [], [], [], []
        y = []  # (Training set)Not a recidivist (is_recid=0 -> 1; is_recid=1 -> 0)
        s1 = []  # (Training set)Sensitive feature (African-American:1, Caucasian:0)
        s2 = []
        for i, row in enumerate(training_set):
            if row[9] == "African-American":  # African-American:1, Caucasian:0
                s1.append(1)
            else:
                s1.append(0)

            if row[5] == "Male":  # Male:1, Female:0
                s2.append(1)
            else:
                s2.append(0)

            if int(row[24]) == 0:  # Not a recidivist (is_recid=0 -> 1; is_recid=1 -> 0)
                y.append(1)
            else:
                y.append(0)

            row_copy = row[5:6] + row[8:9] + row[9:10] + row[10:11] + row[12:16] + row[22:23] + row[39:41] + row[
                                                                                                             48:49]  # Filtering out excess features
            row_mask_s1_copy = row[5:6] + row[8:9] + row[10:11] + row[12:16] + row[22:23] + row[39:41] + row[48:49]
            row_mask_s2_copy = row[8:9] + row[9:10] + row[10:11] + row[12:16] + row[22:23] + row[39:41] + row[48:49]
            row_mask_s1_s2_copy = row[8:9] + row[10:11] + row[12:16] + row[22:23] + row[39:41] + row[48:49]

            # row_copy = row[:24] + row[25:-1]  # Filtering the label and the feature 'two_year_recid' in last column
            # row_mask_s1_copy = row[:9] + row[10:24] + row[25:-1]
            # row_mask_s2_copy = row[:5] + row[6:24] + row[25:-1]
            # row_mask_s1_s2_copy = row[:5] + row[6:8] + row[10:24] + row[25:-1]

            raw_X.append(row_copy)
            raw_X_mask_s1.append(row_mask_s1_copy)
            raw_X_mask_s2.append(row_mask_s2_copy)
            raw_X_mask_s1_s2.append(row_mask_s1_s2_copy)

        # Testing
        raw_testX, raw_testX_mask_s1, raw_testX_mask_s2, raw_testX_mask_s1_s2 = [], [], [], []
        testY = []  # (Testing set)Not a recidivist (T:1, F:0)
        testS1 = []  # (Testing set)Sensitive feature (African-American:1, Caucasian:0)
        testS2 = []  # (Testing set)Sensitive feature (Male:1, Female:0)

        for i, row in enumerate(testing_set):
            if row[9] == "African-American":  # African-American:1, Caucasian:0
                testS1.append(1)
            else:
                testS1.append(0)

            if row[5] == "Male":  # Male:1, Female:0
                testS2.append(1)
            else:
                testS2.append(0)

            if int(row[24]) == 0:  # Not a recidivist (is_recid=0->T:1; is_recid=1->F:0)
                testY.append(1)
            else:
                testY.append(0)

            row_copy = row[5:6] + row[8:9] + row[9:10] + row[10:11] + row[12:16] + row[22:23] + row[39:41] + row[
                                                                                                             48:49]  # Filtering out excess features
            row_mask_s1_copy = row[5:6] + row[8:9] + row[10:11] + row[12:16] + row[22:23] + row[39:41] + row[48:49]
            row_mask_s2_copy = row[8:9] + row[9:10] + row[10:11] + row[12:16] + row[22:23] + row[39:41] + row[48:49]
            row_mask_s1_s2_copy = row[8:9] + row[10:11] + row[12:16] + row[22:23] + row[39:41] + row[48:49]

            # row_copy = row[:24] + row[25:-1]  # Filtering the label and the feature 'two_year_recid' in last column
            # row_mask_s1_copy = row[:9] + row[10:24] + row[25:-1]
            # row_mask_s2_copy = row[:5] + row[6:24] + row[25:-1]
            # row_mask_s1_s2_copy = row[:5] + row[6:8] + row[10:24] + row[25:-1]

            raw_testX.append(row_copy)
            raw_testX_mask_s1.append(row_mask_s1_copy)
            raw_testX_mask_s2.append(row_mask_s2_copy)
            raw_testX_mask_s1_s2.append(row_mask_s1_s2_copy)

    # One-hot Encoding (training_dataset)
    enc.fit(raw_X_mask_s1_s2 + raw_testX_mask_s1_s2)

    if mask_s1_flag:
        X_mask_s1_s2 = np.float32(enc.transform(raw_X_mask_s1_s2).toarray())
        X_mask_s1 = np.float32(np.append(X_mask_s1_s2, np.array([s2]).transpose(), axis=1))
    elif mask_s2_flag:
        X_mask_s1_s2 = np.float32(enc.transform(raw_X_mask_s1_s2).toarray())
        X_mask_s2 = np.float32(np.append(X_mask_s1_s2, np.array([s1]).transpose(), axis=1))
    elif mask_s1_s2_flag:
        X_mask_s1_s2 = np.float32(enc.transform(raw_X_mask_s1_s2).toarray())
    else:
        X_mask_s1_s2 = np.float32(enc.transform(raw_X_mask_s1_s2).toarray())
        X = np.float32(np.append(X_mask_s1_s2, np.array([s1, s2]).transpose(), axis=1))

    # One-hot Encoding (testing)
    if mask_s1_flag:
        testX_mask_s1_s2 = np.float32(enc.transform(raw_testX_mask_s1_s2).toarray())
        testX_mask_s1 = np.float32(np.append(testX_mask_s1_s2, np.array([testS2]).transpose(), axis=1))
    elif mask_s2_flag:
        testX_mask_s1_s2 = np.float32(enc.transform(raw_testX_mask_s1_s2).toarray())
        testX_mask_s2 = np.float32(np.append(testX_mask_s1_s2, np.array([testS1]).transpose(), axis=1))
    elif mask_s1_s2_flag:
        testX_mask_s1_s2 = np.float32(enc.transform(raw_testX_mask_s1_s2).toarray())
    else:
        testX_mask_s1_s2 = np.float32(enc.transform(raw_testX_mask_s1_s2).toarray())
        # testX_mask_s2 = np.float32(np.append(testX_mask_s1_s2, np.array([testS1]).transpose(), axis=1))
        testX = np.float32(np.append(testX_mask_s1_s2, np.array([testS1, testS2]).transpose(), axis=1))

    # Constructing the training dataset
    training_attribute_dict = {
        # 'raw_X': np.array(raw_X), 'raw_X_mask_s1': np.array(raw_X_mask_s1),
        # 'raw_X_mask_s2': np.array(raw_X_mask_s2), 'raw_X_mask_s1_s2': np.array(raw_X_mask_s1_s2),
        's1': s1, 's2': s2, 'y': y
    }
    if mask_s1_flag:
        pca.fit(X_mask_s1)
        training_attribute_dict['X'] = pca.transform(X_mask_s1)
    elif mask_s2_flag:
        pca.fit(X_mask_s2)
        training_attribute_dict['X'] = pca.transform(X_mask_s2)
    elif mask_s1_s2_flag:
        pca.fit(X_mask_s1_s2)
        training_attribute_dict['X'] = pca.transform(X_mask_s1_s2)
    else:
        pca.fit(X)
        training_attribute_dict['X'] = pca.transform(X)

    training_dataset = CustomizedDataset(attribute_dict=training_attribute_dict)

    # Constructing the positive and negative training dataset
    positive_X, negative_X, positive_y, negative_y, positive_s2, negative_s2 = [], [], [], [], [], []
    # positive data point index of ndarry s1
    positive_array = (np.array(s1) == 1)
    for index, item in enumerate(positive_array):
        if item:
            positive_X.append(training_attribute_dict["X"][index])
            positive_s2.append(training_attribute_dict["s2"][index])
            positive_y.append(training_attribute_dict["y"][index])
        else:
            negative_X.append(training_attribute_dict["X"][index])
            negative_s2.append(training_attribute_dict["s2"][index])
            negative_y.append(training_attribute_dict["y"][index])

    positive_training_attribute_dict = {
        "X": np.array(positive_X), 's1': [1 for i in range(len(positive_X))], 's2': positive_s2, 'y': positive_y
    }
    negative_training_attribute_dict = {
        "X": np.array(negative_X), 's1': [0 for i in range(len(negative_X))], 's2': negative_s2, 'y': negative_y
    }
    positive_training_dataset = CustomizedDataset(attribute_dict=positive_training_attribute_dict)
    negative_training_dataset = CustomizedDataset(attribute_dict=negative_training_attribute_dict)

    # Constructing the testing dataset
    testing_attribute_dict = {
        # 'raw_X': np.array(raw_testX), 'raw_X_mask_s1': np.array(raw_testX_mask_s1),
        # 'raw_X_mask_s2': np.array(raw_testX_mask_s2), 'raw_X_mask_s1_s2': np.array(raw_testX_mask_s1_s2),
        's1': testS1, 's2': testS2, 'y': testY
    }
    if mask_s1_flag:
        pca.fit(testX_mask_s1)
        testing_attribute_dict['X'] = pca.transform(testX_mask_s1)
    elif mask_s2_flag:
        pca.fit(testX_mask_s2)
        testing_attribute_dict['X'] = pca.transform(testX_mask_s2)
    elif mask_s1_s2_flag:
        pca.fit(testX_mask_s1_s2)
        testing_attribute_dict['X'] = pca.transform(testX_mask_s1_s2)
    else:
        pca.fit(testX)
        testing_attribute_dict['X'] = pca.transform(testX)

    testing_dataset = CustomizedDataset(attribute_dict=testing_attribute_dict)

    return training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset


def get_DRUG_dataset(data_path, mask_s1_flag=False, mask_s2_flag=False, mask_s1_s2_flag=False):
    enc = OneHotEncoder()
    pca = PCA(n_components=64)

    with open(os.path.join(data_path, 'drug_consumption.data')) as csv_file:
        csv_reader = csv.reader(csv_file)
        raw_data = []

        # Pre_process: Filtering
        for i, row in enumerate(csv_reader):
            if i == 0:  # Skipping the row of feature name
                continue
            raw_data.append(row)

        # Splitting
        random.shuffle(raw_data)
        training_set = raw_data[:1600]
        testing_set = raw_data[1600:]

        # Training set
        raw_X, raw_X_mask_s1, raw_X_mask_s2, raw_X_mask_s1_s2 = [], [], [], []
        y = []  # (Training set)Not abuse volatile substance (Not abuse:1 ; Abuse:0)
        s1 = []  # (Training set)Sensitive feature (White:1, Non-white:0)
        s2 = []  # (Training set)Sensitive feature (Male:1, Female:0)
        for i, row in enumerate(training_set):
            if float(row[5]) == -0.31685:  # White:1, Non-white:0
                s1.append(1)
            else:
                s1.append(0)

            if float(row[2]) < 0:  # Male:1, Female:0
                s2.append(1)
            else:
                s2.append(0)

            if row[31] == 'CL0':  # Not abuse volatile substance (Not abuse:1 ; Abuse:0)
                y.append(1)
            else:
                y.append(0)

            row_copy = row[:31]  # Filtering the label in last column
            row_mask_s1_copy = row[:5] + row[6:31]
            row_mask_s2_copy = row[:2] + row[3:31]
            row_mask_s1_s2_copy = row[:2] + row[3:5] + row[6:31]

            raw_X.append(row_copy)
            raw_X_mask_s1.append(row_mask_s1_copy)
            raw_X_mask_s2.append(row_mask_s2_copy)
            raw_X_mask_s1_s2.append(row_mask_s1_s2_copy)

        # Testing
        raw_testX, raw_testX_mask_s1, raw_testX_mask_s2, raw_testX_mask_s1_s2 = [], [], [], []
        testY = []  # (Testing set)Not abuse volatile substance (Not abuse:1 ; Abuse:0)
        testS1 = []  # (Testing set)Sensitive feature (White:1, Non-white:0)
        testS2 = []  # (Testing set)Sensitive feature (Male:1, Female:0)

        for i, row in enumerate(testing_set):
            if float(row[5]) == -0.31685:  # White:1, Non-white:0
                testS1.append(1)
            else:
                testS1.append(0)

            if float(row[2]) < 0:  # Male:1, Female:0
                testS2.append(1)
            else:
                testS2.append(0)

            if row[31] == 'CL0':  # Not abuse volatile substance (Not abuse:1 ; Abuse:0)
                testY.append(1)
            else:
                testY.append(0)

            row_copy = row[:31]  # Filtering the label in last column
            row_mask_s1_copy = row[:5] + row[6:31]
            row_mask_s2_copy = row[:2] + row[3:31]
            row_mask_s1_s2_copy = row[:2] + row[3:5] + row[6:31]

            raw_testX.append(row_copy)
            raw_testX_mask_s1.append(row_mask_s1_copy)
            raw_testX_mask_s2.append(row_mask_s2_copy)
            raw_testX_mask_s1_s2.append(row_mask_s1_s2_copy)

    # One-hot Encoding (training_dataset)
    enc.fit(raw_X_mask_s1_s2 + raw_testX_mask_s1_s2)
    if mask_s1_flag:
        X_mask_s1_s2 = np.float32(enc.transform(raw_X_mask_s1_s2).toarray())
        X_mask_s1 = np.float32(np.append(X_mask_s1_s2, np.array([s2]).transpose(), axis=1))
    elif mask_s2_flag:
        X_mask_s1_s2 = np.float32(enc.transform(raw_X_mask_s1_s2).toarray())
        X_mask_s2 = np.float32(np.append(X_mask_s1_s2, np.array([s1]).transpose(), axis=1))
    elif mask_s1_s2_flag:
        X_mask_s1_s2 = np.float32(enc.transform(raw_X_mask_s1_s2).toarray())
    else:
        X_mask_s1_s2 = np.float32(enc.transform(raw_X_mask_s1_s2).toarray())
        X = np.float32(np.append(X_mask_s1_s2, np.array([s1, s2]).transpose(), axis=1))

    # One-hot Encoding (testing)
    if mask_s1_flag:
        testX_mask_s1_s2 = np.float32(enc.transform(raw_testX_mask_s1_s2).toarray())
        testX_mask_s1 = np.float32(np.append(testX_mask_s1_s2, np.array([testS2]).transpose(), axis=1))
    elif mask_s2_flag:
        testX_mask_s1_s2 = np.float32(enc.transform(raw_testX_mask_s1_s2).toarray())
        testX_mask_s2 = np.float32(np.append(testX_mask_s1_s2, np.array([testS1]).transpose(), axis=1))
    elif mask_s1_s2_flag:
        testX_mask_s1_s2 = np.float32(enc.transform(raw_testX_mask_s1_s2).toarray())
    else:
        testX_mask_s1_s2 = np.float32(enc.transform(raw_testX_mask_s1_s2).toarray())
        # testX_mask_s2 = np.float32(np.append(testX_mask_s1_s2, np.array([testS1]).transpose(), axis=1))
        testX = np.float32(np.append(testX_mask_s1_s2, np.array([testS1, testS2]).transpose(), axis=1))

    # Constructing the training dataset
    training_attribute_dict = {
        # 'raw_X': np.array(raw_X), 'raw_X_mask_s1': np.array(raw_X_mask_s1),
        # 'raw_X_mask_s2': np.array(raw_X_mask_s2), 'raw_X_mask_s1_s2': np.array(raw_X_mask_s1_s2),
        's1': s1, 's2': s2, 'y': y
    }
    if mask_s1_flag:
        pca.fit(X_mask_s1)
        training_attribute_dict['X'] = pca.transform(X_mask_s1)
    elif mask_s2_flag:
        pca.fit(X_mask_s2)
        training_attribute_dict['X'] = pca.transform(X_mask_s2)
    elif mask_s1_s2_flag:
        pca.fit(X_mask_s1_s2)
        training_attribute_dict['X'] = pca.transform(X_mask_s1_s2)
    else:
        pca.fit(X)
        training_attribute_dict['X'] = pca.transform(X)

    training_dataset = CustomizedDataset(attribute_dict=training_attribute_dict)

    # Constructing the positive and negative training dataset
    positive_X, negative_X, positive_y, negative_y, positive_s2, negative_s2 = [], [], [], [], [], []
    # positive data point index of ndarry s1
    positive_array = (np.array(s1) == 1)
    for index, item in enumerate(positive_array):
        if item:
            positive_X.append(training_attribute_dict["X"][index])
            positive_s2.append(training_attribute_dict["s2"][index])
            positive_y.append(training_attribute_dict["y"][index])
        else:
            negative_X.append(training_attribute_dict["X"][index])
            negative_s2.append(training_attribute_dict["s2"][index])
            negative_y.append(training_attribute_dict["y"][index])

    positive_training_attribute_dict = {
        "X": np.array(positive_X), 's1': [1 for i in range(len(positive_X))], 's2': positive_s2, 'y': positive_y
    }
    negative_training_attribute_dict = {
        "X": np.array(negative_X), 's1': [0 for i in range(len(negative_X))], 's2': negative_s2, 'y': negative_y
    }
    positive_training_dataset = CustomizedDataset(attribute_dict=positive_training_attribute_dict)
    negative_training_dataset = CustomizedDataset(attribute_dict=negative_training_attribute_dict)

    # Constructing the testing dataset
    testing_attribute_dict = {
        # 'raw_X': np.array(raw_testX), 'raw_X_mask_s1': np.array(raw_testX_mask_s1),
        # 'raw_X_mask_s2': np.array(raw_testX_mask_s2), 'raw_X_mask_s1_s2': np.array(raw_testX_mask_s1_s2),
        's1': testS1, 's2': testS2, 'y': testY
    }
    if mask_s1_flag:
        pca.fit(testX_mask_s1)
        testing_attribute_dict['X'] = pca.transform(testX_mask_s1)
    elif mask_s2_flag:
        pca.fit(testX_mask_s2)
        testing_attribute_dict['X'] = pca.transform(testX_mask_s2)
    elif mask_s1_s2_flag:
        pca.fit(testX_mask_s1_s2)
        testing_attribute_dict['X'] = pca.transform(testX_mask_s1_s2)
    else:
        pca.fit(testX)
        testing_attribute_dict['X'] = pca.transform(testX)

    testing_dataset = CustomizedDataset(attribute_dict=testing_attribute_dict)

    return training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset


def get_DUTCH_dataset(data_path, mask_s1_flag, mask_s2_flag, mask_s1_s2_flag):
    # Preprocess
    enc = OneHotEncoder()

    full_X = []
    raw_X = []
    full_y = []  # Using occupation as the class label ('5_4_9': 1, '2_1':0)
    full_s = []  # Sensitive feature (2->Male:0, 1->Female:1)

    with open(os.path.join(data_path, 'dutch_census_2001.arff'), encoding="utf-8") as f:
        header = []
        for line in f:
            if line.startswith("@attribute"):
                header.append(line.split()[1])
            elif line.startswith("@data"):
                break
        df = pd.read_csv(f, header=None)
        df.columns = header
    df = np.array(df).tolist()

    for row in df:
        if math.isnan(row[0]):
            continue
        temp = row[:11]
        try:
            raw_X.append([float(item) for item in temp])
        except Exception:
            continue

        if ("1" in str(row[0])) or (int(row[0]) - 1) == 0:
            full_s.append(float(1))
        else:
            full_s.append(float(0))

        if (row[-1] == "5_4_9") or ("5_4_9" in row[-1]):
            full_y.append(float(1))
        else:
            full_y.append(float(0))

    enc.fit(raw_X)
    full_X = enc.transform(raw_X).toarray()

    training_size = int(len(full_X) * 0.8)
    training_indexes = random.sample(range(0, len(full_X)), training_size)
    X, y, s = [], [], []
    testX, testY, testS = [], [], []
    for i, item in enumerate(full_X):
        if i in training_indexes:
            X.append(item)
            y.append(full_y[i])
            s.append(full_s[i])
        else:
            testX.append(item)
            testY.append(full_y[i])
            testS.append(full_s[i])
    X, testX = np.array(X), np.array(testX)
    # Constructing the training dataset
    # pca.fit(X)
    training_attribute_dict = {'s1': s, 's2': s, 'y': y}
    # training_attribute_dict['X'] = pca.transform(X)
    training_attribute_dict['X'] = X
    training_dataset = CustomizedDataset(attribute_dict=training_attribute_dict)

    # Constructing the positive and negative training dataset
    positive_X, negative_X, positive_y, negative_y, positive_s2, negative_s2 = [], [], [], [], [], []
    # positive data point index of ndarry s1
    positive_array = (np.array(s) == 1)
    for index, item in enumerate(positive_array):
        if item:
            positive_X.append(training_attribute_dict["X"][index])
            positive_s2.append(training_attribute_dict["s2"][index])
            positive_y.append(training_attribute_dict["y"][index])
        else:
            negative_X.append(training_attribute_dict["X"][index])
            negative_s2.append(training_attribute_dict["s2"][index])
            negative_y.append(training_attribute_dict["y"][index])

    positive_training_attribute_dict = {
        "X": np.array(positive_X), 's1': [1 for i in range(len(positive_X))], 's2': positive_s2, 'y': positive_y
    }
    negative_training_attribute_dict = {
        "X": np.array(negative_X), 's1': [0 for i in range(len(negative_X))], 's2': negative_s2, 'y': negative_y
    }
    positive_training_dataset = CustomizedDataset(attribute_dict=positive_training_attribute_dict)
    negative_training_dataset = CustomizedDataset(attribute_dict=negative_training_attribute_dict)

    # Constructing the testing dataset
    # pca.fit(testX)
    testing_attribute_dict = {'s1': testS, 's2': testS, 'y': testY}
    # testing_attribute_dict['X'] = pca.transform(testX)
    testing_attribute_dict['X'] = testX
    testing_dataset = CustomizedDataset(attribute_dict=testing_attribute_dict)

    return training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset


def get_GERMAN_dataset(data_path, mask_s1_flag, mask_s2_flag, mask_s1_s2_flag):
    # Preprocess
    full_X = []
    full_y = []  # Good or bad credit risks (Good:1, Bad:0)
    full_s1 = []  # Sensitive feature: Gender (Female:1, Male:0)
    full_s2 = []  # Sensitive feature: Marital-status (Married:1, Other:0)

    with open(os.path.join(data_path, 'german.data')) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=' ')
        for row in csv_reader:
            if ("A92" in row[8]) or ("A95" in row[8]) or (row[8] == "A92") or (row[8] == "A95"):
                full_s1.append(float(1))  # Female
            else:
                full_s1.append(float(0))  # Male

            if ("A92" in row[8]) or ("A94" in row[8]) or (row[8] == "A92") or (row[8] == "A94"):
                full_s2.append(float(1))  # Married
            else:
                full_s2.append(float(0))  # Other

            if ('1' in row[-1]) or (row[-1] == '1') or (int(row[-1]) - 1 == 0):
                full_y.append(float(1))  # Good
            else:
                full_y.append(float(0))  # Bad

    with open(os.path.join(data_path, 'german.data-numeric')) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=' ')
        for _, row in enumerate(csv_reader):
            new_row = []
            for item in row:
                if (len(item) == 0) or (len(item) - 1 == -1):
                    continue
                new_row.append(float(item))
            full_X.append(new_row)
    # Copy from the description of Renyi(R´E NYI FAIR INFERENCE)
    training_size = int(len(full_X) * 0.8)
    training_indexes = random.sample(range(0, len(full_X)), training_size)
    X, y, s1, s2 = [], [], [], []
    testX, testY, testS1, testS2 = [], [], [], []
    for i, item in enumerate(full_X):
        if i in training_indexes:
            X.append(item)
            y.append(full_y[i])
            s1.append(full_s1[i])
            s2.append(full_s2[i])
        else:
            testX.append(item)
            testY.append(full_y[i])
            testS1.append(full_s1[i])
            testS2.append(full_s2[i])
    X, testX = np.array(X), np.array(testX)
    # Constructing the training dataset
    training_attribute_dict = {'X': X, 's1': s1, 's2': s2, 'y': y}
    training_dataset = CustomizedDataset(attribute_dict=training_attribute_dict)

    # Constructing the positive and negative training dataset
    positive_X, negative_X, positive_y, negative_y, positive_s2, negative_s2 = [], [], [], [], [], []
    # positive data point index of ndarry s1
    positive_array = (np.array(s1) == 1)
    for index, item in enumerate(positive_array):
        if item:
            positive_X.append(training_attribute_dict["X"][index])
            positive_s2.append(training_attribute_dict["s2"][index])
            positive_y.append(training_attribute_dict["y"][index])
        else:
            negative_X.append(training_attribute_dict["X"][index])
            negative_s2.append(training_attribute_dict["s2"][index])
            negative_y.append(training_attribute_dict["y"][index])

    positive_training_attribute_dict = {
        "X": np.array(positive_X), 's1': [1 for i in range(len(positive_X))], 's2': positive_s2, 'y': positive_y
    }
    negative_training_attribute_dict = {
        "X": np.array(negative_X), 's1': [0 for i in range(len(negative_X))], 's2': negative_s2, 'y': negative_y
    }
    positive_training_dataset = CustomizedDataset(attribute_dict=positive_training_attribute_dict)
    negative_training_dataset = CustomizedDataset(attribute_dict=negative_training_attribute_dict)

    # Constructing the testing dataset
    testing_attribute_dict = {'X': testX, 's1': testS1, 's2': testS2, 'y': testY}
    testing_dataset = CustomizedDataset(attribute_dict=testing_attribute_dict)

    return training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset


if __name__ == '__main__':
    print("Testing")

    # data_path = '../dataset/GERMAN/'
    # mask_s1_flag = False
    # mask_s2_flag = False
    # mask_s1_s2_flag = False
    # training_dataset, testing_dataset = get_GERMAN_dataset(data_path, mask_s1_flag, mask_s2_flag, mask_s1_s2_flag)
