### Most of the code is taken from: https://github.com/SoftWiser-group/FairDisCo

import os
import random
import torch
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as transforms

import numpy as np
import pandas as pd

from scipy.special import digamma
from sklearn.neighbors import NearestNeighbors, KDTree
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from sklearn.feature_selection import mutual_info_classif
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from sklearn.utils import shuffle

import xgboost as xgb


import warnings
warnings.filterwarnings('ignore')

def setSeed(seed=2022):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True

class VectorDataset(Dataset):
    def __init__(self, X, S, Y, C=None):
        self.X = X
        self.S = S
        self.Y = Y
        self.C = C

    def __getitem__(self, i):
        x, s, y = self.X[i], self.S[i], self.Y[i]
        if self.C != None:
            return x, self.C[i], s, y
        else:
            return x, s, y
    
    def __len__(self):
        return self.X.shape[0]

def getOneHot(df):
    C = df.values
    D = df.nunique().values.tolist()
    X = []
    for col in df.columns:
        X.append(pd.get_dummies(df[col]).values)
    X = np.concatenate(X, axis=1)
    return X, C, D

def getBinary(df, cols):
    labels = df[cols].apply(lambda s: np.median(s)).values
    x = df[cols].values
    xs = np.zeros_like(x)
    for j in range(len(labels)):
        if x[:,j].max() == labels[j]:
            xs[:,j] = x[:,j]
        else:
            xs[:,j] = (x[:,j] > labels[j]).astype(int)
    df = pd.DataFrame(xs, columns=cols)
    return df

def getDataset(df, S, Y, num_train):
    S = torch.LongTensor(S)
    Y = torch.FloatTensor(Y)
    S_train, S_test = S[:num_train], S[num_train:]
    Y_train, Y_test = Y[:num_train], Y[num_train:]

    df = df.apply(lambda col: LabelEncoder().fit_transform(col))

    X, C, D = getOneHot(df)
    X = torch.FloatTensor(X)
    C = torch.LongTensor(C)

    X_train, X_test = X[:num_train], X[num_train:]
    C_train, C_test = C[:num_train], C[num_train:]
    
    train_data = VectorDataset(X_train, S_train, Y_train, C_train)
    test_data = VectorDataset(X_test, S_test, Y_test, C_test)

    return train_data, test_data, D

def loadAdult(pro_att):
    """
    Adult Census Income: Individual information from 1994 U.S. census. Goal is predicting income >$50,000.
    Protected Attribute: sex / race
    """
    cols = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', \
        'relationship', 'race', 'sex', 'capital-gain', 'capital-loss','hours-per-week', 'native-country', 'salary']

    df_train = pd.read_csv('./data/Adult/adult.data', names=cols)
    df_test = pd.read_csv('./data/Adult/adult.test', names=cols, skiprows=1)

    num_train = df_train.shape[0]
    num_test = df_test.shape[0]
    df = pd.concat([df_train, df_test], ignore_index=True)
    df = df.apply(lambda v: v.astype(str).str.strip() if v.dtype == "object" else v)
    print('train_size {}, test_size {}'.format(num_train, num_test))
    
    df['age'] = pd.cut(df['age'], bins=8, labels=False)
    df['hours-per-week'] = pd.cut(df['hours-per-week'], bins=8, labels=False)
    df['fnlwgt'] = pd.cut(np.log(df['fnlwgt']), bins=8, labels=False)
    df[['capital-gain', 'capital-loss']] = getBinary(df, ['capital-gain', 'capital-loss'])

    if pro_att == 'sex':
        S = (df['sex'] == 'Male').values.astype(int)
        del df['sex']

    if pro_att == 'race':
        S = (df['race'] == 'Black').values.astype(int)
        del df['race']
    
    Y = (df['salary'].apply(lambda x: x == '<=50K' or x == '<=50K.')).values.astype(int)
    del df['salary']

    return getDataset(df, S, Y, num_train)

def loadCompas(pro_att):
    """
    Compas: Contains criminal history of defendants. Goal predicting re-offending in future
    Protected Attribute: sex / race
    """
    df = pd.read_csv('./data/compas-scores-two-years.csv')
    drop_cols = ['id','name','first','last','compas_screening_date',
                'dob', 'juv_fel_count', 'decile_score',
                'juv_misd_count','juv_other_count','days_b_screening_arrest',
                'c_jail_in','c_jail_out','c_case_number','c_offense_date','c_arrest_date',
                'c_days_from_compas','c_charge_desc','is_recid','r_case_number','r_charge_degree',
                'r_days_from_arrest','r_offense_date','r_charge_desc','r_jail_in','r_jail_out',
                'violent_recid','is_violent_recid','vr_case_number','vr_charge_degree','vr_offense_date',
                'vr_charge_desc','type_of_assessment','decile_score','score_text','screening_date',
                'v_type_of_assessment','v_decile_score','v_score_text','v_screening_date','in_custody',
                'out_custody','start','end','event']
    df = df.drop(drop_cols, axis=1)
    
    df = shuffle(df)
    num_train = int(0.8*df.shape[0])
    print('train_size {}, test_size {}'.format(num_train, df.shape[0]-num_train))
    
    df['age'] = pd.cut(df['age'], bins=5, labels=False)
    df['priors_count'] = df['priors_count'].apply(lambda x: x if x<9 else 9)

    if pro_att == 'sex':
        S = (df['sex'] == 'Male').values.astype(int)
        del df['sex']

    if pro_att == 'race':
        S = (df['race'] == 'African-American').values.astype(int)
        del df['race']
    
    Y = df['two_year_recid'].values.astype(int)
    del df['two_year_recid']

    return getDataset(df, S, Y, num_train)


def loadMnist(color=True):
    transform = transforms.Compose([transforms.ToPILImage(),transforms.Resize(64), transforms.CenterCrop(64),transforms.ToTensor()])
    train_data = torchvision.datasets.MNIST(root='./data', train=True, download=True)
    test_data = torchvision.datasets.MNIST(root='./data', train=False, download=True)

    if not color:
        n = len(train_data)
        X_train = torch.zeros(n, 1, 64, 64)
        Y_train = train_data.targets
        S_train = Y_train
        for i in range(n):
            X_train[i,0] = transform(train_data.data[i]).squeeze()
        X_train /= X_train.max()
        
        n = len(test_data)
        X_test = torch.zeros(n, 1, 64, 64)
        Y_test = test_data.targets
        S_test = Y_train
        for i in range(n):
            X_test[i,0] = transform(test_data.data[i]).squeeze()
        X_test /= X_test.max()
    else:
        # train
        n = len(train_data)
        X_train = torch.zeros(n, 3, 64, 64)
        S_train = torch.arange(n) % 3
        Y_train = train_data.targets
        for i in range(n):
            X_train[i,S_train[i]] = transform(train_data.data[i]).squeeze()
        X_train /= X_train.max()
        
        # test
        n = len(test_data)
        X_test = torch.zeros(n, 3, 64, 64)
        S_test = torch.arange(n) % 3
        Y_test = test_data.targets
        for i in range(n):
            X_test[i,S_test[i]] = transform(test_data.data[i]).squeeze()
        X_test /= X_test.max()
    
    train_data = VectorDataset(X_train, S_train, Y_train)
    test_data = VectorDataset(X_test, S_test, Y_test)
    
    return train_data, test_data

all_datasets = ['Adult-sex', 'Adult-race', 'Compas-sex', 'Compas-race']

def load_dataset(dataset):
    assert dataset in all_datasets
    if dataset == 'Adult-sex':
        train_data, test_data, D = loadAdult(pro_att='sex')
    elif dataset == 'Adult-race':
        train_data, test_data, D = loadAdult(pro_att='race')
    elif dataset == 'Compas-sex':
        train_data, test_data, D = loadCompas(pro_att='sex')
    elif dataset == 'Compas-race':
        train_data, test_data, D = loadCompas(pro_att='race')
    return train_data, test_data, D