from fairness.data.objects.list import DATASETS, get_dataset_names
from fairness.data.objects.ProcessedData import ProcessedData
import torch 

import numpy as np
import pandas as pd
import csv

from sklearn.preprocessing import StandardScaler,MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_openml

from folktables import ACSDataSource, ACSPublicCoverage, ACSIncome
import random
random.seed(10)
################ COMPASS #####################

def get_data_params_compass():
    num = 3
    sens = 'race'
    y = "two_year_recid"
    columns_delete = ['two_year_recid', 'sex-race', 'race'] #,'age'']
    print('Dropping columns:', columns_delete)
    return num,sens,y,columns_delete

def get_data_func_compass(num, sens, y, columns_delete, n_splits, remove_sensitive):
    dataset = DATASETS[num]
    all_sensitive_attributes = dataset.get_sensitive_attributes_with_joint()
    
    ProcessedData(dataset)
    processed_dataset = ProcessedData(dataset)
    
    if n_splits == 2:
        train_test_splits = processed_dataset.create_train_test_splits(1)
        train_test_splits.keys()
        train, test = train_test_splits['numerical-binsensitive'][0]
        X_train = train
        X_test = test
        
        s_train =  train[sens].values
        s_test =  test[sens].values
        y_train = train[y]
        y_test = test[y]

        scaler = StandardScaler().fit(X_train)
        s=np.expand_dims(X_train[sens],axis=1)
        st=np.expand_dims(X_test[sens],axis=1)
        t=X_train[y]
        tt=X_test[y]
        #XC_train0=X_train[['age','sex-race','race']]
        #XC_test0=X_test[['age','sex-race','race']]
        
        if remove_sensitive == True:
            X_train0 = X_train.drop(columns=columns_delete)
            X_test0 = X_test.drop(columns=columns_delete)
        
        else: 
            X_train0 = X_train.copy()
            X_test0 = X_test.copy()

        scale_df = lambda df, scaler: pd.DataFrame(scaler.transform(df), columns=df.columns, index=df.index)
        X_train = X_train.pipe(scale_df, scaler)
        X_test = X_test.pipe(scale_df, scaler)
        X_train= X_train.drop([sens,y], axis=1)
        X_train[sens] = s
        X_train[y] = t
        X_test= X_test.drop([sens,y], axis=1)
        X_test[sens] = st
        X_test[y] = tt

        #XC_train=X_train[['age','sex-race','race']].values
        #XC_test=X_test[['age','sex-race','race']].values

        X_train = X_train.drop(columns=columns_delete).values
        X_test = X_test.drop(columns=columns_delete).values
        ### X_train = tout sauf Y et S (et sex-race)
        column_names = X_train0.columns
        
        return X_train0, X_test0, X_train, X_test, y_train, y_test, s_train, s_test, column_names
    
    if n_splits == 3:
        
        train_test_splits = processed_dataset.create_train_test_splits(1)
        train_test_splits.keys()
        train, to_split_more = train_test_splits['numerical-binsensitive'][0]
        X_train = train
        X_to_split = to_split_more
        
        s_train =  train[sens].values
        sensitivet =  to_split_more[sens]
        y_train = train[y]
        y_to_split_more = to_split_more[y]
        
        ############ FURTHER SPLIT #############
        
        # Split the indices of the DataFrame into train and validation sets
        
        #val_indices, test_indices = train_test_split(X_to_split.index, test_size=0.3, random_state=42)
        X_val, X_test, y_val, y_test, s_val, s_test = train_test_split(X_to_split, y_to_split_more, sensitivet, test_size=0.3, random_state=11)

        # Create the train and validation sets using the indices
        #X_val = X_to_split.loc[val_indices]
        #X_test = X_to_split.loc[test_indices]
        #y_val = y_to_split_more.loc[val_indices]
        #y_test = y_to_split_more.loc[test_indices]
        #s_val = sensitivet[val_indices].values
        #s_test = sensitivet[test_indices].values
                
        #################################

        scaler = StandardScaler().fit(X_train)
        s=np.expand_dims(X_train[sens],axis=1)
        st=np.expand_dims(X_test[sens],axis=1)
        stt=np.expand_dims(X_val[sens],axis=1)
        t=X_train[y]
        tt=X_test[y]
        ttt=X_val[y]
        #XC_train0=X_train[['age','sex-race','race']]
        #XC_test0=X_test[['age','sex-race','race']]

        X_train0 = X_train.drop(columns=columns_delete)
        X_test0 = X_test.drop(columns=columns_delete)
        X_val0 = X_val.drop(columns=columns_delete)

        scale_df = lambda df, scaler: pd.DataFrame(scaler.transform(df), columns=df.columns, index=df.index)
        
        X_train = X_train.pipe(scale_df, scaler)
        X_test = X_test.pipe(scale_df, scaler)
        X_val = X_val.pipe(scale_df, scaler)
        
        X_train= X_train.drop([sens,y], axis=1)
        X_train[sens] = s
        X_train[y] = t
        X_test= X_test.drop([sens,y], axis=1)
        X_test[sens] = st
        X_test[y] = tt
        X_val= X_val.drop([sens,y], axis=1)
        X_val[sens] = stt
        X_val[y] = ttt

        #XC_train=X_train[['age','sex-race','race']].values
        #XC_test=X_test[['age','sex-race','race']].values

        X_train = X_train.drop(columns=columns_delete).values
        X_test = X_test.drop(columns=columns_delete).values
        X_val = X_val.drop(columns=columns_delete).values
        ### X_train = tout sauf Y et S (et sex-race)
        
    
        column_names = X_train0.columns

        #removed XC 
        return X_train0, X_test0, X_val0, X_train, X_test, X_val, y_train, y_test, y_val, s_train, s_test, s_val, column_names

################ ADULT #####################
def load_ICU_data(path):
    column_names = ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 
                    'marital_status', 'occupation', 'relationship', 'race', 'sex', 
                    'capital_gain', 'capital_loss', 'hours_per_week', 'country', 'target']
    input_data = (pd.read_csv(path, names=column_names, 
                              na_values="?", sep=r'\s*,\s*', engine='python', header=1))
                  #.loc[lambda df: df['race'].isin(['White', 'Black'])])
    #input_data = pd.concat([input_data, pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test', names=column_names, na_values="?", sep=r'\s*,\s*', engine='python').loc[1:,:] ])

    input_data[['age','fnlwgt','education_num','capital_gain', 'capital_loss', 'hours_per_week']] = input_data[['age','fnlwgt','education_num','capital_gain', 'capital_loss', 'hours_per_week']].astype(int)
    #input_data = input_data.drop(columns=['hours_per_week'])
    # sensitive attributes; we identify 'race' and 'sex' as sensitive attributes
    sensitive_attribs = ['sex']
    S = (input_data.loc[:, sensitive_attribs]
         .assign(sex=lambda df: (df['sex'] == 'Male').astype(int)))
    # targets; 1 when someone makes over 50k , otherwise 0
    y = input_data['target'].replace({'<=50K.': 0, '>50K.': 1, '>50K': 1, '<=50K': 0 })
    XC = input_data.loc[:, ['country','race','age']]
    #XC = input_data.loc[:, ['age','country','race']]
    XC = XC.assign(race=lambda df: (df['race'] == 'White').astype(int))
    #print(XC.shape)
    XC = (XC
           .fillna('Unknown')
           .pipe(pd.get_dummies, columns = ['country'], drop_first=True))
    # features; note that the 'target' and sentive attribute columns are dropped
    XD = (input_data
         #.drop(columns=['target','age','sex','country'])
         .drop(columns=['target','country','race','age','sex'])
         .fillna('Unknown')
         .pipe(pd.get_dummies, columns = ['workclass', 'education', 
                    'marital_status', 'occupation', 'relationship'], drop_first=True))
    #X = X.drop(columns=['hours_per_week'])
    print(f"features XD: {XD.shape[0]} samples, {XD.shape[1]} attributes")
    print(f"features XC: {XC.shape[0]} samples, {XC.shape[1]} attributes")
    print(f"targets y: {y.shape[0]} samples")
    print(f"sensitives S: {S.shape[0]} samples, {S.shape[1]} attributes")
    return XD, XC, y, S


def get_data_func_adult(n_splits):
    # load ICU data set
    XD_train0, XC_train0, y_train, S_train = load_ICU_data('https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data')
    XD_test0, XC_test0, y_test, S_test = load_ICU_data('https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test')
    
    #make up for inconsistency in columns between train and test (one col diff)
    val=XC_test0['age']*0
    ind= XC_train0.columns.get_loc(XC_train0.columns.difference(XC_test0.columns).values[0])
    name= XC_train0.columns.difference(XC_test0.columns).values[0]
    XC_test0.insert(loc=ind, column= name, value=val)# there is a one column diff between train and test --> recreate it with zeros
    
    # scale data
    scaler = MinMaxScaler().fit(XD_train0)
    scale_df = lambda df, scaler: pd.DataFrame(scaler.transform(df), columns=df.columns, index=df.index)
    scalerNC = MinMaxScaler().fit(XC_train0) # why two scalers?
    scale_dfNC = lambda df, scalerNC: pd.DataFrame(scalerNC.transform(df), columns=df.columns, index=df.index)
    XD_train, XD_test = XD_train0.pipe(scale_df, scaler), XD_test0.pipe(scale_df, scaler)
    XC_train, XC_test = XC_train0.pipe(scale_dfNC, scalerNC), XC_test0.pipe(scale_dfNC, scalerNC)

    XD_train, XD_test, XC_train, XC_test, y_train, y_test = XD_train.values, XD_test.values, XC_train.values, XC_test.values, y_train.values, y_test.values
    S_train, S_test =S_train.values.squeeze(1), S_test.values.squeeze(1)
    
    if n_splits == 2:
        X_train0 =  pd.concat((XC_train0,XD_train0), axis=1)
        X_test0 =  pd.concat((XC_test0,XD_test0), axis=1)

        X_train =  np.concatenate((XC_train,XD_train), axis=1)
        X_test =  np.concatenate((XC_test,XD_test), axis=1)
        column_names = list(XC_train0.columns) + list(XD_train0.columns)

        return X_train0, X_test0, X_train, X_test, y_train, y_test, S_train, S_test, column_names
    
    if n_splits == 3:
        X_train0 =  pd.concat((XC_train0,XD_train0), axis=1)
        X_test0 =  pd.concat((XC_test0,XD_test0), axis=1)
        
        X_val0, X_test0 = train_test_split(X_test0, test_size=0.3, random_state=11)
        
        X_train =  np.concatenate((XC_train,XD_train), axis=1)
        X_test =  np.concatenate((XC_test,XD_test), axis=1)
        
        X_val, X_test, y_val, y_test, S_val, S_test = train_test_split(X_test, y_test, S_test, test_size=0.3, random_state=11)

        column_names = list(XC_train0.columns) + list(XD_train0.columns)

        return X_train0, X_test0, X_val0, X_train, X_test, X_val, y_train, y_test, y_val, S_train, S_test, S_val, column_names

################ LAW SCHOOL ##################
def get_data_func_lawschool(n_splits, seed):
    FNAME = 'lawschool.csv'
    df = pd.read_csv(FNAME)
    
    y = df["bar1"]
    del df["bar1"]
    S = df["race7"]
    del df["race1"],df["race2"],df["race3"],df["race4"],df["race5"],df["race6"],df["race7"],df["race8"]
    #,"race2","race3","race4","race5","race6","race7","race8"]]
    
    df["age"] = -1 * df["age"]

    if n_splits == 2:
        X_train0, X_test0, y_train, y_test, S_train, S_test = train_test_split(df, y, S, test_size=0.3, random_state=seed)

        scaler = MinMaxScaler().fit(X_train0)
        scale_df = lambda df0, scaler: pd.DataFrame(scaler.transform(df0), columns=df0.columns, index=df0.index)
        X_train, X_test = X_train0.pipe(scale_df, scaler), X_test0.pipe(scale_df, scaler)

        X_train, X_test = X_train.values, X_test.values

        column_names = list(X_train0.columns)

        return X_train0, X_test0, X_train, X_test, y_train, y_test, S_train, S_test, column_names
    
    
    if n_splits == 3:
        X_train0, X_to_split, y_train, y_to_split, S_train, S_to_split = train_test_split(df, y, S, test_size=0.5, random_state=seed)
        X_val0, X_test0, y_val, y_test, S_val, S_test = train_test_split(X_to_split, y_to_split, S_to_split, test_size=0.3, random_state=seed)

        scaler = MinMaxScaler().fit(X_train0)
        scale_df = lambda df0, scaler: pd.DataFrame(scaler.transform(df0), columns=df0.columns, index=df0.index)        
        X_train, X_test, X_val = X_train0.pipe(scale_df, scaler), X_test0.pipe(scale_df, scaler), X_val0.pipe(scale_df, scaler)
        X_train, X_val, X_test = X_train.values, X_val.values, X_test.values

        column_names = list(X_train0.columns)

        return X_train0, X_test0, X_val0, X_train, X_test, X_val, y_train, y_test, y_val, S_train, S_test, S_val, column_names


################################################################# GET DATASET #####################################################################

def get_datasets(dataset, n_splits, seed, remove_sensitive=True):
    
    if dataset=="adult":
        if n_splits == 2:
            X_train0, X_test0, X_train, X_test, y_train, y_test, S_train, S_test, column_names = get_data_func_adult(n_splits)
            return X_train0, X_test0, X_train, X_test, y_train, y_test, S_train, S_test, column_names
        elif n_splits == 3:
            X_train0, X_test0, X_val0, X_train, X_test, X_val, y_train, y_test, y_val, s_train, s_test, s_val, column_names = get_data_func_adult(n_splits)
            return X_train0, X_test0, X_val0, X_train, X_test, X_val, y_train, y_test, y_val, s_train, s_test, s_val, column_names
        else:
            raise ValueError("Only splits into 2 or 3 sets are supported")
    
    if dataset=="compass":
        num,sens,y,columns_delete = get_data_params_compass()
        if n_splits == 2:
            X_train0, X_test0, X_train, X_test, y_train, y_test, S_train, S_test, column_names = get_data_func_compass(num,sens,y,columns_delete, 
                                                                                                                       n_splits, remove_sensitive)
            return X_train0, X_test0, X_train, X_test, y_train, y_test, S_train, S_test, column_names
        elif n_splits == 3:
            X_train0, X_test0, X_val0, X_train, X_test, X_val, y_train, y_test, y_val, s_train, s_test, s_val, column_names = get_data_func_compass(num,sens,y,columns_delete, n_splits, remove_sensitive)
            return X_train0, X_test0, X_val0, X_train, X_test, X_val, y_train, y_test, y_val, s_train, s_test, s_val, column_names
        else:
            raise ValueError("Only splits into 2 or 3 sets are supported")
    
    if dataset=="law_school":
        if n_splits == 2:
            X_train0, X_test0, X_train, X_test, y_train, y_test, S_train, S_test, column_names = get_data_func_lawschool(n_splits, seed)
            return X_train0, X_test0, X_train, X_test, y_train, y_test, S_train, S_test, column_names
        elif n_splits == 3:
            X_train0, X_test0, X_val0, X_train, X_test, X_val, y_train, y_test, y_val, s_train, s_test, s_val, column_names = get_data_func_lawschool(n_splits, seed)
            return X_train0, X_test0, X_val0, X_train, X_test, X_val, y_train, y_test, y_val, s_train, s_test, s_val, column_names
        else:
            raise ValueError("Only splits into 2 or 3 sets are supported")
    

def write_to_csv_DP(filename, lamb, epoch, p_rule, p_rule_diff, ACC_test, p_changes_test, indices_test):
    with open(filename, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([lamb, epoch, p_rule, p_rule_diff, ACC_test, p_changes_test, indices_test])
        
def write_to_csv_edit_DP(filename, lamb_fair, lambda_post, epoch, p_rule, p_rule_diff, ACC_test, p_changes_test):
    with open(filename, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([lamb_fair, lambda_post, epoch, p_rule, p_rule_diff, ACC_test, p_changes_test])
        
def write_to_csv_ratio_DP(filename, lamb_fair, lambda_ratio, epoch, p_rule, p_rule_diff, ACC_test, p_changes_test, indices):
    with open(filename, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([lamb_fair, lambda_ratio, epoch, p_rule, p_rule_diff, ACC_test, p_changes_test, indices])
        
def write_to_csv_ratio_DP_autoencoder(filename, lamb_fair, lambda_ratio, epoch, p_rule, p_rule_diff, ACC_test, p_changes_test, indices, cosine_similarity,
                                                 number_of_features, indices_features, weights):
    with open(filename, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([lamb_fair, lambda_ratio, epoch, p_rule, p_rule_diff, ACC_test, p_changes_test, indices, cosine_similarity, number_of_features, indices_features, weights])
        
def write_to_csv_ratio_DP_autoencoder_nchanges(filename, lamb_fair, lambda_ratio, p_rule, ACC_test, p_changes_test, p_useful):
    with open(filename, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([lamb_fair, lambda_ratio, p_rule, ACC_test, p_changes_test, p_useful])
        
def write_to_csv_ratio_EO_autoencoder(filename, lamb_fair, lambda_ratio, epoch, DM, ACC_test, p_changes_test, indices, cosine_similarity,
                                                 number_of_features, indices_features, weights):
    with open(filename, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([lamb_fair, lambda_ratio, epoch, DM, ACC_test, p_changes_test, indices, cosine_similarity, number_of_features, indices_features, weights])
        
def write_to_csv_ratio_ablation(filename, lamb_fair, lambda_ratio, seed, model, epoch, p_rule, p_rule_diff, ACC_test, p_changes_test, indices, type_abl):
    with open(filename, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        if type_abl == 'seeds':
            writer.writerow([lamb_fair, lambda_ratio, epoch, seed, p_rule, p_rule_diff, ACC_test, p_changes_test, indices])
        else:
            writer.writerow([lamb_fair, lambda_ratio, epoch, model, p_rule, p_rule_diff, ACC_test, p_changes_test, indices])
            
def write_to_csv_ratio_EO(filename, lamb_fair, lambda_ratio, epoch, DM, ACC_test, p_changes_test, indices):
    with open(filename, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([lamb_fair, lambda_ratio, epoch, DM, ACC_test, p_changes_test, indices])
        
def write_to_csv_zhang_EO(filename, lamb_fair, epoch, DM, ACC_test, p_changes_test, indices):
    with open(filename, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([lamb_fair, epoch, DM, ACC_test, p_changes_test, indices])

def KL(P,Q):
    epsilon = 0.00001
    P = P+epsilon
    Q = Q+epsilon

    divergence = np.sum(P*np.log(P/Q))
    return divergence


def min_max_normalization(vector):
    min_vals = np.min(vector)
    max_vals = np.max(vector)
    normalized_tensor = (vector - min_vals) / (max_vals - min_vals) #each element of normalized_Ypred_var is in [0,1]
    
    return normalized_tensor
    
