import pandas as pd
import numpy as np
from deel.datasets.util_generator_dataset import simple_fair_generator, simple_generator

def randSplit(nb, frac):
    tab=np.ones((nb,), dtype=bool)
    tab[int(nb*frac):]=False
    tab=np.random.permutation(tab)
    return tab

def format_adult_base(PATH = "../data/", frac=-1):
    df_raw_train = pd.read_csv(PATH + "adult.data.csv", header=None, delimiter=";", )

    df_raw_test = pd.read_csv(PATH + "adult.test.csv", header=None, delimiter=";", )


    df_raw_train.columns = [
        "Age", "WorkClass", "fnlwgt", "Education", "EducationNum",
        "MaritalStatus", "Occupation", "Relationship", "Race", "Gender",
        "CapitalGain", "CapitalLoss", "HoursPerWeek", "NativeCountry", "Income"
    ]
    df_raw_test.columns = [
        "Age", "WorkClass", "fnlwgt", "Education", "EducationNum",
        "MaritalStatus", "Occupation", "Relationship", "Race", "Gender",
        "CapitalGain", "CapitalLoss", "HoursPerWeek", "NativeCountry", "Income"
    ]

    df_raw_train.describe()
    df_raw_train['train'] = pd.Series(True, index=df_raw_train.index)
    df_raw_test['train'] = pd.Series(False, index=df_raw_test.index)
    df_raw = pd.concat([df_raw_train, df_raw_test])
    df = df_raw.drop(["fnlwgt", "NativeCountry"], axis=1)
    df['child'] = pd.Series(0, index=df.index)

    df.loc[df['Relationship'] == 'Own-child', 'child'] = 1

    df = df.drop(["Relationship"], axis=1)
    df["Race"].replace(['White', 'Black', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other'],
                       [1, 0, 0, 0, 0], inplace=True)
    df["Income"].replace(["<=50K", ">50K", "<=50K.", ">50K."], [0, 1, 0, 1], inplace=True)

    df["Gender"].replace(["Male", "Female"], [1, 0], inplace=True)
    numcol = [
        "Age", "EducationNum", "CapitalGain", "CapitalLoss", "HoursPerWeek"
    ]
    # normalization
    df[numcol] = (df[numcol] - df[numcol].mean()) / df[numcol].std()
    df = pd.get_dummies(df, columns=['WorkClass', 'Education', 'MaritalStatus', 'Occupation'])
    if frac >= 0:
        df['train'] = randSplit(df.shape[0], frac)
    df_train = df.loc[df['train']]
    df_test = df.loc[~df['train']]
    return df_train,df_test


def adult_generator(batch_size,path,frac = -1,
                    neg=0,protected = False,
                    protected_var = 'Gender',balanced = True, verbose = False):
    df_train, df_test = format_adult_base(PATH = path, frac= frac)
    no_features = ['Income', 'train']
    if protected_var is not None :
        no_features.append(protected_var)
    features= [f for f in df_train.columns if f not in no_features]
    X_train = df_train[features].values
    Y_train = df_train['Income'].values
    Y_train[Y_train == 0] = neg
    Y_train = Y_train[:, np.newaxis]
    
    
    X_test = df_test[features].values
    Y_test = df_test['Income'].values
    Y_test[Y_test == 0] = neg
    Y_test = Y_test[:, np.newaxis]
    if verbose :
        if protected:
            print(f"protected var : {protected_var}")
        else:
            print("no protected vars")
        print(f"nb features : {X_train.shape[1]}")
        print(f"nb train : {X_train.shape[0]}")
        print(f"nb test : {X_test.shape[0]}")
        print(f"%pos : { Y_train[Y_train==1].shape[0]/Y_train.shape[0]*100:2.1f}")

    if not protected :
        dtset = {'train': simple_generator(batch_size, X_train, Y_train), 'trainSize': X_train.shape[0],
                 'valid': simple_generator(batch_size, X_test, Y_test), 'validSize': X_test.shape[0],
                 'test': simple_generator(batch_size, X_test, Y_test), 'testSize': X_test.shape[0],
                 'test_XY': (X_test, Y_test),
                 'batch_size': batch_size}
        return dtset

    S_train = df_train[protected_var].values
    S_test = df_test[protected_var].values
    S_train = S_train[:, np.newaxis]
    S_test = S_test[:, np.newaxis]
    if verbose and protected:
        print(f"%proct : { S_train[S_train==1].shape[0]/S_train.shape[0]*100:2.1f} balanced : {balanced}")    
    dtset = {'train': simple_fair_generator(batch_size, X_train, Y_train, S_train, balanced = balanced), 'trainSize': X_train.shape[0],
             'valid': simple_fair_generator(batch_size,  X_test, Y_test, S_test, balanced = balanced), 'validSize': X_test.shape[0],
             'test': simple_fair_generator(batch_size, X_test, Y_test, S_test, balanced = balanced), 'testSize': X_test.shape[0],
             'train_binary': simple_generator(batch_size, X_train, Y_train), 'trainSize': X_train.shape[0],
             'test_binary': simple_generator(batch_size, X_test, Y_test), 'validSize': X_test.shape[0],
             'test_XYS': (X_test, Y_test, S_test),
             'batch_size': batch_size}
    return dtset