__author__ = ''
__date__ = '2023/06/30'

'''

data preprocess for FairAD

'''

import os
import sys
from tqdm import tqdm
# import cv2
import random
from os import path as osp
import pandas as pd
from pandas import DataFrame, Series
import numpy as np
import torch
import pickle
from collections import Counter
from sklearn.preprocessing import MinMaxScaler, StandardScaler

sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))


# self-defined libraries
from configs import ROOT_DIR, DATA_DIR
from utils.tools import obj_save, obj_load, new_dir




def adult_preprocessing(data_path='data/adult.data', n_train=10000, n_test=2000, seed=42):

    # read data
    df_data = pd.read_csv(data_path, header=None, names=['age', 'workclass', 'fnlwgt', 'education',
                                                                    'education-num', 'marital-status', 'occupation',
                                                                    'relationship', 'race', 'sex', 'capital-gain',
                                                                    'capital-loss',
                                                                    'hours-per-week', 'native-country', 'y'])


   # data size
    print(f'instances: {len(df_data.iloc[:, 0])}') # 32561
    print(f'columns: {len(df_data.iloc[0, :])}') # 15

    # remove the instance with missing value
    df_data_clean = None
    num_has_missing_instance = 0
    for i in range(len(df_data.iloc[:, 0])):
        missing_value = False
        for j in range(len(df_data.iloc[0, :])):
            if str(df_data.iloc[i, j]) == ' ?':
                missing_value = True
                num_has_missing_instance += 1
                break
        if not missing_value:
            if df_data_clean is None:
                df_data_clean = df_data.iloc[i:i+1]
            else:
                df_data_clean = pd.concat([df_data_clean, df_data.iloc[i:i+1]])
    
    # non missing value data size
    print(num_has_missing_instance)
    print(f'clean instnace: {len(df_data_clean.iloc[:, 0])}') # 30162
    print(f'clean data: {df_data_clean.shape}')
    
    df_data_clean.loc[df_data['y'] == ' >50K', 'y'] = 1
    df_data_clean.loc[df_data['y'] == ' <=50K', 'y'] = 0
    df_data_clean.loc[df_data['sex'] == ' Female', 'sex'] = 1
    df_data_clean.loc[df_data['sex'] == ' Male', 'sex'] = -1

    print(f'a instance before preprocess: {df_data_clean.values[0]}')
    for i in ['workclass', 'marital-status', 'occupation', 'education', 'relationship', 'race', 'native-country']:
        column_value = list(set(df_data_clean[i].values))
        # print(len(column_value), column_value)

        for j, value in enumerate(column_value):
            df_data_clean.loc[df_data_clean[i] == value, i] = j + 1
    


    df_data_clean = df_data_clean[['sex', 'age', 'native-country', 'race', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation',
                       'relationship', 'capital-gain', 'capital-loss',
                       'hours-per-week', 'y']]
    print(f'a instance after preprocess: {df_data_clean.values[0]}')


    print(df_data_clean.info())
    # obj_save(osp.join(ROOT_DIR, 'data/Adult/adult_clean.pkl'), df_data_clean)

    normal_data = df_data_clean.loc[df_data_clean['y'] == 0]
    abnormal_data = df_data_clean.loc[df_data_clean['y'] == 1]

    print(f"normal male: {normal_data.loc[normal_data['sex'] == -1].shape}")
    print(f"normal female: {normal_data.loc[normal_data['sex'] == 1].shape}")
    print(f"abnormal male: {abnormal_data.loc[abnormal_data['sex'] == -1].shape}")
    print(f"abnormal female: {abnormal_data.loc[abnormal_data['sex'] == 1].shape}")


def compas_preprocessing(data_path, n_train=2000, n_test=2000, seed=42):

    print('# COMPAS preprocessing ======================================================================')
    df = pd.read_csv(data_path)
    # number_features = df.select_dtypes(['number']).columns
    # print(len(number_features), (number_features))
    print(f'Original data: {df.shape}') # (6173, 53)
    df_sel = df.loc[df['race'].isin(['African-American', 'Caucasian'])]  # binary classification w.r.t. race
    df_sel = df_sel[['race', 'sex', 'age', # demographics
                    'juv_fel_count', 'juv_misd_count', 'juv_other_count', 'priors_count', # criminal history
                     'decile_score', # recidivism score
                     'two_year_recid' # ground truth
                     ]]
    # juvenile_misdemeanors_count; felony
    for i in ['age_cat', 'c_charge_degree', 'c_charge_desc']:
        column_value = list(set(df_sel[i].values))
        # print(len(column_value), column_value)

        for j, value in enumerate(column_value):
            df_sel.loc[df_sel[i] == value, i] = j + 1

    df_sel.loc[df_sel['race'] == 'African-American', 'race'] = -1
    df_sel.loc[df_sel['race'] == 'Caucasian', 'race'] = 1
    df_sel.loc[df_sel['sex'] == 'Male', 'sex'] = -1
    df_sel.loc[df_sel['sex'] == 'Female', 'sex'] = 1
    df_sel.rename(columns={'two_year_recid': 'y'}, inplace=True)
    df_sel.reset_index(drop=True, inplace=True)

    print(f'Clean data: {df_sel.shape}')  # (5278, 9)
    print(f'A instance: {df_sel.values[0:10]}')

    print(df_sel.info())
    # obj_save(osp.join(ROOT_DIR, 'data/Compas/compas_clean.pkl'), df_sel)
    normal_data = df_sel.loc[df_sel['y'] == 0]
    abnormal_data = df_sel.loc[df_sel['y'] == 1]
    print(type(normal_data))

    print(f"normal AA: {normal_data.loc[normal_data['race'] == -1].shape}")
    print(f"normal C: {normal_data.loc[normal_data['race'] == 1].shape}")
    print(f"abnormal AA: {abnormal_data.loc[abnormal_data['race'] == -1].shape}")
    print(f"abnormal AA: {abnormal_data.loc[abnormal_data['race'] == 1].shape}")


def split_train_test(data_path, n_train_normal, data_name=''):

    data = obj_load(data_path)
    print(f'data: {data.shape}')
  
    normal_data = data.loc[data['y'] == 0].values
    abnormal_data = data.loc[data['y'] == 1].values
    print(f'normal data: {type(normal_data),normal_data.shape}')
    print(f'abnormal data: {type(abnormal_data), abnormal_data.shape}')
    # sampling
    idx_train_data = random.sample(list(range(normal_data.shape[0])), k=n_train_normal)
    idx_test_normal_data = []
    for i in range(normal_data.shape[0]):
        if i not in idx_train_data:
            idx_test_normal_data.append(i)

    # training data
    train_data = normal_data[idx_train_data][:, :-1]
    train_label = np.zeros(len(train_data))

    # testing data
    test_normal_data = normal_data[idx_test_normal_data][:, :-1]
    test_abnormal_data = abnormal_data[:, :-1]
    
    test_data = np.concatenate((test_normal_data, test_abnormal_data))
    test_label = np.concatenate((np.zeros(len(test_normal_data)), np.ones(len(test_abnormal_data))))

    print(f'train data: {type(train_data), train_data.shape}')
    print(f'train label: {train_label.shape}')
    print(f'test data: {type(test_data), test_data.shape}')
    print(f'test label: {test_label.shape}')

    # # saving data
    # print('saving data ...')
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/train_data.npy'), train_data)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/train_label.npy'), train_label)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/test_data.npy'), test_data)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/test_label.npy'), test_label)
    # print('save data successfully!')


def celebA_preprocessing(height=64, width=64):

    img_dir = ''
    # image path
    img_path = osp.join(img_dir, 'img_celeba/')

    # new image save path
    new_img_path = osp.join(img_dir, 'img_celeba_64')

    # face bbox file path
    face_boundingbox_anno_file_path = osp.join(img_dir, 'list_bbox_celeba.txt')

    # face attibute file path
    face_attr_celeba_file_path = 'list_attr_celeba.txt'

    # the height and width of new image
    new_h = height
    new_w = width

    if not os.path.exists(img_path):
        print("image path not exist.")
        exit(-1)


    if not os.path.exists(face_boundingbox_anno_file_path):
        print("face_boundingbox_anno_file not exist.")
        exit(-1)

    if not os.path.exists(new_img_path):
        os.makedirs(new_img_path)
    else:
        os.system('rm -rf %s/*'%new_img_path)

    # laoding file
    face_boundingbox_anno_file = open(face_boundingbox_anno_file_path, 'r')
    face_bbox = face_boundingbox_anno_file.readlines()

    for i in tqdm(range(2, len(face_bbox))):
        face_bbox_split = face_bbox[i].split()
        filename = face_bbox_split[0]

        face = []
        for j in range(1, len(face_bbox_split)):
            face.append(int(face_bbox_split[j]))

        face = np.array(face)
        
        try:
            path = os.path.join(img_path, filename)
            # print(filename)
            new_path = os.path.join(new_img_path, filename)
            if not os.path.exists(path):
                print(path, 'not exist')
                continue
            img = cv2.imread(path)

            # 裁剪图像
            newImg = img[face[1]:face[3]+face[1], face[0]:face[2]+face[0]]
            resizeImg = cv2.resize(newImg, (new_h, new_w))
            # 存储新图片
            cv2.imwrite(new_path, resizeImg)
        except:
            print("filename:%s process failed"%(filename))

    face_boundingbox_anno_file.close()


def split_train_test_4_celeba():

    data_path = osp.join('', 'list_attr_celeba.txt')

    with open(data_path, 'r', encoding='utf8') as _file:

        attr_celeb = _file.readlines()
    
    normal_data = {}
    abnormal_data = {}
    attractive_pos = attr_celeb[1].split().index('Attractive') + 1
    male_pos = attr_celeb[1].split().index('Male') + 1

    normal_male = 0
    normal_female = 0
    abnormal_male = 0
    abnormal_female = 0

    for line in attr_celeb[2:]:
        attr = line.split()
        if attr[0] == '101283.jpg':
            continue
        if attr[attractive_pos] == '1':
            normal_data.setdefault(attr[0], int(attr[male_pos]))
            if attr[male_pos] == '1':
                normal_male += 1
            else:
                normal_female += 1
        elif attr[attractive_pos] == '-1':
            abnormal_data.setdefault(attr[0], int(attr[male_pos]))
            if attr[male_pos] == '1':
                abnormal_male += 1
            else:
                abnormal_female += 1
        else:
            print(f'Unknown mark in Attractive term in {attr[0]}')

    print(f'normal male: {normal_male}')
    print(f'normal female: {normal_female}')
    print(f'abnormal male: {abnormal_male}')
    print(f'abnormal female: {abnormal_female}')
    
    
    print(f'normal_data: {len(normal_data)}')
    print(f'a normal sampel: {list(normal_data.keys())[0]} {normal_data[list(normal_data.keys())[0]]}')
    print(f'abnormal_data: {len(abnormal_data)}')
    print(f'a abnormal sample: {list(abnormal_data.keys())[0]} {abnormal_data[list(abnormal_data.keys())[0]]}')
    print(f'all data: {len(normal_data) + len(abnormal_data)}')

    obj_save(osp.join(ROOT_DIR, 'data/CelebA/normal_data.pkl'), normal_data)
    obj_save(osp.join(ROOT_DIR, 'data/CelebA/abnormal_data.pkl'), abnormal_data)

    pass


def split_balanced_imbalanced_4_celeba(n_train_normal_0, n_test_normal_0, ratio=4):

    path_normal_file_name = osp.join(ROOT_DIR, 'data/CelebA/normal_data.pkl')
    path_abnormal_file_name = osp.join(ROOT_DIR, 'data/CelebA/abnormal_data.pkl')
    img_dir = ''

    normal_file_name = obj_load(path_normal_file_name)
    abnormal_file_name = obj_load(path_abnormal_file_name)

    print(f'normal data: {len(normal_file_name)}')
    print(f'normal data: {set(normal_file_name.values())}')
    print(f'abnormal data: {len(abnormal_file_name)}')
    print(f'abnormal data: {set(abnormal_file_name.values())}')

    normal_male_file_name = []
    normal_female_file_name = []
    for key, value in normal_file_name.items():
        if value == 1:
            normal_male_file_name.append(key)
        else:
            normal_female_file_name.append(key)
    
    abnormal_male_file_name = []
    abnormal_female_file_name = []
    for key, value in abnormal_file_name.items():
        if value == 1:
            abnormal_male_file_name.append(key)
        else:
            abnormal_female_file_name.append(key)
    
    n_train_normal_male = n_train_normal_0
    n_test_normal_male = n_test_normal_0
    file_train_normal_male = random.sample(population=normal_male_file_name, k=n_train_normal_male)
    file_test_normal_male = []
    for e in normal_male_file_name:
        if e not in file_train_normal_male:
            file_test_normal_male.append(e)
        if len(file_test_normal_male) == n_test_normal_male:
            break
    
    n_train_normal_female = int(n_train_normal_male / ratio)
    n_test_normal_female = int(n_test_normal_male / ratio)
    file_train_normal_female = random.sample(population=normal_female_file_name, k=n_train_normal_female)
    file_test_normal_female = []
    for e in normal_female_file_name:
        if e not in file_train_normal_female:
            file_test_normal_female.append(e)
        if len(file_test_normal_female) == n_test_normal_female:
            break

    # read image
    train_normal_male_data = None
    for file_name in file_train_normal_male:
        img_path = osp.join(img_dir, file_name)
        img = cv2.imread(img_path)
        vector = img.reshape(1, -1)
        if train_normal_male_data is None:
            train_normal_male_data = vector
        else:
            train_normal_male_data = np.concatenate((train_normal_male_data, vector))
    train_normal_male_data = np.append(train_normal_male_data, np.ones((len(train_normal_male_data), 1)), axis=1)
    print(f'train normal male : {train_normal_male_data.shape}')

    train_normal_female_data = None
    for file_name in file_train_normal_female:
        img_path = osp.join(img_dir, file_name)
        img = cv2.imread(img_path)
        vector = img.reshape(1, -1)
        if train_normal_female_data is None:
            train_normal_female_data = vector
        else:
            train_normal_female_data = np.concatenate((train_normal_female_data, vector))
    train_normal_female_data = np.append(train_normal_female_data, -1 * np.ones((len(train_normal_female_data), 1)), axis=1)
    print(f'train normal female : {train_normal_female_data.shape}')
    train_data = np.concatenate((train_normal_male_data, train_normal_female_data))
    print(f'train data: {train_data.shape}')
    train_label = np.zeros(len(train_data))

    file_test_abnormal_male = random.sample(population=abnormal_male_file_name, k=n_test_normal_male)
    file_test_abnormal_female = random.sample(population=abnormal_female_file_name, k=n_test_normal_female)

    test_normal_male_data = None
    test_normal_female_data = None
    test_abnormal_male_data = None
    test_abnormal_female_data = None
    for file_normal, file_abnormal in zip(file_test_normal_male, file_test_abnormal_male):
        normal_img_path = osp.join(img_dir, file_normal)
        abnormal_img_path = osp.join(img_dir, file_abnormal)
        normal_img = cv2.imread(normal_img_path)
        abnormal_img = cv2.imread(abnormal_img_path)
        normal_vector = normal_img.reshape(1, -1)
        abnormal_vector = abnormal_img.reshape(1, -1)
        if test_normal_male_data is None:
            test_normal_male_data = normal_vector
        else:
            test_normal_male_data = np.concatenate((test_normal_male_data, normal_vector))
        if test_abnormal_male_data is None:
            test_abnormal_male_data = abnormal_vector
        else:
            test_abnormal_male_data = np.concatenate((test_abnormal_male_data, abnormal_vector))
    test_normal_male_data = np.append(test_normal_male_data, np.ones((len(test_normal_male_data), 1)), axis=1)
    test_abnormal_male_data = np.append(test_abnormal_male_data, np.ones((len(test_abnormal_male_data), 1)), axis=1)
    
    for file_normal, file_abnormal in zip(file_test_normal_female, file_test_abnormal_female):
        normal_img_path = osp.join(img_dir, file_normal)
        abnormal_img_path = osp.join(img_dir, file_abnormal)
        normal_img = cv2.imread(normal_img_path)
        abnormal_img = cv2.imread(abnormal_img_path)
        normal_vector = normal_img.reshape(1, -1)
        abnormal_vector = abnormal_img.reshape(1, -1)
        if test_normal_female_data is None:
            test_normal_female_data = normal_vector
        else:
            test_normal_female_data = np.concatenate((test_normal_female_data, normal_vector))
        if test_abnormal_female_data is None:
            test_abnormal_female_data = abnormal_vector
        else:
            test_abnormal_female_data = np.concatenate((test_abnormal_female_data, abnormal_vector))

    test_normal_female_data = np.append(test_normal_female_data, -1 * np.ones((len(test_normal_female_data), 1)), axis=1)
    test_abnormal_female_data = np.append(test_abnormal_female_data, -1 * np.ones((len(test_abnormal_female_data), 1)), axis=1)
    test_normal_data = np.concatenate((test_normal_male_data, test_normal_female_data))
    test_abnormal_data = np.concatenate((test_abnormal_male_data, test_abnormal_female_data))
    test_data = np.concatenate((test_normal_data, test_abnormal_data))
    test_label = np.concatenate((np.zeros(len(test_normal_data)), np.ones(len(test_abnormal_data))))

    print(f'test data: {test_data.shape}')
    print(f'test normal data: {test_normal_data.shape}')
    print(f'male: female: {len(test_normal_male_data)} : {len(test_normal_female_data)}')
    print(f'test abnormal data: {test_abnormal_data.shape}')
    print(f'male: female: {len(test_abnormal_male_data)} : {len(test_abnormal_female_data)}')
    # new_img = new_img.reshape(64, 64, 3)
    # cv2.imwrite('./%s.png' % file_name, new_img)
    
    # file_name_train = np.concatenate((file_train_normal_male, file_train_normal_female))
    # file_name_test_normal = np.concatenate((file_test_normal_male, file_test_normal_female))



    # file_name_test_abnormal = np.concatenate((file_test_abnormal_male, file_test_abnormal_female))
    # file_name_test = np.concatenate((file_name_test_normal, file_name_test_abnormal))
    # save data
    print('save data ...')
    if ratio == 1:
        _dir = 'balanced'
    else:
        _dir = 'imbalanced'
    
    obj_save('target dir/%s/train_data.npy' % _dir, train_data)
    obj_save('/target dir/%s/train_label.npy' % _dir, train_label)
    obj_save('/target dir/%s/test_data.npy' % _dir, test_data)
    obj_save('/target dir/%s/test_label.npy' % _dir, test_label)
    print('save data successfully!')


def split_train_test_4_adult(n_train_normal_0, n_test_normal_0, ratio=4, ratio_abnormal=0.0):

    data_dir = osp.join(ROOT_DIR, 'data')
    data_path = osp.join(data_dir, 'Adult/adult_clean.pkl')

    data = obj_load(data_path)
    print(f'data: {data.shape}')

    normal_data = data.loc[data['y'] == 0]
    abnormal_data = data.loc[data['y'] ==1]

    print(f'normal data: {normal_data.shape}')
    print(f'abnormal data: {abnormal_data.shape}')

    normal_male_data = normal_data.loc[normal_data['sex'] == -1].values[:, :-1]
    normal_female_data = normal_data.loc[normal_data['sex'] == 1].values[:, :-1]
    abnormal_male_data = abnormal_data.loc[abnormal_data['sex'] == -1].values[:, :-1]
    abnormal_female_data = abnormal_data.loc[abnormal_data['sex'] == 1].values[:, :-1]
    print(f'normal male data: {normal_male_data.shape}')
    print(f'normal female data: {normal_female_data.shape}')
    print(f'abnormal male data: {abnormal_male_data.shape}')
    print(f'abnormal female data: {abnormal_female_data.shape}')

    # sampling ==================================================
    ratio = ratio
    n_train_normal_male = n_train_normal_0
    n_test_nomral_male = n_test_normal_0
    assert len(normal_male_data) >= (n_train_normal_male + n_test_nomral_male)

    idx_train_normal_male = random.sample(list(range(normal_male_data.shape[0])), k=n_train_normal_male)
    idx_test_normal_male = []
    for i in range(normal_male_data.shape[0]):
        if i not in idx_train_normal_male:
            idx_test_normal_male.append(i)
        if len(idx_test_normal_male) == n_test_nomral_male:
            break

    train_normal_male_data = normal_male_data[idx_train_normal_male]
    test_normal_male_data = normal_male_data[idx_test_normal_male]

    n_train_normal_female = int(n_train_normal_male / ratio)
    idx_train_normal_female = random.sample(list(range(normal_female_data.shape[0])), k=n_train_normal_female)
    idx_test_normal_female = []

    for i in range(normal_female_data.shape[0]):
        if i not in idx_train_normal_female:
            idx_test_normal_female.append(i)
        if len(idx_test_normal_female) == int(n_test_nomral_male / ratio):
            break

    train_normal_female_data = normal_female_data[idx_train_normal_female]
    test_normal_female_data = normal_female_data[idx_test_normal_female]

    train_data = np.concatenate((train_normal_male_data, train_normal_female_data))
    train_label = np.zeros(len(train_data))
    test_normal_data = np.concatenate((test_normal_male_data, test_normal_female_data))
    print(f'train data: {train_data.shape}')
    print(f'male: female: {len(train_normal_male_data)}: {len(train_normal_female_data)}')
    print(f'test normal data: {test_normal_data.shape}')
    print(f'male: female: {len(test_normal_male_data)}: {len(test_normal_female_data)}')

    if ratio_abnormal > 0.0:
        print(f'ratio of contamination: [{ratio_abnormal}]')
        num_contamination_abnormal_sample = int(ratio_abnormal * len(train_data))
        idx_test_abnormal_male = random.sample(list(range(abnormal_male_data.shape[0])), k=len(test_normal_male_data) + int(num_contamination_abnormal_sample/2))
        idx_test_abnormal_female = random.sample(list(range(abnormal_female_data.shape[0])), k=len(test_normal_female_data) + int(num_contamination_abnormal_sample/2))

        test_abnormal_male_data = abnormal_male_data[idx_test_abnormal_male][:len(test_normal_male_data)]
        test_abnormal_female_data = abnormal_female_data[idx_test_abnormal_female][:len(test_normal_female_data)]
        male_contamination_data = abnormal_male_data[idx_test_abnormal_male][len(test_normal_male_data): ]
        female_contamination_data = abnormal_female_data[idx_test_abnormal_female][len(test_normal_female_data):]
        contamination_data = np.concatenate((male_contamination_data, female_contamination_data))
        train_data = np.concatenate((train_data, contamination_data))
        train_label = np.zeros(len(train_data))
        print(f'contamination data: {contamination_data.shape}')
        print(f'train data: {train_data.shape}')
    else:

        idx_test_abnormal_male = random.sample(list(range(abnormal_male_data.shape[0])), k=len(test_normal_male_data))
        idx_test_abnormal_female = random.sample(list(range(abnormal_female_data.shape[0])), k=len(test_normal_female_data))

        test_abnormal_male_data = abnormal_male_data[idx_test_abnormal_male]
        test_abnormal_female_data = abnormal_female_data[idx_test_abnormal_female]

    test_abnormal_data = np.concatenate((test_abnormal_male_data, test_abnormal_female_data))
    test_data = np.concatenate((test_normal_data, test_abnormal_data))
    test_label = np.concatenate((np.zeros(len(test_normal_data)), np.ones(len(test_abnormal_data))))
    
    print(f'test abnormal data: {test_abnormal_data.shape}')
    print(f'abnormal male: female: {test_abnormal_male_data.shape[0]}:{test_abnormal_female_data.shape[0]}')
    
    # saving data
    # print('saving data ...')
    # data_name = 'Adult'
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/train_data.npy'), train_data)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/train_label.npy'), train_label)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/test_data.npy'), test_data)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/test_label.npy'), test_label)
    # print('save data successfully!')

    # saving contaminated data
    print('saving data ...')
    print(f'train_data: {train_data.shape}')
    print(f'train_label: {train_label.shape}')
    print(f'test data: {test_data.shape}')
    print(f'test label: {test_label.shape}')
    print(f'contamination data: {contamination_data.shape}')
    data_name = 'Adult'
    if ratio == 1:
        save_dir = new_dir(ROOT_DIR, f'data/{data_name}/processed/balanced/contaminated/{str(ratio_abnormal)}')
    else:
        save_dir = new_dir(ROOT_DIR, f'data/{data_name}/processed/imbalanced/contaminated/{str(ratio_abnormal)}')
    obj_save(osp.join(save_dir, 'train_data.npy'), train_data)
    obj_save(osp.join(save_dir, 'train_label.npy'), train_label)
    obj_save(osp.join(save_dir, 'test_data.npy'), test_data)
    obj_save(osp.join(save_dir, 'test_label.npy'), test_label)
    print('save data successfully!')


def split_train_test_4_compas(n_train_normal_0, n_test_normal_0, ratio=4, ratio_abnormal=0.01):

    data_dir = osp.join(ROOT_DIR, 'data')
    data_path = osp.join(data_dir, 'Compas/compas_clean.pkl')

    data = obj_load(data_path)
    print(f'data: {data.shape}')

    normal_data = data.loc[data['y'] == 0]
    abnormal_data = data.loc[data['y'] ==1]

    print(f'normal data: {normal_data.shape}')
    print(f'abnormal data: {abnormal_data.shape}')

    normal_white_data = normal_data.loc[normal_data['race'] == 1].values[:, :-1]
    normal_black_data = normal_data.loc[normal_data['race'] == -1].values[:, :-1]
    abnormal_white_data = abnormal_data.loc[abnormal_data['race'] == 1].values[:, :-1]
    abnormal_black_data = abnormal_data.loc[abnormal_data['race'] == -1].values[:, :-1]

    print(f'normal white data: {normal_white_data.shape}')
    print(f'normal black data: {normal_black_data.shape}')
    print(f'abnormal white data: {abnormal_white_data.shape}')
    print(f'abnormal black data: {abnormal_black_data.shape}')

    # sampling ==================================================
    ratio = ratio
    n_train_normal_black = n_train_normal_0
    n_test_nomral_black = n_test_normal_0
    assert len(normal_black_data) >= (n_train_normal_black + n_test_nomral_black)

    idx_train_normal_black = random.sample(list(range(normal_black_data.shape[0])), k=n_train_normal_black)
    idx_test_normal_black = []
    for i in range(normal_black_data.shape[0]):
        if i not in idx_train_normal_black:
            idx_test_normal_black.append(i)
        if len(idx_test_normal_black) == n_test_nomral_black:
            break

    train_normal_black_data = normal_black_data[idx_train_normal_black]
    test_normal_black_data = normal_black_data[idx_test_normal_black]

    n_train_normal_white = int(n_train_normal_black / ratio)
    idx_train_normal_white = random.sample(list(range(normal_white_data.shape[0])), k=n_train_normal_white)

    idx_test_normal_white = []

    for i in range(normal_white_data.shape[0]):
        if i not in idx_train_normal_white:
            idx_test_normal_white.append(i)
        if len(idx_test_normal_white) == int(n_test_nomral_black / ratio):
            break

    train_normal_white_data = normal_white_data[idx_train_normal_white]
    test_normal_white_data = normal_white_data[idx_test_normal_white]

    train_data = np.concatenate((train_normal_black_data, train_normal_white_data))
    train_label = np.zeros(len(train_data))
    test_normal_data = np.concatenate((test_normal_black_data, test_normal_white_data))
    print(f'train data: {train_data.shape}')
    print(f'white: black: {len(train_normal_white_data)}: {len(train_normal_black_data)}')
    print(f'test normal data: {test_normal_data.shape}')
    print(f'white: black: {len(test_normal_white_data)}: {len(test_normal_black_data)}')

    if ratio_abnormal > 0.0:

        print(f'ratio of contamination: [{ratio_abnormal}]')
        num_contamination_abnormal_sample = int(ratio_abnormal * len(train_data))
        idx_test_abnormal_white = random.sample(list(range(abnormal_white_data.shape[0])), k=len(test_normal_white_data) + int(num_contamination_abnormal_sample/2))
        idx_test_abnormal_black = random.sample(list(range(abnormal_black_data.shape[0])), k=len(test_normal_black_data) + int(num_contamination_abnormal_sample/2))

        test_abnormal_white_data = abnormal_white_data[idx_test_abnormal_white][:len(test_normal_white_data)]
        test_abnormal_black_data = abnormal_black_data[idx_test_abnormal_black][:len(test_normal_black_data)]

        white_contamination_data = abnormal_white_data[idx_test_abnormal_white][len(test_normal_white_data):]
        black_contamination_data = abnormal_black_data[idx_test_abnormal_black][len(test_normal_black_data):]

        contamination_data = np.concatenate((white_contamination_data, black_contamination_data))

        train_data = np.concatenate((train_data, contamination_data))
        train_label = np.zeros(len(train_data))

    else:
        idx_test_abnormal_white = random.sample(list(range(abnormal_white_data.shape[0])), k=len(test_normal_white_data))
        idx_test_abnormal_black = random.sample(list(range(abnormal_black_data.shape[0])), k=len(test_normal_black_data))

        test_abnormal_white_data = abnormal_white_data[idx_test_abnormal_white]
        test_abnormal_black_data = abnormal_black_data[idx_test_abnormal_black]
    test_abnormal_data = np.concatenate((test_abnormal_black_data, test_abnormal_white_data))
    test_data = np.concatenate((test_normal_data, test_abnormal_data))
    test_label = np.concatenate((np.zeros(len(test_normal_data)), np.ones(len(test_abnormal_data))))
    
    print(f'test abnormal data: {test_abnormal_data.shape}')
    print(f'abnormal white: black: {test_abnormal_white_data.shape[0]}:{test_abnormal_black_data.shape[0]}')
    
    # # saving data
    # print('saving data ...')
    # data_name = 'Compas'
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/imbalanced/train_data.npy'), train_data)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/imbalanced/train_label.npy'), train_label)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/imbalanced/test_data.npy'), test_data)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/imbalanced/test_label.npy'), test_label)
    # print('save data successfully!')
     # saving contaminated data
    print('saving data ...')
    print(f'train_data: {train_data.shape}')
    print(f'train_label: {train_label.shape}')
    print(f'test data: {test_data.shape}')
    print(f'test label: {test_label.shape}')
    print(f'contamination data: {contamination_data.shape}')
    data_name = 'Compas'
    if ratio == 1:
        save_dir = new_dir(ROOT_DIR, f'data/{data_name}/processed/balanced/contaminated/{str(ratio_abnormal)}')
    else:
        save_dir = new_dir(ROOT_DIR, f'data/{data_name}/processed/imbalanced/contaminated/{str(ratio_abnormal)}')
    obj_save(osp.join(save_dir, 'train_data.npy'), train_data)
    obj_save(osp.join(save_dir, 'train_label.npy'), train_label)
    obj_save(osp.join(save_dir, 'test_data.npy'), test_data)
    obj_save(osp.join(save_dir, 'test_label.npy'), test_label)
    print('save data successfully!')


def global_contrast_normalization(x: torch.tensor, scale='l2'):
    """
    Apply global contrast normalization to tensor, i.e. subtract mean across features (pixels) and normalize by scale,
    which is either the standard deviation, L1- or L2-norm across features (pixels).
    Note this is a *per sample* normalization globally across features (and not across the dataset).
    """

    assert scale in ('l1', 'l2')

    n_features = int(np.prod(x.shape))

    mean = torch.mean(x)  # mean over all features (pixels) per sample
    x -= mean

    if scale == 'l1':
        x_scale = torch.mean(torch.abs(x))

    if scale == 'l2':
        x_scale = torch.sqrt(torch.sum(x ** 2)) / n_features

    x /= x_scale

    return x


def titanic_process():
    train_data_path = osp.join(DATA_DIR, 'Titanic/train.csv')
    df_row = pd.read_csv(train_data_path)
    print(f'row data size: {df_row.iloc[:, :].values.shape}')

    # The features 'ticket' and 'cabin' have many missing values and so can't add much value to our analysis. 
    # and the 'PassengerId' and 'Name' are not useful for anomaly detection.
    df_row = df_row.drop(['Ticket', 'Cabin', 'PassengerId', 'Name'], axis=1)

    # Remove samples with NaN values
    df_row = df_row.dropna()
    print(f'clean data size: {df_row.iloc[:, :].values.shape}')

    # change the position of attributes
    df = df_row[['Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked', 'Pclass', 'Survived']]

    df.loc[df['Sex'] == 'female', 'Sex'] = -1
    df.loc[df['Sex'] == 'male', 'Sex'] = 1


    for i in ['Embarked']:
        column_value = list(set(df[i].values))
        print(len(column_value), column_value)

        for j, value in enumerate(column_value):
            df.loc[df[i] == value, i] = j + 1

    print(df.info)

    data = df.iloc[:,:].values
    print(f'Survived 0:{len(df.loc[df["Survived"] == 0, "Survived"])}')
    print(f'Survived 1:{len(df.loc[df["Survived"] == 1, "Survived"])}')
    print(f'{type(data), data.shape}')

    # Therefore, we define that 'Survived == 0' as normal data, 'Survived == 1' as abnormal data.

    normal_data = df.loc[df['Survived'] == 0]
    abnormal_data = df.loc[df['Survived'] == 1]
    print(f'normal_data: {normal_data.shape}')
    print(f'abnormal_data: {abnormal_data.shape}')

    male_normal_data = normal_data.loc[normal_data['Sex'] == 1].values[:, :-1]
    female_normal_data = normal_data.loc[normal_data['Sex'] ==  -1].values[:, :-1]
    print(f'normal male: {male_normal_data.shape}')
    print(f'normal female: {female_normal_data.shape}')

    male_abnormal_data = abnormal_data.loc[abnormal_data['Sex'] == 1].values[:, :-1]
    female_abnormal_data = abnormal_data.loc[abnormal_data['Sex'] == -1].values[:, :-1]
    print(f'abnormal male: {male_abnormal_data.shape}')
    print(f'abnormal female: {female_abnormal_data.shape}')

    # split training and test set
    test_male_normal_data = male_normal_data[0:30, :]
    train_male_normal_data = male_normal_data[30:, :]
    test_female_normal_data = female_normal_data[0:30, :]
    train_female_normal_data = female_normal_data[30:, :]
    train_data = np.concatenate((train_male_normal_data, train_female_normal_data), axis=0)
    train_label = np.zeros(len(train_data))

    test_normal_data = np.concatenate((test_male_normal_data, test_female_normal_data), axis=0)
    test_male_abnormal_data = male_abnormal_data[:30, :]
    test_female_abnormal_data = female_abnormal_data[:30, :]
    test_abnormal_data = np.concatenate((test_male_abnormal_data, test_female_abnormal_data), axis=0)
    test_data = np.concatenate((test_normal_data, test_abnormal_data), axis=0)
    test_label = np.concatenate((np.zeros(len(test_normal_data)), np.ones(len(test_abnormal_data))), axis=0)

    train_data = np.array(train_data, dtype=np.float32)
    test_data = np.array(test_data, dtype=np.float32)

    print(f'train data: {train_data.shape}')
    print(f'train label: {sum(train_label)}')
    for i in range(10):
        print(f'a instance: {train_data[i]}')
        print(f'a instance: {train_data[-1*i]}')
    print(f'test data: {test_data.shape}')
    print(f'test label: {sum(test_label)}')

    # saving data
    print('saving data ...')
    data_name = 'Titanic'
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/train_data.npy'), train_data)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/train_label.npy'), train_label)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/test_data.npy'), test_data)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/test_label.npy'), test_label)


def student_performance_process(file_name='student-mat.csv'):

    data_path = osp.join(DATA_DIR, f'Student_Performance/{file_name}')

    df_raw = pd.read_csv(data_path, sep=';', header=0)

    print(df_raw.info) # 395 * 33 (sex is protected attribute and G3 is ground truth)
    df = df_raw[['sex', 'school', 'age', 'address', 'famsize', 'Pstatus', 
                       'Medu', 'Fedu', 'Mjob', 'Fjob', 'reason', 'guardian', 
                       'traveltime','studytime', 'failures', 'schoolsup', 'famsup',
                       'paid', 'activities', 'nursery', 'higher', 'internet', 'romantic',
                       'famrel', 'freetime', 'goout', 'Dalc', 'Walc', 'health', 'absences', 
                       'G1', 'G2', 'G3']]
    for i in ['school', 'address', 'famsize', 'Pstatus', 'Mjob', 'Fjob', 'reason', 'guardian']:
        column_value = list(set(df[i].values))
        print(len(column_value), column_value)

        for j, value in enumerate(column_value):
            df.loc[df[i] == value, i] = j + 1

    for i in ['schoolsup', 'famsup', 'paid', 'activities', 'nursery', 'higher', 'internet', 'romantic']:
    
        df.loc[df[i] == 'yes', i] = 1
        df.loc[df[i] == 'no', i] = -1
    

    for i in range(len(df.iloc[:, -1])): # for G1, G2
        df.iloc[i, -2] = int(df.iloc[i, -2])
        df.iloc[i, -3] = int(df.iloc[i, -3])

    df.loc[df['sex'] == 'M', 'sex'] = 1
    df.loc[df['sex'] == 'F', 'sex'] = -1

    print(df.info)
    print(df.iloc[0, :].values)


    size = len(df.iloc[:, -1])
    p = [0 for _ in range(10)]

    for i in range(size):
        g3  = df.iloc[i, -1]
        if g3 == 0:
            p[0] += 1
        else:
            p[int(g3 / 2 + g3 % 2 - 1)] += 1
    
    print(f'size: {size}')
    print(f'sum_p: {sum(p)}')
    for i, s in enumerate(p):
        print(f'[{i*2}-{(i+1)*2}]: {s / size:.2f}')
    
    for i in range(size):
        g3 = df.iloc[i, -1]
        if g3 >= 18 or g3 <= 2:
            df.iloc[i, -1] = 1
        else:
            df.iloc[i, -1] = 0

    normal_data = df.loc[df['G3'] == 0]
    abnormal_data = df.loc[df['G3'] == 1]
    print(f'normal data: {normal_data.shape}')
    print(f'abnormal data: {abnormal_data.shape}')

    normal_male_data = normal_data.loc[normal_data['sex'] == 1].values[:,:-1]
    normal_female_data = normal_data.loc[normal_data['sex'] == -1].values[:,:-1]

    abnormal_male_data = abnormal_data.loc[abnormal_data['sex'] == 1].values[:,:-1]
    abnormal_female_data = abnormal_data.loc[abnormal_data['sex'] == -1].values[:,:-1]

    print(f'normal male data: {normal_male_data.shape}')
    print(f'normal female data: {normal_female_data.shape}')
    print(f'abnormal male data: {abnormal_male_data.shape}')
    print(f'abnormal female data: {abnormal_female_data.shape}')

    test_normal_male_data = normal_male_data[:26, :]
    train_normal_male_data = normal_male_data[26:, :]
    test_normal_female_data = normal_female_data[:30, :]
    train_normal_female_data = normal_female_data[30:, :]

    train_data = np.concatenate((train_normal_male_data, train_normal_female_data), axis=0)
    train_label = np.zeros(len(train_data))

    test_normal_data = np.concatenate((test_normal_male_data, test_normal_female_data), axis=0)
    test_abnormal_data = np.concatenate((abnormal_male_data, abnormal_female_data), axis=0)
    test_data = np.concatenate((test_normal_data, test_abnormal_data), axis=0)
    test_label = np.concatenate((np.zeros(len(test_normal_data)), np.ones(len(test_abnormal_data))), axis=0)
    train_data = np.array(train_data, dtype=np.float32)
    test_data = np.array(test_data, dtype=np.float32)

    print(f'train data: {train_data.shape}')
    print(f'train label: {sum(train_label)}')
    for i in range(10):
        print(f'a instance: {train_data[i]}')
    print(f'test data: {test_data.shape}')
    print(f'test label: {sum(test_label)}')

    # saving data
    print('saving data ...')
    data_name = 'Student_Performance'
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/train_data.npy'), train_data)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/train_label.npy'), train_label)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/test_data.npy'), test_data)
    # obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/test_label.npy'), test_label)


def credit_preprocess():

    data_path = osp.join(DATA_DIR, 'Credit/default of credit card clients.csv')

    df = pd.read_csv(data_path, sep=',', header=1)

    df = df.drop(['ID'], axis=1)

    # change the features arrangement
    df = df[['AGE', 'LIMIT_BAL', 'SEX', 'EDUCATION', 'MARRIAGE', 'PAY_0', 'PAY_2',
       'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2',
       'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1',
       'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6',
       'default payment next month'
    ]]

    df.rename(columns={'default payment next month': 'y'}, inplace=True)

   
    normal_data = df.loc[df['y'] == 0]
    abnormal_data = df.loc[df['y'] == 1]
    print(f'normal data: {normal_data.shape}')
    print(f'abnormal data: {abnormal_data.shape}')

    normal_s1_30 = normal_data.loc[normal_data['AGE'] < 30]
    normal_s1_60 = normal_data.loc[normal_data['AGE'] > 60]
    normal_s1 = pd.concat([normal_s1_30, normal_s1_60], axis=0)
    normal_s1.loc[normal_s1['AGE'] > 0, 'AGE'] = 0

    normal_s2 = normal_data.loc[30 <= normal_data['AGE']]
    normal_s2 = normal_s2.loc[normal_s2['AGE'] <=60]
    normal_s2.loc[normal_s2['AGE'] > 0, 'AGE'] = 1

    print(f'normal data ==================')
    print(f'Age<30 or Age>60: {normal_s1.shape}')
    print(f'30<=Age<=60: {normal_s2.shape}')

    abnormal_s1_30 = abnormal_data.loc[abnormal_data['AGE'] < 30]
    abnormal_s1_60 = abnormal_data.loc[abnormal_data['AGE'] > 60]
    abnormal_s1 = pd.concat([abnormal_s1_30, abnormal_s1_60], axis=0)
    abnormal_s1.loc[abnormal_s1['AGE'] > 0, 'AGE'] = 0

    abnormal_s2 = abnormal_data.loc[30 <= abnormal_data['AGE']]
    abnormal_s2 = abnormal_s2.loc[abnormal_s2['AGE'] <= 60]
    abnormal_s2.loc[abnormal_s2['AGE'] > 0, 'AGE'] = 1

    print(f'abnormal data =================')
    print(f'Age<30 or Age>60: {abnormal_s1.shape}')
    print(f'30<=Age<=60: {abnormal_s2.shape}')

    # balanced: ======================================
    # train: 5000:5000
    # test: 2000:2000:2000:2000
    b_train_number = 5000
    b_test_number = 2000
    normal_s1_data = np.random.permutation(normal_s1.values[:, :-1])
    normal_s2_data = np.random.permutation(normal_s2.values[:, :-1])
    abnormal_s1_data = np.random.permutation(abnormal_s1.values[:, :-1])
    abnormal_s2_data = np.random.permutation(abnormal_s2.values[:, :-1])

    print(f'instance of normal s1: {normal_s1_data[0]}')
    print(f'instance of normal s2: {normal_s2_data[0]}')
    print(f'instance of abnormal s1: {abnormal_s1_data[0]}')
    print(f'instance of normal s2: {abnormal_s2_data[0]}')

    b_train_s1 = normal_s1_data[:b_train_number]
    b_train_s2 = normal_s2_data[:b_train_number]
    b_train_data = np.concatenate((b_train_s1, b_train_s2), axis=0)
    b_train_lab = np.zeros(len(b_train_data))

    b_test_normal_s1 = normal_s1_data[b_train_number : b_train_number + b_test_number]
    b_test_normal_s2 = normal_s2_data[b_train_number : b_train_number + b_test_number]
    b_test_normal = np.concatenate((b_test_normal_s1, b_test_normal_s2), axis=0)

    b_test_abnormal_s1 = abnormal_s1_data[:b_test_number]
    b_test_abnormal_s2 = abnormal_s2_data[:b_test_number]
    b_test_abnormal = np.concatenate((b_test_abnormal_s1, b_test_abnormal_s2), axis=0)
    b_test_data = np.concatenate((b_test_normal, b_test_abnormal), axis=0)
    b_test_lab = np.concatenate((np.zeros(len(b_test_normal)), np.ones(len(b_test_abnormal))), axis=0)

    print('balanced split ==================')
    print(f'train data: {b_train_data.shape}')
    print(f'test data: {b_test_data.shape}')
    # saving data
    print('saving data ...')
    data_name = 'Credit'
    obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/train_data.npy'), b_train_data)
    obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/train_label.npy'), b_train_lab)
    obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/test_data.npy'), b_test_data)
    obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/balanced/test_label.npy'), b_test_lab)

    # # imbalanced:================================================
    # # train: 8000:2000
    # # test: 4000:1000:4000:1000
    b_train_number_s1 = 2000
    b_train_number_s2 = 8000
    b_test_number_s1 = 1000
    b_test_number_s2 = 4000
    b_train_s1 = normal_s1_data[:b_train_number_s1]
    b_train_s2 = normal_s2_data[:b_train_number_s2]
    b_train_data = np.concatenate((b_train_s1, b_train_s2), axis=0)
    b_train_lab = np.zeros(len(b_train_data))

    b_test_normal_s1 = normal_s1_data[b_train_number_s1 : b_train_number_s1 + b_test_number_s1]
    b_test_normal_s2 = normal_s2_data[b_train_number_s2 : b_train_number_s2 + b_test_number_s2]
    b_test_normal = np.concatenate((b_test_normal_s1, b_test_normal_s2), axis=0)

    b_test_abnormal_s1 = abnormal_s1_data[:b_test_number_s1]
    b_test_abnormal_s2 = abnormal_s2_data[:b_test_number_s2]
    b_test_abnormal = np.concatenate((b_test_abnormal_s1, b_test_abnormal_s2), axis=0)
    b_test_data = np.concatenate((b_test_normal, b_test_abnormal), axis=0)
    b_test_lab = np.concatenate((np.zeros(len(b_test_normal)), np.ones(len(b_test_abnormal))), axis=0)

    print('imbalanced split ==================')
    print(f'train data: {b_train_data.shape}')
    print(f'test data: {b_test_data.shape}')
    # saving data
    print('saving data ...')
    data_name = 'Credit'
    obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/imbalanced/train_data.npy'), b_train_data)
    obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/imbalanced/train_label.npy'), b_train_lab)
    obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/imbalanced/test_data.npy'), b_test_data)
    obj_save(osp.join(ROOT_DIR, f'data/{data_name}/processed/imbalanced/test_label.npy'), b_test_lab)
    
    pass


if __name__ == '__main__':

    data_dir = osp.join(ROOT_DIR, 'data')

    # data_name = 'Adult'
    # adult_preprocessing(data_path=osp.join(data_dir, 'Adult/adult.data'))
    # split_train_test(data_path=osp.join(ROOT_DIR, 'data/Adult/adult_clean.pkl'), n_train_normal=20000, data_name=data_name)

    # data_name = 'Compas'
    # compas_preprocessing(osp.join(data_dir, 'Compas/compas-scores-two-years_clean.csv'))
    # split_train_test(data_path=osp.join(DATA_DIR, 'Compas/compas_clean.pkl'), n_train_normal=2000, data_name=data_name)

    # celebA_preprocessing(height=64, width=64) # 101283.jpg process failed
    # split_train_test_4_celeba()

    # split_balanced_imbalanced_4_celeba(n_train_normal_0=8000, n_test_normal_0=4000, ratio=4)

    # titanic_process()
    # student_performance_process()

    # credit_preprocess()
    pass
